基于pytorch的深度学习场景识别代码及数据集
时间: 2024-01-30 22:03:48 浏览: 187
以下是基于PyTorch的深度学习场景识别代码及数据集:
## 数据集
1. Places365-Standard:该数据集包含365种不同的场景类别,每个类别有500张图片。下载链接:http://data.csail.mit.edu/places/places365/val_256.tar
2. MIT Indoor:该数据集包含67种室内场景类别,每个类别有80张图片。下载链接:http://web.mit.edu/torralba/www/indoor.html
## 代码
以下是一个使用ResNet50模型进行场景分类的PyTorch示例代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
# 定义场景类别数和训练参数
num_classes = 365
batch_size = 32
num_epochs = 10
learning_rate = 0.001
# 定义数据预处理
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 定义数据集
class SceneDataset(Dataset):
def __init__(self, root, transform=None):
self.root = root
self.transform = transform
self.filenames = []
self.labels = []
self._read_data()
def __len__(self):
return len(self.filenames)
def __getitem__(self, index):
filename = self.filenames[index]
label = self.labels[index]
image = Image.open(os.path.join(self.root, filename)).convert('RGB')
if self.transform is not None:
image = self.transform(image)
return image, label
def _read_data(self):
with open(os.path.join(self.root, 'places365_val.txt'), 'r') as f:
lines = f.readlines()
for line in lines:
filename, label = line.strip().split()
self.filenames.append(filename)
self.labels.append(int(label))
# 定义模型
model = models.resnet50(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, num_classes)
# 定义优化器和损失函数
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
# 定义训练函数
def train(model, dataloader):
model.train()
total_loss = 0.0
for i, (inputs, labels) in enumerate(dataloader):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
if (i + 1) % 100 == 0:
print('Batch [{}/{}], Loss: {:.4f}'.format(i + 1, len(dataloader), total_loss / (i + 1)))
# 定义测试函数
def test(model, dataloader):
model.eval()
with torch.no_grad():
correct, total = 0, 0
for inputs, labels in dataloader:
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy on the test set: %d %%' % (100 * correct / total))
# 加载数据集
data_root = './val_256/'
test_dataset = SceneDataset(data_root, transform)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
# 测试模型
test(model, test_dataloader)
```
在测试过程中,我们使用了`test()`函数来评估模型在测试集上的准确性。
阅读全文