pytorch实现自定义模块
时间: 2023-10-23 08:29:44 浏览: 44
在 PyTorch 中,我们可以通过继承 `nn.Module` 类来实现自定义模块。下面是一个简单的例子,展示了如何实现一个简单的全连接层模块。
```python
import torch.nn as nn
class MyLinear(nn.Module):
def __init__(self, input_size, output_size):
super(MyLinear, self).__init__()
self.weight = nn.Parameter(torch.Tensor(output_size, input_size))
self.bias = nn.Parameter(torch.Tensor(output_size))
nn.init.xavier_uniform_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, x):
x = x.view(x.size(0), -1) # 将输入数据展平,以适应全连接层
out = torch.mm(x, self.weight.t()) + self.bias
return out
```
在上述代码中,我们定义了一个名为 `MyLinear` 的自定义模块,该模块继承自 `nn.Module` 类。我们在 `__init__` 方法中定义了模块的参数,包括权重和偏置,这些参数都是 `nn.Parameter` 类型,它们会自动被注册到模型参数列表中。我们还使用 `nn.init` 模块中的函数来初始化权重和偏置。在 `forward` 方法中,我们对输入数据进行展平操作,并使用 `torch.mm` 函数进行矩阵乘法运算,最后加上偏置即可。
使用自定义模块和内置模块的方式是一致的。例如,我们可以这样使用上述定义的自定义全连接层模块:
```python
import torch
# 创建一个 MyLinear 实例
linear = MyLinear(784, 10)
# 随机生成一个大小为 (1, 784) 的张量
x = torch.randn(1, 784)
# 使用自定义模块进行前向计算
out = linear(x)
print(out.size()) # 输出:torch.Size([1, 10])
```
这里我们创建了一个输入大小为 784,输出大小为 10 的全连接层模块,并对一个大小为 (1, 784) 的随机输入数据进行前向计算,得到了大小为 (1, 10) 的输出。