SVM的pytorch代码
时间: 2023-09-23 17:13:30 浏览: 102
以下是一个简单的基于PyTorch的SVM的代码示例:
```
import torch
import torch.nn as nn
import torch.optim as optim
class SVM(nn.Module):
def __init__(self, in_features):
super(SVM, self).__init__()
self.linear = nn.Linear(in_features, 1)
def forward(self, x):
return self.linear(x)
def hinge_loss(self, scores, targets):
margin = 1 - scores * targets
return torch.mean(torch.max(torch.zeros_like(margin), margin))
model = SVM(in_features=4)
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MarginRankingLoss()
# Training
num_epochs = 10
for epoch in range(num_epochs):
epoch_loss = 0.0
for i, (x, y) in enumerate(train_loader):
optimizer.zero_grad()
scores = model(x)
loss = model.hinge_loss(scores, y)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print("Epoch {}, Loss: {:.4f}".format(epoch+1, epoch_loss/len(train_loader)))
# Testing
model.eval()
correct = 0.0
total = 0.0
with torch.no_grad():
for x, y in test_loader:
scores = model(x)
preds = torch.sign(scores).flatten()
correct += torch.sum(preds == y)
total += len(y)
print("Accuracy: {:.2f}".format(correct/total * 100))
```
请注意,这里使用了带有hinge loss的线性SVM来解决二元分类问题。优化器使用了随机梯度下降(SGD)。训练过程中,我们简单地迭代每个minibatch,计算损失并反向传播。在测试集上评估时,我们使用了预测的符号值(正负)来计算准确性指标。
阅读全文