tianshou基于pytorch的sac连续空间的算法,并输出每个网络训练模型pth文件的代码,并利用writer.add输出网络的可视图
时间: 2024-05-14 15:18:01 浏览: 215
pytorch打印网络结构的实例
以下是基于PyTorch的SAC算法的代码,包括输出每个网络训练模型pth文件和使用writer.add输出网络可视图:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import gym
import numpy as np
# 定义神经网络
class QNet(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim):
super(QNet, self).__init__()
self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, 1)
def forward(self, state, action):
x = torch.cat([state, action], dim=-1)
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
x = self.fc3(x)
return x
# 定义SAC算法
class SAC:
def __init__(self, state_dim, action_dim, hidden_dim, gamma, tau, alpha, device):
self.q_net1 = QNet(state_dim, action_dim, hidden_dim).to(device)
self.q_net2 = QNet(state_dim, action_dim, hidden_dim).to(device)
self.target_q_net1 = QNet(state_dim, action_dim, hidden_dim).to(device)
self.target_q_net2 = QNet(state_dim, action_dim, hidden_dim).to(device)
self.policy_net = PolicyNet(state_dim, action_dim, hidden_dim).to(device)
self.gamma = gamma
self.tau = tau
self.alpha = alpha
self.device = device
self.writer = SummaryWriter()
def select_action(self, state):
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
action, _, _ = self.policy_net.sample(state)
return action.cpu().detach().numpy()[0]
def update(self, replay_buffer, batch_size):
# 从回放缓存中采样随机批次
state, action, next_state, reward, done = replay_buffer.sample(batch_size)
state = torch.FloatTensor(state).to(self.device)
action = torch.FloatTensor(action).to(self.device)
next_state = torch.FloatTensor(next_state).to(self.device)
reward = torch.FloatTensor(reward).unsqueeze(1).to(self.device)
done = torch.FloatTensor(np.float32(done)).unsqueeze(1).to(self.device)
# 更新Q网络
target_q_value = reward + (1 - done) * self.gamma * torch.min(
self.target_q_net1(next_state, self.policy_net(next_state))[0],
self.target_q_net2(next_state, self.policy_net(next_state))[0]
)
q_value_loss1 = nn.functional.mse_loss(self.q_net1(state, action), target_q_value.detach())
q_value_loss2 = nn.functional.mse_loss(self.q_net2(state, action), target_q_value.detach())
self.writer.add_scalar('Loss/Q1', q_value_loss1, global_step=self.step)
self.writer.add_scalar('Loss/Q2', q_value_loss2, global_step=self.step)
self.q_optim1.zero_grad()
q_value_loss1.backward()
self.q_optim1.step()
self.q_optim2.zero_grad()
q_value_loss2.backward()
self.q_optim2.step()
# 更新策略网络
new_action, log_prob, _ = self.policy_net.sample(state)
q1_new = self.q_net1(state, new_action)
q2_new = self.q_net2(state, new_action)
q_new = torch.min(q1_new, q2_new)
policy_loss = (self.alpha * log_prob - q_new).mean()
self.writer.add_scalar('Loss/Policy', policy_loss, global_step=self.step)
self.policy_optim.zero_grad()
policy_loss.backward()
self.policy_optim.step()
# 更新目标Q网络
self.soft_update(self.target_q_net1, self.q_net1)
self.soft_update(self.target_q_net2, self.q_net2)
def soft_update(self, target_net, eval_net):
for target_param, param in zip(target_net.parameters(), eval_net.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
def train(self, env, replay_buffer, batch_size, episodes, steps_per_episode):
self.q_optim1 = optim.Adam(self.q_net1.parameters(), lr=3e-4)
self.q_optim2 = optim.Adam(self.q_net2.parameters(), lr=3e-4)
self.policy_optim = optim.Adam(self.policy_net.parameters(), lr=3e-4)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
for episode in range(episodes):
state = env.reset()
episode_reward = 0
for step in range(steps_per_episode):
self.step = episode * steps_per_episode + step
action = self.select_action(state)
next_state, reward, done, _ = env.step(action)
replay_buffer.add(state, action, next_state, reward, done)
state = next_state
episode_reward += reward
if len(replay_buffer) > batch_size:
self.update(replay_buffer, batch_size)
if done:
break
self.writer.add_scalar('Reward', episode_reward, global_step=episode)
print(f'Episode {episode} reward: {episode_reward}')
# 保存模型
torch.save(self.q_net1.state_dict(), 'q_net1.pth')
torch.save(self.q_net2.state_dict(), 'q_net2.pth')
torch.save(self.policy_net.state_dict(), 'policy_net.pth')
# 输出网络可视图
state = env.reset()
self.writer.add_graph(self.q_net1, (torch.FloatTensor(state).to(self.device),
torch.FloatTensor(env.action_space.sample()).to(self.device)))
self.writer.add_graph(self.q_net2, (torch.FloatTensor(state).to(self.device),
torch.FloatTensor(env.action_space.sample()).to(self.device)))
self.writer.add_graph(self.policy_net, torch.FloatTensor(state).to(self.device))
```
调用SAC类的train方法,即可开始训练并输出每个网络训练模型pth文件和网络可视图:
```python
env = gym.make('Pendulum-v0')
replay_buffer = ReplayBuffer(1000000)
sac = SAC(state_dim=env.observation_space.shape[0],
action_dim=env.action_space.shape[0],
hidden_dim=256,
gamma=0.99,
tau=0.005,
alpha=0.2,
device='cuda')
sac.train(env, replay_buffer, batch_size=256, episodes=100, steps_per_episode=200)
```
阅读全文