构建一个简易的判断奇数和偶数的神经网络代码实现
时间: 2024-03-01 12:55:05 浏览: 20
好的,以下是一个简单的神经网络代码实现,用于判断奇数和偶数。这个实现使用Python和PyTorch深度学习框架。
首先,我们需要导入必要的库:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
```
接下来,我们定义一个数据集类,用于加载数据和标签:
```python
class NumberDataset(Dataset):
def __init__(self):
self.data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
self.label = [1, 0, 1, 0, 1, 0, 1, 0, 1, 0] # 奇数为1,偶数为0
def __getitem__(self, index):
x = torch.tensor([self.data[index]], dtype=torch.float32)
y = torch.tensor([self.label[index]], dtype=torch.float32)
return x, y
def __len__(self):
return len(self.data)
```
然后,我们定义一个神经网络模型:
```python
class NumberNet(nn.Module):
def __init__(self):
super(NumberNet, self).__init__()
self.fc1 = nn.Linear(1, 2)
self.fc2 = nn.Linear(2, 2)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.softmax(self.fc2(x), dim=1)
return x
```
接下来,我们定义训练函数:
```python
def train(model, train_loader, optimizer, criterion, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 10 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
```
最后,我们定义测试函数:
```python
def test(model, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
test_loss += nn.functional.binary_cross_entropy(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.argmax(dim=1, keepdim=True)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
```
最后,我们可以开始训练和测试:
```python
if __name__ == '__main__':
train_dataset = NumberDataset()
test_dataset = NumberDataset()
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)
model = NumberNet()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
criterion = nn.BCELoss()
for epoch in range(1, 10):
train(model, train_loader, optimizer, criterion, epoch)
test(model, test_loader)
```
这就是一个简单的神经网络实现,用于判断奇数和偶数。当然,此实现仅用于演示目的。在实际应用中,我们需要更大的数据集,更复杂的模型和更多的调整,以获得更好的性能。
相关推荐
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)