series.interpolate(limit_area='inside', method='linear').mask(m > thresh).round(2)
时间: 2024-01-12 15:05:22 浏览: 79
这段代码的作用是对一个 pandas Series 进行插值处理,处理方法为线性插值(method='linear'),且只在插值区域内进行插值(limit_area='inside')。然后使用一个掩码(mask)将大于阈值(thresh)的值设置为缺失值(NaN),最后使用 round(2) 方法将结果保留两位小数。
相关问题
分析这个代码class OhemCrossEntropy(nn.Module): def __init__(self, ignore_label=-1, thres=0.7, min_kept=100000, weight=None): super(OhemCrossEntropy, self).__init__() self.thresh = thres self.min_kept = max(1, min_kept) self.ignore_label = ignore_label self.criterion = nn.CrossEntropyLoss( weight=weight, ignore_index=ignore_label, reduction='none' ) def _ce_forward(self, score, target): ph, pw = score.size(2), score.size(3) h, w = target.size(1), target.size(2) if ph != h or pw != w: score = F.interpolate(input=score, size=( h, w), mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS) loss = self.criterion(score, target) return loss def _ohem_forward(self, score, target, **kwargs): ph, pw = score.size(2), score.size(3) h, w = target.size(1), target.size(2) if ph != h or pw != w: score = F.interpolate(input=score, size=( h, w), mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS) pred = F.softmax(score, dim=1) pixel_losses = self.criterion(score, target).contiguous().view(-1) mask = target.contiguous().view(-1) != self.ignore_label tmp_target = target.clone() tmp_target[tmp_target == self.ignore_label] = 0 pred = pred.gather(1, tmp_target.unsqueeze(1)) pred, ind = pred.contiguous().view(-1,)[mask].contiguous().sort() min_value = pred[min(self.min_kept, pred.numel() - 1)] threshold = max(min_value, self.thresh) pixel_losses = pixel_losses[mask][ind] pixel_losses = pixel_losses[pred < threshold] return pixel_losses.mean() def forward(self, score, target): if config.MODEL.NUM_OUTPUTS == 1: score = [score] weights = config.LOSS.BALANCE_WEIGHTS assert len(weights) == len(score) functions = [self._ce_forward] * \ (len(weights) - 1) + [self._ohem_forward] return sum([ w * func(x, target) for (w, x, func) in zip(weights, score, functions) ])
这是一个实现了OHEM(Online Hard Example Mining)的交叉熵损失函数,用于解决深度学习中存在难样本或噪声样本导致训练效果不好的问题。其中,thresh表示像素的softmax预测概率阈值,小于该阈值的像素被认为是难样本;min_kept表示每张图中保留的难样本最小数量;ignore_label表示忽略的标签;weight表示权重。_ce_forward函数实现了普通的交叉熵损失的计算;_ohem_forward函数实现了OHEM的计算。在forward函数中,对于有多个输出的模型,采用了权重平衡的方式进行计算。
def get_input(self, batch, k): x = batch[k] if len(x.shape) == 3: x = x[..., None] x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() if self.batch_resize_range is not None: lower_size = self.batch_resize_range[0] upper_size = self.batch_resize_range[1] if self.global_step <= 4: # do the first few batches with max size to avoid later oom new_resize = upper_size else: new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) if new_resize != x.shape[2]: x = F.interpolate(x, size=new_resize, mode="bicubic") x = x.detach() return x解析
这段代码是一个函数`get_input`,它用于将输入数据batch中的指定键值k取出来,并做一些预处理,最终返回一个张量x。具体来说,该函数的实现包括以下几个步骤:
1. 取出batch中键值为k的数据,并将其赋值给变量x。
2. 检查x的形状是否为3维,如果是,则在最后一维添加一个维度,使其成为4维张量。
3. 将x的维度从(批大小, 高, 宽, 通道数)的顺序改为(批大小, 通道数, 高, 宽)的顺序。
4. 如果batch_resize_range不为None,则对x进行大小调整。具体来说,如果当前训练步数(self.global_step)小于等于4,则将x的大小调整为batch_resize_range的上限,否则将x的大小随机调整到batch_resize_range中的一个大小。调整大小的方法是使用双三次插值方法(mode="bicubic")将x调整到指定大小。
5. 将x转换为浮点数类型,并返回。
阅读全文