def generate_output(args,epoch, model, gen_dataset, disp_uncertainty=True,startPoint=500, endPoint=3500): if args.save_fig: # Turn on evaluation mode which disables dropout. model.eval() hidden = model.init_hidden(1) outSeq = [] upperlim95 = [] lowerlim95 = [] with torch.no_grad(): for i in range(endPoint): if i>=startPoint: # if disp_uncertainty and epoch > 40: # outs = [] # model.train() # for i in range(20): # out_, hidden_ = model.forward(out+0.01*Variable(torch.randn(out.size())).cuda(),hidden,noise=True) # outs.append(out_) # model.eval() # outs = torch.cat(outs,dim=0) # out_mean = torch.mean(outs,dim=0) # [bsz * feature_dim] # out_std = torch.std(outs,dim=0) # [bsz * feature_dim] # upperlim95.append(out_mean + 2.58*out_std/np.sqrt(20)) # lowerlim95.append(out_mean - 2.58*out_std/np.sqrt(20)) out, hidden = model.forward(out, hidden) #print(out_mean,out) else: out, hidden = model.forward(gen_dataset[i].unsqueeze(0), hidden) outSeq.append(out.data.cpu()[0][0].unsqueeze(0)) outSeq = torch.cat(outSeq,dim=0) # [seqLength * feature_dim] target= preprocess_data.reconstruct(gen_dataset.cpu(), TimeseriesData.mean, TimeseriesData.std) outSeq = preprocess_data.reconstruct(outSeq, TimeseriesData.mean, TimeseriesData.std)
时间: 2024-02-14 10:09:42 浏览: 23
这是一个用于生成模型输出的函数,参数包括 args(包含一些配置信息)、epoch(当前训练轮数)、model(模型)、gen_dataset(用于生成模型输出的数据集)、disp_uncertainty(是否显示不确定性)、startPoint(生成模型输出的起始位置)和 endPoint(生成模型输出的结束位置)。
如果 args 中的 save_fig 参数为 True,则将模型置于评估模式,禁用 dropout。然后定义隐藏状态 hidden,并初始化 outSeq、upperlim95 和 lowerlim95 为空列表。
接下来,循环 endPoint 次,其中 i 从 0 到 endPoint-1,每次循环都会进行如下操作:
- 如果 i 大于等于 startPoint,则调用 model.forward() 函数,将上一次的输出 out 和隐藏状态 hidden 作为输入,并获取当前模型的输出 out 和隐藏状态 hidden。如果 disp_uncertainty 为 True 且 epoch 大于 40,则进行不确定性估计,并将结果存储到 upperlim95 和 lowerlim95 中。
- 如果 i 小于 startPoint,则调用 model.forward() 函数,将 gen_dataset[i] 作为输入,并获取当前模型的输出 out 和隐藏状态 hidden。
- 将 out.data.cpu()[0][0].unsqueeze(0)(即输出的第一个元素)添加到 outSeq 列表中。
循环结束后,将 outSeq 拼接成一个张量,然后使用 preprocess_data.reconstruct() 函数将 gen_dataset 和 outSeq 进行反归一化处理,得到原始数据的预测值和模型输出。最后将 outSeq 返回。