pytorch卷积神经网络把一组图分为两类代码
时间: 2024-10-21 18:10:02 浏览: 12
卷积神经网络pytorch简易代码实现
在PyTorch中创建一个卷积神经网络(Convolutional Neural Network, CNN)用于图像分类的基本步骤包括以下几个部分:
1. 导入必要的库:
```python
import torch
import torchvision
import torchvision.transforms as transforms
```
2. 数据预处理:
```python
transform = transforms.Compose([
transforms.Resize((64, 64)), # 图像缩放大小
transforms.RandomHorizontalFlip(), # 随机水平翻转增强数据
transforms.ToTensor(), # 转换成张量
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化像素值
])
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
```
3. 创建数据加载器:
```python
batch_size = 128
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
```
4. 构建CNN模型:
```python
model = torchvision.models.resnet18(pretrained=True) # 使用预训练的ResNet18
num_classes = 2 # 分类的类别数,这里是二分类任务
model.fc = torch.nn.Linear(model.fc.in_features, num_classes) # 修改最后一层全连接层
model = model.to(device) # 将模型移到GPU上,如果可用
```
5. 定义损失函数和优化器:
```python
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
```
6. 训练过程:
```python
num_epochs = 10
for epoch in range(num_epochs):
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
# 前向传播
outputs = model(images)
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}")
```
7. 测试模型:
```python
model.eval() # 设置为评估模式
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Test Accuracy of the Model: {100 * correct / total}%")
```
阅读全文