new_state_dict = state_dict.copy() # for key in state_dict: # if 'blocks' in key and 'attn' in key and 'relative' not in key and 'mask' not in key: # new_key = key.replace('attn','temporal_attn') # if not new_key in state_dict: # new_state_dict[new_key] = state_dict[key] # else: # new_state_dict[new_key] = state_dict[new_key] # if 'blocks' in key and 'norm1' in key and 'relative' not in key and 'mask' not in key: # new_key = key.replace('norm1','temporal_norm') # if not new_key in state_dict: # new_state_dict[new_key] = state_dict[key] # else: # new_state_dict[new_key] = state_dict[new_key] # state_dict = new_state_dict
时间: 2023-06-18 13:04:23 浏览: 99
ADD KEY 的用法
这段代码的作用是将一个PyTorch模型的state_dict复制到一个新的字典new_state_dict中,并对一些特定的键进行修改。如果state_dict中的某些键包含字符串'blocks'和'attn',但不包含'relative'和'mask',则将其替换为'temporal_attn'。类似地,如果某些键包含字符串'blocks'和'norm1',但不包含'relative'和'mask',则将其替换为'temporal_norm'。最后,返回修改后的new_state_dict。这个代码片段可能是为了将一个模型的特定层的权重转移到另一个模型中的相应层。
阅读全文