将每一个函数都进行逐行分析
时间: 2024-11-10 14:25:15 浏览: 6
### 函数 `make_data_loader`
#### 功能
该函数用于创建一个PyTorch的数据加载器(`DataLoader`),并返回该加载器。
#### 参数
- `spec`: 数据集配置字典。
- `tag`: 标签字符串,通常为 `'train'` 或 `'val'`。
#### 实现细节
1. **检查 `spec` 是否为空**:
```python
if spec is None:
return None
```
如果 `spec` 为空,则直接返回 `None`。
2. **创建数据集**:
```python
dataset = datasets.make(spec['D:\资料\代码\McASSR\McASSR-main\datasets\created_data.mat'])
```
使用 `datasets.make` 方法从指定路径加载数据集。
3. **包装数据集**:
```python
dataset = datasets.make(spec['wrapper'], args={'created_data.mat': dataset})
```
使用 `datasets.make` 方法对数据集进行包装,传递原始数据集作为参数。
4. **日志记录**:
```python
log('{} dataset: size={}'.format(tag, len(dataset)))
```
记录数据集的大小。
5. **记录数据集样本形状**:
```python
for k, v in dataset[0].items():
log(' {}: shape={}'.format(k, tuple(v.shape)))
```
遍历数据集中第一个样本的所有键值对,并记录每个键对应的张量形状。
6. **创建数据加载器**:
```python
loader = DataLoader(
dataset,
batch_size=spec['batch_size'],
shuffle=(tag == 'train'),
num_workers=8,
pin_memory=True
)
```
创建 `DataLoader`,设置批量大小、是否打乱数据、工作线程数和是否使用内存锁定。
7. **返回数据加载器**:
```python
return loader
```
### 函数 `make_data_loaders`
#### 功能
该函数用于创建训练和验证数据加载器,并返回这两个加载器。
#### 参数
无
#### 实现细节
1. **创建训练数据加载器**:
```python
train_loader = make_data_loader(config.get('train_dataset'), tag='train')
```
调用 `make_data_loader` 函数,传入训练数据集配置和标签 `'train'`。
2. **创建验证数据加载器**:
```python
val_loader = make_data_loader(config.get('val_dataset'), tag='val')
```
调用 `make_data_loader` 函数,传入验证数据集配置和标签 `'val'`。
3. **返回数据加载器**:
```python
return train_loader, val_loader
```
### 函数 `prepare_training`
#### 功能
该函数用于准备训练所需的模型、优化器、起始轮次和学习率调度器。
#### 参数
无
#### 实现细节
1. **检查是否恢复训练**:
```python
if config.get('resume') is not None:
sv_file = torch.load(config['resume'])
model = models.make(sv_file['model'], load_sd=True).cuda()
optimizer = utils.make_optimizer(model.parameters(), sv_file['optimizer'], load_sd=True)
epoch_start = sv_file['epoch'] + 1
if config.get('multi_step_lr') is None:
lr_scheduler = None
else:
lr_scheduler = MultiStepLR(optimizer, **config['multi_step_lr'])
for _ in range(epoch_start - 1):
lr_scheduler.step()
```
如果配置中指定了恢复训练的文件路径,则加载模型、优化器状态和起始轮次,并初始化或恢复学习率调度器。
2. **如果未恢复训练**:
```python
else:
model = models.make(config['model']).cuda()
optimizer = utils.make_optimizer(model.parameters(), config['optimizer'])
epoch_start = 1
if config.get('multi_step_lr') is None:
lr_scheduler = None
else:
lr_scheduler = MultiStepLR(optimizer, **config['multi_step_lr'])
```
如果没有恢复训练,初始化新的模型、优化器和学习率调度器,并设置起始轮次为1。
3. **记录模型参数数量**:
```python
log('model: #params={}'.format(utils.compute_num_params(model, text=True)))
```
4. **返回训练所需对象**:
```python
return model, optimizer, epoch_start, lr_scheduler
```
### 函数 `train`
#### 功能
该函数用于执行单个训练轮次,计算损失并更新模型参数。
#### 参数
- `train_loader`: 训练数据加载器。
- `model`: 模型。
- `optimizer`: 优化器。
- `epoch`: 当前轮次。
#### 实现细节
1. **设置模型为训练模式**:
```python
model.train()
```
2. **定义损失函数**:
```python
loss_fn = nn.L1Loss()
```
3. **初始化平均损失计算器**:
```python
train_loss = utils.Averager()
```
4. **定义评估指标函数**:
```python
metric_fn = utils.calc_psnr
```
5. **获取数据归一化参数**:
```python
data_norm = config['data_norm']
t = data_norm['inp']
inp_sub = torch.FloatTensor(t['sub']).view(1, -1, 1, 1).cuda()
inp_div = torch.FloatTensor(t['div']).view(1, -1, 1, 1).cuda()
t = data_norm['ref']
ref_sub = torch.FloatTensor(t['sub']).view(1, -1, 1, 1).cuda()
ref_div = torch.FloatTensor(t['div']).view(1, -1, 1, 1).cuda()
t = data_norm['gt']
gt_sub = torch.FloatTensor(t['sub']).view(1, 1, -1).cuda()
gt_div = torch.FloatTensor(t['div']).view(1, 1, -1).cuda()
t = data_norm['ref']
ref_hr_sub = torch.FloatTensor(t['sub']).view(1, 1, -1).cuda()
ref_hr_div = torch.FloatTensor(t['div']).view(1, 1, -1).cuda()
```
6. **计算每个轮次的迭代次数**:
```python
num_dataset = 800
iter_per_epoch = int(num_dataset / config.get('train_dataset')['batch_size'] * config.get('train_dataset')['dataset']['args']['repeat'])
iteration = 0
```
7. **遍历训练数据加载器**:
```python
for batch in tqdm(train_loader, leave=False, desc='train'):
for k, v in batch.items():
batch[k] = v.cuda()
inp = (batch['inp'] - inp_sub) / inp_div
ref = (batch['ref'] - ref_sub) / ref_div
ref_hr = (batch['ref_hr'] - ref_hr_sub) / ref_hr_div
pred, ref_loss = model(inp, batch['inp_hr_coord'], batch['inp_cell'], ref, ref_hr)
gt = (batch['gt'] - gt_sub) / gt_div
loss_pred = loss_fn(pred, gt)
loss_ref = loss_fn(ref_loss, ref_hr)
loss = loss_pred * 0.7 + loss_ref * 0.3
psnr = metric_fn(pred, gt)
writer.add_scalars('loss', {'total_loss': loss.item()}, (epoch-1)*iter_per_epoch + iteration)
writer.add_scalars('psnr', {'train': psnr}, (epoch-1)*iter_per_epoch + iteration)
iteration += 1
train_loss.add(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
pred = None; loss = None
```
8. **返回平均训练损失**:
```python
return train_loss.item()
```
### 函数 `main`
#### 功能
该函数是主函数,负责整个训练流程的管理和控制。
#### 参数
- `config_`: 配置字典。
- `save_path`: 保存路径。
#### 实现细节
1. **全局变量赋值**:
```python
global config, log, writer
config = config_
log, writer = utils.set_save_path(save_path, remove=False)
```
2. **保存配置文件**:
```python
with open(os.path.join(save_path, 'config.yaml'), 'w') as f:
yaml.dump(config, f, sort_keys=False)
```
3. **处理数据归一化参数**:
```python
if config.get('data_norm') is None:
config['data_norm'] = {
'inp': {'sub': [0], 'div': [1]},
'ref': {'sub': [0], 'div': [1]},
'gt': {'sub': [0], 'div': [1]}
}
```
4. **创建数据加载器**:
```python
train_loader, val_loader = make_data_loaders()
```
5. **准备训练**:
```python
model, optimizer, epoch_start, lr_scheduler = prepare_training()
```
6. **多GPU支持**:
```python
n_gpus = len(os.environ['CUDA_VISIBLE_DEVICES'].split(','))
if n_gpus > 1:
model = nn.parallel.DataParallel(model)
```
7. **训练循环**:
```python
epoch_max = config['epoch_max']
epoch_val = config.get('epoch_val')
epoch_save = config.get('epoch_save')
max_val_v = -1e18
timer = utils.Timer()
for epoch in range(epoch_start, epoch_max + 1):
t_epoch_start = timer.t()
log_info = ['epoch {}/{}'.format(epoch, epoch_max)]
writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)
train_loss = train(train_loader, model, optimizer, epoch)
if lr_scheduler is not None:
lr_scheduler.step()
log_info.append('train: loss={:.4f}'.format(train_loss))
if n_gpus > 1:
model_ = model.module
else:
model_ = model
model_spec = config['model']
model_spec['sd'] = model_.state_dict()
optimizer_spec = config['optimizer']
optimizer_spec['sd'] = optimizer.state_dict()
sv_file = {
'model': model_spec,
'optimizer': optimizer_spec,
'epoch': epoch
}
torch.save(sv_file, os.path.join(save_path, 'epoch-last.pth'))
if (epoch_save is not None) and (epoch % epoch_save == 0):
torch.save(sv_file, os.path.join(save_path, 'epoch-{}.pth'.format(epoch)))
if (epoch_val is not None) and (epoch % epoch_val == 0):
if n_gpus > 1 and (config.get('eval_bsize') is not None):
model_ = model.module
else:
model_ = model
val_res = eval_psnr(val_loader, model_, data_norm=config['data_norm'], eval_type=config.get('eval_type'), eval_bsize=config.get('eval_bsize'))
log_info.append('val: psnr={:.4f}'.format(val_res))
if val_res > max_val_v:
max_val_v = val_res
torch.save(sv_file, os.path.join(save_path, 'epoch-best.pth'))
t = timer.t()
prog = (epoch - epoch_start + 1) / (epoch_max - epoch_start + 1)
t_epoch = utils.time_text(t - t_epoch_start)
t_elapsed, t_all = utils.time_text(t), utils.time_text(t / prog)
log_info.append('{} {}/{}'.format(t_epoch, t_elapsed, t_all))
log(', '.join(log_info))
writer.flush()
```
8. **主程序入口**:
```python
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config')
parser.add_argument('--name', default=None)
parser.add_argument('--tag', default=None)
parser.add_argument('--gpu', default='0')
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
with open(args.config, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
print('config loaded.')
save_name = args.name if args.name is not None else '_' + args.config.split('/')[-1][:-len('.yaml')]
if args.tag is not None:
save_name += '_' + args.tag
save_path = os.path.join('./save', save_name)
main(config, save_path)
```
### 总结
这个脚本实现了一个完整的深度学习训练流程,包括数据加载、模型准备、训练、验证和保存。主要功能如下:
- **数据加载**: `make_data_loader` 和 `make_data_loaders` 函数负责创建训练和验证数据加载器。
- **模型准备**: `prepare_training` 函数负责初始化模型、优化器和学习率调度器。
- **训练**: `train` 函数执行单个训练轮次,计算损失并更新模型参数。
- **主流程管理**: `main` 函数负责整个训练流程的管理和控制,包括训练、验证和模型保存。
阅读全文