支持向量机多分类pytorch代码
时间: 2024-11-11 18:14:22 浏览: 30
svm支持向量机python代码
支持向量机(SVM)是一种监督学习算法,主要用于二分类任务。在PyTorch库中,我们可以使用`sklearn`库结合PyTorch的模型进行多分类,因为`sklearn`有现成的支持向量机多类别分类模块。以下是一个简化的例子,展示如何使用`sklearn`的SVC在PyTorch环境中进行多分类:
```python
import torch
from sklearn import datasets
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
# 加载数据集(例如Iris)
iris = datasets.load_iris()
X = iris.data
y = iris.target
# 数据预处理,标准化特征
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
# 将标签转换为独热编码
y_train_one_hot = torch.from_numpy(to_categorical(y_train, num_classes=len(set(y))))
y_test_one_hot = torch.from_numpy(to_categorical(y_test, num_classes=len(set(y))))
# 使用sklearn的SVC进行多分类
svc = SVC(kernel='linear') # 可以尝试其他核函数,如'poly', 'rbf'
svc.fit(X_train, y_train_one_hot.argmax(axis=1)) # 训练模型
# 预测
y_pred = svc.predict(X_test)
y_pred_one_hot = torch.from_numpy(to_categorical(svc.predict(X_test), num_classes=len(set(y))))
# 计算准确率
accuracy = (y_pred == y_test).mean().item()
print("Accuracy:", accuracy)
```
请注意,这个例子仅用于演示,实际应用中可能需要对数据进行更多的预处理步骤,并可能需要调整模型超参数以达到更好的性能。
阅读全文