if (it+1) % args.save_interval == 0: # To save storage space, I only checkpoint the weights of G. # If you'd like to keep weights of G, D, optim_G, optim_D, # please use save() instead of saveG(). attgan.saveG(os.path.join( 'output', args.experiment_name, 'checkpoint', 'weights.{:d}.pth'.format(epoch) )) # attgan.save(os.path.join( # 'output', args.experiment_name, 'checkpoint', 'weights.{:d}.pth'.format(epoch) # )) if (it+1) % args.sample_interval == 0: attgan.eval() with torch.no_grad(): samples = [fixed_img_a] for i, att_b in enumerate(sample_att_b_list): att_b_ = (att_b * 2 - 1) * args.thres_int if i > 0: att_b_[..., i - 1] = att_b_[..., i - 1] * args.test_int / args.thres_int samples.append(attgan.G(fixed_img_a, att_b_)) samples = torch.cat(samples, dim=3) writer.add_image('sample', vutils.make_grid(samples, nrow=1, normalize=True, range=(-1., 1.)), it+1) vutils.save_image(samples, os.path.join( 'output', args.experiment_name, 'sample_training', 'Epoch_({:d})_({:d}of{:d}).jpg'.format(epoch, it%it_per_epoch+1, it_per_epoch) ), nrow=1, normalize=True, range=(-1., 1.)) it += 1
时间: 2024-02-14 21:29:53 浏览: 123
这段代码包含了两个条件语句块。
第一个条件语句块中的条件 `(it+1) % args.save_interval == 0` 检查是否到了保存模型的间隔。如果条件为真,则执行保存模型的代码块。
在保存模型的代码块中,使用 `attgan.saveG()` 函数将生成器 `G` 的权重保存到指定路径下。`os.path.join()` 函数用于构建保存路径,其中包括实验名称、检查点目录和文件名。文件名使用了格式化字符串来包含当前 epoch 的值。
第二个条件语句块中的条件 `(it+1) % args.sample_interval == 0` 检查是否到了生成样本的间隔。如果条件为真,则执行生成样本的代码块。
在生成样本的代码块中,首先调用 `attgan.eval()` 将模型设置为评估模式,然后使用 `torch.no_grad()` 上下文管理器禁用梯度计算。
接下来,通过循环遍历 `sample_att_b_list` 中的属性 B 值,并根据属性 B 值计算 `att_b_`。这里使用了一系列的数学计算,包括归一化处理和一些额外的操作。
然后,调用 `attgan.G(fixed_img_a, att_b_)` 生成样本图像,并将其添加到 `samples` 列表中。
最后,使用 `writer.add_image()` 将生成的样本图像添加到摘要信息中,使用 `vutils.save_image()` 将样本图像保存到指定路径下。保存路径的构建方式与保存模型时类似,包括实验名称、样本训练目录和文件名。文件名使用了格式化字符串来包含当前 epoch 和迭代次数的值。
最后一行代码 `it += 1` 将迭代次数加1,用于更新迭代次数的计数。
总结起来,这段代码根据条件判断是否到了保存模型和生成样本的间隔。如果满足条件,则执行相应的代码块,保存模型或生成样本。这样可以在训练过程中定期保存模型,并生成样本进行可视化和评估。同时,迭代次数也会被更新。
阅读全文