基于pytorch环境用arcface训练lfw模型,并保存模型,同时测试模型的完整代码
时间: 2023-06-10 09:07:39 浏览: 177
Pytorch——保存训练好的模型参数
5星 · 资源好评率100%
以下是基于 PyTorch 环境下使用 ArcFace 训练 LFW 数据集并保存模型的示例代码,同时也包含了测试模型的部分。
首先,需要安装必要的 PyTorch 和 torchvision 库:
```
pip install torch torchvision
```
接着,我们下载并解压 LFW 数据集,可以从以下链接获取:
http://vis-www.cs.umass.edu/lfw/lfw.tgz
解压后得到的目录结构应该是这样的:
```
lfw/
person1/
image1.jpg
image2.jpg
...
person2/
image1.jpg
image2.jpg
...
...
```
其中,每个人的照片都放在一个单独的目录下,目录名即为人名。
接下来是训练代码:
```python
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import numpy as np
# 定义一些超参数
batch_size = 32
num_workers = 4
num_epochs = 10
embedding_size = 512
lr = 0.1
momentum = 0.9
weight_decay = 5e-4
num_classes = 5749 # LFW 数据集中的人数
# 定义数据预处理
transform = transforms.Compose([
transforms.Resize((112, 112)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# 定义数据集类
class LFWDataset(Dataset):
def __init__(self, root):
self.root = root
self.img_paths = []
self.labels = []
self.class_dict = {}
self.class_idx = 0
# 遍历数据集,获取所有图片路径和标签
for person_name in os.listdir(root):
person_dir = os.path.join(root, person_name)
if not os.path.isdir(person_dir):
continue
img_names = os.listdir(person_dir)
self.class_dict[person_name] = self.class_idx
self.class_idx += 1
for img_name in img_names:
img_path = os.path.join(person_dir, img_name)
self.img_paths.append(img_path)
self.labels.append(self.class_dict[person_name])
def __getitem__(self, index):
img_path = self.img_paths[index]
label = self.labels[index]
img = Image.open(img_path).convert('RGB')
img = transform(img)
return img, label
def __len__(self):
return len(self.img_paths)
# 定义模型
class ArcFace(nn.Module):
def __init__(self, num_classes, embedding_size):
super(ArcFace, self).__init__()
self.num_classes = num_classes
self.embedding_size = embedding_size
self.backbone = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d((1, 1))
)
self.fc = nn.Linear(512, embedding_size)
self.fc_arc = nn.Linear(embedding_size, num_classes)
def forward(self, x, labels=None):
x = self.backbone(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
if labels is not None:
w = self.fc_arc.weight
ww = torch.norm(w, dim=1, keepdim=True)
w = w / ww
x_norm = torch.norm(x, dim=1, keepdim=True)
x = x / x_norm
cos_theta = torch.matmul(x, w.transpose(0, 1))
cos_theta = cos_theta.clamp(-1, 1)
theta = torch.acos(cos_theta)
one_hot = torch.zeros_like(cos_theta)
one_hot.scatter_(1, labels.view(-1, 1), 1)
if x.is_cuda:
one_hot = one_hot.cuda()
target_logit = cos_theta * one_hot + (1 - one_hot) * (torch.cos(theta + 0.5))
output = self.fc_arc(target_logit)
else:
output = self.fc_arc(x)
return output
# 创建数据集和 DataLoader
train_dataset = LFWDataset('lfw')
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
# 创建模型和优化器
model = ArcFace(num_classes, embedding_size)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
# 将模型放入 GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)
# 开始训练
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_dataloader):
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(images, labels)
loss = nn.CrossEntropyLoss()(outputs, labels)
loss.backward()
optimizer.step()
if (i + 1) % 10 == 0:
print('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f'
% (epoch + 1, num_epochs, i + 1, len(train_dataloader), loss.item()))
# 保存模型
torch.save(model.state_dict(), 'arcface_lfw.pth')
```
训练完成后,我们可以使用以下代码来测试模型:
```python
# 加载模型
model = ArcFace(num_classes, embedding_size)
model.load_state_dict(torch.load('arcface_lfw.pth'))
model.to(device)
# 创建测试集
class LFWTestDataset(Dataset):
def __init__(self, pairs_path, root):
self.pairs_path = pairs_path
self.root = root
self.transform = transform
with open(pairs_path) as f:
pairs_lines = f.readlines()
self.pairs = []
for line in pairs_lines[1:]:
pair = line.strip().split('\t')
if len(pair) == 3:
self.pairs.append((os.path.join(root, pair[0], pair[1] + '.jpg'),
os.path.join(root, pair[0], pair[2] + '.jpg')))
elif len(pair) == 4:
self.pairs.append((os.path.join(root, pair[0], pair[1] + '.jpg'),
os.path.join(root, pair[2], pair[3] + '.jpg')))
else:
raise ValueError('Invalid pair: %s' % line.strip())
def __len__(self):
return len(self.pairs)
def __getitem__(self, index):
pair = self.pairs[index]
img1 = Image.open(pair[0]).convert('RGB')
img2 = Image.open(pair[1]).convert('RGB')
img1 = self.transform(img1)
img2 = self.transform(img2)
return img1, img2
# 定义测试函数
def test(model, pairs_path, root):
model.eval()
test_dataset = LFWTestDataset(pairs_path, root)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=num_workers)
correct = 0
total = 0
with torch.no_grad():
for images1, images2 in test_dataloader:
images1 = images1.to(device)
images2 = images2.to(device)
embeddings1 = model(images1)
embeddings2 = model(images2)
distances = torch.sum((embeddings1 - embeddings2) ** 2, dim=1)
for i in range(len(distances)):
if test_dataset.pairs[total + i][0].split('/')[-2] == test_dataset.pairs[total + i][1].split('/')[-2]:
if distances[i] < 1.0:
correct += 1
else:
if distances[i] > 1.0:
correct += 1
total += len(distances)
acc = 100.0 * correct / total
print('Accuracy: %.2f%% (%d/%d)' % (acc, correct, total))
# 测试模型
test(model, 'lfw/pairs.txt', 'lfw')
```
注意,上面的 `test` 函数中,我们使用了 LFW 数据集中的 `pairs.txt` 文件,这个文件包含了若干行对比图片的路径和标签,我们可以根据这些对比图片计算模型的准确率。
完整代码可在以下链接中找到:
https://github.com/JNingWei/arcface-lfw-pytorch
阅读全文