用torch编写的双路特征提取代码,一路用来提取高光谱空间特征,一路用来提取谱间特征,然后用concat融合两种特征,最后经过一层二维卷积使得融合后的特征图维度减半
时间: 2024-03-09 18:50:11 浏览: 97
以下是一个使用 PyTorch 编写的双路特征提取代码,其中一路用于提取高光谱空间特征,一路用于提取谱间特征,然后使用 `torch.cat` 进行特征融合,最后通过一层二维卷积使融合后的特征图尺寸减半:
```python
import torch.nn as nn
class DualFeatureExtractor(nn.Module):
def __init__(self):
super(DualFeatureExtractor, self).__init__()
# 第一路:高光谱空间特征提取
self.hsi_conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.hsi_bn1 = nn.BatchNorm2d(32)
self.hsi_conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.hsi_bn2 = nn.BatchNorm2d(64)
self.hsi_pool = nn.MaxPool2d(kernel_size=2, stride=2)
# 第二路:谱间特征提取
self.spec_conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
self.spec_bn1 = nn.BatchNorm2d(32)
self.spec_conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.spec_bn2 = nn.BatchNorm2d(64)
self.spec_pool = nn.MaxPool2d(kernel_size=2, stride=2)
# 特征融合
self.conv_cat = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
self.bn_cat = nn.BatchNorm2d(64)
def forward(self, x_hsi, x_spec):
# 高光谱空间特征提取
x_hsi = self.hsi_conv1(x_hsi)
x_hsi = self.hsi_bn1(x_hsi)
x_hsi = nn.functional.relu(x_hsi)
x_hsi = self.hsi_conv2(x_hsi)
x_hsi = self.hsi_bn2(x_hsi)
x_hsi = nn.functional.relu(x_hsi)
x_hsi = self.hsi_pool(x_hsi)
# 谱间特征提取
x_spec = self.spec_conv1(x_spec)
x_spec = self.spec_bn1(x_spec)
x_spec = nn.functional.relu(x_spec)
x_spec = self.spec_conv2(x_spec)
x_spec = self.spec_bn2(x_spec)
x_spec = nn.functional.relu(x_spec)
x_spec = self.spec_pool(x_spec)
# 特征融合
x = torch.cat((x_hsi, x_spec), dim=1)
x = self.conv_cat(x)
x = self.bn_cat(x)
x = nn.functional.relu(x)
x = nn.functional.max_pool2d(x, kernel_size=2, stride=2)
return x
```
在这个示例中,我们定义了一个名为 `DualFeatureExtractor` 的类,它包含两个子模块 `hsi_conv1`、`hsi_conv2` 和 `hsi_pool`,用于提取高光谱空间特征,以及子模块 `spec_conv1`、`spec_conv2` 和 `spec_pool`,用于提取谱间特征。然后,我们使用 `torch.cat` 将两个特征图连接在一起,形成一个更大的特征图。最后,我们使用 `conv_cat` 和 `bn_cat` 对融合后的特征图进行卷积和批量归一化,使其尺寸减半。
阅读全文