用自定义模型的方式写出LeNet
时间: 2024-05-09 10:17:09 浏览: 86
LLNet模型实现
5星 · 资源好评率100%
LeNet是一个简单的卷积神经网络,主要用于手写数字识别。下面是使用自定义模型的方式编写LeNet的示例代码:
```python
import tensorflow as tf
# 定义LeNet模型
def custom_lenet(input_shape, num_classes):
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(filters=6, kernel_size=(5, 5), activation='relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Conv2D(filters=16, kernel_size=(5, 5), activation='relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(units=120, activation='relu'),
tf.keras.layers.Dense(units=84, activation='relu'),
tf.keras.layers.Dense(units=num_classes, activation='softmax')
])
return model
# 设定输入形状和输出类别数
input_shape = (32, 32, 1)
num_classes = 10
# 创建模型
model = custom_lenet(input_shape, num_classes)
# 打印模型结构
model.summary()
```
在这个示例代码中,我们首先定义了一个名为custom_lenet的函数,它接受输入形状input_shape和输出类别数num_classes作为参数,返回一个Sequential模型对象,这个模型按照LeNet的结构定义了一系列卷积层、池化层和全连接层,并使用ReLU和softmax作为激活函数。
然后,我们在主程序中调用custom_lenet函数,传入输入形状和输出类别数,得到一个名为model的Sequential模型对象。最后,我们使用model.summary()打印出模型的结构,以便检查和调试。
阅读全文