给我一个使用预训练的resnet提取图像特征的python代码
时间: 2024-02-18 22:59:43 浏览: 229
使用预训练的resnet提取图像特征的Python代码如下:
```python
import torch
import torchvision.models as models
# 加载预训练的resnet18模型
resnet18 = models.resnet18(pretrained=True)
# 将模型设置为评估模式
resnet18.eval()
# 加载一张图片
img = Image.open('image.jpg')
# 对图片进行预处理
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
img_tensor = preprocess(img)
# 增加一个维度,变成batch size为1的4D张量
img_tensor = img_tensor.unsqueeze(0)
# 使用resnet18提取图像特征
with torch.no_grad():
features = resnet18(img_tensor)
# 将特征张量转换为NumPy数组
features = features.squeeze().numpy()
```
这段代码首先加载了预训练的resnet18模型,然后加载一张图片并进行预处理。接着将图片转换为张量,并使用resnet18模型提取特征。最后将特征张量转换为NumPy数组。
阅读全文