pytorch图片识别
时间: 2023-11-30 07:43:12 浏览: 38
PyTorch是一个基于Python的科学计算库,它支持GPU加速,提供了两个高级功能: 1)具有强大的GPU加速的张量计算(如NumPy) 2)包含自动求导系统的深度神经网络。因此,PyTorch非常适合用于图像识别任务。以下是一个简单的PyTorch图像识别的例子:
```python
import torch
import torchvision
from torchvision import transforms
# 加载预训练模型
model = torchvision.models.resnet18(pretrained=True)
# 将模型设置为评估模式
model.eval()
# 加载并预处理图像
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
image = Image.open("image.jpg")
image = transform(image)
image = image.unsqueeze(0)
# 使用模型进行预测
with torch.no_grad():
output = model(image)
# 获取预测结果
_, predicted = torch.max(output.data, 1)
print(predicted)
```
上述代码使用了预训练的ResNet-18模型对一张图像进行分类。首先,我们加载了预训练模型,并将其设置为评估模式。然后,我们加载并预处理了一张图像。最后,我们使用模型进行预测,并获取预测结果。