请基于pytorch写一个猫狗识别的程序
时间: 2023-03-31 09:03:03 浏览: 138
你好,我可以回答这个问题。以下是基于 PyTorch 的猫狗识别程序的代码示例:
```
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
# 定义模型
class CatDogClassifier(nn.Module):
def __init__(self):
super(CatDogClassifier, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(64 * 28 * 28, 500)
self.fc2 = nn.Linear(500, 2)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = self.pool(torch.relu(self.conv3(x)))
x = x.view(-1, 64 * 28 * 28)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# 加载模型
model = CatDogClassifier()
model.load_state_dict(torch.load('catdog_classifier.pth'))
# 定义预处理函数
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225])
])
# 定义分类函数
def classify(image_path):
image = Image.open(image_path)
image_tensor = transform(image)
image_tensor = image_tensor.unsqueeze()
with torch.no_grad():
output = model(image_tensor)
_, predicted = torch.max(output.data, 1)
return predicted.item()
# 使用示例
image_path = 'cat.jpg'
result = classify(image_path)
if result == :
print('这是一只猫')
else:
print('这是一只狗')
```
以上是基于 PyTorch 的猫狗识别程序的代码示例,希望能对你有所帮助。
阅读全文