pytorch 用PINN求解具有分段间断初值的一维sod激波的代码
时间: 2023-12-10 19:39:00 浏览: 225
PyTorch搭建一维线性回归模型(二)
以下是使用 PyTorch 实现求解一维 Sod 激波方程的代码:
```python
import torch
import numpy as np
import matplotlib.pyplot as plt
# 定义模型
class PINN(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(2, 50)
self.fc2 = torch.nn.Linear(50, 50)
self.fc3 = torch.nn.Linear(50, 50)
self.fc4 = torch.nn.Linear(50, 50)
self.fc5 = torch.nn.Linear(50, 50)
self.fc6 = torch.nn.Linear(50, 1)
def forward(self, x):
x = torch.tanh(self.fc1(x))
x = torch.tanh(self.fc2(x))
x = torch.tanh(self.fc3(x))
x = torch.tanh(self.fc4(x))
x = torch.tanh(self.fc5(x))
x = self.fc6(x)
return x
# 定义边界条件
def boundary_loss(model, xb, yb):
pred = model(xb)
return torch.mean((pred - yb)**2)
# 定义 PDE 方程
def pde_loss(model, x):
gamma = 1.4
rho_l = 1.0
u_l = 0
p_l = 1
rho_r = 0.125
u_r = 0
p_r = 0.1
x_l = 0
x_r = 0.5
t = 0.2
x = torch.tensor(x, dtype=torch.float32, requires_grad=True)
t = torch.tensor(t, dtype=torch.float32, requires_grad=True)
# 初始条件
if x <= 0.5:
rho = rho_l
u = u_l
p = p_l
else:
rho = rho_r
u = u_r
p = p_r
# 求解状态量
a = torch.sqrt(gamma*p/rho)
M = u/a
F = torch.cat((rho*u, rho*u**2+p, u*(p/(gamma-1)+0.5*rho*u**2+rho*u*a*(M+1)), (p/(gamma-1)+0.5*rho*u**2+rho*u*a*(M+1))), dim=0)
G = torch.tensor([0, 0, 0, 0], dtype=torch.float32)
# 求解梯度
grad_F = torch.autograd.grad(F, x, create_graph=True)[0]
grad_G = torch.autograd.grad(G, x, create_graph=True)[0]
# 求解 PDE 方程的残差
res = grad_F - grad_G
loss = torch.mean(res**2)
return loss
# 数据准备
x = np.linspace(0, 0.5, 100)
y = np.zeros_like(x)
xb = np.array([[0], [0.5]])
yb = np.array([[1, 0, 1], [0.125, 0, 0.1]])
# 模型训练
model = PINN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for i in range(20000):
optimizer.zero_grad()
loss = boundary_loss(model, xb, yb) + pde_loss(model, x)
loss.backward()
optimizer.step()
if i % 1000 == 0:
print("Epoch: {}, Loss: {}".format(i, loss.item()))
# 模型预测
x_test = np.linspace(0, 0.5, 1000)
y_test = model(torch.tensor(np.vstack((x_test, y_test)).T, dtype=torch.float32)).detach().numpy()
# 结果可视化
plt.plot(x, y)
plt.plot(x_test, y_test)
plt.show()
```
其中,`PINN` 类定义了一个 6 层的全连接神经网络,用于求解 Sod 激波方程。`boundary_loss` 函数计算边界条件的损失,`pde_loss` 函数计算 PDE 方程的残差。最后使用 Adam 优化器对模型进行训练,并将训练得到的结果进行可视化。
阅读全文