写一个tensorflow读取csv文件及搭建神经网络训练的代码
时间: 2023-04-04 16:01:35 浏览: 125
可以的,以下是一个读取csv文件并搭建神经网络进行训练的TensorFlow代码示例:
```python
import tensorflow as tf
import numpy as np
# 读取csv文件
filename = 'data.csv'
dataset = tf.data.experimental.CsvDataset(filename, [tf.float32, tf.float32, tf.float32, tf.float32], header=True)
# 将数据集分为训练集和测试集
train_dataset = dataset.take(80).shuffle(100).batch(16)
test_dataset = dataset.skip(80).batch(16)
# 搭建神经网络模型
model = tf.keras.Sequential([
tf.keras.layers.Dense(10, activation='relu', input_shape=(3,)),
tf.keras.layers.Dense(1)
])
# 编译模型
model.compile(optimizer='adam', loss='mse')
# 训练模型
model.fit(train_dataset, epochs=50)
# 评估模型
loss = model.evaluate(test_dataset)
print('Test Loss: {}'.format(loss))
```
这个代码可以读取名为"data.csv"的csv文件,其中包含4列数据,前3列是特征,最后一列是标签。代码将数据集分为训练集和测试集,搭建了一个包含2个全连接层的神经网络模型,并使用均方误差作为损失函数进行训练和评估。
阅读全文