pytorch+resnet18实现长尾数据集分类实现长尾数据集分类(一一)
实验基于论文:实验基于论文: Class-Balanced Loss Based on Effective Number of Samples
Class-balanced-loss代码地址:代码地址:https://github.com/vandit15/Class-balanced-loss-pytorch
resnet18代码参考链接:代码参考链接:https://blog.csdn.net/sunqiande88/article/details/80100891
制作数据集制作数据集
论文中通过公式n=niuin = n_iu^in=niui,iii为类索引.制作长尾cifar10数据集.以下代码以不均匀比例100为例.也可以通过科
学上网在谷歌云链接下载.
loadcifar.py
import torch
import torch.utils.data as Data
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
def unpickle(file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
# 从源文件读取数据
# 返回 train_data[12406,3072]和labels[12406] # test_data[10000,3072]和labels[10000] def get_data(train=False):
data = None
labels = None
new_data = None
new_labels = []
if train == True:
for i in range(1, 6):
batch = unpickle('data/cifar-10-batches-py/data_batch_' + str(i))
if i == 1:
data = batch[b'data'] labels = batch[b'labels'] else:
data = np.concatenate([data, batch[b'data']])
labels = np.concatenate([labels, batch[b'labels']])
count = np.zeros((10),dtype=np.int)
for i in range(len(labels)):
labels[i] = labels[i].reshape(1,1)
data[i] = data[i].reshape((1,3072))
if count[labels[i]] < int(np.floor(5000 * ((1 / 100) ** (1 / 9)) ** (labels[i]))):
count[labels[i]] += 1
if i == 0:
new_data = data[i] else:
new_data = np.concatenate([new_data,data[i]])
new_labels.append(labels[i])
else:
continue
new_labels = np.array(new_labels)
new_data = new_data.reshape(-1,3072)
else:
batch = unpickle('data/cifar-10-batches-py/test_batch')
new_data = batch[b'data'] new_labels = batch[b'labels']
return new_data, new_labels
# 图像预处理函数,Compose会将多个transform操作包在一起
# 对于彩色图像,色彩通道不存在平稳特性
transform = transforms.Compose([
# ToTensor是指把PIL.Image(RGB) 或者numpy.ndarray(H x W x C)
# 从0到255的值映射到0到1的范围内,并转化成Tensor格式。
transforms.ToTensor(),
# Normalize函数将图像数据归一化到[-1,1] transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
评论0