tensorflow2.0 csv文件
时间: 2023-09-29 11:08:01 浏览: 93
在 TensorFlow 2.0 中,你可以使用 `tf.data` 模块来读取 CSV 文件。以下是一个示例代码,它演示了如何使用 `tf.data.experimental.CsvDataset()` 函数从 CSV 文件中读取数据:
```python
import tensorflow as tf
# 定义 CSV 文件的列名和默认值
CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth', 'Species']
SPECIES = ['Setosa', 'Versicolor', 'Virginica']
DEFAULTS = [[0.0], [0.0], [0.0], [0.0], [0]]
# 读取 CSV 文件
def load_data(file_path):
dataset = tf.data.experimental.CsvDataset(
file_path,
record_defaults=DEFAULTS,
header=True,
field_delim=','
)
# 将数据转换为 (features, label) 对
def _parse_csv_row(*vals):
features = dict(zip(CSV_COLUMN_NAMES, vals))
label = features.pop('Species')
return features, tf.argmax(tf.equal(label, SPECIES))
return dataset.map(_parse_csv_row)
# 加载训练数据和测试数据
train_data = load_data('train.csv')
test_data = load_data('test.csv')
```
其中,`tf.data.experimental.CsvDataset()` 函数的参数解释如下:
- `file_path`:CSV 文件路径;
- `record_defaults`:每个列的默认值;
- `header`:是否包含标题行;
- `field_delim`:分隔符。
接下来,你可以使用 `train_data` 和 `test_data` 来训练和测试模型。
阅读全文