(256, 256) uint8的minist数据集怎么转成张量保存起来
时间: 2024-05-04 14:19:13 浏览: 9
可以使用`numpy`和`tensorflow`库来将`minist`数据集转换为张量并保存。以下是一个示例代码:
```python
import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import mnist
# 加载数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 转换为张量并进行归一化
x_train = tf.convert_to_tensor(x_train.reshape(-1, 28, 28, 1) / 255.0, dtype=tf.float32)
y_train = tf.convert_to_tensor(y_train, dtype=tf.int32)
x_test = tf.convert_to_tensor(x_test.reshape(-1, 28, 28, 1) / 255.0, dtype=tf.float32)
y_test = tf.convert_to_tensor(y_test, dtype=tf.int32)
# 保存为tfrecord格式
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def to_tfrecord(x, y, output_file):
with tf.io.TFRecordWriter(output_file) as writer:
for i in range(x.shape[0]):
feature = {
'image': _bytes_feature(tf.io.serialize_tensor(x[i])),
'label': _int64_feature(y[i])
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
to_tfrecord(x_train, y_train, 'train.tfrecord')
to_tfrecord(x_test, y_test, 'test.tfrecord')
```
上述代码将`minist`数据集转换为张量并进行归一化。然后使用`tfrecord`格式将训练集和测试集保存到磁盘上。注意需要定义`_bytes_feature`和`_int64_feature`两个辅助函数来将张量和标签转换为`tf.train.Feature`格式。