import torch from torch import nn from d2l import torch as d2l batch_size = 256 train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size) # 3.7.1. 初始化模型参数 # PyTorch不会隐式地调整输入的形状。因此, # 我们在线性层前定义了展平层(flatten),来调整网络输入的形状 net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10)) def init_weights(m): if type(m) == nn.Linear: nn.init.normal_(m.weight, std=0.01) net.apply(init_weights) # 3.7.2 重新审视Softmax的实现 loss = nn.CrossEntropyLoss(reduction='none') # 3.7.3 优化算法 trainer = torch.optim.SGD(net.parameters(), lr=0.1) # 3.7.4 训练 num_epochs = 10 d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer) d2l.plt.show()
时间: 2023-06-11 15:04:44 浏览: 334
d2l-0.15.1-py3-none-any【d2l工具包-人工智能】
这段代码是在Python中使用PyTorch库进行深度学习时的常用代码片段。它调用了PyTorch和d2l(torch)库,并使用d2l提供的函数从Fashion-MNIST数据集中加载训练和测试数据。其中batch_size参数指定了批量大小。
阅读全文