将.csv转换成.record
时间: 2024-04-29 21:22:59 浏览: 168
Nifi-RecordProcessing.pdf
将.csv文件转换成.record文件,需要经过以下步骤:
1. 准备数据
假设我们有一个.csv文件,其中包含了许多图片的标注信息,每一行代表一张图片的信息,包括图片的路径、标注框的坐标、标注框内的物体类别等。例如:
```
/path/to/image1.jpg,xmin1,ymin1,xmax1,ymax1,class1
/path/to/image2.jpg,xmin2,ymin2,xmax2,ymax2,class2
/path/to/image3.jpg,xmin3,ymin3,xmax3,ymax3,class3
...
```
2. 生成.tfrecord文件
使用TensorFlow提供的API,我们可以很容易地将.csv文件转换成.tfrecord文件。具体方法如下:
```python
import tensorflow as tf
# 定义函数,将一行csv数据转换成一个tf.train.Example对象
def create_example(image_path, xmin, ymin, xmax, ymax, class_name):
with tf.io.gfile.GFile(image_path, 'rb') as f:
encoded_image = f.read()
example = tf.train.Example(features=tf.train.Features(feature={
'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[encoded_image])),
'image/format': tf.train.Feature(bytes_list=tf.train.BytesList(value=['jpeg'.encode('utf-8')])),
'image/object/bbox/xmin': tf.train.Feature(float_list=tf.train.FloatList(value=[xmin])),
'image/object/bbox/ymin': tf.train.Feature(float_list=tf.train.FloatList(value=[ymin])),
'image/object/bbox/xmax': tf.train.Feature(float_list=tf.train.FloatList(value=[xmax])),
'image/object/bbox/ymax': tf.train.Feature(float_list=tf.train.FloatList(value=[ymax])),
'image/object/class/text': tf.train.Feature(bytes_list=tf.train.BytesList(value=[class_name.encode('utf-8')])),
'image/object/class/label': tf.train.Feature(int64_list=tf.train.Int64List(value=[class_id])),
}))
return example
# 读取csv文件,生成tf.train.Example对象,写入tfrecord文件
csv_file = '/path/to/csv_file.csv'
output_file = '/path/to/output_file.record'
with tf.io.gfile.GFile(output_file, 'wb') as f:
writer = tf.io.TFRecordWriter(f.name)
with tf.io.gfile.GFile(csv_file, 'r') as csvfile:
reader = csv.reader(csvfile)
for row in reader:
image_path, xmin, ymin, xmax, ymax, class_name = row
example = create_example(image_path, float(xmin), float(ymin), float(xmax), float(ymax), class_name)
writer.write(example.SerializeToString())
writer.close()
```
3. 验证.tfrecord文件
我们可以使用TensorFlow提供的API,验证生成的.tfrecord文件是否正确。具体方法如下:
```python
import tensorflow as tf
# 定义函数,从.tfrecord文件中读取tf.train.Example对象
def parse_example(example_proto):
features = {
'image/encoded': tf.io.FixedLenFeature([], tf.string),
'image/format': tf.io.FixedLenFeature([], tf.string),
'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32),
'image/object/class/text': tf.io.VarLenFeature(tf.string),
'image/object/class/label': tf.io.VarLenFeature(tf.int64),
}
parsed_features = tf.io.parse_single_example(example_proto, features)
return parsed_features
# 读取.tfrecord文件,验证其中包含的数据是否正确
tfrecord_file = '/path/to/output_file.record'
dataset = tf.data.TFRecordDataset(tfrecord_file)
dataset = dataset.map(parse_example)
for example in dataset.take(10):
print(example)
```
如果输出的结果和原始的.csv文件中的数据一致,就说明生成的.tfrecord文件是正确的。
阅读全文