pytorch检测代码
时间: 2023-09-11 12:11:42 浏览: 41
以下是一个简单的 PyTorch 检测代码示例:
```python
import torch
import torch.nn.functional as F
from PIL import Image
# 定义模型
class Detector(torch.nn.Module):
def __init__(self):
super(Detector, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 32, kernel_size=3)
self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3)
self.fc1 = torch.nn.Linear(64*22*22, 128)
self.fc2 = torch.nn.Linear(128, 2)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = x.view(-1, 64*22*22)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# 加载模型
model = Detector()
model.load_state_dict(torch.load("detector.pt"))
# 加载图像
img = Image.open("test.jpg").convert("RGB")
# 图像预处理
img = transforms.Resize((224,224))(img)
img = transforms.ToTensor()(img)
img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img)
# 执行模型推理
model.eval()
with torch.no_grad():
output = model(img.unsqueeze(0))
_, predicted = torch.max(output.data, 1)
# 输出结果
if predicted[0] == 0:
print("Not a cat.")
else:
print("It's a cat!")
```
上述代码中,我们定义了一个简单的卷积神经网络模型 `Detector`,并使用 `torch.load` 方法加载训练好的模型参数。接着,我们使用 Pillow 库中的 `Image` 类加载测试图像,进行预处理,并将其输入到模型中进行推理。最后,根据模型输出的结果进行判断,输出检测结果。