没有合适的资源?快使用搜索试试~ 我知道了~
首页Pytorch学习笔记Day二(全连接层学习)
MNIST数据集用全连接层实现 import torch from torch import nn from torch.nn import functional as F from torch import optim import torchvision from matplotlib import pyplot as plt def plot_curve(data): fig = plt.figure() plt.plot(range(len(data)), data, color='blue') plt.legend(['value'], loc='upper
资源详情
资源评论
资源推荐

Pytorch学习笔记学习笔记Day二(全连接层学习)二(全连接层学习)
MNIST数据集用全连接层实现
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
from matplotlib import pyplot as plt
def plot_curve(data):
fig = plt.figure()
plt.plot(range(len(data)), data, color='blue')
plt.legend(['value'], loc='upper right')
plt.xlabel('step')
plt.ylabel('value')
plt.show()
def plot_image(img, label, name):
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.tight_layout()
plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
plt.title("{}: {}".format(name, label[i].item()))
plt.xticks([])
plt.yticks([])
plt.show()
def one_hot(label, depth=10):
out = torch.zeros(label.size(0), depth)
idx = torch.LongTensor(label).view(-1, 1)
out.scatter_(dim=1, index=idx, value=1)
return out
batch_size = 512
# step1. load dataset
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('../data', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('../data', train=False, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=False)
x, y = next(iter(train_loader))
# plot_image(x, y, 'image sample')
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# xw+b
self.fc1 = nn.Linear(28*28, 256)
















安全验证
文档复制为VIP权益,开通VIP直接复制

评论0