pytorch 图像风格转换代码
时间: 2023-07-29 14:06:01 浏览: 180
下面是一个简单的 PyTorch 图像风格转换的代码示例:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
import copy
from PIL import Image
# 定义参数
content_img_path = 'content.jpg'
style_img_path = 'style.jpg'
output_img_path = 'output.jpg'
image_size = 512
style_weight = 1000000
content_weight = 1
# 转换函数
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 加载图片并进行转换
def load_image(image_path, transform=None, max_size=None, shape=None):
image = Image.open(image_path)
if max_size:
scale = max_size / max(image.size)
size = round(scale * image.size[0]), round(scale * image.size[1])
image = image.resize(size, Image.LANCZOS)
if shape:
image = image.resize(shape, Image.LANCZOS)
if transform:
image = transform(image).unsqueeze(0)
return image
# 定义模型
class ContentLoss(nn.Module):
def __init__(self, target):
super(ContentLoss, self).__init__()
self.target = target.detach()
def forward(self, input):
self.loss = nn.functional.mse_loss(input, self.target)
return input
class StyleLoss(nn.Module):
def __init__(self, target_feature):
super(StyleLoss, self).__init__()
self.target = gram_matrix(target_feature).detach()
def forward(self, input):
G = gram_matrix(input)
self.loss = nn.functional.mse_loss(G, self.target)
return input
def gram_matrix(input):
batch_size, channel, height, width = input.size()
features = input.view(batch_size * channel, height * width)
G = torch.mm(features, features.t())
return G.div(batch_size * channel * height * width)
class Normalization(nn.Module):
def __init__(self, mean, std):
super(Normalization, self).__init__()
self.mean = torch.tensor(mean).view(-1, 1, 1)
self.std = torch.tensor(std).view(-1, 1, 1)
def forward(self, img):
return (img - self.mean) / self.std
cnn = models.vgg19(pretrained=True).features.eval()
# 将模型中间层的输出作为损失函数的输入
content_layers = ['conv_4']
style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
content_losses = []
style_losses = []
model = nn.Sequential(Normalization(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
i = 0
for layer in cnn.children():
if isinstance(layer, nn.Conv2d):
i += 1
name = 'conv_{}'.format(i)
elif isinstance(layer, nn.ReLU):
name = 'relu_{}'.format(i)
layer = nn.ReLU(inplace=False)
elif isinstance(layer, nn.MaxPool2d):
name = 'pool_{}'.format(i)
elif isinstance(layer, nn.BatchNorm2d):
name = 'bn_{}'.format(i)
else:
raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))
model.add_module(name, layer)
if name in content_layers:
target = model(load_image(content_img_path)).detach()
content_loss = ContentLoss(target)
model.add_module("content_loss_{}".format(i), content_loss)
content_losses.append(content_loss)
if name in style_layers:
target_feature = model(load_image(style_img_path)).detach()
style_loss = StyleLoss(target_feature)
model.add_module("style_loss_{}".format(i), style_loss)
style_losses.append(style_loss)
# 反向传播,优化参数
input_img = load_image(content_img_path).clone()
optimizer = optim.LBFGS([input_img.requires_grad_()])
run = [0]
while run[0] <= 300:
def closure():
input_img.data.clamp_(0, 1)
optimizer.zero_grad()
model(input_img)
style_score = 0
content_score = 0
for sl in style_losses:
style_score += sl.loss
for cl in content_losses:
content_score += cl.loss
style_score *= style_weight
content_score *= content_weight
loss = style_score + content_score
loss.backward()
run[0] += 1
if run[0] % 50 == 0:
print("run {}".format(run))
print("Style Loss: {:4f} Content Loss: {:4f}".format(
style_score.item(), content_score.item()))
print()
return style_score + content_score
optimizer.step(closure)
# 保存生成的图片
output_img = input_img.data.clamp_(0, 1)
output_image = transforms.ToPILImage()(output_img.squeeze(0))
output_image.save(output_img_path)
```
这个代码示例使用了 VGG-19 模型,定义了内容损失、风格损失和正则化层,并使用 LBFGS 优化器进行反向传播。你可以根据自己的需求进行参数调整和模型选择。
阅读全文