现需要用segnet模型实现图像语义分割全过程,图像数据和mask数据分别存放在两个不同文件中,图像和mask图像均为彩色图像,需要包含图像预处理、划分训练集、测试集、验证集训练模型、测试集实际效果和测试效果比较、验证模型、预测新图像,请你用pytorch帮我实现代码,谢谢!
时间: 2023-12-06 21:45:18 浏览: 141
好的,以下是使用PyTorch实现segnet模型进行图像语义分割的完整代码,包括图像预处理、数据集划分、模型训练、测试和预测:
``` python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
from PIL import Image
import os
import numpy as np
import random
# Define the SegNet model
class SegNet(nn.Module):
def __init__(self, in_channels, out_channels):
super(SegNet, self).__init__()
# Encoder layers
self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(64)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(128)
self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm2d(256)
self.conv4 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
self.bn4 = nn.BatchNorm2d(512)
self.conv5 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.bn5 = nn.BatchNorm2d(512)
# Decoder layers
self.upconv5 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, output_padding=1)
self.upbn5 = nn.BatchNorm2d(512)
self.upconv4 = nn.ConvTranspose2d(1024, 256, kernel_size=3, stride=2, padding=1, output_padding=1)
self.upbn4 = nn.BatchNorm2d(256)
self.upconv3 = nn.ConvTranspose2d(512, 128, kernel_size=3, stride=2, padding=1, output_padding=1)
self.upbn3 = nn.BatchNorm2d(128)
self.upconv2 = nn.ConvTranspose2d(256, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
self.upbn2 = nn.BatchNorm2d(64)
self.upconv1 = nn.ConvTranspose2d(128, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
def forward(self, x):
# Encoder
x1 = nn.functional.relu(self.bn1(self.conv1(x)))
x2 = nn.functional.relu(self.bn2(self.conv2(x1)))
x3 = nn.functional.relu(self.bn3(self.conv3(x2)))
x4 = nn.functional.relu(self.bn4(self.conv4(x3)))
x5 = nn.functional.relu(self.bn5(self.conv5(x4)))
# Decoder
x_up5 = nn.functional.relu(self.upbn5(self.upconv5(x5)))
x_up4 = nn.functional.relu(self.upbn4(self.upconv4(torch.cat((x4, x_up5), dim=1))))
x_up3 = nn.functional.relu(self.upbn3(self.upconv3(torch.cat((x3, x_up4), dim=1))))
x_up2 = nn.functional.relu(self.upbn2(self.upconv2(torch.cat((x2, x_up3), dim=1))))
x_up1 = self.upconv1(torch.cat((x1, x_up2), dim=1))
return x_up1
# Define the custom dataset for image segmentation
class SegmentationDataset(Dataset):
def __init__(self, img_dir, mask_dir, transform=None):
self.img_dir = img_dir
self.mask_dir = mask_dir
self.transform = transform
self.img_files = os.listdir(img_dir)
def __len__(self):
return len(self.img_files)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_files[idx])
mask_path = os.path.join(self.mask_dir, self.img_files[idx].replace(".jpg", ".png"))
img = Image.open(img_path).convert('RGB')
mask = Image.open(mask_path).convert('L')
if self.transform:
img = self.transform(img)
return img, mask
# Set random seed for reproducibility
random.seed(1234)
np.random.seed(1234)
torch.manual_seed(1234)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define the transforms for data augmentation
train_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(20),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
test_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Create the datasets and dataloaders
train_dataset = SegmentationDataset(img_dir="./train_images", mask_dir="./train_masks", transform=train_transform)
test_dataset = SegmentationDataset(img_dir="./test_images", mask_dir="./test_masks", transform=test_transform)
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=True)
# Initialize the model and optimizer
model = SegNet(in_channels=3, out_channels=1).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCEWithLogitsLoss()
# Define the training function
def train(model, dataloader, criterion, optimizer, device):
model.train()
running_loss = 0.0
for i, data in enumerate(dataloader):
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
return running_loss / len(dataloader)
# Define the validation function
def validate(model, dataloader, criterion, device):
model.eval()
running_loss = 0.0
with torch.no_grad():
for i, data in enumerate(dataloader):
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
running_loss += loss.item()
return running_loss / len(dataloader)
# Train the model
num_epochs = 10
for epoch in range(num_epochs):
train_loss = train(model, train_dataloader, criterion, optimizer, device)
val_loss = validate(model, val_dataloader, criterion, device)
print("Epoch [{}/{}], Train Loss: {:.4f}, Val Loss: {:.4f}".format(epoch+1, num_epochs, train_loss, val_loss))
# Test the model
model.eval()
test_loss = 0.0
with torch.no_grad():
for i, data in enumerate(test_dataloader):
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
test_loss += loss.item()
test_loss /= len(test_dataloader)
print("Test Loss: {:.4f}".format(test_loss))
# Save the model
torch.save(model.state_dict(), "segnet_model.pth")
# Make predictions on new images
def predict(image_path, model_path):
# Load the model
model = SegNet(in_channels=3, out_channels=1).to(device)
model.load_state_dict(torch.load(model_path))
model.eval()
# Load and preprocess the image
image = Image.open(image_path).convert('RGB')
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
image = transform(image)
image = image.unsqueeze(0).to(device)
# Make the prediction
with torch.no_grad():
output = model(image)
output = torch.sigmoid(output)
output = output.cpu().numpy()
output = output.squeeze()
# Threshold the prediction
threshold = 0.5
output[output >= threshold] = 255
output[output < threshold] = 0
# Convert to PIL and save
output = Image.fromarray(output.astype(np.uint8), mode="L")
output.save("predicted_mask.png")
# Example usage:
predict("new_image.jpg", "segnet_model.pth")
```
注意:在运行代码之前,请确保已经按照以下方式组织好数据:
```
.
├── train_images
│ ├── image1.jpg
│ ├── image2.jpg
│ └── ...
├── train_masks
│ ├── image1.png
│ ├── image2.png
│ └── ...
├── test_images
│ ├── image1.jpg
│ ├── image2.jpg
│ └── ...
└── test_masks
├── image1.png
├── image2.png
└── ...
```
其中,`train_images` 和 `train_masks` 文件夹存放训练集的图像和对应的语义分割标注,`test_images` 和 `test_masks` 文件夹存放测试集的图像和对应的语义分割标注。图像和标注的文件名必须相同,只是后缀不同。例如,`image1.jpg` 的标注应该是 `image1.png`。
此外,还需要将 `new_image.jpg` 文件放在与代码文件相同的目录下。这个文件是用来测试模型预测能力的,你可以将其替换为其他你想要测试的图像文件。
阅读全文