写一个基于pytorch平台的unet网络加入可以调整每个类别权重的代码
时间: 2024-05-02 09:18:56 浏览: 157
下面是一个基于PyTorch平台的UNET网络加入可以调整每个类别权重的代码:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class UNet(nn.Module):
def __init__(self, in_channels=3, out_channels=1, init_features=32):
super(UNet, self).__init__()
self.encoder1 = UNet._block(in_channels, init_features, name="enc1")
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder2 = UNet._block(
init_features, init_features * 2, name="enc2")
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder3 = UNet._block(
init_features * 2, init_features * 4, name="enc3")
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder4 = UNet._block(
init_features * 4, init_features * 8, name="enc4")
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
self.center = UNet._block(
init_features * 8, init_features * 16, name="center")
self.decoder4 = UNet._block(
init_features * 16 + init_features * 8, init_features * 8, name="dec4")
self.upconv4 = nn.ConvTranspose2d(
init_features * 16, init_features * 8, kernel_size=2, stride=2)
self.decoder3 = UNet._block(
init_features * 8 + init_features * 4, init_features * 4, name="dec3")
self.upconv3 = nn.ConvTranspose2d(
init_features * 8, init_features * 4, kernel_size=2, stride=2)
self.decoder2 = UNet._block(
init_features * 4 + init_features * 2, init_features * 2, name="dec2")
self.upconv2 = nn.ConvTranspose2d(
init_features * 4, init_features * 2, kernel_size=2, stride=2)
self.decoder1 = UNet._block(
init_features * 2 + init_features, init_features, name="dec1")
self.conv = nn.Conv2d(
in_channels=init_features, out_channels=out_channels, kernel_size=1)
def forward(self, x):
enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool1(enc1))
enc3 = self.encoder3(self.pool2(enc2))
enc4 = self.encoder4(self.pool3(enc3))
center = self.center(self.pool4(enc4))
dec4 = self.decoder4(torch.cat(
[center, enc4], dim=1))
up4 = self.upconv4(dec4)
dec3 = self.decoder3(torch.cat(
[up4, enc3], dim=1))
up3 = self.upconv3(dec3)
dec2 = self.decoder2(torch.cat(
[up3, enc2], dim=1))
up2 = self.upconv2(dec2)
dec1 = self.decoder1(torch.cat(
[up2, enc1], dim=1))
return F.sigmoid(self.conv(dec1))
@staticmethod
def _block(in_channels, features, name):
return nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=features, kernel_size=3, padding=1),
nn.BatchNorm2d(num_features=features),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=features, out_channels=features, kernel_size=3, padding=1),
nn.BatchNorm2d(num_features=features),
nn.ReLU(inplace=True)
)
class WeightedCrossEntropyLoss(nn.Module):
def __init__(self, weight=None, reduction='mean'):
super(WeightedCrossEntropyLoss, self).__init__()
self.weight = weight
self.reduction = reduction
def forward(self, inputs, targets):
loss = F.binary_cross_entropy(inputs, targets, reduction='none')
if self.weight is not None:
assert self.weight.shape == targets.shape, "Weight shape must match target shape"
weight_tensor = torch.from_numpy(self.weight).float().to(targets.device)
loss = loss * weight_tensor
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else:
return loss
# Example usage
model = UNet()
weight = [1, 2] # class weights
criterion = WeightedCrossEntropyLoss(weight=weight)
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(num_epochs):
for images, labels in dataloader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# Print loss
print("Epoch [{}/{}], Batch [{}/{}], Loss: {:.4f}".format(
epoch+1, num_epochs, i+1, total_step, loss.item()))
```
阅读全文