用Python写一个卷积神经网络
时间: 2023-07-20 13:35:38 浏览: 176
python卷积神经网络实现
好的,以下是用 Python 实现一个简单的卷积神经网络(Convolutional Neural Network, CNN),用于对手写数字进行识别。
首先,我们需要导入必要的库和数据集:
```python
import tensorflow as tf
from tensorflow.keras.datasets import mnist
# 加载 MNIST 数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 对图像做归一化处理
x_train, x_test = x_train / 255.0, x_test / 255.0
# 增加一个维度,使得输入符合 CNN 的要求
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
# 创建训练数据集和测试数据集
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
```
接下来,我们可以定义一个包含两个卷积层和两个全连接层的 CNN:
```python
class MyCNN(tf.keras.Model):
def __init__(self):
super(MyCNN, self).__init__()
self.conv1 = tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1))
self.pool1 = tf.keras.layers.MaxPooling2D()
self.conv2 = tf.keras.layers.Conv2D(64, 3, activation='relu')
self.pool2 = tf.keras.layers.MaxPooling2D()
self.flatten = tf.keras.layers.Flatten()
self.fc1 = tf.keras.layers.Dense(128, activation='relu')
self.fc2 = tf.keras.layers.Dense(10, activation='softmax')
def call(self, x):
x = self.conv1(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.pool2(x)
x = self.flatten(x)
x = self.fc1(x)
return self.fc2(x)
```
注意,这里的输入数据形状为 `(28, 28, 1)`,因为 MNIST 数据集中的图像大小为 28x28,且是灰度图像(只有一个通道)。
最后,我们可以编写训练和测试的代码:
```python
# 创建模型实例
model = MyCNN()
# 定义损失函数和优化器
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()
# 训练模型
for epoch in range(5):
for images, labels in train_ds:
with tf.GradientTape() as tape:
predictions = model(images)
loss = loss_fn(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
print('Epoch {}, loss={}'.format(epoch, loss))
# 测试模型
accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
for images, labels in test_ds:
predictions = model(images)
accuracy.update_state(labels, predictions)
print('Test accuracy: {}'.format(accuracy.result()))
```
这里我们训练了 5 个 epochs,并在测试数据集上计算了模型的准确率。
阅读全文