Duplicate weights for temporal attention and temporal norm # new_state_dict = state_dict.copy() # for key in state_dict: # if 'blocks' in key and 'attn' in key and 'relative' not in key and 'mask' not in key: # new_key = key.replace('attn','temporal_attn') # if not new_key in state_dict: # new_state_dict[new_key] = state_dict[key] # else: # new_state_dict[new_key] = state_dict[new_key] # if 'blocks' in key and 'norm1' in key and 'relative' not in key and 'mask' not in key: # new_key = key.replace('norm1','temporal_norm') # if not new_key in state_dict: # new_state_dict[new_key] = state_dict[key] # else: # new_state_dict[new_key] = state_dict[new_key] # state_dict = new_state_dict
时间: 2023-06-19 21:06:44 浏览: 170
这段代码是在加载预训练模型的过程中进行的一些操作,主要是为了避免在模型中出现重复的权重。在这段代码中,首先将原始的权重字典进行复制,然后通过遍历原始字典中的键,将其中与时间相关的注意力和归一化操作的键名进行修改,修改后的键名中将'attn'替换为'temporal_attn',将'norm1'替换为'temporal_norm'。如果修改后的键名不在新的字典中,则将其添加到新字典中;否则,将其权重值更新为新的字典中已有的键的权重值。这样就保证了模型中不会出现重复的权重。
相关问题
if isinstance(self.pretrained, str): self.apply(_init_weights) logger = get_root_logger() logger.info(f'load model from: {self.pretrained}') checkpoint = torch.load(self.pretrained, map_location='cpu') state_dict = checkpoint['model'] state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).repeat(1,1,self.patch_size[0],1,1) / self.patch_size[0] ## Duplicate weights for temporal attention and temporal norm # new_state_dict = state_dict.copy() # for key in state_dict: # if 'blocks' in key and 'attn' in key and 'relative' not in key and 'mask' not in key: # new_key = key.replace('attn','temporal_attn') # if not new_key in state_dict: # new_state_dict[new_key] = state_dict[key] # else: # new_state_dict[new_key] = state_dict[new_key] # if 'blocks' in key and 'norm1' in key and 'relative' not in key and 'mask' not in key: # new_key = key.replace('norm1','temporal_norm') # if not new_key in state_dict: # new_state_dict[new_key] = state_dict[key] # else: # new_state_dict[new_key] = state_dict[new_key] # state_dict = new_state_dict msg = self.load_state_dict(state_dict, strict=False) logger.info('Missing keys: {}'.format(msg.missing_keys)) logger.info('Unexpected keys: {}'.format(msg.unexpected_keys)) logger.info(f"=> loaded successfully '{self.pretrained}'") del checkpoint torch.cuda.empty_cache() elif self.pretrained is None: self.apply(_init_weights) else: raise TypeError('pretrained must be a str or None')
这段代码主要是加载预训练模型的过程。如果传入的预训练模型是一个字符串,那么就会根据该字符串的路径加载模型。加载过程中会将模型的权重进行初始化,并将预训练模型的权重加载到新的模型中。如果预训练模型是 None,则会进行默认初始化。如果传入的预训练模型不是 str 或 None,则会抛出一个 TypeError。
高通Linux内核sn65dsi84芯片驱动定义了这个结构体:static const struct drm_bridge_funcs sn65dsi83_funcs = { .attach = sn65dsi83_attach, .atomic_pre_enable = sn65dsi83_atomic_pre_enable, .atomic_enable = sn65dsi83_atomic_enable, .atomic_disable = sn65dsi83_atomic_disable, .atomic_post_disable = sn65dsi83_atomic_post_disable, .mode_valid = sn65dsi83_mode_valid, .atomic_duplicate_state = drm_atomic_helper_bridge_duplicate_state, .atomic_destroy_state = drm_atomic_helper_bridge_destroy_state, .atomic_reset = drm_atomic_helper_bridge_reset, .atomic_get_input_bus_fmts = sn65dsi83_atomic_get_input_bus_fmts, };它的执行顺序是什么
这段代码是定义了一个名为sn65dsi83_funcs的静态常量结构体,其中包含了一些函数指针,这些函数指针指向了具体的函数实现。这些函数指针按照定义顺序依次对应以下函数:attach、atomic_pre_enable、atomic_enable、atomic_disable、atomic_post_disable、mode_valid、atomic_duplicate_state、atomic_destroy_state、atomic_reset以及atomic_get_input_bus_fmts。在实际使用中,这些函数可能会被驱动程序中的其他函数调用,以完成具体的功能。
阅读全文