weightedrandomsampler
时间: 2023-09-11 09:05:05 浏览: 43
WeightedRandomSampler is a sampling technique used in PyTorch to sample data from a dataset with a probability proportional to the weights assigned to each sample. This is useful when the dataset is imbalanced, and we want to ensure that the model sees a balanced representation of the data during training.
The WeightedRandomSampler takes two arguments: the weights of each sample and the number of samples to draw. The weights can be any positive numbers and do not need to sum to one. The number of samples to draw can be less than the total number of samples in the dataset, allowing us to create smaller subsets of the data.
Here is an example of using WeightedRandomSampler:
```
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
# create a dataset with 100 samples
dataset = torch.utils.data.TensorDataset(torch.randn(100, 3), torch.randint(0, 2, (100,)))
# calculate weights for each sample based on the class distribution
class_count = torch.tensor([len(torch.where(dataset.tensors[1] == t)[0]) for t in torch.unique(dataset.tensors[1])])
weights = 1.0 / class_count.float()
sample_weights = weights[dataset.tensors[1]]
# create a sampler with the weighted probabilities
sampler = WeightedRandomSampler(sample_weights, num_samples=len(dataset), replacement=True)
# create a dataloader using the sampler
dataloader = DataLoader(dataset, batch_size=16, sampler=sampler)
# iterate over the dataloader and print the class distribution
for i, (x, y) in enumerate(dataloader):
print(f"Batch {i}, Class distribution: {torch.bincount(y)}")
```
In this example, we create a dataset with 100 samples, where the second tensor represents the class labels. We then calculate the weights for each sample based on the class distribution, where samples from the minority class are given higher weights. We create a WeightedRandomSampler with the sample weights and use it to create a DataLoader. Finally, we iterate over the dataloader and print the class distribution of each batch to confirm that the sampler is working as expected.