torch.Storage
时间: 2023-11-01 13:22:56 浏览: 43
`torch.Storage` 是 PyTorch 中的一个底层数据结构,用于存储一维的连续数据块。它是一个抽象类,不能直接实例化,而是通过其子类来实现具体的存储方式,比如 `torch.FloatStorage`、`torch.DoubleStorage` 等。
`torch.Storage` 中包含了一些常用的方法,比如 `size()` 返回存储元素的总数,`fill_(value)` 用给定的值填充存储,`copy_(source)` 将另一个存储中的数据复制到当前存储中等等。
通常情况下,我们不需要手动使用 `torch.Storage`,而是直接使用高级接口,如 `torch.Tensor`。但在一些特殊场景下,例如需要使用共享内存的多进程并行计算,我们可能需要手动创建和使用 `torch.Storage`。
相关问题
torch.storage
是什么?
Torch.storage是PyTorch中用于保存连续Memory块的数据类型。它通常被用来存储矩阵或张量等多维数组的数据。使用Torch.storage可以方便地在不同的进程或设备之间共享数据,同时也可以通过它来实现存储和恢复模型。常见的使用方式是将存储对象中的数据以一定格式写入到硬盘上,以便之后重新读回和使用。
_state = torch.LongTensor(state, dtype=torch.float32) TypeError: new() received an invalid combination of arguments - got (numpy.ndarray, dtype=torch.dtype), but expected one of: * (*, torch.device device) didn't match because some of the keywords were incorrect: dtype * (torch.Storage storage) * (Tensor other) * (tuple of ints size, *, torch.device device) * (object data, *, torch.device device)
这个错误是由于将 `numpy.ndarray` 类型的数据直接传递给 `torch.LongTensor` 的构造函数,而没有指定数据类型。根据错误提示,可以看到 `dtype=torch.float32` 是不合法的参数组合。
为了解决这个问题,我们需要将 `numpy.ndarray` 数据类型转换为 `torch.Tensor` 数据类型,同时指定数据类型为 `torch.float32`。可以使用 `torch.from_numpy()` 函数或者 `torch.tensor()` 函数将 `numpy.ndarray` 数据类型转换为 `torch.Tensor` 数据类型,如下所示:
```
import numpy as np
import torch
state = np.array([1, 2, 3])
_state = torch.from_numpy(state).float() # 或者使用 torch.tensor(state, dtype=torch.float32)
```
这样就可以将 `numpy.ndarray` 类型的数据转换为 `torch.Tensor` 类型,并且指定了数据类型为 `torch.float32`。