基于pytorch的nsl-kdd数据集训练和测试损失函数和准确率代码
时间: 2023-12-10 18:42:03 浏览: 163
以下是基于PyTorch的NSL-KDD数据集训练和测试的损失函数和准确率代码示例:
``` python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import NSLKDD
from torchvision import transforms
# 定义损失函数
loss_fn = nn.CrossEntropyLoss()
# 定义准确率计算函数
def accuracy(output, target):
with torch.no_grad():
_, preds = torch.max(output, dim=1)
correct = torch.sum(preds == target)
acc = correct.float() / target.shape[0]
return acc
# 定义模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(41, 64)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, 5)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
# 加载NSL-KDD数据集
train_dataset = NSLKDD(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = NSLKDD(root='./data', train=False, download=True, transform=transforms.ToTensor())
# 定义数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
# 初始化模型、优化器
model = Net()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练模型
for epoch in range(10):
train_loss, train_acc = 0.0, 0.0
model.train()
for data, target in train_loader:
optimizer.zero_grad()
output = model(data.view(-1, 41))
loss = loss_fn(output, target)
acc = accuracy(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item() * data.shape[0]
train_acc += acc.item() * data.shape[0]
train_loss /= len(train_loader.dataset)
train_acc /= len(train_loader.dataset)
# 测试模型
test_loss, test_acc = 0.0, 0.0
model.eval()
with torch.no_grad():
for data, target in test_loader:
output = model(data.view(-1, 41))
loss = loss_fn(output, target)
acc = accuracy(output, target)
test_loss += loss.item() * data.shape[0]
test_acc += acc.item() * data.shape[0]
test_loss /= len(test_loader.dataset)
test_acc /= len(test_loader.dataset)
# 打印训练和测试结果
print(f'Epoch {epoch + 1}/{10}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, '
f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')
```
这里使用了PyTorch的基本元素,包括数据集、数据加载器、模型、损失函数、优化器、训练和测试循环等。在训练和测试循环中,我们计算了每个批次的损失和准确率,并将其累加到总损失和准确率上。最后,我们将总损失和准确率除以数据集大小,得到平均损失和准确率。
阅读全文