Implement linear SVM model for binary classification task and use autograd to optimize it by Pytorch.
时间: 2024-05-03 12:23:20 浏览: 122
python SVM 线性分类模型的实现
5星 · 资源好评率100%
Sure, here's an example implementation of a linear SVM model for binary classification using PyTorch and autograd for optimization:
```
import torch
import torch.nn as nn
import torch.optim as optim
class LinearSVM(nn.Module):
def __init__(self, input_dim):
super(LinearSVM, self).__init__()
self.weights = nn.Parameter(torch.randn(input_dim))
self.bias = nn.Parameter(torch.randn(1))
def forward(self, x):
linear_output = torch.dot(x, self.weights) + self.bias
return linear_output
# Define training data and labels
train_data = torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]])
train_labels = torch.tensor([1, 1, -1, -1])
# Initialize model and optimizer
svm = LinearSVM(input_dim=2)
optimizer = optim.SGD(svm.parameters(), lr=0.01)
# Define training loop
num_epochs = 1000
for epoch in range(num_epochs):
svm.train()
optimizer.zero_grad()
output = svm(train_data)
loss = torch.mean(torch.clamp(1 - train_labels * output, min=0))
loss.backward()
optimizer.step()
# Evaluate model on test data
test_data = torch.tensor([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [4.0, 4.0]])
svm.eval()
test_predictions = torch.sign(svm(test_data)).detach().numpy()
print(test_predictions)
```
In this example, we define a `LinearSVM` class that inherits from `nn.Module` and implements a linear SVM model with a single linear layer. We use `nn.Parameter` to define the model's weight and bias parameters, which are then optimized using the `optim.SGD` optimizer.
In the training loop, we compute the SVM loss using the hinge loss function and backpropagate the gradients using autograd. We then update the model parameters using the optimizer's `step` method.
Finally, we evaluate the trained model on some test data by passing it through the model and taking the sign of the output (since the SVM is a binary classifier). We use `detach().numpy()` to convert the output to a numpy array for easier interpretation.
Note: This is just a simple example implementation of a linear SVM in PyTorch using autograd. In practice, you may want to use a more robust implementation or library for SVMs, such as LIBLINEAR or scikit-learn.
阅读全文