Convert ndarrays in sample to Tensors
时间: 2024-03-21 10:43:26 浏览: 19
如果你正在编写 PyTorch 模型的数据集类,你可以使用 `torch.Tensor` 将样本中的 NumPy 数组转换为张量。这里是一个简单的示例:
```python
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, index):
sample = self.data[index]
tensor_sample = {
'input': torch.Tensor(sample['input']),
'target': torch.Tensor(sample['target'])
}
return tensor_sample
def __len__(self):
return len(self.data)
```
在这个例子中,我们假设 `data` 是一个包含 NumPy 数组的列表。在 `__getitem__` 方法中,我们将样本中的 `'input'` 和 `'target'` 数组转换为张量,并将它们存储在一个字典中,然后返回该字典。这使得我们可以在模型中轻松地使用样本张量。
请注意,这只是一个简单的示例,具体实现可能因数据类型和形状而异。