in __check_input__ assert edge_index.dtype == torch.long
时间: 2024-05-31 10:10:36 浏览: 308
这是一个Python断言(assertion),用于在代码中检查某些条件是否为真,如果不为真,则抛出AssertionError异常。在这个例子中,代码在检查输入的边缘索引(edge_index)是否是torch.long类型,如果不是,就会抛出AssertionError异常。这个断言是为了确保代码的正确性和稳定性。
相关问题
def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, add_self_loops=True, flow="source_to_target", dtype=None): fill_value = 2. if improved else 1. if isinstance(edge_index, SparseTensor): assert flow in ["source_to_target"] adj_t = edge_index if not adj_t.has_value(): adj_t = adj_t.fill_value(1., dtype=dtype) if add_self_loops: adj_t = fill_diag(adj_t, fill_value) deg = sparsesum(adj_t, dim=1) deg_inv_sqrt = deg.pow_(-0.5) deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.) adj_t = mul(adj_t, deg_inv_sqrt.view(-1, 1)) adj_t = mul(adj_t, deg_inv_sqrt.view(1, -1)) return adj_t else: assert flow in ["source_to_target", "target_to_source"] num_nodes = maybe_num_nodes(edge_index, num_nodes) if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device) if add_self_loops: edge_index, tmp_edge_weight = add_remaining_self_loops( edge_index, edge_weight, fill_value, num_nodes) assert tmp_edge_weight is not None edge_weight = tmp_edge_weight row, col = edge_index[0], edge_index[1] idx = col if flow == "source_to_target" else row deg = scatter_add(edge_weight, idx, dim=0, dim_size=num_nodes) deg_inv_sqrt = deg.pow_(-0.5) deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
这段代码是一个用于计算GCN中归一化邻接矩阵的函数。下面是对代码的解读:
输入参数:
- edge_index:图的边索引,可以是一个包含两行的长为2的Tensor,表示边的起点和终点的索引。也可以是一个SparseTensor对象,表示稀疏的边索引。
- edge_weight:边的权重,可选参数,默认为None。如果不提供权重,则默认为全1.
- num_nodes:节点的数量,可选参数,默认为None。如果不提供数量,则通过edge_index推断得到。
- improved:布尔值,表示是否采用改进的归一化方式,默认为False。
- add_self_loops:布尔值,表示是否添加自环,默认为True。
- flow:字符串,表示信息传播的方向,默认为"source_to_target",即从源节点到目标节点。
函数内部逻辑:
1. 根据是否是SparseTensor对象,判断是稀疏还是稠密的边索引。
2. 对于稀疏边索引,首先判断信息传播方向是否为"source_to_target",然后进行按行归一化的处理。具体操作包括:
- 填充值的选择根据是否采用改进的方式而定。
- 如果边索引没有值,则填充为1。
- 如果需要添加自环,则对稀疏边索引进行填充对角线操作。
- 计算每个节点的度矩阵。
- 计算度矩阵的逆平方根。
- 对邻接矩阵进行按行归一化。
3. 对于稠密边索引,首先根据信息传播方向进行处理。具体操作包括:
- 如果需要添加自环,则对边索引和权重进行添加自环操作。
- 根据信息传播方向,选择相应的索引进行节点度计算。
- 计算度矩阵的逆平方根。
- 对边索引进行按元素归一化。
最终返回归一化后的邻接矩阵或者边索引。
Focal 损失函数代码如下:def focal_loss(input_values, gamma): """Computes the focal loss""" p = torch.exp(-input_values) loss = (1 - p) ** gamma * input_values return loss.mean() class FocalLoss(nn.Module): def __init__(self, weight=None, gamma=0.): super(FocalLoss, self).__init__() assert gamma >= 0 self.gamma = gamma self.weight = weight def forward(self, input, target): return focal_loss(F.cross_entropy(input, target, reduction='none', weight=self.weight), self.gamma) LDAM损失函数代码如下:class LDAMLoss(nn.Module): def __init__(self, cls_num_list, max_m=0.5, weight=None, s=30): super(LDAMLoss, self).__init__() m_list = 1.0 / np.sqrt(np.sqrt(cls_num_list)) m_list = m_list * (max_m / np.max(m_list)) m_list = torch.cuda.FloatTensor(m_list) self.m_list = m_list assert s > 0 self.s = s self.weight = weight def forward(self, x, target): index = torch.zeros_like(x, dtype=torch.uint8) index.scatter_(1, target.data.view(-1, 1), 1) index_float = index.type(torch.cuda.FloatTensor) batch_m = torch.matmul(self.m_list[None, :], index_float.transpose(0,1)) batch_m = batch_m.view((-1, 1)) x_m = x - batch_m output = torch.where(index, x_m, x) return F.cross_entropy(self.s*output, target, weight=self.weight) Large Margin aware Focal (LMF) 损失函数是 Focal 损失函数和 LDAM损失函数加权的线性组合,帮我用pytorch代码实现LMF损失函数
以下是LMF损失函数的PyTorch实现:
```
class LMF_Loss(nn.Module):
def __init__(self, cls_num_list, max_m=0.5, weight=None, s=30, gamma=2.0, alpha=0.25):
super(LMF_Loss, self).__init__()
self.gamma = gamma
self.alpha = alpha
self.focal_loss = FocalLoss(weight=weight, gamma=gamma)
self.ldam_loss = LDAMLoss(cls_num_list, max_m=max_m, weight=weight, s=s)
def forward(self, x, target):
focal_loss = self.focal_loss(x, target)
ldam_loss = self.ldam_loss(x, target)
lmf_loss = self.alpha * focal_loss + (1 - self.alpha) * ldam_loss.pow(self.gamma)
return lmf_loss.mean()
```
其中,`gamma`为Focal Loss的超参数,`alpha`为LMF Loss中Focal Loss的权重。其他参数含义与原始的Focal Loss和LDAM Loss相同。
阅读全文