Defines the PatchGAN discriminator with the specified arguments. class NLayerDiscriminator(nn.Module): def init(self, input_nc=3, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, use_parallel=True): super(NLayerDiscriminator, self).init() self.use_parallel = use_parallel if type(norm_layer) == functools.partial: use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d self.conv1 = nn.Conv2d(input_nc, ndf, kernel_size=3, padding=1) self.conv_offset1 = nn.Conv2d(ndf, 18, kernel_size=3, stride=1, padding=1) init_offset1 = torch.Tensor(np.zeros([18, ndf, 3, 3])) self.conv_offset1.weight = torch.nn.Parameter(init_offset1) # 初始化为0 self.conv_mask1 = nn.Conv2d(ndf, 9, kernel_size=3, stride=1, padding=1) init_mask1 = torch.Tensor(np.zeros([9, ndf, 3, 3]) + np.array([0.5])) self.conv_mask1.weight = torch.nn.Parameter(init_mask1) # 初始化为0.5 kw = 4 padw = int(np.ceil((kw-1)/2)) nf_mult = 1 for n in range(1, n_layers): nf_mult_prev = nf_mult nf_mult = min(2n, 8) self.sequence2 = [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] nf_mult_prev = nf_mult nf_mult = min(2n_layers, 8) self.sequence2 += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] self.sequence2 += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] if use_sigmoid: self.sequence2 += [nn.Sigmoid()] def forward(self, input): input = self.conv1(input) offset1 = self.conv_offset1(input) mask1 = torch.sigmoid(self.conv_mask1(input)) sequence1 = [ torchvision.ops.deform_conv2d(input=input, offset=offset1, weight=self.conv1.weight, mask=mask1, padding=(1, 1)) 上述代码中出现错误:RuntimeError: Expected weight_c.size(1) * n_weight_grps == input_c.size(1) to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.),请问如何解决,给出修改后的代码

时间: 2023-08-22 08:04:22 浏览: 24
这个错误通常是由于卷积层的通道数和输入数据的通道数不匹配导致的。根据错误信息可以看出,这个错误是在第 49 行的 deform_conv2d 函数中出现的。根据代码,这个函数是使用可变形卷积进行特征提取的,因此需要注意卷积层的通道数和输入数据的通道数是否匹配。具体来说,需要检查卷积层的输入通道数和权重矩阵的通道数是否匹配。 下面是修改后的代码,主要是修改了可变形卷积的输入参数,确保通道数匹配: ```python class NLayerDiscriminator(nn.Module): def __init__(self, input_nc=3, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, use_parallel=True): super(NLayerDiscriminator, self).__init__() self.use_parallel = use_parallel if type(norm_layer) == functools.partial: use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d self.conv1 = nn.Conv2d(input_nc, ndf, kernel_size=3, padding=1) self.conv_offset1 = nn.Conv2d(ndf, 18, kernel_size=3, stride=1, padding=1) init_offset1 = torch.Tensor(np.zeros([18, ndf, 3, 3])) self.conv_offset1.weight = torch.nn.Parameter(init_offset1) # 初始化为0 self.conv_mask1 = nn.Conv2d(ndf, 9, kernel_size=3, stride=1, padding=1) init_mask1 = torch.Tensor(np.zeros([9, ndf, 3, 3]) + np.array([0.5])) self.conv_mask1.weight = torch.nn.Parameter(init_mask1) # 初始化为0.5 kw = 4 padw = int(np.ceil((kw-1)/2)) nf_mult = 1 self.sequence2 = [] for n in range(1, n_layers): nf_mult_prev = nf_mult nf_mult = min(2**n, 8) self.sequence2 += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] nf_mult_prev = nf_mult nf_mult = min(2**n_layers, 8) self.sequence2 += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] self.sequence2 += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] if use_sigmoid: self.sequence2 += [nn.Sigmoid()] def forward(self, input): input = self.conv1(input) offset1 = self.conv_offset1(input) mask1 = torch.sigmoid(self.conv_mask1(input)) sequence1 = [ torchvision.ops.deform_conv2d( input=input, offset=offset1, weight=self.conv1.weight, mask=mask1, padding=1, stride=1, dilation=1 ) ] for layer in self.sequence2: sequence1.append(layer(sequence1[-1])) return sequence1[-1] ``` 需要注意的是,由于 deform_conv2d 函数的输入参数有很多,需要根据实际情况进行调整。

相关推荐

可以使用modbus_tk.modbus_rtu.master._compute_crc函数计算Modbus RTU消息的校验码,并使用struct.pack()函数将消息打包为字节流。以下是代码示例: python import serial import time import struct import modbus_tk.defines as cst from modbus_tk import modbus_rtu def pack_modbus_rtu_msg(address, function_code, starting_address, quantity_of_x, output_value=None): # 构造Modbus RTU消息 if output_value is not None: data_len = 2 + 2*len(output_value) else: data_len = 2 msg = struct.pack(">BBHH", address, function_code, starting_address, quantity_of_x) if output_value is not None: for val in output_value: msg += struct.pack(">H", val) # 计算校验码(CRC) crc = modbus_rtu.master._compute_crc(msg) msg += struct.pack(">H", crc) return msg def main(): # 打开两个串口 com1 = serial.Serial(port='com2', baudrate=38400, bytesize=8, parity='N', stopbits=1) com2 = serial.Serial(port='com3', baudrate=38400, bytesize=8, parity='N', stopbits=1) # 创建Modbus RTU主机(master) master = modbus_rtu.RtuMaster(com1) master.set_timeout(1.0) master.set_verbose(True) # 读取保持寄存器数据 red1 = master.execute(1, cst.READ_HOLDING_REGISTERS, 0, 9)[0] new = int(red1 * 0.98) master.execute(1, function_code=cst.WRITE_MULTIPLE_REGISTERS, starting_address=0, output_value=[new]) # 打包Modbus RTU消息并写入com2串口 msg = pack_modbus_rtu_msg(1, cst.READ_HOLDING_REGISTERS, 0, 9) com2.write(msg) # 读取输入寄存器数据 input_regs = master.execute(1, cst.READ_INPUT_REGISTERS, 0, 9) print(input_regs, 1) # 读取输入线圈数据 input_bits = master.execute(1, cst.READ_COILS, 0, 9) print(input_bits, 2) # 读取输出线圈数据 output_bits = master.execute(1, cst.READ_DISCRETE_INPUTS, 0, 9) print(output_bits, 3)

将QT += core QT -= gui CONFIG += c++11 TARGET = UavRectifyLoadLIb CONFIG += console CONFIG -= app_bundle TEMPLATE = app SOURCES += main.cpp # The following define makes your compiler emit warnings if you use # any feature of Qt which as been marked deprecated (the exact warnings # depend on your compiler). Please consult the documentation of the # deprecated API in order to know how to port your code away from it. DEFINES += QT_DEPRECATED_WARNINGS win32{ CONFIG(debug, debug|release){ DESTDIR = $$PWD/../../../../RasterManager/bin/Debug } else{ DESTDIR = $$PWD/../../../../RasterManager/bin/release } INCLUDEPATH += $$PWD/../../../include/gdal1101 DEPENDPATH += $$PWD/../../../include/gdal1101 } else{ CONFIG(debug, debug|release){ DESTDIR = $$PWD/../../../product/release32 } else{ DESTDIR = $$PWD/../../../product/release32 } } # You can also make your code fail to compile if you use deprecated APIs. # In order to do so, uncomment the following line. # You can also select to disable deprecated APIs only up to a certain version of Qt. #DEFINES += QT_DISABLE_DEPRECATED_BEFORE=0x060000 # disables all the APIs deprecated before Qt 6.0.0 unix:!macx: LIBS += -L$$PWD/../../../product/release32/ -lUAVAutoRectifyMt -lUAVAutoRectify -lUAVAutoRectifyFi INCLUDEPATH += $$PWD/../include DEPENDPATH += $$PWD/../include unix:!macx: LIBS += -L$$PWD/../../../lib/opencvf249/ -lopencv_core unix:!macx: LIBS += -L$$PWD/../../../lib/opencvf249/ -lopencv_highgui unix:!macx: LIBS += -L$$PWD/../../../lib/opencvf249/ -lopencv_imgproc INCLUDEPATH += $$PWD/../../../lib/opencvf249 DEPENDPATH += $$PWD/../../../lib/opencvf249 unix:!macx: LIBS += -L$$PWD/../../../../../../../usr/local/lib/ -lopencv_core #unix:!macx: LIBS += -L$$PWD/../../../../../../../usr/local/lib/ -lopencv_highgui unix:!macx: LIBS += -L$$PWD/../../../../../../../usr/local/lib/ -lopencv_imgproc unix:!macx: LIBS += -L$$PWD/../../../../../../../usr/local/lib/ -lopencv_features2d unix:!macx: LIBS += -L$$PWD/../../../../../../../usr/local/lib/ -lopencv_imgcodecs INCLUDEPATH += $$PWD/../../../../../../../usr/local/include DEPENDPATH += $$PWD/../../../../../../../usr/local/include unix:!macx: LIBS += -L$$PWD/../../../product/release32/ -lDEMDriver unix:!macx: LIBS += -L$$PWD/../../../product/release32/ -lProjection unix:!macx: LIBS += -L$$PWD/../../../product/release32/ -lIImage_gC QMAKE_CXXFLAGS_RELEASE = $$QMAKE_CFLAGS_RELEASE_WITH_DEBUGINFO QMAKE_LFLAGS_RELEASE = $$QMAKE_LFLAGS_RELEASE_WITH_DEBUGINFO 转为cmake格式

根据给出的代码,可以将其转化为以下CMake代码: cmake_minimum_required(VERSION 3.5) project(UavRectifyLoadLIb LANGUAGES CXX) set(CMAKE_CXX_STANDARD 11) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_AUTOMOC ON) set(CMAKE_AUTORCC ON) set(CMAKE_AUTOUIC ON) find_package(Qt5Core REQUIRED) add_executable(UavRectifyLoadLIb main.cpp ) target_link_libraries(UavRectifyLoadLIb PRIVATE Qt5::Core UAVAutoRectifyMt UAVAutoRectify UAVAutoRectifyFi DEMDriver Projection IImage_gC opencv_core opencv_highgui opencv_imgproc opencv_features2d opencv_imgcodecs ) target_include_directories(UavRectifyLoadLIb PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../../include/gdal1101 ${CMAKE_CURRENT_SOURCE_DIR}/../include ${CMAKE_CURRENT_SOURCE_DIR}/../../../lib/opencvf249 ${CMAKE_CURRENT_SOURCE_DIR}/../../../../../../../usr/local/include ) if(UNIX AND NOT APPLE) target_link_directories(UavRectifyLoadLIb PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../../product/release32 ${CMAKE_CURRENT_SOURCE_DIR}/../../../../../../../usr/local/lib ) endif() if(WIN32) if(CMAKE_BUILD_TYPE STREQUAL "Debug") set_target_properties(UavRectifyLoadLIb PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG ${CMAKE_CURRENT_SOURCE_DIR}/../../../../RasterManager/bin/Debug ) else() set_target_properties(UavRectifyLoadLIb PROPERTIES RUNTIME_OUTPUT_DIRECTORY_RELEASE ${CMAKE_CURRENT_SOURCE_DIR}/../../../../RasterManager/bin/release ) endif() else() if(CMAKE_BUILD_TYPE STREQUAL "Debug") set_target_properties(UavRectifyLoadLIb PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG ${CMAKE_CURRENT_SOURCE_DIR}/../../../product/release32 ) else() set_target_properties(UavRectifyLoadLIb PROPERTIES RUNTIME_OUTPUT_DIRECTORY_RELEASE ${CMAKE_CURRENT_SOURCE_DIR}/../../../product/release32 ) endif() endif()

最新推荐

2.0 ctia-test-plan-for-wirelesss-device-over-the-air-performance-ver-3-8-1.pdf

This test plan defines general requirements for equipment configurations, laboratory techniques, test methodologies, and evaluation criteria that must be met in order to ensure the accurate, ...

LE Audio协议草案 Basic_Audio_Profile_d09r06.pdf

The Basic Audio Profile (BAP) defines procedures for audio stream control by using the Generic Attribute Profile (GATT) and the Generic Access Profile (GAP) for devices that want to use Bluetooth Low ...

802.11be draft 4.0 wifi7standard

802.11be draft 4.0 wifi7standard

Java基础笔记-8-15

线程存活判断以及线程控制的相关内容。重点学习了isAlive()方法、Join()的三个重载方法、setDaemon()方法,难点在于线程执行过程中对于线程状态的理解和判断

torch_scatter-2.0.9-cp38-cp38-win_amd64.whl.zip

需要配和指定版本torch-1.10.0+cu111使用,请在安装该模块前提前安装torch-1.10.0+cu111以及对应cuda11.1和cudnn

哈希排序等相关算法知识

哈希排序等相关算法知识

混合神经编码调制的设计和训练方法

可在www.sciencedirect.com在线获取ScienceDirectICTExpress 8(2022)25www.elsevier.com/locate/icte混合神经编码调制:设计和训练方法Sung Hoon Lima,Jiyong Hana,Wonjong Noha,Yujae Songb,Sang-WoonJeonc,a大韩民国春川,翰林大学软件学院b韩国龟尾国立技术学院计算机软件工程系,邮编39177c大韩民国安山汉阳大学电子电气工程系接收日期:2021年9月30日;接收日期:2021年12月31日;接受日期:2022年1月30日2022年2月9日在线发布摘要提出了一种由内码和外码组成的混合编码调制方案。外码可以是任何标准的二进制具有有效软解码能力的线性码(例如,低密度奇偶校验(LDPC)码)。内部代码使用深度神经网络(DNN)设计,该深度神经网络获取信道编码比特并输出调制符号。为了训练DNN,我们建议使用损失函数,它是受广义互信息的启发。所得到的星座图被示出优于具有5G标准LDPC码的调制�

利用Pandas库进行数据分析与操作

# 1. 引言 ## 1.1 数据分析的重要性 数据分析在当今信息时代扮演着至关重要的角色。随着信息技术的快速发展和互联网的普及,数据量呈爆炸性增长,如何从海量的数据中提取有价值的信息并进行合理的分析,已成为企业和研究机构的一项重要任务。数据分析不仅可以帮助我们理解数据背后的趋势和规律,还可以为决策提供支持,推动业务发展。 ## 1.2 Pandas库简介 Pandas是Python编程语言中一个强大的数据分析工具库。它提供了高效的数据结构和数据分析功能,为数据处理和数据操作提供强大的支持。Pandas库是基于NumPy库开发的,可以与NumPy、Matplotlib等库结合使用,为数

帮我用C++和Qt写一个m*n的矩阵最小路径和可视化,要求用格子展示,最小路径采用其它的颜色标记

好的,以下是一个基本的实现: 1. 首先需要定义一个矩阵类来存储矩阵数据和计算最小路径和。 ```c++ class Matrix{ public: Matrix(int rows, int cols); ~Matrix(); void setValue(int i, int j, int value); //设置元素的值 int getValue(int i, int j); //获取元素的值 int getRows(); //获取行数 int getCols(); //获取列数 int getMinPathSum(); //获取最

基于android的视频播放器的设计与实现--大学毕业论文.doc

基于android的视频播放器的设计与实现--大学毕业论文.doc