import torch from torch import nn from d2l import torch as d2l batch_size = 256 train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size) # 3.7.1. 初始化模型参数 # PyTorch不会隐式地调整输入的形状。因此, # 我们在线性层前定义了展平层(flatten),来调整网络输入的形状 net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10)) def init_weights(m): if type(m) == nn.Linear: nn.init.normal_(m.weight, std=0.01) net.apply(init_weights) # 3.7.2 重新审视Softmax的实现 loss = nn.CrossEntropyLoss(reduction='none') # 3.7.3 优化算法 trainer = torch.optim.SGD(net.parameters(), lr=0.1) # 3.7.4 训练 num_epochs = 10 d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer) d2l.plt.show()
时间: 2023-06-11 14:04:46 浏览: 149
d2l-0.15.1-py3-none-any【d2l工具包-人工智能】
这段代码导入了PyTorch、PyTorch中的nn模块,以及d2l库中的torch模块。然后使用d2l库中的load_data_fashion_mnist函数,以批次大小为256,加载了Fashion-MNIST数据集中的训练数据和测试数据,分别保存在train_iter和test_iter中。
阅读全文