train_ds = tf.data.Dataset.from_tensor_slices( (train_x_all, train_target)).shuffle(20000).batch(32)
时间: 2024-04-27 14:22:29 浏览: 212
浅谈tensorflow中dataset.shuffle和dataset.batch dataset.repeat注意点
这段代码使用 TensorFlow 的 Dataset API,将训练数据集 train_x_all 和标签 train_target 分别转化为 Tensorflow 中的张量,并创建一个数据集对象。接着使用 shuffle 方法将数据集中的数据进行随机打乱,参数 20000 表示打乱数据时使用的缓冲区大小。最后使用 batch 方法将数据集中的数据按照指定的批次大小进行分批处理,参数 32 表示每个批次的大小为 32。这样就得到了一个可以用于训练神经网络的数据集对象 train_ds。
阅读全文