pytorch 花朵识别代码
时间: 2023-09-01 07:08:43 浏览: 43
以下是使用 PyTorch 实现花朵识别的示例代码:
首先,我们需要导入必要的包:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
```
然后,我们可以定义数据的预处理方式:
```python
transform = transforms.Compose(
[transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
```
接下来,我们可以定义数据集和数据加载器:
```python
train_data = datasets.ImageFolder('path/to/train/dataset', transform=transform)
test_data = datasets.ImageFolder('path/to/test/dataset', transform=transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)
```
然后,我们可以定义模型:
```python
class FlowerNet(nn.Module):
def __init__(self):
super(FlowerNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 102),
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
```
接下来,我们可以定义损失函数和优化器:
```python
model = FlowerNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
```
然后,我们可以进行训练和测试:
```python
num_epochs = 10
for epoch in range(num_epochs):
train_loss = 0.0
train_acc = 0.0
model.train()
for i, (images, labels) in enumerate(train_loader):
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item() * images.size(0)
_, preds = torch.max(outputs, 1)
train_acc += torch.sum(preds == labels.data)
train_loss = train_loss / len(train_data)
train_acc = train_acc / len(train_data)
print('Epoch: {} \tTraining Loss: {:.6f} \tTraining Accuracy: {:.6f}'.format(
epoch+1, train_loss, train_acc))
model.eval()
test_loss = 0.0
test_acc = 0.0
for i, (images, labels) in enumerate(test_loader):
outputs = model(images)
loss = criterion(outputs, labels)
test_loss += loss.item() * images.size(0)
_, preds = torch.max(outputs, 1)
test_acc += torch.sum(preds == labels.data)
test_loss = test_loss / len(test_data)
test_acc = test_acc / len(test_data)
print('Epoch: {} \tTesting Loss: {:.6f} \tTesting Accuracy: {:.6f}'.format(
epoch+1, test_loss, test_acc))
```
以上就是使用 PyTorch 实现花朵识别的示例代码。请注意替换代码中的数据集路径和其他参数以适应您自己的应用场景。