请写一段lenet基于Python3.7,tensorflow2.1
时间: 2023-06-01 18:04:05 浏览: 101
基于python语言的tensorflow源码
3星 · 编辑精心推荐
import tensorflow as tf
class LeNet(tf.keras.Model):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = tf.keras.layers.Conv2D(filters=6, kernel_size=(5,5), activation='relu', input_shape=(32,32,1))
self.pool1 = tf.keras.layers.MaxPooling2D(pool_size=(2,2))
self.conv2 = tf.keras.layers.Conv2D(filters=16, kernel_size=(5,5), activation='relu')
self.pool2 = tf.keras.layers.MaxPooling2D(pool_size=(2,2))
self.flatten = tf.keras.layers.Flatten()
self.fc1 = tf.keras.layers.Dense(units=120, activation='relu')
self.fc2 = tf.keras.layers.Dense(units=84, activation='relu')
self.output_layer = tf.keras.layers.Dense(units=10, activation='softmax')
def call(self, inputs):
x = self.conv1(inputs)
x = self.pool1(x)
x = self.conv2(x)
x = self.pool2(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.fc2(x)
output = self.output_layer(x)
return output
# test
model = LeNet()
model.build(input_shape=(None, 32, 32, 1))
model.summary()
阅读全文