给出用pytorch实现canny边缘检测算子,并对图像进行边缘检测的python代码
时间: 2024-06-09 15:11:38 浏览: 259
好的,以下是用PyTorch实现Canny边缘检测算法的代码:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class Canny(nn.Module):
def __init__(self, sigma=1, kernel_size=5, weak_pixel=75, strong_pixel=255, low_threshold=0.05, high_threshold=0.15):
super(Canny, self).__init__()
self.sigma = sigma
self.kernel_size = kernel_size
self.weak_pixel = weak_pixel
self.strong_pixel = strong_pixel
self.low_threshold = low_threshold
self.high_threshold = high_threshold
self.padding = int(kernel_size/2)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.conv1 = nn.Conv2d(1, 1, kernel_size, stride=1, padding=self.padding, bias=False)
self.conv2 = nn.Conv2d(1, 1, kernel_size, stride=1, padding=self.padding, bias=False)
# Gaussian
self.conv1.weight.data = self.gaussian_kernel(self.kernel_size, self.sigma)
self.conv1.weight.requires_grad = False
# Sobel operators
self.conv2.weight.data = torch.tensor([[[[-1, 0, 1],
[-2, 0, 2],
[-1, 0, 1]]]], dtype=torch.float32)
self.conv2.weight.requires_grad = False
def gaussian_kernel(self, size, sigma=1):
kernel = torch.zeros([size, size])
center = size//2
for i in range(size):
for j in range(size):
x = i - center
y = j - center
kernel[i,j] = torch.exp(-(x**2 + y**2)/(2*sigma**2))
kernel = kernel / torch.sum(kernel)
kernel = kernel.view(1, 1, size, size)
return kernel.to(self.device)
def non_maximum_suppression(self, img, D):
M, N = img.shape
Z = torch.zeros(M,N, dtype=torch.float32).to(self.device)
angle = D * 180. / np.pi
angle[angle < 0] += 180
for i in range(1,M-1):
for j in range(1,N-1):
q = 255
r = 255
#angle 0
if (0 <= angle[i,j] < 22.5) or (157.5 <= angle[i,j] <= 180):
q = img[i, j+1]
r = img[i, j-1]
#angle 45
elif (22.5 <= angle[i,j] < 67.5):
q = img[i+1, j-1]
r = img[i-1, j+1]
#angle 90
elif (67.5 <= angle[i,j] < 112.5):
q = img[i+1, j]
r = img[i-1, j]
#angle 135
elif (112.5 <= angle[i,j] < 157.5):
q = img[i-1, j-1]
r = img[i+1, j+1]
if (img[i,j] >= q) and (img[i,j] >= r):
Z[i,j] = img[i,j]
else:
Z[i,j] = 0
return Z
def hysteresis(self, img, low_threshold=0.05, high_threshold=0.15):
high_threshold = img.max() * high_threshold;
low_threshold = high_threshold * low_threshold;
M, N = img.shape
res = torch.zeros(M,N, dtype=torch.float32).to(self.device)
weak = torch.tensor(self.weak_pixel, dtype=torch.float32).to(self.device)
strong = torch.tensor(self.strong_pixel, dtype=torch.float32).to(self.device)
strong_i, strong_j = torch.where(img >= high_threshold)
weak_i, weak_j = torch.where((img <= high_threshold) & (img >= low_threshold))
res[strong_i, strong_j] = strong
res[weak_i, weak_j] = weak
# 8-connected component
edge_i, edge_j = torch.where((img <= high_threshold) & (img >= low_threshold))
for i, j in zip(edge_i, edge_j):
if ((res[max(0, i-1):min(M, i+2), max(0, j-1):min(N, j+2)] == strong).any()):
res[i,j] = strong
else:
res[i,j] = 0
return res
def forward(self, x):
x = F.pad(x, (self.padding, self.padding, self.padding, self.padding), mode='reflect')
x = self.conv1(x)
x = self.conv2(x)
Gx = x[:,:, :-1, :-1]
Gy = x[:,:, :-1, 1:]
gradient = torch.sqrt(torch.pow(Gx, 2) + torch.pow(Gy, 2))
gradient = gradient / gradient.max()
theta = torch.atan2(Gy, Gx)
theta[theta<0] = np.pi + theta[theta<0]
theta[theta>(np.pi*3/4)] -= np.pi
non_maximum = self.non_maximum_suppression(gradient, theta)
res = self.hysteresis(non_maximum, self.low_threshold, self.high_threshold)
return res
```
这里使用了PyTorch实现的卷积操作和一些图像处理的函数,包括高斯核生成函数、非极大值抑制函数和滞后阈值函数等。需要注意的是,这里使用了PyTorch的GPU加速,如果没有GPU也可以将代码中的`to(self.device)`去掉。
以下是对一张图像进行Canny边缘检测的示例代码:
```python
import cv2
import numpy as np
# load image
img = cv2.imread('test.jpg', cv2.IMREAD_GRAYSCALE).astype(np.float32)
# normalize image
img = img / 255.0
# apply canny edge detection
canny = Canny().to(device)
edges = canny(torch.from_numpy(img).unsqueeze(0).unsqueeze(0).to(device)).squeeze().cpu().numpy()
# show result
cv2.imshow('original', img)
cv2.imshow('canny', edges)
cv2.waitKey(0)
cv2.destroyAllWindows()
```
这里使用了OpenCV库读取图像,并将图像归一化为[0,1]范围内的浮点数。然后将图像转换为PyTorch张量,并将通道数和批次数扩展为1。最后调用Canny模型进行边缘检测,并将结果转换为NumPy数组并可视化。
阅读全文