用akshare写出股票强化学习代码,画图显示买卖点和收益
时间: 2023-07-19 22:23:52 浏览: 149
好的,以下是我使用akshare库编写的股票强化学习代码和画图显示买卖点和收益的程序:
```python
import akshare as ak
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# 获取股票数据
stock_code = 'sh600519'
stock_data = ak.stock_zh_a_hist(symbol=stock_code, adjust="qfq", start_date='20180101')
stock_data.index = pd.to_datetime(stock_data['日期'])
stock_data = stock_data.drop(['日期'], axis=1)
# 计算技术指标RSI
def calc_RSI(data, time_period=14):
delta = data.diff()
gain = delta.where(delta > 0, 0)
loss = - delta.where(delta < 0, 0)
avg_gain = gain.rolling(time_period).mean()
avg_loss = loss.rolling(time_period).mean()
RS = avg_gain / avg_loss
RSI = 100 - (100 / (1 + RS))
return RSI
stock_data['RSI'] = calc_RSI(stock_data['收盘'])
# 定义状态和动作
def get_state(stock_data, t):
state = []
for i in range(t - 4, t + 1):
state.append(stock_data.iloc[i]['收盘'])
state.append(stock_data.iloc[i]['RSI'])
return np.array(state)
def get_action(q_values, state, eps=0.1):
if np.random.uniform() < eps:
return np.random.randint(0, 3)
else:
return np.argmax(q_values[state])
# 定义强化学习过程
alpha = 0.1
gamma = 0.99
epsilon = 0.1
q_values = np.zeros((200, 3))
rewards = []
for episode in range(500):
state = get_state(stock_data, 5)
done = False
total_reward = 0
while not done:
action = get_action(q_values, tuple(state), epsilon)
if action == 0:
reward = -0.1
elif action == 1:
reward = 0.1
else:
reward = 0
next_state = get_state(stock_data, 6)
q_values[tuple(state)][action] += alpha * (reward + gamma * np.max(q_values[tuple(next_state)]) - q_values[tuple(state)][action])
state = next_state
total_reward += reward
if len(stock_data) <= 6:
done = True
else:
stock_data = stock_data.iloc[1:]
if len(stock_data) <= 6:
done = True
rewards.append(total_reward)
# 计算收益和买卖点
def calc_profit(stock_data, q_values):
positions = []
profits = []
for i in range(5, len(stock_data) - 1):
state = get_state(stock_data, i)
action = np.argmax(q_values[tuple(state)])
if action == 0:
positions.append(0)
profits.append(0)
elif action == 1:
positions.append(1)
profits.append(stock_data.iloc[i + 1]['收盘'] - stock_data.iloc[i]['收盘'])
else:
positions.append(-1)
profits.append(stock_data.iloc[i]['收盘'] - stock_data.iloc[i + 1]['收盘'])
return positions, profits
positions, profits = calc_profit(stock_data, q_values)
# 画图显示买卖点和收益
fig, ax1 = plt.subplots()
ax1.plot(stock_data['收盘'], color='black')
ax1.set_ylabel('Price')
ax2 = ax1.twinx()
ax2.plot(positions, color='blue')
ax2.set_ylabel('Position')
ax3 = ax1.twinx()
ax3.spines["right"].set_position(("axes", 1.2))
ax3.plot(np.cumsum(profits), color='red')
ax3.set_ylabel('Profit')
plt.show()
```
这段程序使用了akshare库获取股票数据,并计算了RSI指标用于状态的定义。然后定义了强化学习过程,使用Q-learning算法更新Q值,并记录总收益。最后计算买卖点和收益,并用图表展示。
阅读全文