猫狗训练模型代码实例
时间: 2024-05-01 15:15:41 浏览: 14
这篇博客使用TensorFlow框架,使用预训练模型进行猫狗分类。代码在Github上已经公开,可以从[这里](https://github.com/jmhIcoding/dogsVScats)获取。
使用预训练模型进行微调的代码如下,其中包括了数据处理、模型构建和训练三个部分:
数据处理部分[^1]:
```python
import tensorflow as tf
slim = tf.contrib.slim
...
def get_dataset(dataset_name, split_name, dataset_dir, file_pattern):
"""
获取指定数据集和指定数据集中的数据切分
"""
file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
if dataset_name == 'imagenet':
return dataset.get_split(split_name, dataset_dir, file_pattern)
elif dataset_name == 'flowers':
return flowers.get_split(split_name, dataset_dir, file_pattern)
elif dataset_name == 'cifar10':
return cifar10.get_split(split_name, dataset_dir, file_pattern)
elif dataset_name == 'mnist':
return mnist.get_split(split_name, dataset_dir, file_pattern)
elif dataset_name == 'cats_vs_dogs':
return dogs.get_split(split_name, dataset_dir, file_pattern)
else:
raise ValueError('Invalid dataset name %s.' % dataset_name)
def load_batch(dataset, batch_size, height, width, is_training=True):
"""
加载一批数据
"""
data_provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
shuffle=is_training,
common_queue_capacity=2 * batch_size,
common_queue_min=batch_size)
image_raw, label = data_provider.get(['image', 'label'])
image = inception_preprocessing.preprocess_image(
image_raw,
height,
width,
is_training=is_training)
images, labels = tf.train.batch(
[image, label],
batch_size=batch_size,
num_threads=4,
capacity=5 * batch_size)
return images, labels
```
模型构建部分:
```python
import tensorflow as tf
slim = tf.contrib.slim
...
def build_model(inputs, num_classes, is_training=True, scope='vgg_16'):
"""
构建VGG16模型
"""
with tf.variable_scope(scope, 'vgg_16', [inputs]) as sc:
end_points_collection = sc.name + '_end_points'
with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
outputs_collections=end_points_collection):
net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')
net = slim.max_pool2d(net, [2, 2], scope='pool1')
...
net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv5')
net = slim.max_pool2d(net, [2, 2], scope='pool5')
net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6')
net = slim.dropout(net, 0.5, is_training=is_training,
scope='dropout6')
net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
net = slim.dropout(net, 0.5, is_training=is_training,
scope='dropout7')
net = slim.conv2d(net, num_classes, [1, 1],
activation_fn=None,
normalizer_fn=None,
scope='fc8')
end_points = slim.utils.convert_collection_to_dict(end_points_collection)
return net, end_points
```
训练部分:
```python
import tensorflow as tf
slim = tf.contrib.slim
...
def run_training(dataset_name, train_dir, dataset_dir, num_classes=2, batch_size=32, num_epochs=10, initial_learning_rate=0.0001):
"""
训练模型
"""
with tf.Graph().as_default():
tf.logging.set_verbosity(tf.logging.INFO)
# 获取数据集
dataset = get_dataset(dataset_name, 'train', dataset_dir, '%s_*.tfrecord')
images, labels = load_batch(dataset, batch_size=batch_size, height=224, width=224, is_training=True)
# 构建网络
logits, end_points = build_model(images, num_classes=num_classes, is_training=True)
# 定义损失函数
one_hot_labels = slim.one_hot_encoding(labels, num_classes)
slim.losses.softmax_cross_entropy(logits=logits, onehot_labels=one_hot_labels)
total_loss = slim.losses.get_total_loss()
# 定义优化器
global_step = tf.train.get_or_create_global_step()
learning_rate = tf.train.exponential_decay(
initial_learning_rate,
global_step,
decay_steps=1000,
decay_rate=0.96,
staircase=True)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
train_op = slim.learning.create_train_op(total_loss, optimizer, global_step=global_step)
# 进行训练
saver = tf.train.Saver(tf.global_variables())
slim.learning.train(
train_op,
train_dir,
log_every_n_steps=1,
save_summaries_secs=20,
saver=saver,
number_of_steps=num_epochs * dataset.num_samples // batch_size,
save_interval_secs=120)
if __name__ == '__main__':
run_training('cats_vs_dogs', '/tmp/cats_vs_dogs', '/path/to/dataset')
```