cnn代码
时间: 2023-07-02 19:13:53 浏览: 56
下面是一个简单的CNN模型的示例代码,用于对MNIST手写数字数据集进行分类:
```python
import tensorflow as tf
from tensorflow.keras import layers
# 定义CNN模型结构
model = tf.keras.Sequential([
layers.Conv2D(32, (3,3), activation='relu', input_shape=(28, 28, 1)),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3,3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3,3), activation='relu'),
layers.Flatten(),
layers.Dense(64, activation='relu'),
layers.Dense(10, activation='softmax')
])
# 编译模型
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# 训练模型
model.fit(train_images, train_labels, epochs=5,
validation_data=(test_images, test_labels))
```
以上代码中,我们使用了TensorFlow的Keras API来定义CNN模型结构。该模型包括了三个卷积层和两个全连接层,其中使用了ReLU激活函数,最后一层使用了softmax激活函数来进行多分类。我们使用了Adam优化器和交叉熵损失函数,并在训练过程中计算了准确率作为评估指标。