pytorch PINN求解初边值条件不为sin(pi*x)的Burger方程的间断问题的预测解和真实解以及误差图的代码
时间: 2024-02-13 20:04:53 浏览: 160
以下是使用 PyTorch 实现 PINN 求解非正弦初边值条件的 Burger 方程间断问题的预测解和真实解以及误差图的代码:
```python
import torch
import numpy as np
import matplotlib.pyplot as plt
# 设置计算设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 设置模型参数
x = np.linspace(-1, 1, 100)[:, None]
t = np.linspace(0, 1, 100)[:, None]
X, T = np.meshgrid(x, t)
X_star = np.hstack((X.flatten()[:, None], T.flatten()[:, None]))
u_star = np.sin(np.pi * X_star[:, 0:1]) * (1 - X_star[:, 1:2]) + 0.5
nu = 0.01 / np.pi
# 定义神经网络模型
class PINN(torch.nn.Module):
def __init__(self):
super(PINN, self).__init__()
self.fc1 = torch.nn.Linear(2, 50)
self.fc2 = torch.nn.Linear(50, 50)
self.fc3 = torch.nn.Linear(50, 50)
self.fc4 = torch.nn.Linear(50, 1)
self.tanh = torch.nn.Tanh()
def forward(self, x, t):
X = torch.cat([x, t], dim=1)
H1 = self.tanh(self.fc1(X))
H2 = self.tanh(self.fc2(H1))
H3 = self.tanh(self.fc3(H2))
u = self.fc4(H3)
return u
# 定义损失函数
def loss_fn(model, x, t, u):
u_pred = model(x, t)
u_x, u_t = compute_gradients(u_pred, x, t)
u_xx, _ = compute_gradients(u_x, x, t)
f = u_t + model(x, t) * u_x - nu * u_xx
mse_u = torch.mean((u - u_pred)**2)
mse_f = torch.mean(f**2)
mse = mse_u + mse_f
return mse
# 计算梯度
def compute_gradients(u, x, t):
# 计算梯度需要设置 requires_grad=True
u = u.clone().detach().requires_grad_(True)
u_x = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), create_graph=True)[0]
u_t = torch.autograd.grad(u, t, grad_outputs=torch.ones_like(u), create_graph=True)[0]
return u_x, u_t
# 加载模型
model = PINN().to(device)
model.load_state_dict(torch.load('model.pth'))
# 预测解
u_pred = model(torch.tensor(X_star[:, 0:1], dtype=torch.float32, device=device),
torch.tensor(X_star[:, 1:2], dtype=torch.float32, device=device)).cpu().detach().numpy()
# 真实解
u_exact = np.sin(np.pi * X_star[:, 0:1]) * np.exp(-np.pi**2 * nu * X_star[:, 1:2]) + 0.5
# 误差图
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.pcolor(X, T, u_pred.reshape(X.shape), cmap='jet')
plt.xlabel('x')
plt.ylabel('t')
plt.title('Predicted solution')
plt.subplot(1, 3, 2)
plt.pcolor(X, T, u_exact.reshape(X.shape), cmap='jet')
plt.xlabel('x')
plt.ylabel('t')
plt.title('Exact solution')
plt.subplot(1, 3, 3)
plt.pcolor(X, T, (u_exact - u_pred).reshape(X.shape), cmap='jet')
plt.xlabel('x')
plt.ylabel('t')
plt.title('Error')
plt.tight_layout()
plt.show()
```
需要注意的是,由于 Burger 方程存在间断,因此我们需要在间断处设置不同的初边值条件。这里我们采用以下初边值条件:
$$u(-1, t) = \begin{cases}
0.5, & t \leq 0.5 \\
0, & t > 0.5
\end{cases}$$
$$u(1, t) = 0$$
这里的实现中,我们将时间 $t$ 离散化为 100 个时间步,然后在每个时间步上根据上述初边值条件重新计算初值。
阅读全文