python读取csv文件数据用于建立训练集和测试集,csv文件前5列为特征,第六列为标签,编写代码适用于tensorflow2.0以上版本
时间: 2024-04-11 07:28:23 浏览: 137
以下是一个示例代码,用于读取CSV文件并创建训练集和测试集,然后使用TensorFlow 2.0以上版本进行模型训练:
```python
import pandas as pd
import tensorflow as tf
from sklearn.model_selection import train_test_split
# 读取CSV文件
data = pd.read_csv('data.csv')
# 提取特征列和标签列
features = data.iloc[:, :5]
labels = data.iloc[:, 5]
# 划分训练集和测试集
train_features, test_features, train_labels, test_labels = train_test_split(features, labels, test_size=0.2)
# 创建TensorFlow的训练数据集和测试数据集
train_dataset = tf.data.Dataset.from_tensor_slices((train_features.values, train_labels.values))
test_dataset = tf.data.Dataset.from_tensor_slices((test_features.values, test_labels.values))
# 定义模型
model = tf.keras.Sequential([
tf.keras.layers.Dense(10, activation='relu', input_shape=(5,)),
tf.keras.layers.Dense(1)
])
# 编译模型
model.compile(optimizer='adam',
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
# 训练模型
model.fit(train_dataset.shuffle(len(train_features)).batch(32),
epochs=10,
validation_data=test_dataset.batch(32))
```
注意:请确保将`data.csv`替换为你的CSV文件的实际路径。此代码假设标签列是二进制分类问题,可以根据实际情况进行修改。
阅读全文