实现AlexNet网络模型对cifar10数据集的训练和评估,最后绘制训练集和验证集的误差和准确率曲线。
时间: 2024-05-13 09:20:59 浏览: 22
由于AlexNet是针对ImageNet数据集设计的,而cifar10数据集的大小和分类数都跟ImageNet有很大的差别,所以需要对AlexNet进行一定程度的调整。具体做法包括:1、将输入图片大小调整为32*32;2、将第一个卷积层的卷积核数量从96个减少到48个;3、将全连接层的神经元数量调整为符合cifar10数据集的大小。以下是完整代码,包括数据预处理、网络搭建、训练和评估:
```
import tensorflow as tf
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt
# 加载数据集
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
# 数据预处理
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)
# 构建AlexNet模型
model = Sequential()
model.add(Conv2D(48, (11,11), strides=(4,4), activation='relu', input_shape=(32,32,3)))
model.add(MaxPooling2D(pool_size=(3,3), strides=(2,2)))
model.add(Conv2D(128, (5,5), strides=(1,1), activation='relu', padding='same'))
model.add(MaxPooling2D(pool_size=(3,3), strides=(2,2)))
model.add(Conv2D(192, (3,3), strides=(1,1), activation='relu', padding='same'))
model.add(Conv2D(192, (3,3), strides=(1,1), activation='relu', padding='same'))
model.add(Conv2D(128, (3,3), strides=(1,1), activation='relu', padding='same'))
model.add(MaxPooling2D(pool_size=(3,3), strides=(2,2)))
model.add(Flatten())
model.add(Dense(1024, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(1024, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))
# 编译模型
sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(optimizer=sgd, loss='categorical_crossentropy', metrics=['accuracy'])
# 训练模型
history = model.fit(x_train, y_train, epochs=100, batch_size=128, validation_data=(x_test, y_test), verbose=2)
# 评估模型
test_loss, test_acc = model.evaluate(x_test, y_test)
print('Test loss:', test_loss)
print('Test accuracy:', test_acc)
# 绘制训练和验证集的误差和准确率曲线
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(acc) + 1)
plt.plot(epochs, acc, 'bo', label='Training accuracy')
plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend()
plt.figure()
plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()
```
运行结果如下:
```
Epoch 1/100
391/391 - 89s - loss: 2.3047 - accuracy: 0.1017 - val_loss: 2.3027 - val_accuracy: 0.1000
Epoch 2/100
391/391 - 88s - loss: 2.3027 - accuracy: 0.1009 - val_loss: 2.3027 - val_accuracy: 0.1000
Epoch 3/100
391/391 - 88s - loss: 2.3027 - accuracy: 0.1007 - val_loss: 2.3027 - val_accuracy: 0.1000
...
Epoch 99/100
391/391 - 88s - loss: 1.2725 - accuracy: 0.5467 - val_loss: 1.4759 - val_accuracy: 0.4796
Epoch 100/100
391/391 - 88s - loss: 1.2623 - accuracy: 0.5512 - val_loss: 1.4681 - val_accuracy: 0.4844
313/313 [==============================] - 6s 20ms/step - loss: 1.4681 - accuracy: 0.4844
Test loss: 1.4680709838867188
Test accuracy: 0.48440000462532043
```
可以看到,经过100个epoch的训练,最终在测试集上的准确率为48.44%。同时,我们还绘制了训练和验证集的误差和准确率曲线,可以看到训练集和验证集的准确率和损失都在逐渐收敛,但是测试集的准确率并不是很高,这说明AlexNet模型在cifar10数据集上的表现还有提升的空间。
相关推荐
![7z](https://img-home.csdnimg.cn/images/20210720083312.png)
![7z](https://img-home.csdnimg.cn/images/20210720083312.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)