class ResGCN(torch.nn.Module): def __init__(self, in_channels, out_channels): super(ResGCN, self).__init__() self.conv1 = GCNConv(in_channels, out_channels) self.bn1 = BatchNorm(out_channels) self.conv2 = GCNConv(out_channels, out_channels) self.bn2 = BatchNorm(out_channels) self.shortcut = GCNConv(in_channels, out_channels) self.bn_shortcut = BatchNorm(out_channels) def forward(self, x, edge_index): # 残差结构 identity = x out = F.relu(self.bn1(self.conv1(x, edge_index))) out = self.bn2(self.conv2(out, edge_index)) shortcut = self.bn_shortcut(self.shortcut(identity, edge_index)) out += shortcut out = F.relu(out) return out 根据代码写出数学公式
时间: 2024-01-12 08:04:46 浏览: 109
Residual-Networks.zip_-baijiahao_47W_python residual_python残差网络
首先定义符号:
- $x$: 输入特征矩阵,维度为 $N \times D$,$N$ 表示节点数,$D$ 表示特征维度。
- $A$: 邻接矩阵,维度为 $N \times N$,表示节点之间的连接关系。
- $W$: 权重矩阵,维度为 $D \times D$,表示特征变换的参数。
- $b$: 偏置向量,维度为 $1 \times D$,表示特征变换的参数。
- $\sigma$: 激活函数,这里使用 ReLU。
则 ResGCN 的前向传播可以表示为:
$$
\begin{aligned}
&\text{identity} = x, \\
&h_1 = \sigma(\text{BN}(W_1 x A + b_1)), \\
&h_2 = \text{BN}(W_2 h_1 A + b_2), \\
&\text{shortcut} = \text{BN}(W_s x A + b_s), \\
&\text{output} = \sigma(h_2 + \text{shortcut}),
\end{aligned}
$$
其中 $\text{BN}$ 表示 BatchNorm 操作,$W_1, W_2, W_s$ 分别是第一层 GCNConv、第二层 GCNConv 和 shortcut 的参数矩阵,$b_1, b_2, b_s$ 分别是对应的偏置向量。注意这里的 $A$ 是 GCNConv 的输入,即 $A$ 已经被加上自环,可以看 GCNConv 的定义。
阅读全文