pyqt5实现u2net图像分割界面pytorch代码
时间: 2023-10-01 22:07:52 浏览: 469
下面是一个简单的PyQt5界面实现U2Net图像分割的例子,使用PyTorch实现。
```
import sys
import os
import numpy as np
from PIL import Image
from PyQt5.QtWidgets import QApplication, QMainWindow, QLabel, QPushButton, QFileDialog
from PyQt5.QtGui import QPixmap
import torch
import torchvision.transforms as transforms
from model.u2net import U2NET
class MainWindow(QMainWindow):
def __init__(self):
super().__init__()
# 创建UI界面
self.initUI()
# 加载模型
self.model = U2NET()
self.model.load_state_dict(torch.load("u2net.pth", map_location=torch.device('cpu')))
self.model.eval()
def initUI(self):
# 设置窗口标题和大小
self.setWindowTitle("U2Net Image Segmentation")
self.setGeometry(100, 100, 800, 600)
# 创建标签和按钮
self.label = QLabel(self)
self.label.setGeometry(25, 50, 750, 450)
self.label.setStyleSheet("border: 1px solid black;")
self.button = QPushButton("Select Image", self)
self.button.setGeometry(25, 525, 150, 50)
self.button.clicked.connect(self.selectImage)
self.button2 = QPushButton("Segment Image", self)
self.button2.setGeometry(200, 525, 150, 50)
self.button2.clicked.connect(self.segmentImage)
def selectImage(self):
# 打开文件对话框,选择要处理的图像
options = QFileDialog.Options()
options |= QFileDialog.DontUseNativeDialog
fileName, _ = QFileDialog.getOpenFileName(self,"QFileDialog.getOpenFileName()", "","All Files (*);;Images (*.png *.jpg *.jpeg)", options=options)
if fileName:
# 加载图像并显示在标签上
pixmap = QPixmap(fileName)
pixmap = pixmap.scaled(750, 450)
self.label.setPixmap(pixmap)
# 将图像转换为PyTorch tensor格式
self.input_image = Image.open(fileName).convert("RGB")
self.transform = transforms.Compose([transforms.Resize((320, 320)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
self.input_tensor = self.transform(self.input_image).unsqueeze(0)
def segmentImage(self):
# 对选择的图像进行分割
with torch.no_grad():
output_tensor = self.model(self.input_tensor)
# 将输出转换为PIL Image格式
output_tensor = output_tensor.squeeze().numpy()
output_tensor = np.where(output_tensor > 0.5, 1.0, 0.0)
output_image = Image.fromarray((output_tensor * 255).astype(np.uint8)).convert("L")
# 显示分割结果
output_pixmap = QPixmap.fromImage(ImageQt(output_image))
output_pixmap = output_pixmap.scaled(750, 450)
self.label.setPixmap(output_pixmap)
if __name__ == "__main__":
# 创建应用程序和主窗口
app = QApplication(sys.argv)
mainWindow = MainWindow()
mainWindow.show()
sys.exit(app.exec_())
```
在上面的代码中,我们首先创建了一个`MainWindow`类,它继承自`QMainWindow`类,并重写了`initUI`方法来创建UI界面。在`initUI`方法中,我们创建了一个标签和两个按钮,其中一个用于选择图像,另一个用于对图像进行分割。
在选择图像按钮的回调函数`selectImage`中,我们使用`QFileDialog`打开一个文件对话框,让用户选择要处理的图像。然后,我们使用`PIL`库来加载图像,并将其转换为PyTorch tensor格式。在转换过程中,我们使用了`transforms`模块来对图像进行缩放、标准化等预处理操作。
在对图像进行分割的按钮回调函数`segmentImage`中,我们将输入张量传递给已加载的U2Net模型,并得到输出张量。然后,我们将输出张量转换为PIL Image格式,并将其显示在标签上。在转换过程中,我们使用了NumPy来将输出张量转换为二值图像,使用`PIL`库将其转换为灰度图像,并使用`QPixmap`将其转换为Qt图像格式。
最后,我们在`__main__`函数中创建了应用程序和主窗口,并调用`show`方法来显示窗口。
阅读全文