python CNN图像分类
时间: 2023-08-22 11:10:44 浏览: 108
Python CNN图像分类通常包括以下几个步骤:
1. 导入所需库
2. 定义数据集类
3. 初始化模型
4. 训练模型
5. 测试模型
下面是一个简单的示例代码:
```python
import numpy as np
import os
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torch import nn, optim
# 定义数据集类
class ClassifyDataset(Dataset):
def __init__(self, root_path, data_file, img_size=256):
self.data_files = np.loadtxt(data_file, dtype=np.str)
self.root_path = root_path
self.class_list = os.listdir(os.path.join(root_path, 'train'))
self.transforms = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor()])
def __getitem__(self, item):
data_file = self.data_files[item]
data_file = os.path.join(self.root_path, data_file)
img = Image.open(data_file)
data_file = data_file.replace('/', '\\')
tmp = data_file.split('\\')
label_name = tmp[-2]
label = self.class_list.index(label_name)
img = self.transforms(img)
label = torch.tensor(label)
return img, label
def __len__(self):
return len(self.data_files)
# 初始化模型
def init_model():
model = torchvision.models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(CLASSES))
model = model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
return model, criterion, optimizer
# 训练模型
def train_model(model, criterion, optimizer, trainloader, epochs=25):
for epoch in range(epochs):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data[0].cuda(), data[1].cuda()
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
print('Epoch %d loss: %.3f' % (epoch + 1, running_loss / (i + 1)))
# 测试模型
def test_model(model, testloader):
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data[0].cuda(), data[1].cuda()
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))
```
阅读全文