采用 PSNR(Peak Signal-to-Noise Ratio)和 SSIM(Structural Similarity Index)两个指标进行评价一个算法,这个算法的功能是将黑白人像照片进行彩色化,给出完整代码
时间: 2024-05-16 13:12:29 浏览: 133
这里是一个简单的彩色化算法,使用PyTorch实现,并且使用PSNR和SSIM两个指标进行评价:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from PIL import Image
from skimage.measure import compare_psnr, compare_ssim
# 定义彩色化模型
class ColorizationModel(nn.Module):
def __init__(self):
super(ColorizationModel, self).__init__()
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.conv4 = nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1)
self.conv5 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
self.conv6 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.conv7 = nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1)
self.conv8 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
self.conv9 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv10 = nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1)
self.conv11 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv12 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv13 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv14 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv15 = nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1)
self.conv16 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv17 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv18 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv19 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv20 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv21 = nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1)
self.conv22 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv23 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv24 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv25 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv26 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv27 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
self.conv28 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.conv29 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.conv30 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
self.conv31 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
self.conv32 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
self.conv33 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.conv34 = nn.Conv2d(64, 2, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = nn.functional.relu(self.conv1(x))
x = nn.functional.relu(self.conv2(x))
x = nn.functional.relu(self.conv3(x))
x = nn.functional.relu(self.conv4(x))
x = nn.functional.relu(self.conv5(x))
x = nn.functional.relu(self.conv6(x))
x = nn.functional.relu(self.conv7(x))
x = nn.functional.relu(self.conv8(x))
x = nn.functional.relu(self.conv9(x))
x = nn.functional.relu(self.conv10(x))
x = nn.functional.relu(self.conv11(x))
x = nn.functional.relu(self.conv12(x))
x = nn.functional.relu(self.conv13(x))
x = nn.functional.relu(self.conv14(x))
x = nn.functional.relu(self.conv15(x))
x = nn.functional.relu(self.conv16(x))
x = nn.functional.relu(self.conv17(x))
x = nn.functional.relu(self.conv18(x))
x = nn.functional.relu(self.conv19(x))
x = nn.functional.relu(self.conv20(x))
x = nn.functional.relu(self.conv21(x))
x = nn.functional.relu(self.conv22(x))
x = nn.functional.relu(self.conv23(x))
x = nn.functional.relu(self.conv24(x))
x = nn.functional.relu(self.conv25(x))
x = nn.functional.relu(self.conv26(x))
x = nn.functional.relu(self.conv27(x))
x = nn.functional.relu(self.conv28(x))
x = nn.functional.relu(self.conv29(x))
x = nn.functional.relu(self.conv30(x))
x = nn.functional.relu(self.conv31(x))
x = nn.functional.relu(self.conv32(x))
x = nn.functional.tanh(self.conv33(x))
x = nn.functional.sigmoid(self.conv34(x))
return x
# 定义训练函数
def train(model, train_loader, criterion, optimizer):
model.train()
for inputs, targets in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
# 定义测试函数
def test(model, test_loader):
model.eval()
psnr_total = 0
ssim_total = 0
for inputs, targets in test_loader:
outputs = model(inputs)
psnr_total += compare_psnr(targets[0].numpy().transpose(1, 2, 0), outputs[0].detach().numpy().transpose(1, 2, 0), data_range=1)
ssim_total += compare_ssim(targets[0].numpy().transpose(1, 2, 0), outputs[0].detach().numpy().transpose(1, 2, 0), multichannel=True)
psnr_avg = psnr_total / len(test_loader)
ssim_avg = ssim_total / len(test_loader)
return psnr_avg, ssim_avg
# 定义数据集加载函数
def load_data(data_dir, batch_size):
train_transforms = transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
test_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.ImageFolder(root=data_dir + '/train', transform=train_transforms)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = datasets.ImageFolder(root=data_dir + '/test', transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
return train_loader, test_loader
# 设置超参数
data_dir = './data'
batch_size = 64
lr = 0.001
num_epochs = 10
# 加载数据
train_loader, test_loader = load_data(data_dir, batch_size)
# 初始化模型、损失函数和优化器
model = ColorizationModel()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
# 训练模型
for epoch in range(num_epochs):
train(model, train_loader, criterion, optimizer)
psnr, ssim = test(model, test_loader)
print('Epoch [{}/{}], PSNR: {:.4f}, SSIM: {:.4f}'.format(epoch+1, num_epochs, psnr, ssim))
# 保存模型
torch.save(model.state_dict(), 'colorization_model.pth')
```
注:此代码仅为示例,实际应用中还可以进行更多的优化和改进。
阅读全文