小样本训练的tensorflow代码
时间: 2024-01-21 10:03:30 浏览: 89
小样本训练通常使用的是迁移学习技术,即使用预训练好的模型作为初始模型,在小样本上微调模型。以下是一个基于TensorFlow的小样本训练代码示例:
```python
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import VGG16
# 加载预训练模型
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
# 冻结预训练模型的所有层
for layer in base_model.layers:
layer.trainable = False
# 添加新的全连接层
model = tf.keras.models.Sequential([
base_model,
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(256, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(1, activation='sigmoid')
])
# 编译模型
model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=2e-5),
loss='binary_crossentropy',
metrics=['accuracy'])
# 加载数据
train_data_generator = ImageDataGenerator(rescale=1./255, rotation_range=20, width_shift_range=0.1, height_shift_range=0.1, shear_range=0.1, zoom_range=0.1, horizontal_flip=True, fill_mode='nearest')
train_generator = train_data_generator.flow_from_directory('train_dir', target_size=(224, 224), batch_size=32, class_mode='binary')
validation_data_generator = ImageDataGenerator(rescale=1./255)
validation_generator = validation_data_generator.flow_from_directory('validation_dir', target_size=(224, 224), batch_size=32, class_mode='binary')
# 训练模型
history = model.fit(train_generator, steps_per_epoch=100, epochs=30, validation_data=validation_generator, validation_steps=50)
```
在这个示例中,我们使用了VGG16作为预训练模型,冻结了所有的层,并添加了一个新的全连接层。我们使用ImageDataGenerator来加载和增强数据,使用fit方法来训练模型。在训练过程中,我们使用了步数来控制每个epoch处理的批次数,以适应小样本训练的场景。
阅读全文