用akshare,gym写出多股票强化学习代码,包含模型的保存和模型的应用
时间: 2024-02-25 22:58:50 浏览: 105
在强化学习框架下学习交易
好的,我可以为您提供一个简单的示例代码,用于使用`akshare`库获取多支股票的历史价格数据,并使用`gym`库实现股票交易的强化学习模型。
首先,我们需要安装`akshare`和`gym`库:
```
pip install akshare gym
```
然后,我们可以使用`akshare`库获取多支股票的历史价格数据。这里以获取中国平安(601318.SH)和招商银行(600036.SH)的历史价格数据为例:
```python
import akshare as ak
# 获取历史价格数据
stock_codes = ["601318.SH", "600036.SH"]
df = ak.stock_zh_a_daily(symbol=stock_codes[0], adjust="qfq")
for code in stock_codes[1:]:
df = df.merge(ak.stock_zh_a_daily(symbol=code, adjust="qfq"), how="outer", on="trade_date")
df = df.dropna() # 去除缺失数据
print(df.head())
```
接下来,我们可以使用`gym`库实现股票交易的强化学习模型。这里我们使用一个简单的策略:当股票价格上涨时,买入;当股票价格下跌时,卖出。我们使用一个简单的神经网络作为策略模型,输入为当前股票价格和持仓情况,输出为买入或卖出的决策。
```python
import numpy as np
import gym
class StockTradingEnv(gym.Env):
def __init__(self, df, initial_balance=10000):
super().__init__()
self.df = df
self.initial_balance = initial_balance
self.balance = initial_balance
self.shares = 0
self.current_step = 0
self.max_steps = len(df)
self.action_space = gym.spaces.Discrete(2)
self.observation_space = gym.spaces.Box(low=0, high=np.inf, shape=(2,))
def reset(self):
self.balance = self.initial_balance
self.shares = 0
self.current_step = 0
return self._get_observation()
def step(self, action):
assert self.action_space.contains(action)
price = self.df.iloc[self.current_step]["close"]
reward = 0
if action == 0: # 买入
if self.balance >= price:
self.balance -= price
self.shares += 1
else:
reward = -10
elif action == 1: # 卖出
if self.shares > 0:
self.balance += price
self.shares -= 1
else:
reward = -10
self.current_step += 1
done = self.current_step >= self.max_steps
if not done:
reward += self._get_reward(price)
return self._get_observation(), reward, done, {}
def _get_observation(self):
price = self.df.iloc[self.current_step]["close"]
return np.array([price, self.shares])
def _get_reward(self, price):
return (self.balance + self.shares * price - self.initial_balance) / self.initial_balance
```
最后,我们可以使用`stable_baselines3`库训练模型并保存模型。这里我们使用`A2C`算法训练模型:
```python
from stable_baselines3 import A2C
from stable_baselines3.common.vec_env import DummyVecEnv
env = StockTradingEnv(df)
env = DummyVecEnv([lambda: env])
model = A2C("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=int(1e5))
model.save("model")
```
现在我们已经训练好了模型并将其保存到了`model`文件中。我们可以使用该模型对新的股票数据进行交易:
```python
model = A2C.load("model")
df_test = ak.stock_zh_a_daily(symbol="601318.SH", adjust="qfq")
env_test = StockTradingEnv(df_test)
env_test = DummyVecEnv([lambda: env_test])
obs = env_test.reset()
for i in range(len(df_test)):
action, _ = model.predict(obs)
obs, _, done, _ = env_test.step(action)
if done:
break
print(f"Final balance: {env_test.balance:.2f}")
```
这段代码可以使用之前训练好的模型对中国平安的股票数据进行交易,并输出最终的收益。
阅读全文