keras实现用resnet18非预训练模型对四种天气分类,使用sgd
时间: 2023-11-03 21:04:09 浏览: 115
好的,下面是使用Keras实现用ResNet18非预训练模型对四种天气分类的代码,使用SGD优化器:
```python
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, GlobalAveragePooling2D, Dense, BatchNormalization, Activation, Add
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import cifar10
from sklearn.model_selection import train_test_split
import numpy as np
# 加载数据集并进行预处理
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()
train_images = train_images / 255.0
test_images = test_images / 255.0
train_labels = to_categorical(train_labels, num_classes=10)
test_labels = to_categorical(test_labels, num_classes=10)
train_images, val_images, train_labels, val_labels = train_test_split(train_images, train_labels, test_size=0.2, random_state=42)
# 定义ResNet18模型
def resnet_block(inputs, num_filters, strides=1, activation='relu'):
x = Conv2D(num_filters, kernel_size=3, strides=strides, padding='same')(inputs)
x = BatchNormalization()(x)
x = Activation(activation)(x)
x = Conv2D(num_filters, kernel_size=3, strides=1, padding='same')(x)
x = BatchNormalization()(x)
shortcut = inputs
if strides > 1:
shortcut = Conv2D(num_filters, kernel_size=1, strides=strides, padding='same')(inputs)
shortcut = BatchNormalization()(shortcut)
x = Add()([x, shortcut])
x = Activation(activation)(x)
return x
def resnet18(input_shape, num_classes):
inputs = Input(shape=input_shape)
x = Conv2D(64, kernel_size=7, strides=2, padding='same')(inputs)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = MaxPooling2D(pool_size=(3, 3), strides=2, padding='same')(x)
x = resnet_block(x, 64)
x = resnet_block(x, 64)
x = resnet_block(x, 128, strides=2)
x = resnet_block(x, 128)
x = resnet_block(x, 256, strides=2)
x = resnet_block(x, 256)
x = resnet_block(x, 512, strides=2)
x = resnet_block(x, 512)
x = GlobalAveragePooling2D()(x)
outputs = Dense(num_classes, activation='softmax')(x)
model = Model(inputs=inputs, outputs=outputs)
return model
# 定义模型结构和优化器
model = resnet18(input_shape=(32, 32, 3), num_classes=4)
model.compile(optimizer=SGD(lr=0.01, momentum=0.9), loss='categorical_crossentropy', metrics=['accuracy'])
# 训练模型
model.fit(train_images, train_labels, batch_size=64, epochs=100, validation_data=(val_images, val_labels))
# 在测试集上评估模型
test_loss, test_acc = model.evaluate(test_images, test_labels)
print("Test accuracy:", test_acc)
```
这里使用了CIFAR-10数据集,将其转化为四分类问题。你可以根据自己的数据集进行修改。
阅读全文