VGG16图像检索 pytorch
时间: 2023-12-20 19:32:41 浏览: 162
以下是使用PyTorch实现VGG16图像检索的示例代码:
```python
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
# 加载预训练的VGG16模型
vgg16 = models.vgg16(pretrained=True)
# 设置模型为评估模式
vgg16.eval()
# 图像预处理
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 加载并预处理图像
image = Image.open('image.jpg')
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0)
# 使用GPU加速(如果可用)
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
vgg16.to('cuda')
# 前向传播
with torch.no_grad():
output = vgg16(input_batch)
# 获取预测结果
_, predicted_idx = torch.max(output, 1)
predicted_label = predicted_idx.item()
# 加载ImageNet标签
with open('imagenet_labels.txt') as f:
labels = f.readlines()
# 打印预测结果
print('Predicted label:', labels[predicted_label])
```
请注意,上述代码中的`image.jpg`是待检索的图像文件,`imagenet_labels.txt`是ImageNet数据集的标签文件。
阅读全文