这句代码 week_feature=week_feature.unsqueeze(1).repeat(1, n_time, n_node, 1)报错'numpy.ndarray' object has no attribute 'unsqueeze',为什么产生呢,怎么解决,pytorch版本
时间: 2024-03-26 13:39:25 浏览: 67
详解pytorch中squeeze()和unsqueeze()函数介绍
这个错误提示表明你正在使用 NumPy 数组而不是 PyTorch 张量,而 NumPy 数组没有 `unsqueeze` 方法。要使用 PyTorch 的 `unsqueeze` 方法,你需要将 NumPy 数组转换为 PyTorch 张量。你可以使用 `torch.from_numpy()` 方法将 NumPy 数组转换为 PyTorch 张量。修改代码如下:
```python
import torch
week_feature = torch.from_numpy(week_feature)
week_feature = week_feature.unsqueeze(1).repeat(1, n_time, n_node, 1)
```
这应该可以解决你的问题。注意,这里假设你已经将 NumPy 数组 `week_feature` 定义好了。此外,要注意 PyTorch 版本,这个方法在 PyTorch 0.4.0 及以上版本中可用。
阅读全文