keras使用cnn实现fashion-mnist数据集识别
时间: 2024-10-24 16:14:20 浏览: 42
Keras是一个高级神经网络API,它可以在TensorFlow等深度学习库之上运行。如果你想使用Keras的卷积神经网络(Convolutional Neural Network,CNN)来识别Fashion-MNIST数据集,这是一个包含10种服装类别(如T恤、牛仔裤、鞋等)的手写数字图像的数据集,可以按照以下步骤操作:
1. **导入必要的库**:
```python
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
```
2. **加载Fashion-MNIST数据**:
```python
(train_images, train_labels), (test_images, test_labels) = datasets.fashion_mnist.load_data()
```
3. **预处理数据**:
- 将图像转换为通道在最后的维度
```python
train_images, test_images = train_images / 255.0, test_images / 255.0
train_images = train_images[..., tf.newaxis]
test_images = test_images[..., tf.newaxis]
```
4. **构建CNN模型**:
```python
model = models.Sequential([
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Flatten(), # 展平为一维数组以便输入全连接层
layers.Dense(128, activation='relu'),
layers.Dropout(0.5), # 防止过拟合
layers.Dense(10, activation='softmax') # 输出层,10个节点对应10个类别
])
```
5. **编译模型**:
```python
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
```
6. **训练模型**:
```python
model.fit(train_images, train_labels, epochs=10, validation_split=0.1)
```
7. **评估模型性能**:
```python
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
print(f'Test accuracy: {test_acc}')
```
阅读全文