def train(): Dtr, Val, Dte = load_data() print('train...') epoch_num = 30 best_model = None min_epochs = 5 min_val_loss = 5 model = cnn().to(device) optimizer = optim.Adam(model.parameters(), lr=0.0008) criterion = nn.CrossEntropyLoss().to(device) # criterion = nn.BCELoss().to(device) for epoch in tqdm(range(epoch_num), ascii=True): train_loss = [] for batch_idx, (data, target) in enumerate(Dtr, 0): try: data, target = Variable(data).to(device), Variable(target.long()).to(device) # target = target.view(target.shape[0], -1) # print(target) optimizer.zero_grad() output = model(data) # print(output) loss = criterion(output, target) loss.backward() optimizer.step() train_loss.append(loss.cpu().item()) except: continue # validation val_loss = get_val_loss(model, Val) model.train() if epoch + 1 > min_epochs and val_loss < min_val_loss: min_val_loss = val_loss best_model = copy.deepcopy(model) torch.save(best_model.state_dict(), r"E:\dataset\Airbnb\training_data\model\cnn.pkl")
时间: 2024-01-13 21:05:00 浏览: 160
UART.zip_reset_uart dtr_uart.h
这段代码是一个简单的CNN模型的训练过程。它的输入是经过处理的数据集 Dtr, Val, Dte,其中 Dtr 是训练集,Val 是验证集,Dte 是测试集。模型的优化器采用 Adam 算法,损失函数采用交叉熵损失。
在每个 epoch 中,模型会对训练集 Dtr 进行迭代,计算损失并更新参数。同时,模型会在验证集 Val 上计算损失,如果当前的验证损失比之前最小的验证损失要小,则更新最小验证损失和最佳模型参数。当 epoch 数量达到最大值或者验证损失连续 min_epochs 次未下降时,训练过程会结束。
最后,最佳模型的参数会被保存到本地文件中。
阅读全文