pytorch代码,2维数据集train是ndarray 格式,添加一个特征维度将其3维,转换为张量
时间: 2024-03-28 22:41:56 浏览: 139
PyTorch中Tensor的维度变换实现
5星 · 资源好评率100%
可以使用`torch.from_numpy()`函数将numpy数组转换为PyTorch张量,并使用`.unsqueeze()`函数在第三个维度上添加一个维度,将2维数据集转换为3维张量。以下是代码示例:
``` python
import numpy as np
import torch
# 创建2维数据集
train = np.array([[1, 2], [3, 4], [5, 6]])
# 将数据集转换为张量,添加一个维度
train_tensor = torch.from_numpy(train).unsqueeze(2)
print(train_tensor.shape) # 输出 torch.Size([3, 2, 1])
```
其中,`unsqueeze(2)`中的2表示在第三个维度上添加一个维度,将原来的2维数据集转换为3维张量。
阅读全文