pytorch现一维CNN
时间: 2023-08-13 22:05:46 浏览: 156
对于一维卷积神经网络(1D CNN)的实现,PyTorch提供了相应的模块torch.nn.Conv1d。以下是一个示例的一维CNN模型的代码:
```python
import torch
import torch.nn as nn
class OneDCNN(nn.Module):
def __init__(self, input_size, num_classes):
super(OneDCNN, self).__init__()
self.conv1 = nn.Conv1d(input_size, 64, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool1d(kernel_size=2, stride=2)
self.fc = nn.Linear(64, num_classes)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.maxpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# 创建一个随机输入
input_size = 10
batch_size = 16
seq_length = 50
num_classes = 5
input_data = torch.randn(batch_size, input_size, seq_length)
# 创建模型实例并进行前向传播
model = OneDCNN(input_size, num_classes)
output = model(input_data)
print(output.shape) # 输出为(batch_size, num_classes)
```
在这个示例中,我们定义了一个名为OneDCNN的继承自nn.Module的类,该类实现了一个简单的一维CNN模型。模型包含一个卷积层、ReLU激活函数、最大池化层和全连接层。在forward方法中,我们将输入数据通过卷积层、激活函数、池化层和全连接层依次传递,并返回最终的输出。
需要注意的是,输入数据的维度为(batch_size, input_size, seq_length),其中batch_size表示批量大小,input_size表示输入特征的维度,seq_length表示序列长度。输出的维度为(batch_size, num_classes),其中num_classes表示分类的类别数。
阅读全文