pytorch对fashion mnist 分类 源代码
时间: 2023-10-28 13:06:10 浏览: 49
以下是使用PyTorch对Fashion MNIST数据集进行分类的源代码示例:
```python
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# Define transform to normalize data
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])])
# Download and load the training and testing datasets
trainset = datasets.FashionMNIST('~/.pytorch/FashionMNIST_data/', download=True, train=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
testset = datasets.FashionMNIST('~/.pytorch/FashionMNIST_data/', download=True, train=False, transform=transform)
testloader = DataLoader(testset, batch_size=64, shuffle=True)
# Define the neural network architecture
model = nn.Sequential(nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 10))
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# Train the model
epochs = 10
for e in range(epochs):
running_loss = 0
for images, labels in trainloader:
# Flatten images into a 784 long vector
images = images.view(images.shape[0], -1)
# Forward pass, backward pass and optimize
optimizer.zero_grad()
output = model.forward(images)
loss = criterion(output, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
else:
print(f"Training loss: {running_loss/len(trainloader)}")
# Test the trained network on the test set
accuracy = 0
with torch.no_grad():
model.eval()
for images, labels in testloader:
# Flatten images into a 784 long vector
images = images.view(images.shape[0], -1)
# Forward pass and calculate accuracy
logps = model(images)
ps = torch.exp(logps)
top_p, top_class = ps.topk(1, dim=1)
equals = top_class == labels.view(*top_class.shape)
accuracy += torch.mean(equals.type(torch.FloatTensor))
else:
print(f"Accuracy: {accuracy/len(testloader)}")
```
这段代码使用了一个简单的前馈神经网络,包含四个全连接层和ReLU激活函数。训练使用了随机梯度下降优化器和交叉熵损失函数,测试使用了测试集计算模型的准确率。 Fashion MNIST是一个流行的图像分类数据集,由Zalando的研究团队收集。它包含62,000个训练图像和10,000个测试图像,每个图像的大小为28x28像素。