tf训练MNIST包括自定义DataGenerater
时间: 2024-04-29 20:22:47 浏览: 9
MNIST是一个非常流行的手写数字识别数据集,许多深度学习框架都提供了预处理好的MNIST数据集,但是在实际应用中,我们可能需要自己定义数据生成器来处理自己的数据集。下面是一个使用TensorFlow训练MNIST并自定义数据生成器的示例代码:
```python
import tensorflow as tf
from tensorflow.keras.datasets import mnist
# 下载MNIST数据集
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# 将数据集归一化
train_images, test_images = train_images / 255.0, test_images / 255.0
# 定义数据生成器
def data_generator(images, labels, batch_size):
while True:
# 随机打乱数据
indices = tf.random.shuffle(tf.range(len(images)))
images = tf.gather(images, indices)
labels = tf.gather(labels, indices)
# 分批处理数据
for i in range(0, len(images), batch_size):
yield images[i:i+batch_size], labels[i:i+batch_size]
# 定义模型
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
# 编译模型
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# 训练模型
batch_size = 32
train_generator = data_generator(train_images, train_labels, batch_size)
test_generator = data_generator(test_images, test_labels, batch_size)
model.fit(train_generator, epochs=5, steps_per_epoch=len(train_images)//batch_size,
validation_data=test_generator, validation_steps=len(test_images)//batch_size)
# 评估模型
model.evaluate(test_images, test_labels)
```
在上面的代码中,我们首先下载了MNIST数据集,并将数据集归一化。然后定义了一个数据生成器`data_generator`,该生成器可以无限循环地生成数据,并随机打乱数据。我们使用`yield`语句来生成数据,这样我们就可以在训练过程中不停地生成新的数据。接着,我们定义了一个简单的神经网络模型,并编译模型。最后,我们使用`fit`函数训练模型,并使用`evaluate`函数评估模型。需要注意的是,在使用`fit`函数训练模型时,我们传入的是数据生成器而不是原始数据集。这样做的好处是可以在不占用过多内存的情况下训练模型。