with tqdm(dataloader['train']) as td: for batch_data in td: vision = batch_data['vision'].to(self.args.device) audio = batch_data['audio'].to(self.args.device) text = batch_data['text'].to(self.args.device) labels = batch_data['labels']['M'].to(self.args.device) if self.args.train_mode == 'classification': labels = labels.view(-1).long() else: labels = labels.view(-1, 1) # clear gradient ,梯度初始化为0,把loss关于weight的导数变成0 optimizer.zero_grad() # forward 前馈 outputs = model(text, audio, vision)['M'] # compute loss 计算损失 loss = self.criterion(outputs, labels) # backward 反馈 loss.backward() # update 跟新参数 optimizer.step() # store results train_loss += loss.item() y_pred.append(outputs.cpu()) y_true.append(labels.cpu())
时间: 2023-04-11 11:01:21 浏览: 150
这段代码使用了tqdm库来显示进度条,遍历了dataloader中的训练数据。对于每个batch_data,将其中的vision、audio、text和labels分别转移到设备上,并根据训练模式对labels进行不同的处理。如果是分类模式,则将labels展平并转换为long类型;否则将labels展开成一列。
阅读全文