再写一个行人属性识别训练图片代码
时间: 2024-02-21 07:02:12 浏览: 126
好的,以下是一个简单的行人属性识别训练图片代码,使用PyTorch框架:
```python
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import os
from PIL import Image
# 定义数据集类
class PedestrianDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.image_list = os.listdir(root_dir)
def __len__(self):
return len(self.image_list)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir, self.image_list[idx])
image = Image.open(img_name)
label = img_name.split('_')[-1].split('.')[0] # 根据文件名提取标签
if self.transform:
image = self.transform(image)
return {'image': image, 'label': label}
# 定义数据预处理
data_transforms = transforms.Compose([
transforms.RandomRotation(degrees=15),
transforms.RandomHorizontalFlip(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 加载训练集和验证集数据
train_dataset = PedestrianDataset(root_dir='path/to/train_data', transform=data_transforms)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataset = PedestrianDataset(root_dir='path/to/val_data', transform=data_transforms)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
# 定义模型
model = torch.nn.Sequential(
torch.nn.Conv2d(3, 16, kernel_size=3),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2),
torch.nn.Conv2d(16, 32, kernel_size=3),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2),
torch.nn.Flatten(),
torch.nn.Linear(32 * 54 * 54, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 2)
)
# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 训练模型
for epoch in range(10):
model.train()
for batch_idx, batch_data in enumerate(train_loader):
images, labels = batch_data['image'], batch_data['label']
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch_idx, batch_data in enumerate(val_loader):
images, labels = batch_data['image'], batch_data['label']
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Epoch: {} | Loss: {:.4f} | Validation Accuracy: {:.2f}%'.format(epoch+1, loss.item(), 100 * correct / total))
```
注意,这只是一个简单的模板,实际应用中需要根据具体情况进行调整和改进。
阅读全文