用pytorch编写一个全连接分类神经网络,输入为3*2的数组,输出维度为6
时间: 2024-03-05 16:51:52 浏览: 14
好的,以下是一个使用PyTorch编写的全连接分类神经网络,输入为3*2的数组,输出维度为6的示例代码:
```python
import torch
import torch.nn as nn
# 定义全连接分类神经网络模型
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(6, 16)
self.fc2 = nn.Linear(16, 32)
self.fc3 = nn.Linear(32, 6)
def forward(self, x):
x = x.view(-1, 6)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
# 创建模型实例
model = MyModel()
# 定义输入数据
inputs = torch.randn(1, 3, 2)
# 前向传播计算输出
outputs = model(inputs)
print(outputs.shape) # 输出维度为[1, 6]
```
这个模型包含3个全连接层,分别是输入层、隐藏层和输出层。输入层的输入维度为6(即3*2),隐藏层的输出维度为16,第二个隐藏层的输出维度为32,输出层的输出维度为6。在前向传播过程中,输入数据需要通过view方法展开成2维张量,然后分别经过3个全连接层计算。最终输出的张量维度为[1, 6],其中1表示输入数据只有1个样本,6表示输出维度为6。