在tensorflow的环境下,搭建resnet18,加入SENet进行模型增强,强化数据,训练cifa10,使准确率在90左右,给出完整代码,打印出最终结果
时间: 2024-05-16 19:13:28 浏览: 69
ResNet18+SENet 实现cifar10分类 training准确率95.66%,test准确率90.77%
以下是在tensorflow环境下搭建resnet18,并加入SENet进行模型增强的代码:
```python
import tensorflow as tf
from tensorflow import keras
def resnet_block(inputs, filters, kernel_size, strides):
x = keras.layers.Conv2D(filters, kernel_size=kernel_size, strides=strides, padding='same')(inputs)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU()(x)
x = keras.layers.Conv2D(filters, kernel_size=kernel_size, strides=1, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
if strides == 2:
inputs = keras.layers.Conv2D(filters, kernel_size=1, strides=strides, padding='same')(inputs)
inputs = keras.layers.BatchNormalization()(inputs)
x = keras.layers.add([x, inputs])
x = keras.layers.ReLU()(x)
return x
def resnet18(input_shape, num_classes):
inputs = keras.layers.Input(shape=input_shape)
x = keras.layers.Conv2D(64, kernel_size=3, strides=1, padding='same')(inputs)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU()(x)
x = resnet_block(x, filters=64, kernel_size=3, strides=1)
x = resnet_block(x, filters=64, kernel_size=3, strides=1)
x = resnet_block(x, filters=128, kernel_size=3, strides=2)
x = resnet_block(x, filters=128, kernel_size=3, strides=1)
x = resnet_block(x, filters=256, kernel_size=3, strides=2)
x = resnet_block(x, filters=256, kernel_size=3, strides=1)
x = resnet_block(x, filters=512, kernel_size=3, strides=2)
x = resnet_block(x, filters=512, kernel_size=3, strides=1)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dense(num_classes, activation='softmax')(x)
model = keras.models.Model(inputs=inputs, outputs=x)
return model
class SEBlock(keras.layers.Layer):
def __init__(self, reduction_ratio=16, **kwargs):
super().__init__(**kwargs)
self.reduction_ratio = reduction_ratio
def build(self, input_shape):
_, _, _, filters = input_shape
self.average_pooling = keras.layers.GlobalAveragePooling2D()
self.fc1 = keras.layers.Dense(filters // self.reduction_ratio, activation='relu')
self.fc2 = keras.layers.Dense(filters, activation='sigmoid')
def call(self, inputs):
x = self.average_pooling(inputs)
x = self.fc1(x)
x = self.fc2(x)
x = keras.layers.Reshape((1, 1, filters))(x)
return inputs * x
def se_resnet18(input_shape, num_classes):
inputs = keras.layers.Input(shape=input_shape)
x = keras.layers.Conv2D(64, kernel_size=3, strides=1, padding='same')(inputs)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU()(x)
x = SEBlock()(x)
x = resnet_block(x, filters=64, kernel_size=3, strides=1)
x = SEBlock()(x)
x = resnet_block(x, filters=64, kernel_size=3, strides=1)
x = SEBlock()(x)
x = resnet_block(x, filters=128, kernel_size=3, strides=2)
x = SEBlock()(x)
x = resnet_block(x, filters=128, kernel_size=3, strides=1)
x = SEBlock()(x)
x = resnet_block(x, filters=256, kernel_size=3, strides=2)
x = SEBlock()(x)
x = resnet_block(x, filters=256, kernel_size=3, strides=1)
x = SEBlock()(x)
x = resnet_block(x, filters=512, kernel_size=3, strides=2)
x = SEBlock()(x)
x = resnet_block(x, filters=512, kernel_size=3, strides=1)
x = SEBlock()(x)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dense(num_classes, activation='softmax')(x)
model = keras.models.Model(inputs=inputs, outputs=x)
return model
```
接下来,我们可以使用cifa10数据集进行训练,并打印出最终结果:
```python
import numpy as np
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# 加载数据集
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
# 数据增强
datagen = ImageDataGenerator(
width_shift_range=0.1,
height_shift_range=0.1,
horizontal_flip=True,
)
# 建立模型
model = se_resnet18(input_shape=(32,32,3), num_classes=10)
# 编译模型
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# 训练模型
history = model.fit_generator(datagen.flow(x_train, y_train, batch_size=128),
epochs=50,
validation_data=(x_test, y_test),
verbose=1)
# 打印最终结果
scores = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
```
通过以上代码,我们可以得到模型的最终准确率。
阅读全文