pytorch 输入多光谱数据
时间: 2023-08-20 07:05:15 浏览: 185
如果你想输入多光谱数据到 PyTorch 中,可以使用 PyTorch 的数据加载器 `DataLoader` 对数据进行预处理和批处理。以下是一个简单的例子,假设你有一个包含多光谱图像文件的文件夹:
```python
import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
class MultiSpectralDataset(Dataset):
def __init__(self, data_dir):
self.data_dir = data_dir
self.data_files = os.listdir(data_dir)
def __len__(self):
return len(self.data_files)
def __getitem__(self, idx):
data_file = self.data_files[idx]
img = Image.open(os.path.join(self.data_dir, data_file))
img = np.array(img)
# 将图像转换为 PyTorch Tensor
img = torch.from_numpy(img).float()
return img
```
在上面的代码中,我们定义了一个自定义的数据集类 `MultiSpectralDataset`,其中 `__getitem__` 函数将每个多光谱图像文件加载为一个 PyTorch Tensor,并从数据集中返回它。然后,我们可以使用 `DataLoader` 对数据进行批处理和预处理:
```python
dataset = MultiSpectralDataset(data_dir='path/to/data')
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
for batch in dataloader:
# 对批量进行处理
pass
```
在上面的代码中,我们创建了一个数据加载器 `dataloader`,它将数据集 `dataset` 分成大小为 4 的批量,并在每个 epoch 时打乱数据的顺序。然后,我们可以在训练循环中使用 `dataloader` 来获取批量数据并进行处理。
阅读全文