weightedrandomsampler代码python
时间: 2023-09-11 18:04:14 浏览: 36
以下是一个简单的 WeightedRandomSampler 的 Python 代码:
```python
import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
class CustomDataset(Dataset):
def __init__(self, data, targets):
self.data = data
self.targets = targets
def __getitem__(self, index):
x = self.data[index]
y = self.targets[index]
return x, y
def __len__(self):
return len(self.data)
# 假设有 1000 个样本,其中 800 个标签为 0,200 个标签为 1
data = torch.randn(1000, 10)
targets = torch.cat([torch.zeros(800), torch.ones(200)])
# 计算每个类别的权重
class_count = torch.tensor([len(torch.where(targets == t)[0]) for t in torch.unique(targets)])
weight = 1. / class_count.float()
samples_weight = weight[targets.long()]
# 使用 WeightedRandomSampler
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
# 使用 DataLoader
dataset = CustomDataset(data, targets)
loader = DataLoader(dataset, batch_size=32, sampler=sampler)
```
在这个例子中,我们首先定义了一个 CustomDataset 类,它继承自 PyTorch 的 Dataset 类。我们的数据集包含 1000 个样本和相应的标签。然后,我们计算每个类别的权重,并将它们传递给 WeightedRandomSampler 类。最后,我们将数据集和采样器传递给 DataLoader,以获得采样后的批次数据。