对csv分类数据利用pytorch搭建全连接神经网络
时间: 2023-07-30 21:09:49 浏览: 79
好的,我可以回答你的问题。首先,你需要导入 PyTorch 库和其他必要的库,例如 pandas 和 numpy。然后,你应该准备你的数据,这里假设你有一个名为 data.csv 的数据集,其中包含分类变量和目标变量。接下来,你可以使用 pandas 库读取 csv 文件并将数据分成输入特征和目标变量。然后,你需要将分类变量转换为数字标签。这可以通过使用 LabelEncoder 类来完成,该类将每个分类变量映射到一个唯一的数字标签。接下来,你需要将数据集分成训练集和测试集。这可以通过使用 train_test_split 函数来完成。最后,你可以使用 PyTorch 来搭建全连接神经网络。以下是一个示例代码:
```python
import torch
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
# 读取数据集
data = pd.read_csv('data.csv')
# 分离输入特征和目标变量
X = data.iloc[:, :-1].values
y = data.iloc[:, -1].values
# 将分类变量转换为数字标签
le = LabelEncoder()
for i in range(X.shape[1]):
if type(X[0][i]) == str:
X[:, i] = le.fit_transform(X[:, i])
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
# 将 numpy 数组转换为 tensor
X_train = torch.from_numpy(X_train).float()
y_train = torch.from_numpy(y_train).long()
X_test = torch.from_numpy(X_test).float()
y_test = torch.from_numpy(y_test).long()
# 定义全连接神经网络
class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden) # 隐藏层
self.out = torch.nn.Linear(n_hidden, n_output) # 输出层
def forward(self, x):
x = torch.relu(self.hidden(x)) # 激活函数
x = self.out(x)
return x
# 定义模型和优化器
net = Net(n_feature=X_train.shape[1], n_hidden=10, n_output=len(np.unique(y_train)))
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
loss_func = torch.nn.CrossEntropyLoss()
# 训练模型
for epoch in range(100):
out = net(X_train)
loss = loss_func(out, y_train)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 输出训练误差
if epoch % 10 == 0:
print('Epoch: %d | Loss: %.4f' % (epoch, loss.item()))
# 测试模型
with torch.no_grad():
out = net(X_test)
predictions = torch.argmax(out, axis=1)
accuracy = (predictions == y_test).sum() / len(y_test)
print('Accuracy:', accuracy)
```
在这个示例中,我们使用了一个具有一个隐藏层的全连接神经网络。你可以根据你的数据集和需要调整输入特征的数量、隐藏层的大小和输出层的数量。训练模型时,我们使用了交叉熵损失函数和随机梯度下降优化器。测试模型时,我们计算了模型的准确度。