基于pytorch的图像识别
时间: 2023-07-11 17:49:49 浏览: 46
PyTorch 是一个广泛用于深度学习的框架,它提供了许多用于图像识别的工具和模型。在 PyTorch 中,我们可以使用 torchvision 包来加载和处理图像数据集,并使用预先训练好的模型进行图像分类和识别。
下面是一个基于 PyTorch 的图像分类示例代码,可以帮助你更好地理解:
```
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
# 定义数据集的路径和预处理方式
data_path = "/data"
transform = transforms.Compose(
[transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# 加载训练集和测试集
trainset = torchvision.datasets.ImageFolder(root=data_path + "/train", transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
testset = torchvision.datasets.ImageFolder(root=data_path + "/test", transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=2)
# 定义模型,这里以 resnet18 为例
model = torchvision.models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# 训练模型
for epoch in range(10): # 对整个数据集训练 10 次
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 200 == 199: # 每 200 个 mini-batch 输出一次 loss
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 200))
running_loss = 0.0
print('Finished Training')
# 在测试集上测试模型
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the test images: %d %%' % (
100 * correct / total))
```
这个示例代码使用了 torchvision 包加载了一个图像数据集,使用了一个预先训练好的 ResNet18 模型进行分类,使用了 SGD 优化器和 CrossEntropyLoss 损失函数进行训练。你可以根据自己的需求修改这个代码并尝试不同的模型和超参数。