用python实现灰白图像上色
时间: 2023-09-19 09:09:20 浏览: 100
实现灰白图像上色可以使用深度学习的方法,具体可以使用卷积神经网络(Convolutional Neural Network, CNN)实现。以下是基于 PyTorch 框架的实现代码,假设原图像是灰度图像,需要将其上色为彩色图像。
首先导入必要的库:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
```
定义数据集类,读取灰度图像和对应的彩色图像:
```python
class ImageDataset(Dataset):
def __init__(self, data_path, transform=None):
self.data_path = data_path
self.transform = transform
self.image_filenames = os.listdir(os.path.join(self.data_path, 'gray'))
def __len__(self):
return len(self.image_filenames)
def __getitem__(self, index):
gray_image = Image.open(os.path.join(self.data_path, 'gray', self.image_filenames[index]))
color_image = Image.open(os.path.join(self.data_path, 'color', self.image_filenames[index]))
if self.transform:
gray_image = self.transform(gray_image)
color_image = self.transform(color_image)
return gray_image, color_image
```
定义卷积神经网络模型,包含一个预处理层(将灰度图像转换为 RGB 图像)、若干个卷积层和反卷积层(上采样)、一个输出层(将图像的像素值限制在 [0, 1] 区间内):
```python
class ColorNet(nn.Module):
def __init__(self):
super(ColorNet, self).__init__()
self.preprocess = nn.Sequential(
nn.Conv2d(1, 3, kernel_size=1, stride=1),
nn.BatchNorm2d(3),
nn.ReLU(inplace=True)
)
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True)
)
self.conv3 = nn.Sequential(
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True)
)
self.conv4 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True)
)
self.conv5 = nn.Sequential(
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True)
)
self.conv6 = nn.Sequential(
nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True)
)
self.conv7 = nn.Sequential(
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True)
)
self.deconv1 = nn.Sequential(
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True)
)
self.deconv2 = nn.Sequential(
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True)
)
self.deconv3 = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.deconv4 = nn.Sequential(
nn.ConvTranspose2d(64, 3, kernel_size=3, stride=1, padding=1),
nn.Tanh()
)
def forward(self, x):
x = self.preprocess(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
x = self.conv6(x)
x = self.conv7(x)
x = self.deconv1(x)
x = self.deconv2(x)
x = self.deconv3(x)
x = self.deconv4(x)
return x
```
定义训练函数:
```python
def train(model, criterion, optimizer, dataloader, device):
model.train()
running_loss = 0.0
for i, data in enumerate(dataloader):
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 10 == 9: # print every 10 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 10))
running_loss = 0.0
```
定义测试函数,用于测试模型在测试集上的表现:
```python
def test(model, criterion, dataloader, device):
model.eval()
test_loss = 0.0
with torch.no_grad():
for data in dataloader:
inputs, labels = data[0].to(device), data[1].to(device)
outputs = model(inputs)
test_loss += criterion(outputs, labels).item()
test_loss /= len(dataloader.dataset)
print('Test Loss: {:.6f}'.format(test_loss))
```
定义训练参数和训练过程:
```python
data_transforms = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
train_dataset = ImageDataset('data/train', transform=data_transforms)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_dataset = ImageDataset('data/test', transform=data_transforms)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = ColorNet().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0002)
for epoch in range(50):
train(model, criterion, optimizer, train_dataloader, device)
test(model, criterion, test_dataloader, device)
```
最后,使用训练好的模型对一张灰度图像进行上色:
```python
gray_image = Image.open('test_gray.png').convert('L')
gray_image_transformed = data_transforms(gray_image).unsqueeze(0).to(device)
color_image = model(gray_image_transformed)
color_image = color_image.squeeze(0).cpu().detach().numpy().transpose(1, 2, 0)
color_image = (color_image + 1) / 2 # 将像素值从 [-1, 1] 转换为 [0, 1]
Image.fromarray((color_image * 255).astype('uint8')).show()
```
其中 `'test_gray.png'` 是一张灰度图像的文件名。
阅读全文