torch人体关键点检测代码
时间: 2024-05-01 18:19:08 浏览: 171
以下是使用PyTorch实现的人体关键点检测代码,使用了COCO数据集和预训练的Hourglass网络。
```python
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import numpy as np
import cv2
import json
import os
class COCODataset(Dataset):
def __init__(self, root_dir, ann_file, transform):
self.root_dir = root_dir
self.transform = transform
with open(ann_file) as f:
self.annotations = json.load(f)['annotations']
def __len__(self):
return len(self.annotations)
def __getitem__(self, idx):
annotation = self.annotations[idx]
img_path = os.path.join(self.root_dir, annotation['file_name'])
img = cv2.imread(img_path)
keypoints = np.array(annotation['keypoints']).reshape(-1, 3)
keypoints = keypoints[:, :2] # 取前两列
if self.transform:
img, keypoints = self.transform(img, keypoints)
return img, keypoints
class ToTensor(object):
def __call__(self, img, keypoints):
img = img.transpose((2, 0, 1)).astype(np.float32) / 255.0
keypoints = keypoints.astype(np.float32)
return torch.from_numpy(img), torch.from_numpy(keypoints)
class RandomHorizontalFlip(object):
def __init__(self, p=0.5):
self.p = p
def __call__(self, img, keypoints):
if np.random.rand() < self.p:
img = cv2.flip(img, 1)
keypoints[:, 0] = img.shape[1] - keypoints[:, 0]
return img, keypoints
class HourglassNet(nn.Module):
def __init__(self, num_stacks=2, num_blocks=4, num_classes=17):
super().__init__()
self.num_stacks = num_stacks
self.num_blocks = num_blocks
self.num_classes = num_classes
self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3)
self.res1 = ResidualBlock(64, 128)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.res2 = ResidualBlock(128, 128)
self.res3 = ResidualBlock(128, 256)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.res4 = ResidualBlock(256, 256)
self.res5 = ResidualBlock(256, 512)
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.res6 = ResidualBlock(512, 512)
self.res7 = ResidualBlock(512, 512)
self.res8 = ResidualBlock(512, 512)
self.res9 = ResidualBlock(512, 1024)
self.hourglass_modules = nn.ModuleList()
for i in range(num_stacks):
hourglass_module = nn.ModuleList()
for j in range(num_blocks):
hourglass_module.append(ResidualBlock(1024, 1024))
self.hourglass_modules.append(hourglass_module)
self.conv2 = nn.ModuleList()
self.conv3 = nn.ModuleList()
self.conv4 = nn.ModuleList()
self.conv5 = nn.ModuleList()
self.conv6 = nn.ModuleList()
self.conv7 = nn.ModuleList()
for i in range(num_stacks):
self.conv2.append(nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1))
self.conv3.append(nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1))
self.conv4.append(nn.Conv2d(in_channels=1024, out_channels=num_classes, kernel_size=1))
if i < num_stacks - 1:
self.conv5.append(nn.Conv2d(in_channels=num_classes, out_channels=1024, kernel_size=1))
self.conv6.append(nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1))
self.conv7.append(nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1))
def forward(self, x):
x = self.conv1(x)
x = self.res1(x) + self.res1(x)
x = self.pool1(x)
x = self.res2(x)
x = self.res3(x)
x = self.pool2(x)
x = self.res4(x)
x = self.res5(x)
x = self.pool3(x)
x = self.res6(x)
x = self.res7(x)
x = self.res8(x)
x = self.res9(x)
outputs = []
for i in range(self.num_stacks):
hourglass = self.hourglass_modules[i]
conv2 = self.conv2[i]
conv3 = self.conv3[i]
conv4 = self.conv4[i]
conv5 = self.conv5[i] if i < self.num_stacks - 1 else None
conv6 = self.conv6[i] if i < self.num_stacks - 1 else None
conv7 = self.conv7[i] if i < self.num_stacks - 1 else None
y = hourglass[0](x)
for j in range(self.num_blocks - 1):
y = hourglass[j + 1](y)
y = conv2(y)
y = conv3(y)
y = conv4(y)
outputs.append(y)
if conv5 is not None:
z = conv5(y)
z = conv6(z)
z = conv7(z)
x = x + y + z
if i < self.num_stacks - 1:
x = x + F.interpolate(outputs[-1], scale_factor=2, mode='nearest')
return outputs
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
self.bn1 = nn.BatchNorm2d(num_features=out_channels)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(num_features=out_channels)
self.relu2 = nn.ReLU()
self.conv3 = nn.Conv2d(in_channels=out_channels, out_channels=in_channels, kernel_size=1)
self.bn3 = nn.BatchNorm2d(num_features=in_channels)
self.relu3 = nn.ReLU()
def forward(self, x):
x1 = self.conv1(x)
x1 = self.bn1(x1)
x1 = self.relu1(x1)
x1 = self.conv2(x1)
x1 = self.bn2(x1)
x1 = self.relu2(x1)
x1 = self.conv3(x1)
x1 = self.bn3(x1)
x1 = self.relu3(x1)
return x + x1
def heatmaps_to_keypoints(heatmaps):
keypoints = []
for heatmap in heatmaps:
y, x = np.unravel_index(np.argmax(heatmap), heatmap.shape)
keypoints.append((x, y))
return keypoints
def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)
transform = transforms.Compose([
RandomHorizontalFlip(p=0.5),
ToTensor()
])
dataset_train = COCODataset(root_dir='path/to/coco/train2017',
ann_file='path/to/coco/annotations/person_keypoints_train2017.json',
transform=transform)
dataloader_train = DataLoader(dataset_train, batch_size=4, shuffle=True, num_workers=4)
model = HourglassNet(num_stacks=2, num_blocks=4, num_classes=17).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()
for epoch in range(10):
print(f'Epoch {epoch + 1}')
running_loss = 0.0
for i, (inputs, targets) in enumerate(dataloader_train):
inputs = inputs.to(device)
targets = targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = 0.0
for j, output in enumerate(outputs):
heatmap = output[:, :-1, :, :]
target = targets[:, :, :, j:j+1]
loss += criterion(heatmap, target)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
if i % 10 == 9:
print(f'Batch {i + 1}, Loss {running_loss / ((i + 1) * inputs.size(0)):.4f}')
print(f'Training Loss: {running_loss / len(dataset_train):.4f}')
model.eval()
transform = transforms.Compose([
ToTensor()
])
dataset_val = COCODataset(root_dir='path/to/coco/val2017',
ann_file='path/to/coco/annotations/person_keypoints_val2017.json',
transform=transform)
dataloader_val = DataLoader(dataset_val, batch_size=1, shuffle=False)
with torch.no_grad():
num_correct = 0
num_total = 0
for inputs, targets in dataloader_val:
inputs = inputs.to(device)
targets = targets.numpy()[0]
outputs = model(inputs)
heatmaps = outputs[-1].cpu().numpy()[0]
keypoints_pred = heatmaps_to_keypoints(heatmaps)
for j, kp in enumerate(keypoints_pred):
if np.linalg.norm(kp - targets[j]) < 10:
num_correct += 1
num_total += 1
print(f'Accuracy: {num_correct / num_total:.4f}')
if __name__ == '__main__':
main()
```
其中,`COCODataset`类用于加载COCO数据集,`ToTensor`类用于将图像和关键点转换为PyTorch张量,`RandomHorizontalFlip`类用于进行随机水平翻转增强。`HourglassNet`类实现了Hourglass网络,用于预测关键点热图。`heatmaps_to_keypoints`函数将热图转换为关键点坐标。`main`函数用于训练和测试模型。
阅读全文