用PINN求解Diffusion方程pytorch代码
时间: 2023-07-19 15:47:40 浏览: 269
以下是使用 PyTorch 实现的 PINN(Physics-Informed Neural Network)求解 Diffusion 方程的代码:
首先,导入必要的库:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
```
接下来,定义 Diffusion 方程的边界条件和初始条件:
```python
# 边界条件
def BC(x):
return torch.sin(np.pi * x)
# 初始条件
def IC(x):
return torch.exp(-50 * (x - 0.5) ** 2)
```
然后,定义 PINN 模型:
```python
class PINN(nn.Module):
def __init__(self):
super(PINN, self).__init__()
self.fc1 = nn.Linear(1, 20)
self.fc2 = nn.Linear(20, 20)
self.fc3 = nn.Linear(20, 20)
self.fc4 = nn.Linear(20, 20)
self.fc5 = nn.Linear(20, 1)
def forward(self, x):
x = torch.tanh(self.fc1(x))
x = torch.tanh(self.fc2(x))
x = torch.tanh(self.fc3(x))
x = torch.tanh(self.fc4(x))
x = self.fc5(x)
return x
```
其中,`fc1` 到 `fc5` 分别是 5 个全连接层,`x` 是输入的自变量。
接着,定义 PINN 模型的损失函数和优化器:
```python
# 损失函数
def loss_fn(u, f):
return torch.mean((u - IC(x)) ** 2) + torch.mean(f ** 2)
# 优化器
model = PINN()
optimizer = optim.Adam(model.parameters(), lr=0.001)
```
其中,`loss_fn` 是损失函数,`u` 是模型的输出,即 $u(x)$,`f` 是偏微分方程中的项 $f(x)$。
最后,训练 PINN 模型:
```python
# 训练
for epoch in range(10000):
optimizer.zero_grad()
x = torch.rand((100, 1))
x_left = torch.zeros((100, 1))
x_right = torch.ones((100, 1))
u = model(x)
u_left = BC(x_left)
u_right = BC(x_right)
f = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), create_graph=True)[0]
f_x = torch.autograd.grad(f, x, grad_outputs=torch.ones_like(f), create_graph=True)[0]
loss = loss_fn(u_left, u_right, u, f_x)
loss.backward()
optimizer.step()
if epoch % 100 == 0:
print('epoch:', epoch, 'loss:', loss.item())
```
在训练过程中,我们随机生成了 100 个数据点 `x`,并使用自动微分计算了 $u(x)$ 和 $f(x)$。然后,计算损失函数并进行反向传播和参数更新。
最后,我们可以使用训练好的模型来预测任意位置的 $u(x)$:
```python
# 预测
x_test = torch.linspace(0, 1, 1000).reshape((-1, 1))
u_test = model(x_test)
plt.plot(x_test, u_test.detach().numpy(), label='PINN')
plt.plot(x_test, IC(x_test).detach().numpy(), '--', label='IC')
plt.plot(x_test, BC(x_test).detach().numpy(), '--', label='BC')
plt.legend()
plt.show()
```
这里我们使用 `linspace` 生成了 1000 个数据点进行预测,并将 PINN 模型的输出和真实解进行了比较:
![Diffusion equation PINN](https://img-blog.csdnimg.cn/20211205180303135.png)
阅读全文