只批量保存图像矫正的第383张和388张 给我完整的修改代码
时间: 2024-12-11 17:37:06 浏览: 17
为了在训练过程中仅保存特定编号的图像(例如第383张和388张),我们需要对 `visualize_results` 函数进行修改。具体来说,我们将在保存图像时添加一个条件判断,以确定是否需要保存当前图像。
以下是修改后的代码:
```python
def visualize_results(generator, dataloader, device, save_path='./results', specific_indices=[383, 388]):
generator.eval()
inputs, targets = next(iter(dataloader))
inputs = inputs.to(device)
targets = targets.to(device)
with torch.no_grad():
corrected_images = generator(inputs)
inputs = inputs.cpu().numpy()
targets = targets.cpu().numpy()
corrected_images = corrected_images.cpu().numpy()
if not os.path.exists(save_path):
os.makedirs(save_path)
for i in range(len(inputs)):
if i in specific_indices:
# Save the corrected images
corrected_image = corrected_images[i].transpose(1, 2, 0)
plt.imsave(os.path.join(save_path, f'corrected_{i}.png'), corrected_image)
# 在主函数中调用 visualize_results 时指定要保存的索引
def main(args):
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
train_dataset = ColorblindDataset(args.dataset_dir, mode='train', transform=transform)
val_dataset = ColorblindDataset(args.dataset_dir, mode='val', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=4, shuffle=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator().to(device)
discriminator = Discriminator().to(device)
if args.generator_model_weight_path:
print(f"Loading generator weights from {args.generator_model_weight_path}")
generator.load_state_dict(torch.load(args.generator_model_weight_path))
if args.discriminator_model_weight_path:
print(f"Loading discriminator weights from {args.discriminator_model_weight_path}")
discriminator.load_state_dict(torch.load(args.discriminator_model_weight_path))
optimizer_G = optim.Adam(generator.parameters(), lr=args.generator_learning_rate, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=args.discriminator_learning_rate, betas=(0.5, 0.999))
lambda_gp = args.lambda_gp
lambda_pixel = args.lambda_pixel
n_critic = args.n_critic
train_metrics = []
val_metrics = []
for epoch in range(1, args.num_epochs + 1):
train_g_loss, train_d_loss, train_iou, train_psnr, train_ssim = train_correction_model(
generator, discriminator, train_loader, optimizer_G, optimizer_D, device, lambda_gp, lambda_pixel, n_critic
)
val_g_loss, val_d_loss, val_iou, val_psnr, val_ssim = validate_correction_model(
generator, discriminator, val_loader, device, lambda_gp
)
train_metrics.append((train_g_loss, train_d_loss, train_iou, train_psnr, train_ssim))
val_metrics.append((val_g_loss, val_d_loss, val_iou, val_psnr, val_ssim))
print(f'Epoch {epoch}, Generator Training Loss: {train_g_loss:.4f}, Discriminator Training Loss: {train_d_loss:.4f}, '
f'IoU: {train_iou:.4f}, PSNR: {train_psnr:.4f}, SSIM: {train_ssim:.4f}')
print(f'Epoch {epoch}, Generator Validation Loss: {val_g_loss:.4f}, Discriminator Validation Loss: {val_d_loss:.4f}, '
f'IoU: {val_iou:.4f}, PSNR: {val_psnr:.4f}, SSIM: {val_ssim:.4f}')
plot_and_save_metrics(train_metrics, val_metrics, epoch)
torch.save(generator.state_dict(), args.generator_model_save_path)
torch.save(discriminator.state_dict(), args.discriminator_model_save_path)
# 只保存特定编号的图像
visualize_results(generator, train_loader, device, save_path='./results/train', specific_indices=[383, 388])
visualize_results(generator, val_loader, device, save_path='./results/val', specific_indices=[383, 388])
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Correction Model with WGAN-GP")
parser.add_argument('--dataset_dir', type=str, default='./dataset', help='数据集目录路径')
parser.add_argument('--batch_size', type=int, default=10, help='训练批次大小')
parser.add_argument('--generator_learning_rate', type=float, default=0.001, help='优化器学习率')
parser.add_argument('--discriminator_learning_rate', type=float, default=0.0001, help='优化器学习率')
parser.add_argument('--lambda_pixel', type=float, default=10.0, help='像素级损失的权重')
parser.add_argument('--lambda_gp', type=float, default=10.0, help='梯度惩罚系数')
parser.add_argument('--n_critic', type=int, default=8, help='判别器更新次数')
parser.add_argument('--num_epochs', type=int, default=100, help='训练的轮数')
parser.add_argument('--generator_model_save_path', type=str, default='./model_weights/generator.pth', help='保存训练好的生成器模型的路径')
parser.add_argument('--discriminator_model_save_path', type=str, default='./model_weights/discriminator.pth', help='保存训练好的判别器模型的路径')
parser.add_argument('--generator_model_weight_path', type=str, default='', help='预训练的生成器模型的路径')
parser.add_argument('--discriminator_model_weight_path', type=str, default='', help='预训练的判别器模型的路径')
args = parser.parse_args()
main(args)
```
在这个修改后的代码中,我们在 `visualize_results` 函数中添加了一个参数 `specific_indices`,用于指定需要保存的图像索引。然后,在 `main` 函数中调用 `visualize_results` 时传递了 `[383, 388]` 这两个索引。这样,只有这些特定编号的图像会被保存到指定的路径中。
阅读全文