gcanet图像去雾python实现
时间: 2024-01-10 07:02:33 浏览: 130
GCANet是一种图像去雾算法,可以使用Python实现。以下是一个实现的示例代码:
```python
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import os
# 定义GCANet模型
class GCANet(nn.Module):
def __init__(self):
super(GCANet, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.conv4 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
self.conv6 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
self.conv7 = nn.Conv2d(512, 1024, kernel_size=3, padding=1)
self.conv8 = nn.Conv2d(1024, 2048, kernel_size=3, padding=1)
self.conv9 = nn.Conv2d(2048, 1024, kernel_size=3, padding=1)
self.conv10 = nn.Conv2d(1024, 512, kernel_size=3, padding=1)
self.conv11 = nn.Conv2d(512, 256, kernel_size=3, padding=1)
self.conv12 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
self.conv13 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
self.conv14 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
self.conv15 = nn.Conv2d(32, 16, kernel_size=3, padding=1)
self.conv16 = nn.Conv2d(16, 3, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.sigmoid = nn.Sigmoid()
self.gamma = nn.Parameter(torch.zeros(1))
self.beta = nn.Parameter(torch.zeros(1))
def forward(self, x):
x1 = self.relu(self.conv1(x))
x = self.pool(x1)
x2 = self.relu(self.conv2(x))
x = self.pool(x2)
x3 = self.relu(self.conv3(x))
x = self.pool(x3)
x4 = self.relu(self.conv4(x))
x = self.pool(x4)
x5 = self.relu(self.conv5(x))
x = self.pool(x5)
x6 = self.relu(self.conv6(x))
x = self.pool(x6)
x7 = self.relu(self.conv7(x))
x = self.pool(x7)
x8 = self.relu(self.conv8(x))
x = self.upsample(x8)
x = torch.cat([x, x7], dim=1)
x = self.relu(self.conv9(x))
x = self.upsample(x)
x = torch.cat([x, x6], dim=1)
x = self.relu(self.conv10(x))
x = self.upsample(x)
x = torch.cat([x, x5], dim=1)
x = self.relu(self.conv11(x))
x = self.upsample(x)
x = torch.cat([x, x4], dim=1)
x = self.relu(self.conv12(x))
x = self.upsample(x)
x = torch.cat([x, x3], dim=1)
x = self.relu(self.conv13(x))
x = self.upsample(x)
x = torch.cat([x, x2], dim=1)
x = self.relu(self.conv14(x))
x = self.upsample(x)
x = torch.cat([x, x1], dim=1)
x = self.relu(self.conv15(x))
x = self.conv16(x)
x = self.gamma * x + self.beta
x = self.sigmoid(x)
return x
# 定义数据集类
class MyDataset(Dataset):
def __init__(self, img_path):
self.img_path = img_path
self.img_list = os.listdir(img_path)
def __len__(self):
return len(self.img_list)
def __getitem__(self, idx):
img_name = self.img_list[idx]
img = cv2.imread(os.path.join(self.img_path, img_name))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img.astype(np.float32) / 255.0
img_hazy = img.copy()
img_hazy[:, :, 0] = img_hazy[:, :, 0] * 0.3 + img_hazy[:, :, 1] * 0.59 + img_hazy[:, :, 2] * 0.11
img_hazy[:, :, 1] = img_hazy[:, :, 0]
img_hazy[:, :, 2] = img_hazy[:, :, 0]
img_hazy = img_hazy + np.random.randn(*img_hazy.shape) * 0.1
img_hazy = np.clip(img_hazy, 0, 1)
return torch.from_numpy(img_hazy.transpose(2, 0, 1)), torch.from_numpy(img.transpose(2, 0, 1))
# 训练模型
def train():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = MyDataset('data')
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
model = GCANet().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(50):
running_loss = 0.0
for i, (img_hazy, img_gt) in enumerate(dataloader):
img_hazy, img_gt = img_hazy.to(device), img_gt.to(device)
optimizer.zero_grad()
output = model(img_hazy)
loss = criterion(output, img_gt)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 10 == 9:
print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 10))
running_loss = 0.0
torch.save(model.state_dict(), 'gcanet.pth')
# 使用模型进行去雾
def test():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCANet().to(device)
model.load_state_dict(torch.load('gcanet.pth'))
img_hazy = cv2.imread('hazy.jpg')
img_hazy = cv2.cvtColor(img_hazy, cv2.COLOR_BGR2RGB)
img_hazy = img_hazy.astype(np.float32) / 255.0
img_hazy[:, :, 0] = img_hazy[:, :, 0] * 0.3 + img_hazy[:, :, 1] * 0.59 + img_hazy[:, :, 2] * 0.11
img_hazy[:, :, 1] = img_hazy[:, :, 0]
img_hazy[:, :, 2] = img_hazy[:, :, 0]
img_hazy = np.expand_dims(img_hazy.transpose(2, 0, 1), axis=0)
img_hazy = torch.from_numpy(img_hazy).to(device)
with torch.no_grad():
img_dehaze = model(img_hazy)
img_dehaze = img_dehaze.cpu().numpy()[0].transpose(1, 2, 0)
img_dehaze[:, :, 0] = img_dehaze[:, :, 0] * 1 + img_hazy[0, 0].cpu().numpy()
img_dehaze[:, :, 1] = img_dehaze[:, :, 1] * 1 + img_hazy[0, 1].cpu().numpy()
img_dehaze[:, :, 2] = img_dehaze[:, :, 2] * 1 + img_hazy[0, 2].cpu().numpy()
img_dehaze = np.clip(img_dehaze, 0, 1)
img_dehaze = cv2.cvtColor((img_dehaze * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)
cv2.imwrite('dehaze.jpg', img_dehaze)
if __name__ == '__main__':
train()
test()
```
需要注意的是,这里的训练数据集是在`data`文件夹中,文件夹中包含了一些有雾的图像和对应的无雾图像。训练代码会将这些数据集读入,并对模型进行训练。测试代码则是对一张有雾的图像进行去雾,去雾结果保存为`dehaze.jpg`。
阅读全文