编写图像特征提取与检索的代码
时间: 2024-05-04 17:22:31 浏览: 11
图像特征提取和检索是计算机视觉领域的一个重要研究方向,常用的方法是使用深度学习模型对图像进行特征提取,然后使用相似性匹配算法进行检索。下面是一个简单的图像特征提取与检索的代码示例。
1. 特征提取
首先,我们需要使用深度学习模型对图像进行特征提取。在这里,我们选择使用预训练的 ResNet50 模型。代码如下:
```python
import torch
import torchvision.models as models
import torchvision.transforms as transforms
# 加载预训练模型
model = models.resnet50(pretrained=True)
# 移除最后一层全连接层
model = torch.nn.Sequential(*list(model.children())[:-1])
# 设置模型为评估模式
model.eval()
# 图像预处理
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 特征提取函数
def extract_features(image_path):
# 加载图像
image = Image.open(image_path).convert('RGB')
# 预处理图像
image = transform(image)
# 添加批次维度
image = image.unsqueeze(0)
# 获取特征向量
with torch.no_grad():
features = model(image)
# 压缩特征向量
features = features.squeeze()
features = features.numpy()
return features
```
2. 检索
接下来,我们使用相似性匹配算法对图像进行检索。在这里,我们选择使用余弦相似度作为相似性度量。代码如下:
```python
import os
import numpy as np
from PIL import Image
# 图像库
image_dir = 'image_dir'
image_files = os.listdir(image_dir)
# 特征向量库
feature_dir = 'feature_dir'
feature_files = os.listdir(feature_dir)
# 加载特征向量
features = []
for feature_file in feature_files:
feature_path = os.path.join(feature_dir, feature_file)
feature = np.load(feature_path)
features.append(feature)
features = np.array(features)
# 相似性匹配函数
def search_similar_images(image_path, top_k=5):
# 提取特征向量
query_feature = extract_features(image_path)
# 计算余弦相似度
similarities = np.dot(features, query_feature) / (np.linalg.norm(features, axis=1) * np.linalg.norm(query_feature))
# 获取前 top_k 相似图像
indices = similarities.argsort()[::-1][:top_k]
# 输出结果
print('Query image:')
Image.open(image_path).show()
for index in indices:
image_path = os.path.join(image_dir, image_files[index])
print('Similar image:')
Image.open(image_path).show()
```
这样,我们就完成了一个简单的图像特征提取与检索的代码。使用该代码,我们可以从图像库中找到与查询图像最相似的图像。