pytorch图像检索
时间: 2025-01-05 15:31:11 浏览: 5
### 如何使用 PyTorch 实现图像检索
#### 准备工作
为了构建基于卷积神经网络(CNN)的图像检索系统,首先需要安装必要的库。这些库包括但不限于 `torch`, `torchvision` 和其他辅助工具如 `matplotlib`、`numpy` 等[^3]。
```bash
pip install torch>=0.4.0 torchvision matplotlib numpy scipy pillow urllib3 scikit-image
```
#### 加载预训练模型并提取特征向量
通过加载一个预先训练好的 CNN 模型来作为特征抽取器是一个常见的做法。这里可以选择像 ResNet 这样的经典架构:
```python
import torch
from torchvision import models, transforms
from PIL import Image
import numpy as np
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = models.resnet50(pretrained=True).to(device)
model.eval()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
def extract_features(image_path):
img = Image.open(image_path).convert('RGB')
tensor = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
features = model(tensor)
return features.cpu().numpy()[0]
query_feature_vector = extract_features("path_to_query_image.jpg") # 查询图片路径
database_feature_vectors = [] # 数据库中所有图片对应的特征向量列表
for db_img in database_images_paths:
vec = extract_features(db_img)
database_feature_vectors.append(vec)
```
上述代码片段展示了如何利用预训练的ResNet50模型从给定的一张查询图片以及数据库内的多张图片中分别获取它们各自的特征表示[^1]。
#### 计算相似度得分
一旦拥有了两张或多张图片之间的特征表达形式之后,就可以计算彼此间的距离以衡量其相似程度。常用的方法有欧氏距离(Euclidean Distance),余弦相似度(Cosine Similarity)等:
```python
from sklearn.metrics.pairwise import cosine_similarity
similarities = []
for feature_vec in database_feature_vectors:
sim_score = cosine_similarity([feature_vec], [query_feature_vector])
similarities.append(sim_score.item())
sorted_indices = sorted(range(len(database_images_paths)), key=lambda i: similarities[i], reverse=True)[:top_k]
most_similar_images = [database_images_paths[idx] for idx in sorted_indices]
print(f"Top {len(most_similar_images)} most similar images are:")
for img in most_similar_images:
print(img)
```
这段脚本实现了对于每一对比较对象之间应用余弦相似性的评估过程,并最终返回最接近于查询样本的结果集合[^2]。
阅读全文