加载自己图片数据集并加以训练
时间: 2023-07-11 13:12:16 浏览: 90
好的,我来给你讲一下如何加载自己的图片数据集并进行训练。
首先,你需要将图片数据集按照一定的目录结构组织起来。例如,我们可以将不同类别的图片放在不同的目录中,每个目录的名称表示该类别的名称。具体来说,假设你的图片数据集包含两个类别:cat 和 dog,你可以按照如下的目录结构组织数据集:
```
data/
train/
cat/
cat001.jpg
cat002.jpg
...
dog/
dog001.jpg
dog002.jpg
...
val/
cat/
cat101.jpg
cat102.jpg
...
dog/
dog101.jpg
dog102.jpg
...
```
其中,train 目录下包含训练集,val 目录下包含验证集。每个类别的图片都放在对应的目录中。
接下来,我们需要使用 PyTorch 的 ImageFolder 类来加载数据集。ImageFolder 类会自动读取指定目录下的数据,并将数据按照类别进行分类。
```python
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
# 定义数据变换
data_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 加载数据集
train_dataset = datasets.ImageFolder('data/train', transform=data_transform)
val_dataset = datasets.ImageFolder('data/val', transform=data_transform)
# 创建 DataLoader
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
```
在上面的代码中,我们定义了一个数据变换 data_transform,它将图片先缩放到 224x224 的大小,然后进行随机水平翻转,最后转换成 Tensor,并进行归一化。接着,我们使用 ImageFolder 类加载数据集,并传入数据变换。最后,我们创建了两个 DataLoader,分别用于训练集和验证集。
现在,我们可以使用 DataLoader 进行模型训练了。具体来说,我们需要定义一个模型,并定义一个损失函数和一个优化器进行训练。
```python
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
# 加载预训练模型
model = models.resnet18(pretrained=True)
# 替换最后一层全连接层
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 2)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
for inputs, labels in train_dataloader:
# 前向传播
outputs = model(inputs)
# 计算损失
loss = criterion(outputs, labels)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 在验证集上评估模型
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in val_dataloader:
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print('Epoch [{}/{}], Loss: {:.4f}, Accuracy: {:.2f}%'.format(epoch+1, num_epochs, loss.item(), accuracy))
```
在上面的代码中,我们使用了一个预训练的 ResNet-18 模型,并将最后一层全连接层替换成一个二分类的线性层。我们使用交叉熵损失函数和 SGD 优化器进行训练。在每个 epoch 结束时,我们在验证集上评估模型的准确率。
希望这个回答能够帮助到你!
阅读全文
相关推荐
data:image/s3,"s3://crabby-images/c7f95/c7f957a578cbb465f17670ca5ec5de6d8fbcb44e" alt="-"
data:image/s3,"s3://crabby-images/76d5d/76d5dcefc5ad32aa65e7d5f6e5b202b09b84830d" alt="-"
data:image/s3,"s3://crabby-images/c7f95/c7f957a578cbb465f17670ca5ec5de6d8fbcb44e" alt="-"
data:image/s3,"s3://crabby-images/e802a/e802a808507cc67c433d0f14f4478cfc18013243" alt="-"
data:image/s3,"s3://crabby-images/6eee2/6eee29554420e01e83364d49443b3b12df11c8af" alt=""
data:image/s3,"s3://crabby-images/6eee2/6eee29554420e01e83364d49443b3b12df11c8af" alt=""
data:image/s3,"s3://crabby-images/6eee2/6eee29554420e01e83364d49443b3b12df11c8af" alt=""
data:image/s3,"s3://crabby-images/6eee2/6eee29554420e01e83364d49443b3b12df11c8af" alt=""
data:image/s3,"s3://crabby-images/6eee2/6eee29554420e01e83364d49443b3b12df11c8af" alt=""
data:image/s3,"s3://crabby-images/6eee2/6eee29554420e01e83364d49443b3b12df11c8af" alt=""
data:image/s3,"s3://crabby-images/e802a/e802a808507cc67c433d0f14f4478cfc18013243" alt="-"
data:image/s3,"s3://crabby-images/6eee2/6eee29554420e01e83364d49443b3b12df11c8af" alt=""
data:image/s3,"s3://crabby-images/6eee2/6eee29554420e01e83364d49443b3b12df11c8af" alt=""
data:image/s3,"s3://crabby-images/6eee2/6eee29554420e01e83364d49443b3b12df11c8af" alt=""
data:image/s3,"s3://crabby-images/6eee2/6eee29554420e01e83364d49443b3b12df11c8af" alt=""
data:image/s3,"s3://crabby-images/6eee2/6eee29554420e01e83364d49443b3b12df11c8af" alt=""
data:image/s3,"s3://crabby-images/6eee2/6eee29554420e01e83364d49443b3b12df11c8af" alt=""
data:image/s3,"s3://crabby-images/6eee2/6eee29554420e01e83364d49443b3b12df11c8af" alt=""
data:image/s3,"s3://crabby-images/6eee2/6eee29554420e01e83364d49443b3b12df11c8af" alt=""