pytorch PINN求解具有初边值的热传导方程间断问题的代码(含真实解和误差的图像代码)
时间: 2024-02-01 18:17:06 浏览: 310
以下是使用PyTorch实现的PINN求解具有初边值的热传导方程间断问题的代码,包括真实解和误差的图像代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
# 设置随机数种子
torch.manual_seed(1234)
np.random.seed(1234)
# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 定义热传导方程
class HeatEquation(nn.Module):
def __init__(self):
super(HeatEquation, self).__init__()
self.net = nn.Sequential(
nn.Linear(2, 50),
nn.Tanh(),
nn.Linear(50, 50),
nn.Tanh(),
nn.Linear(50, 50),
nn.Tanh(),
nn.Linear(50, 1)
)
def forward(self, x):
return self.net(x)
# 定义边界条件
def boundary_condition(x):
return x[:,0:1]*(1-x[:,0:1])*x[:,1:2]*(1-x[:,1:2])
# 定义真实解
def analytical_solution(x):
return np.sin(np.pi*x[:,0:1])*np.sin(np.pi*x[:,1:2])
# 定义损失函数
def loss_func(net, x, y, b):
y_pred = net(x)
y_grad = torch.autograd.grad(y_pred.sum(), x, create_graph=True)[0]
y_t = y_grad[:,1:2]
y_x = y_grad[:,0:1]
y_xx = torch.autograd.grad(y_x.sum(), x, create_graph=True)[0][:,0:1]
loss = nn.MSELoss()(y_pred, y) + nn.MSELoss()(y_t, y_xx) + nn.MSELoss()(b, torch.zeros_like(b))
return loss
# 生成训练数据
n = 2000
x = np.random.uniform(low=0.0, high=1.0, size=(n,2)).astype('float32')
y = analytical_solution(torch.from_numpy(x)).detach().numpy()
b = boundary_condition(torch.from_numpy(x)).detach().numpy()
# 转换为张量
x = torch.from_numpy(x).to(device)
y = torch.from_numpy(y).to(device)
b = torch.from_numpy(b).to(device)
# 初始化模型
net = HeatEquation().to(device)
# 定义优化器
optimizer = optim.Adam(net.parameters(), lr=0.001)
# 训练模型
epochs = 5000
for epoch in range(epochs):
optimizer.zero_grad()
loss = loss_func(net, x, y, b)
loss.backward()
optimizer.step()
if epoch % 100 == 0:
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, epochs, loss.item()))
# 绘制真实解和PINN求解的解的比较图像
x_test = np.linspace(0, 1, 100)
y_test = np.linspace(0, 1, 100)
X, Y = np.meshgrid(x_test, y_test)
xy = np.hstack((X.reshape(-1,1), Y.reshape(-1,1)))
xy_tensor = torch.from_numpy(xy).float().to(device)
z_true = analytical_solution(xy_tensor).cpu().detach().numpy().reshape(100, 100)
z_pinns = net(xy_tensor).cpu().detach().numpy().reshape(100, 100)
plt.subplot(121)
plt.pcolor(X, Y, z_true)
plt.colorbar()
plt.title('True Solution')
plt.subplot(122)
plt.pcolor(X, Y, z_pinns)
plt.colorbar()
plt.title('PINNs Solution')
plt.show()
# 绘制误差图像
z_error = np.abs(z_true - z_pinns)
plt.pcolor(X, Y, z_error)
plt.colorbar()
plt.title('Error')
plt.show()
```
在训练过程中,每100个epoch打印一次损失值。训练完成后,绘制真实解和PINN求解的解的比较图像和误差图像。
阅读全文