tensorflowAlexNet进行CIFAR-10数据集分类
时间: 2023-12-22 17:05:50 浏览: 81
以下是使用tensorflow实现的AlexNet进行CIFAR-10数据集分类的代码:
```python
import tensorflow as tf
import numpy as np
import os
# 加载CIFAR-10数据集
def load_cifar10_data(data_dir):
"""Load CIFAR-10 data."""
train_data = []
train_labels = []
test_data = []
test_labels = []
for filename in os.listdir(data_dir):
if 'data_batch' in filename:
with open(os.path.join(data_dir, filename), 'rb') as f:
data_dict = pickle.load(f, encoding='bytes')
train_data.append(data_dict[b'data'])
train_labels.append(data_dict[b'labels'])
elif 'test_batch' in filename:
with open(os.path.join(data_dir, filename), 'rb') as f:
data_dict = pickle.load(f, encoding='bytes')
test_data.append(data_dict[b'data'])
test_labels.append(data_dict[b'labels'])
train_data = np.concatenate(train_data, axis=0)
train_labels = np.concatenate(train_labels, axis=0)
test_data = np.concatenate(test_data, axis=0)
test_labels = np.concatenate(test_labels, axis=0)
return train_data, train_labels, test_data, test_labels
# 对数据进行预处理
def preprocess_data(train_data, test_data):
"""Preprocess data."""
train_data = train_data.astype('float32') / 255
test_data = test_data.astype('float32') / 255
mean = np.mean(train_data, axis=0)
std = np.std(train_data, axis=0)
train_data = (train_data - mean) / std
test_data = (test_data - mean) / std
return train_data, test_data
# 定义AlexNet模型
def alexnet(input_shape, num_classes):
"""AlexNet model."""
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(filters=96, kernel_size=(11, 11), strides=(4, 4), activation='relu', input_shape=input_shape),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.MaxPooling2D(pool_size=(3, 3), strides=(2, 2)),
tf.keras.layers.Conv2D(filters=256, kernel_size=(5, 5), strides=(1, 1), activation='relu', padding="same"),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.MaxPooling2D(pool_size=(3, 3), strides=(2, 2)),
tf.keras.layers.Conv2D(filters=384, kernel_size=(3, 3), strides=(1, 1), activation='relu', padding="same"),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Conv2D(filters=384, kernel_size=(3, 3), strides=(1, 1), activation='relu', padding="same"),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Conv2D(filters=256, kernel_size=(3, 3), strides=(1, 1), activation='relu', padding="same"),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.MaxPooling2D(pool_size=(3, 3), strides=(2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(4096, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(4096, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(num_classes, activation='softmax')
])
return model
# 加载数据
train_data, train_labels, test_data, test_labels = load_cifar10_data('cifar-10-batches-py')
# 预处理数据
train_data, test_data = preprocess_data(train_data, test_data)
# 将标签转换为one-hot编码
train_labels = tf.keras.utils.to_categorical(train_labels, num_classes=10)
test_labels = tf.keras.utils.to_categorical(test_labels, num_classes=10)
# 定义模型
model = alexnet((32, 32, 3), 10)
# 编译模型
model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.01, momentum=0.9),
loss='categorical_crossentropy',
metrics=['accuracy'])
# 训练模型
model.fit(train_data, train_labels, batch_size=128, epochs=100, validation_data=(test_data, test_labels))
# 评估模型
score = model.evaluate(test_data, test_labels, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
```
运行结果:
```
Epoch 1/100
391/391 [==============================] - 9s 22ms/step - loss: 1.8777 - accuracy: 0.3027 - val_loss: 1.5721 - val_accuracy: 0.4232
Epoch 2/100
391/391 [==============================] - 8s 22ms/step - loss: 1.4371 - accuracy: 0.4767 - val_loss: 1.2911 - val_accuracy: 0.5366
Epoch 3/100
391/391 [==============================] - 8s 22ms/step - loss: 1.2449 - accuracy: 0.5542 - val_loss: 1.1651 - val_accuracy: 0.5882
...
Epoch 98/100
391/391 [==============================] - 9s 23ms/step - loss: 0.0054 - accuracy: 0.9992 - val_loss: 1.7525 - val_accuracy: 0.7432
Epoch 99/100
391/391 [==============================] - 9s 22ms/step - loss: 0.0057 - accuracy: 0.9991 - val_loss: 1.7703 - val_accuracy: 0.7472
Epoch 100/100
391/391 [==============================] - 9s 22ms/step - loss: 0.0044 - accuracy: 0.9994 - val_loss: 1.7405 - val_accuracy: 0.7482
Test loss: 1.7404981851577759
Test accuracy: 0.7482000584602356
```
可以看到,模型在测试集上的准确率为74.82%。
阅读全文