使用pytorch写一段rsenet网络的代码
时间:2023-03-16 21:11:53 浏览:21
这是一段使用pytorch写的Rsenet网络的代码:import torch
from torch import nnclass Rsenet(nn.Module):
def __init__(self, in_channel, out_channel):
super(Rsenet, self).__init__()
self.conv1 = nn.Conv2d(in_channel, out_channel, 3, padding=1)
self.conv2 = nn.Conv2d(out_channel, out_channel, 3, padding=1)
self.conv3 = nn.Conv2d(out_channel, out_channel, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
self.pool = nn.MaxPool2d(2) def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.relu(x)
x = self.conv3(x)
x = self.relu(x)
x = self.pool(x)
return x