写一个torch人脸识别的简单程序
时间: 2023-02-08 18:33:02 浏览: 165
Python-pytorch实现的人脸检测和人脸识别
3星 · 编辑精心推荐
首先,你需要有一个人脸识别数据集,其次需要安装 PyTorch。在这个例子中,我们将使用一个预训练的 ResNet-34 模型来提取人脸特征,然后使用这些特征来训练一个线性分类器。
``` python
import torch
import torchvision.models as models
from torch import nn
from torchvision import transforms
# 加载预训练的 ResNet-34 模型
model = models.resnet34(pretrained=True)
# 取出模型最后一层(全连接层)
model = nn.Sequential(*list(model.children())[:-1])
# 预处理图像数据
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 读取图像,提取特征
img = Image.open("example.jpg")
img_tensor = transform(img).unsqueeze(0)
features = model(img_tensor)
# 训练线性分类器
classifier = nn.Linear(512, num_classes)
optimizer = torch.optim.SGD(classifier.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
# 迭代训练
for epoch in range(num_epochs):
# 计算分类概率
logits = classifier(features)
loss = criterion(logits, labels)
# 更新参数
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 在新图像上进行预测
img = Image.open("test.jpg")
img_tensor = transform(img).unsqueeze(0)
features = model(img_tensor)
logits = classifier(features)
predicted_label = torch.argmax(logits)
```
这只是一个简单的例子,你还需要自己准备数据集并调整模型参数以
阅读全文