def train(model, train_features, train_labels, test_features, test_labels, num_epochs, learning_rate, weight_decay, batch_size): 帮我加个注释
时间: 2023-06-19 22:04:06 浏览: 135
# 定义训练函数,接收模型、训练特征、训练标签、测试特征、测试标签、训练轮数、学习率、权重衰减和批大小等参数
def train(model, train_features, train_labels, test_features, test_labels,
num_epochs, learning_rate, weight_decay, batch_size):
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
# 定义训练和测试数据集
train_dataset = TensorDataset(train_features, train_labels)
test_dataset = TensorDataset(test_features, test_labels)
# 定义训练和测试数据集的加载器
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 开始训练模型
for epoch in range(num_epochs):
# 训练模式
model.train()
train_loss = 0.0
train_acc = 0.0
for features, labels in train_loader:
# 将特征和标签放入设备中
features = features.to(device)
labels = labels.to(device)
# 清空优化器的梯度
optimizer.zero_grad()
# 前向传播
outputs = model(features)
# 计算损失
loss = criterion(outputs, labels)
# 反向传播
loss.backward()
optimizer.step()
# 计算训练损失
train_loss += loss.item() * features.size(0)
# 计算训练准确率
_, preds = torch.max(outputs, 1)
train_acc += torch.sum(preds == labels.data)
# 计算平均训练损失和准确率
train_loss = train_loss / len(train_features)
train_acc = train_acc / len(train_features)
# 测试模式
model.eval()
test_loss = 0.0
test_acc = 0.0
with torch.no_grad():
for features, labels in test_loader:
# 将特征和标签放入设备中
features = features.to(device)
labels = labels.to(device)
# 前向传播
outputs = model(features)
# 计算损失
loss = criterion(outputs, labels)
# 计算测试损失
test_loss += loss.item() * features.size(0)
# 计算测试准确率
_, preds = torch.max(outputs, 1)
test_acc += torch.sum(preds == labels.data)
# 计算平均测试损失和准确率
test_loss = test_loss / len(test_features)
test_acc = test_acc / len(test_features)
# 输出训练和测试结果
print('Epoch [{}/{}], Train Loss: {:.4f}, Train Acc: {:.4f}, Test Loss: {:.4f}, Test Acc: {:.4f}'
.format(epoch+1, num_epochs, train_loss, train_acc, test_loss, test_acc))
# 返回训练好的模型
return model
阅读全文