model = Unet(num_classes=num_classes, pretrained=pretrained, backbone=backbone).train() if not pretrained: weights_init(model) if model_path != '': # ------------------------------------------------------# # 权值文件请看README,百度网盘下载 # ------------------------------------------------------# print('Load weights {}.'.format(model_path)) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(device) model_dict = model.state_dict() pretrained_dict = torch.load(model_path, map_location=device) # pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)} for k, v in pretrained_dict.items(): if k not in model_dict: print('no!!!!!!!!!!!!!!!!!!!!!!!!!!!!') print(k) pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # 排除新加入的模块 与 model_dict不同名的层都被忽略 print('\n') for k, v in pretrained_dict.items(): if not (np.shape(model_dict[k]) == np.shape(v)): print('no!!!!!!!!!!!!!!!!!!!!!!!!!!!!shape') print(k) pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)} # 排除bufen_dict里与 model_dict 不同形状的权重 print('!!!!') model_dict.update(pretrained_dict) model.load_state_dict(model_dict)详细解释以及库函数参数
时间: 2023-06-16 18:04:47 浏览: 209
tf_unet-master.zip_unet tf_unet 分割_unet数据集_vgg unet_语义分割
这段代码是用来加载预训练的权重文件并将其应用到模型中的。
首先,通过 `Unet()` 函数创建了一个名为 `model` 的 Unet 模型,并根据传入的参数 `pretrained` 和 `backbone` 是否为 `True` 或者有值,来决定是否加载预训练模型的权重。如果 `pretrained` 为 `False`,则调用 `weights_init()` 函数对模型进行权重初始化。
接下来,如果传入的 `model_path` 不为 `''`,则说明要加载预训练的权重文件。通过 `torch.load()` 函数加载权重文件,并通过 `map_location` 参数指定权重文件应该在哪个设备上运行。然后,将模型的状态字典存储在 `model_dict` 变量中,将预训练模型的状态字典存储在 `pretrained_dict` 变量中。
对于预训练模型中的每个键-值对,代码会检查其是否存在于模型中。如果不存在,则输出 "no!!!!!!!!!!!!!!!!!!!!!!!!!!!!" 和该键。然后,通过列表推导式,将 `pretrained_dict` 中键不存在于 `model_dict` 中的键-值对移除。接着,对于每个键-值对,代码会检查其形状是否与模型中的相同。如果不同,则输出 "no!!!!!!!!!!!!!!!!!!!!!!!!!!!!shape" 和该键。然后,通过列表推导式,将 `pretrained_dict` 中形状与 `model_dict` 不同的键-值对移除。最后,将 `pretrained_dict` 中的键-值对合并到 `model_dict` 中,并通过 `model.load_state_dict()` 函数将新的状态字典加载到模型中。
阅读全文