@torch.no_grad() def main(): global args, save_path args = parser.parse_args() data_dir = Path(args.data) print("=> fetching img pairs in '{}'".format(args.data)) if args.output is None: save_path = data_dir/'flow' else: save_path = Path(args.output) print('=> will save everything to {}'.format(save_path)) save_path.makedirs_p() # Data loading code input_transform = transforms.Compose([transforms.Normalize(mean=[0,0,0], std=[255,255,255]) ])
时间: 2024-04-08 07:32:24 浏览: 157
这段代码定义了一个名为 `main` 的函数,并使用装饰器 `@torch.no_grad()` 来禁用梯度计算,以减少内存消耗。
`main` 函数首先使用 `parser.parse_args()` 解析命令行参数,并将结果保存在全局变量 `args` 中。
接下来,代码通过 `Path` 类将命令行参数中的数据文件夹路径转换为 `data_dir` 对象,并打印一条消息指示正在获取图像对。
然后,代码根据命令行参数中的输出文件夹路径创建一个目录,并将其保存在全局变量 `save_path` 中。如果未指定输出文件夹路径,则默认创建一个名为 "flow" 的子目录。
接着,代码创建了一个名为 `input_transform` 的图像转换对象。在这个例子中,它只应用了一个归一化操作,即将图像的像素值从 `[0,0,0]` 到 `[255,255,255]` 范围进行归一化。
以上就是这段代码的主要内容和功能。如果有更多问题,可以继续提问。
相关问题
下面代码转化为paddle2.2.2代码 :from __future__ import division import os, time, scipy.io import torch import torch.nn as nn import torch.optim as optim import numpy as np import glob import cv2 import argparse from PIL import Image from skimage.measure import compare_psnr,compare_ssim from tensorboardX import SummaryWriter from models import RViDeNet from utils import * parser = argparse.ArgumentParser(description='Pretrain denoising model') parser.add_argument('--gpu_id', dest='gpu_id', type=int, default=0, help='gpu id') parser.add_argument('--num_epochs', dest='num_epochs', type=int, default=33, help='num_epochs') parser.add_argument('--patch_size', dest='patch_size', type=int, default=128, help='patch_size') parser.add_argument('--batch_size', dest='batch_size', type=int, default=1, help='batch_size') args = parser.parse_args() os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id) save_dir = './pretrain_model' if not os.path.isdir(save_dir): os.makedirs(save_dir) gt_paths1 = glob.glob('./data/SRVD_data/raw_clean/MOT17-02_raw/*.tiff') gt_paths2 = glob.glob('./data/SRVD_data/raw_clean/MOT17-09_raw/*.tiff') gt_paths3 = glob.glob('./data/SRVD_data/raw_clean/MOT17-10_raw/*.tiff') gt_paths4 = glob.glob('./data/SRVD_data/raw_clean/MOT17-11_raw/*.tiff') gt_paths = gt_paths1 + gt_paths2 + gt_paths3 + gt_paths4 ps = args.patch_size # patch size for training batch_size = args.batch_size # batch size for training
把原始代码转换为PaddlePaddle 2.2.2 的代码如下:
```python
import os
import time
import argparse
import glob
import cv2
import numpy as np
from PIL import Image
from skimage.measure import compare_psnr, compare_ssim
from tensorboardX import SummaryWriter
from models import RViDeNet
from utils import *
import paddle
from paddle import nn
from paddle.optimizer import optim
paddle.set_device('gpu')
parser = argparse.ArgumentParser(description='Pretrain denoising model')
parser.add_argument('--gpu_id', dest='gpu_id', type=int, default=0, help='gpu id')
parser.add_argument('--num_epochs', dest='num_epochs', type=int, default=33, help='num_epochs')
parser.add_argument('--patch_size', dest='patch_size', type=int, default=128, help='patch_size')
parser.add_argument('--batch_size', dest='batch_size', type=int, default=1, help='batch_size')
args = parser.parse_args()
save_dir = './pretrain_model'
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
gt_paths1 = glob.glob('./data/SRVD_data/raw_clean/MOT17-02_raw/*.tiff')
gt_paths2 = glob.glob('./data/SRVD_data/raw_clean/MOT17-09_raw/*.tiff')
gt_paths3 = glob.glob('./data/SRVD_data/raw_clean/MOT17-10_raw/*.tiff')
gt_paths4 = glob.glob('./data/SRVD_data/raw_clean/MOT17-11_raw/*.tiff')
gt_paths = gt_paths1 + gt_paths2 + gt_paths3 + gt_paths4
ps = args.patch_size # patch size for training
batch_size = args.batch_size # batch size for training
num_epochs = args.num_epochs
train_dataset = DatasetDenoising(gt_paths, ps=ps)
train_loader = paddle.io.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
model = RViDeNet()
model.train()
optimizer = optim.Adam(learning_rate=1e-4, parameters=model.parameters())
writer = SummaryWriter()
for epoch in range(num_epochs):
epoch_start_time = time.time()
epoch_loss = 0
for i, (noisy_patches, gt_patches) in enumerate(train_loader()):
noisy_patches = paddle.to_tensor(noisy_patches)
gt_patches = paddle.to_tensor(gt_patches)
output = model(noisy_patches)
loss = nn.functional.mse_loss(output, gt_patches)
optimizer.clear_grad()
loss.backward()
optimizer.step()
epoch_loss += loss
epoch_time = time.time() - epoch_start_time
epoch_loss = epoch_loss / len(train_loader)
print("Epoch [{}/{}] Loss: {:.5f} [{:.2f}s]".format(epoch + 1, num_epochs, epoch_loss, epoch_time))
writer.add_scalar("Loss/train", epoch_loss, epoch + 1)
if (epoch + 1) % 10 == 0:
model_path = os.path.join(save_dir, 'RViDeNet_epoch{}.pdparams'.format(epoch + 1))
paddle.save(model.state_dict(), model_path)
print("Saving model to: {}".format(model_path))
writer.close()
```
代码不完整从return epoch_g_loss,往下继续
以下是 `main` 函数中从 `return epoch_g_loss, ...` 开始的部分:
```python
...
return epoch_g_loss, epoch_d_loss, epoch_iou, epoch_psnr, epoch_ssim
def validate_correction_model(generator, discriminator, dataloader, device, lambda_gp):
generator.eval()
discriminator.eval()
running_g_loss = 0.0
running_d_loss = 0.0
running_iou = 0.0
running_psnr = 0.0
running_ssim = 0.0
with torch.no_grad():
for inputs, targets in tqdm(dataloader, desc="Validation"):
inputs = inputs.to(device)
targets = targets.to(device)
corrected_images = generator(inputs)
real_validity = discriminator(targets)
fake_validity = discriminator(corrected_images)
d_loss = -torch.mean(real_validity) + torch.mean(fake_validity)
g_adv_loss = -torch.mean(fake_validity)
g_loss = g_adv_loss
running_g_loss += g_loss.item()
running_d_loss += d_loss.item()
iou = compute_iou(corrected_images, targets)
psnr = compute_psnr(corrected_images, targets)
ssim = compute_ssim(corrected_images, targets)
running_iou += iou
running_psnr += psnr
running_ssim += ssim
epoch_g_loss = running_g_loss / len(dataloader)
epoch_d_loss = running_d_loss / len(dataloader)
epoch_iou = running_iou / len(dataloader)
epoch_psnr = running_psnr / len(dataloader)
epoch_ssim = running_ssim / len(dataloader)
return epoch_g_loss, epoch_d_loss, epoch_iou, epoch_psnr, epoch_ssim
def visualize_results(generator, dataloader, device, num_images=10, save_path='./results'):
generator.eval()
inputs, targets = next(iter(dataloader))
inputs = inputs.to(device)
targets = targets.to(device)
with torch.no_grad():
corrected_images = generator(inputs)
inputs = inputs.cpu().numpy()
targets = targets.cpu().numpy()
corrected_images = corrected_images.cpu().numpy()
if not os.path.exists(save_path):
os.makedirs(save_path)
plt.figure(figsize=(20, 10))
for i in range(num_images):
plt.subplot(3, num_images, i + 1)
plt.imshow(targets[i].transpose(1, 2, 0))
plt.title("Original")
plt.axis('off')
plt.subplot(3, num_images, i + 1 + num_images)
plt.imshow(inputs[i].transpose(1, 2, 0))
plt.title("Simulated Colorblind")
plt.axis('off')
plt.subplot(3, num_images, i + 1 + 2 * num_images)
plt.imshow(corrected_images[i].transpose(1, 2, 0))
plt.title("Corrected")
plt.axis('off')
plt.tight_layout()
plt.savefig(f'{save_path}_visualization.png')
plt.show()
def plot_and_save_metrics(train_metrics, val_metrics, epoch, path='./metrics_plots'):
if not os.path.exists(path):
os.makedirs(path)
epochs = np.arange(1, epoch + 1)
train_g_losses, train_d_losses, train_ious, train_psnrs, train_ssims = zip(*train_metrics)
val_g_losses, val_d_losses, val_ious, val_psnrs, val_ssims = zip(*val_metrics)
plt.figure()
plt.plot(epochs, train_g_losses, label='Training Generator Loss')
plt.plot(epochs, val_g_losses, label='Validation Generator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Generator Loss over Epochs')
plt.legend()
plt.savefig(f'{path}/generator_loss.png')
plt.close()
plt.figure()
plt.plot(epochs, train_d_losses, label='Training Discriminator Loss')
plt.plot(epochs, val_d_losses, label='Validation Discriminator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Discriminator Loss over Epochs')
plt.legend()
plt.savefig(f'{path}/discriminator_loss.png')
plt.close()
plt.figure()
plt.plot(epochs, train_ious, label='Training IoU')
plt.plot(epochs, val_ious, label='Validation IoU')
plt.xlabel('Epoch')
plt.ylabel('IoU')
plt.title('IoU over Epochs')
plt.legend()
plt.savefig(f'{path}/iou.png')
plt.close()
plt.figure()
plt.plot(epochs, train_psnrs, label='Training PSNR')
plt.plot(epochs, val_psnrs, label='Validation PSNR')
plt.xlabel('Epoch')
plt.ylabel('PSNR')
plt.title('PSNR over Epochs')
plt.legend()
plt.savefig(f'{path}/psnr.png')
plt.close()
plt.figure()
plt.plot(epochs, train_ssims, label='Training SSIM')
plt.plot(epochs, val_ssims, label='Validation SSIM')
plt.xlabel('Epoch')
plt.ylabel('SSIM')
plt.title('SSIM over Epochs')
plt.legend()
plt.savefig(f'{path}/ssim.png')
plt.close()
def main(args):
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
train_dataset = ColorblindDataset(args.dataset_dir, mode='train', transform=transform)
val_dataset = ColorblindDataset(args.dataset_dir, mode='val', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=4, shuffle=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator().to(device)
discriminator = Discriminator().to(device)
if args.generator_model_weight_path:
print(f"Loading generator weights from {args.generator_model_weight_path}")
generator.load_state_dict(torch.load(args.generator_model_weight_path))
if args.discriminator_model_weight_path:
print(f"Loading discriminator weights from {args.discriminator_model_weight_path}")
discriminator.load_state_dict(torch.load(args.discriminator_model_weight_path))
optimizer_G = optim.Adam(generator.parameters(), lr=args.generator_learning_rate, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=args.discriminator_learning_rate, betas=(0.5, 0.999))
lambda_gp = args.lambda_gp
lambda_pixel = args.lambda_pixel
n_critic = args.n_critic
train_metrics = []
val_metrics = []
for epoch in range(1, args.num_epochs + 1):
train_g_loss, train_d_loss, train_iou, train_psnr, train_ssim = train_correction_model(
generator, discriminator, train_loader, optimizer_G, optimizer_D, device, lambda_gp, lambda_pixel, n_critic)
val_g_loss, val_d_loss, val_iou, val_psnr, val_ssim = validate_correction_model(
generator, discriminator, val_loader, device, lambda_gp)
train_metrics.append((train_g_loss, train_d_loss, train_iou, train_psnr, train_ssim))
val_metrics.append((val_g_loss, val_d_loss, val_iou, val_psnr, val_ssim))
print(f'Epoch {epoch}, Generator Training Loss: {train_g_loss:.4f}, Discriminator Training Loss: {train_d_loss:.4f}, '
f'IoU: {train_iou:.4f}, PSNR: {train_psnr:.4f}, SSIM: {train_ssim:.4f}')
print(f'Epoch {epoch}, Generator Validation Loss: {val_g_loss:.4f}, Discriminator Validation Loss: {val_d_loss:.4f}, '
f'IoU: {val_iou:.4f}, PSNR: {val_psnr:.4f}, SSIM: {val_ssim:.4f}')
plot_and_save_metrics(train_metrics, val_metrics, epoch)
torch.save(generator.state_dict(), args.generator_model_save_path)
torch.save(discriminator.state_dict(), args.discriminator_model_save_path)
visualize_results(generator, train_loader, device, save_path='./results/train')
visualize_results(generator, val_loader, device, save_path='./results/val')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Correction Model with WGAN-GP")
parser.add_argument('--dataset_dir', type=str, default='./dataset', help='数据集目录路径')
parser.add_argument('--batch_size', type=int, default=10, help='训练批次大小')
parser.add_argument('--generator_learning_rate', type=float, default=0.001, help='优化器学习率')
parser.add_argument('--discriminator_learning_rate', type=float, default=0.0001, help='优化器学习率')
parser.add_argument('--lambda_pixel', type=float, default=10.0, help='像素级损失的权重')
parser.add_argument('--lambda_gp', type=float, default=10.0, help='梯度惩罚系数')
parser.add_argument('--n_critic', type=int, default=8, help='判别器更新次数')
parser.add_argument('--num_epochs', type=int, default=100, help='训练的轮数')
parser.add_argument('--generator_model_save_path', type=str, default='./model_weights/generator.pth', help='保存训练好的生成器模型的路径')
parser.add_argument('--discriminator_model_save_path', type=str, default='./model_weights/discriminator.pth', help='保存训练好的判别器模型的路径')
parser.add_argument('--generator_model_weight_path', type=str, default='', help='预训练的生成器模型的路径')
parser.add_argument('--discriminator_model_weight_path', type=str, default='', help='预训练的判别器模型的路径')
args = parser.parse_args()
main(args)
```
这段代码包括了 `validate_correction_model`、`visualize_results` 和 `plot_and_save_metrics` 函数,以及 `main` 函数的剩余部分。这些函数分别用于验证模型性能、可视化结果和绘制并保存指标图。`main` 函数则负责整体的训练流程控制。
阅读全文
相关推荐
![pdf](https://img-home.csdnimg.cn/images/20241231044930.png)
![pdf](https://img-home.csdnimg.cn/images/20241231044930.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)