使用使用pytorch实现论文中的实现论文中的unet网络网络
设计神经网络的一般步骤:设计神经网络的一般步骤:
1. 设计框架
2. 设计骨干网络
Unet网络设计的步骤:网络设计的步骤:
1. 设计Unet网络工厂模式
2. 设计编解码结构
3. 设计卷积模块
4. unet实例模块
Unet网络最重要的特征:网络最重要的特征:
1. 编解码结构。
2. 解码结构,比FCN更加完善,采用连接方式。
3. 本质是一个框架,编码部分可以使用很多图像分类网络。
示例代码:
import torch
import torch.nn as nn
class Unet(nn.Module):
#初始化参数:Encoder,Decoder,bridge
#bridge默认值为无,如果有参数传入,则用该参数替换None
def __init__(self,Encoder,Decoder,bridge = None):
super(Unet,self).__init__()
self.encoder = Encoder(encoder_blocks)
self.decoder = Decoder(decoder_blocks)
self.bridge = bridge
def forward(self,x):
res = self.encoder(x)
out,skip = res[0],res[1,:] if bridge is not None:
out = bridge(out)
out = self.decoder(out,skip)
return out
#设计编码模块
class Encoder(nn.Module):
def __init__(self,blocks):
super(Encoder,self).__init__()
#assert:断言函数,避免出现参数错误
assert len(blocks) > 0
#nn.Modulelist():模型列表,所有的参数可以纳入网络,但是没有forward函数
self.blocks = nn.Modulelist(blocks)
def forward(self,x):
skip = [] for i in range(len(self.blocks) - 1):
x = self.blocks[i](x)
skip.append(x)
res = [self.block[i+1](x)] #列表之间可以通过+号拼接
res += skip
return res
#设计Decoder模块
class Decoder(nn.Module):
def __init__(self,blocks):
super(Decoder, self).__init__()
assert len(blocks) > 0
self.blocks = nn.Modulelist(blocks)
def ceter_crop(self,skips,x):
_,_,height1,width1 = skips.shape()
_,_,height2,width2 = x.shape()
评论0