输出利用resnet18预训练cifar10并保存训练权重成.h5文件的python代码
时间: 2024-05-09 22:20:21 浏览: 244
经cifar-10数据集训练后的权重文件
以下是利用resnet18预训练cifar10并保存训练权重成.h5文件的Python代码:
```
import tensorflow as tf
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.models import Model
# 加载CIFAR-10数据集
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)
# 加载ResNet18模型,并去掉最后一层
resnet_model = ResNet50(weights='imagenet', include_top=False, input_shape=(32, 32, 3))
# 添加自定义分类头
x = resnet_model.output
x = Flatten()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(10, activation='softmax')(x)
# 构建完整模型
model = Model(inputs=resnet_model.input, outputs=predictions)
# 冻结ResNet18预训练部分的权重
for layer in resnet_model.layers:
layer.trainable = False
# 编译模型
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
# 训练模型
model.fit(x_train, y_train, batch_size=32, epochs=10, validation_data=(x_test, y_test))
# 保存训练权重成.h5文件
model.save_weights('resnet18_cifar10.h5')
```
注:以上代码仅为示例,可能需要适当调整以适应不同的数据集和任务。
阅读全文