torch.load unpickler.load()中super().find_class(mod_name, name)报错Can't get attribute
时间: 2023-06-11 21:10:41 浏览: 468
这个问题可能是由于pickle文件中保存的对象所在的模块没有被正确导入而导致的。在使用`torch.load`函数加载文件时,它使用`pickle`模块来反序列化对象。当pickle模块尝试重新构建对象时,它需要找到对象所在的模块,并使用`__main__`或该模块的名称来封装对象。如果pickle无法找到正确的模块,它会引发一个错误。
为了解决这个问题,您可以尝试以下几个步骤:
1.确保您正在加载的pickle文件是由相同版本的PyTorch生成的。如果您使用的是不同版本的PyTorch,可能会导致pickle文件无法正确地加载。
2.确保您的模型代码中的所有自定义模块都正确导入。如果您的模型代码中存在自定义模块,那么在加载pickle文件之前,您需要将这些模块正确地导入到Python环境中。可以使用`importlib.import_module`函数来导入它们。
3.在加载pickle文件之前,尝试手动导入包含模型的模块。如果pickle无法正确地找到模型所在的模块,您可以尝试手动导入该模块。例如,如果您的模型代码位于`models.py`中,您可以使用`import models`来手动导入该模块。
4.使用`torch.jit`模块来重新导出模型。如果您无法解决pickle加载错误,您可以尝试使用`torch.jit`模块将您的模型重新导入到脚本中。这将使您的模型成为一个独立的脚本,并且可以通过简单的导入来加载它。
相关问题
torch.load unpickler.load()中super().find_class(mod_name, name)报错
这个问题通常是因为 pickle 文件中引用的某些类已经被删除或重命名了。试着使用下面的代码来解决问题:
```python
import sys
import torch
# 新建一个类,继承 torch.serialization.pickle.Unpickler
class CustomUnpickler(torch.serialization.pickle.Unpickler):
def find_class(self, module, name):
if name == 'YourClassName': # 将 YourClassName 替换为实际类名
from your_module import YourClassName # 将 your_module 和 YourClassName 替换为实际的模块名和类名
return YourClassName
return super().find_class(module, name)
# 打开 pickle 文件,并使用自定义的 Unpickler 进行反序列化
with open('your_pickle_file.pkl', 'rb') as f:
obj = CustomUnpickler(f).load()
```
在上述代码中,我们新建了一个类 `CustomUnpickler`,并在其中重写了 `find_class` 方法,以便在 pickle 文件中找不到类定义时,可以使用自定义的方式加载类。你需要将 `YourClassName` 和 `your_module` 替换为实际的类名和模块名。
Traceback (most recent call last): File "C:\yolov5-5.0\detect.py", line 178, in <module> detect() File "C:\yolov5-5.0\detect.py", line 34, in detect model = attempt_load(weights, map_location=device) # load FP32 model File "C:\yolov5-5.0\models\experimental.py", line 118, in attempt_load ckpt = torch.load(w, map_location=map_location) # load File "C:\Users\admin\.conda\envs\gwstu_env\lib\site-packages\torch\serialization.py", line 789, in load return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args) File "C:\Users\admin\.conda\envs\gwstu_env\lib\site-packages\torch\serialization.py", line 1131, in _load result = unpickler.load() File "C:\Users\admin\.conda\envs\gwstu_env\lib\site-packages\torch\serialization.py", line 1124, in find_class return super().find_class(mod_name, name) AttributeError: Can't get attribute 'SPPF' on <module 'models.common' from 'C:\\yolov5-5.0\\models\\common.py'>
这是一个Python错误追踪(Traceback),说明在运行 `C:\yolov5-5.0\detect.py` 文件时出现了错误。具体来说,错误是由 `AttributeError` 引发的,原因是无法在 `models.common` 模块中获取名为 `SPPF` 的属性。
可能的原因是在 `models.common` 模块中不存在名为 `SPPF` 的属性。你可以检查一下代码中是否正确地引用了 `models.common` 模块中的属性名。另外,你也可以检查一下 `models.common` 模块是否存在或是否被正确导入。
阅读全文