model = Unet(num_classes=num_classes, pretrained=pretrained, backbone=backbone).train() if not pretrained: weights_init(model) if model_path != '': # ------------------------------------------------------# # 权值文件请看README,百度网盘下载 # ------------------------------------------------------# print('Load weights {}.'.format(model_path)) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(device) model_dict = model.state_dict() pretrained_dict = torch.load(model_path, map_location=device) # pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)} for k, v in pretrained_dict.items(): if k not in model_dict: print('no!!!!!!!!!!!!!!!!!!!!!!!!!!!!') print(k) pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # 排除新加入的模块 与 model_dict不同名的层都被忽略 print('\n') for k, v in pretrained_dict.items(): if not (np.shape(model_dict[k]) == np.shape(v)): print('no!!!!!!!!!!!!!!!!!!!!!!!!!!!!shape') print(k) pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)} # 排除bufen_dict里与 model_dict 不同形状的权重 print('!!!!') model_dict.update(pretrained_dict) model.load_state_dict(model_dict)这段代码啥意思
时间: 2023-06-16 18:05:00 浏览: 150
tf_unet-master.zip_unet tf_unet 分割_unet数据集_vgg unet_语义分割
这段代码实现了一个加载预训练权重的功能。首先根据传入的参数创建一个 Unet 模型,如果不是预训练模型,则初始化权重。如果传入了预训练模型的路径,则加载该路径下的预训练权重。
在加载预训练权重时,代码首先判断预训练权重中是否有新加入的模块,如果有则提示“no!!!!!!!!!!!!!!!!!!!!!!!!!!!!”。然后将预训练权重中与模型不同名的层都被忽略,然后判断模型中每个权重的形状是否与预训练权重中对应的形状相同,如果不同则提示“no!!!!!!!!!!!!!!!!!!!!!!!!!!!!shape”。最后将预训练权重中与模型形状相同的权重加载到模型中。
阅读全文