resnet推理代码
时间: 2024-04-17 22:22:34 浏览: 148
tensorRT部署resnet网络Python、c++源码
ResNet(Residual Network)是一种非常流行的深度学习模型,用于解决图像分类、目标检测等计算机视觉任务。下面是一个简单的ResNet推理代码示例:
```python
import torch
import torch.nn as nn
import torchvision.models as models
# 加载预训练的ResNet模型
resnet = models.resnet50(pretrained=True)
resnet.eval()
# 输入数据预处理
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 加载并预处理图像
image = Image.open('image.jpg')
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0)
# 使用GPU进行推理(如果可用)
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
resnet.to('cuda')
# 进行推理
with torch.no_grad():
output = resnet(input_batch)
# 输出结果
_, predicted_idx = torch.max(output, 1)
print('Predicted class:', predicted_idx.item())
```
这段代码使用了PyTorch框架,首先加载了预训练的ResNet-50模型,并将其设置为评估模式(eval)。然后,通过预处理将输入图像转换为模型所需的格式。接下来,将图像输入模型进行推理,并获取输出结果。最后,使用argmax函数找到输出中概率最高的类别,并打印出预测结果。
阅读全文