# -*- coding: utf-8 -*- import matplotlib.pyplot as plt from scipy.io import loadmat from math import sqrt from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score def quota(inv_y, inv_yhat,name): rmse_ = sqrt(mean_squared_error(inv_y, inv_yhat)) mae_=mean_absolute_error(inv_y, inv_yhat) r2_ = r2_score(inv_y, inv_yhat) print(name,'的rmse:', rmse_, ' mae:', mae_, ' R2:', r2_) data0=loadmat('result/vmd_cnn_bilstm_result.mat')['true'] data1=loadmat('result/cnn_bilstm_result.mat')['pred'] data2=loadmat('result/vmd_cnn_bilstm_result.mat')['pred'] data3=loadmat('result/bilstm_result.mat')['pred'] quota(data0,data1,'CNN-BiLSTM') quota(data0,data2,'VMD-CNN-BiLSTM') quota(data0,data3,'BiLSTM') # 画图 plt.figure() plt.plot(data0,'-',label='Real') plt.plot(data1,'-',label='CNN-BiLSTM') plt.plot(data2,'-',label='VMD-VNN-BiLSTM') plt.plot(data3,'-',label='BiLSTM') plt.grid() plt.legend() plt.xlabel('time/h') plt.ylabel('Compare') plt.show()
时间: 2023-08-18 18:08:00 浏览: 121
这段代码是一个 Python 脚本,它包含了一个名为 `quota()` 的函数和一些数据处理和可视化的代码。这个脚本的主要功能是比较不同深度学习模型在某个数据集上的性能表现,并通过可视化的方式展示比较结果。
首先,代码通过 `loadmat()` 函数加载了四个数据文件,分别是真实值 `data0` 和三个模型的预测结果 `data1`、`data2` 和 `data3`。然后,通过调用 `quota()` 函数计算了每个模型在真实值上的均方根误差(RMSE)、平均绝对误差(MAE)和 R2 得分。这些指标可以用来衡量模型的预测能力,RMSE 和 MAE 越小、R2 得分越高表示模型性能越好。
最后,代码使用 Matplotlib 库绘制了一个折线图,将真实值和三个模型的预测结果进行了对比。图表的 x 轴表示时间,y 轴表示比较结果。通过比较不同模型的预测结果,可以直观地了解它们的性能差异。
相关问题
请帮我修改下面代码中的错误# -*- coding: utf-8 -*- """ Created on Sun May 28 18:08:36 2023 @author: lll """ import numpy as np import matplotlib.pyplot as plt from scipy.optimize import brentq from scipy.stats import norm # 定义BS模型计算期权价格的函数 def bs_price(S, K, r, T, sigma, option='call'): d1 = (np.log(S/K) + (r + 0.5*sigma**2)*T) / (sigma*np.sqrt(T)) d2 = d1 - sigma*np.sqrt(T) if option == 'call': price = S*norm.cdf(d1) - K*np.exp(-r*T)*norm.cdf(d2) else: price = K*np.exp(-r*T)*norm.cdf(-d2) - S*norm.cdf(-d1) return price # 定义计算隐含波动率的函数 def implied_vol(S, K, r, T, price, option='call'): def f(sigma): return bs_price(S, K, r, T, sigma, option) - price return brentq(f, 0.001, 10) # 定义计算波动率微笑图形的函数 def smile_vol(S, r, T, vol_list, K_list, option='call'): implied_vol_list = [] for K, vol in zip(K_list, vol_list): price = bs_price(S, K, r, T, vol, option) implied_vol_list.append(implied_vol(S, K, r, T, price, option)) plt.plot(K_list, implied_vol_list) plt.xlabel('Strike') plt.ylabel('Implied Volatility') plt.title('Volatility Smile') plt.show() # 示例代码 S = 100 r = 0.05 T = 1 K_list = np.arange(80, 121, 5) vol_call_list = [0.2, 0.3, 0.4, 0.5, 0.6] vol_put_list = [0.6, 0.5, 0.4, 0.3, 0.2] smile_vol(S, r, T, vol_call_list, K_list, option='call') smile_vol(S, r, T, vol_put_list, K_list, option='put')
这段代码看起来没有明显的语法错误,但是我注意到在函数`smile_vol`中,`vol_list`和`K_list`的长度必须相等,否则会出现错误。
如果你遇到了任何其他错误,请提供错误信息以便进一步调试。
以下是修改后的代码:
```python
# -*- coding: utf-8 -*-
"""
Created on Sun May 28 18:08:36 2023
@author: lll
"""
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import brentq
from scipy.stats import norm
# 定义BS模型计算期权价格的函数
def bs_price(S, K, r, T, sigma, option='call'):
d1 = (np.log(S/K) + (r + 0.5*sigma**2)*T) / (sigma*np.sqrt(T))
d2 = d1 - sigma*np.sqrt(T)
if option == 'call':
price = S*norm.cdf(d1) - K*np.exp(-r*T)*norm.cdf(d2)
else:
price = K*np.exp(-r*T)*norm.cdf(-d2) - S*norm.cdf(-d1)
return price
# 定义计算隐含波动率的函数
def implied_vol(S, K, r, T, price, option='call'):
def f(sigma):
return bs_price(S, K, r, T, sigma, option) - price
return brentq(f, 0.001, 10)
# 定义计算波动率微笑图形的函数
def smile_vol(S, r, T, vol_list, K_list, option='call'):
if len(vol_list) != len(K_list):
raise ValueError("Lengths of vol_list and K_list must be equal.")
implied_vol_list = []
for K, vol in zip(K_list, vol_list):
price = bs_price(S, K, r, T, vol, option)
implied_vol_list.append(implied_vol(S, K, r, T, price, option))
plt.plot(K_list, implied_vol_list)
plt.xlabel('Strike')
plt.ylabel('Implied Volatility')
plt.title('Volatility Smile')
plt.show()
# 示例代码
S = 100
r = 0.05
T = 1
K_list = np.arange(80, 121, 5)
vol_call_list = [0.2, 0.3, 0.4, 0.5, 0.6]
vol_put_list = [0.6, 0.5, 0.4, 0.3, 0.2]
smile_vol(S, r, T, vol_call_list, K_list, option='call')
smile_vol(S, r, T, vol_put_list, K_list, option='put')
```
请将下面代码中的两张波动率微笑的图绘制到一张图中# -- coding: utf-8 -- """ Created on Sun May 28 18:08:36 2023 @author: lll """ import numpy as np import matplotlib.pyplot as plt from scipy.optimize import brentq from scipy.stats import norm # 定义BS模型计算期权价格的函数 def bs_price(S, K, r, T, sigma, option='call'): d1 = (np.log(S/K) + (r + 0.5sigma**2)T) / (sigmanp.sqrt(T)) d2 = d1 - sigmanp.sqrt(T) if option == 'call': price = Snorm.cdf(d1) - Knp.exp(-rT)norm.cdf(d2) else: price = Knp.exp(-rT)norm.cdf(-d2) - Snorm.cdf(-d1) return price # 定义计算隐含波动率的函数 def implied_vol(S, K, r, T, price, option='call'): def f(sigma): return bs_price(S, K, r, T, sigma, option) - price return brentq(f, 0.001, 10) # 定义计算波动率微笑图形的函数 def smile_vol(S, r, T, vol_list, K_list, option='call'): if len(vol_list) != len(K_list): raise ValueError("Lengths of vol_list and K_list must be equal.") implied_vol_list = [] for K, vol in zip(K_list, vol_list): price = bs_price(S, K, r, T, vol, option) implied_vol_list.append(implied_vol(S, K, r, T, price, option)) plt.plot(K_list, implied_vol_list) plt.xlabel('Strike') plt.ylabel('Implied Volatility') plt.title('Volatility Smile') plt.show() # 示例代码 S = 100 r = 0.05 T = 1 K_list = np.arange(80, 121, 9) vol_call_list = [0.2, 0.3, 0.4, 0.5, 0.6] vol_put_list = [0.6, 0.5, 0.4, 0.3, 0.2] smile_vol(S, r, T, vol_call_list, K_list, option='call') smile_vol(S, r, T, vol_put_list, K_list, option='put')
以下是将两张波动率微笑的图绘制到一张图中的代码:
```python
# 示例代码
S = 100
r = 0.05
T = 1
K_list = np.arange(80, 121, 9)
vol_call_list = [0.2, 0.3, 0.4, 0.5, 0.6]
vol_put_list = [0.6, 0.5, 0.4, 0.3, 0.2]
# 绘制波动率微笑图形
implied_vol_call_list = []
implied_vol_put_list = []
for K, vol_call, vol_put in zip(K_list, vol_call_list, vol_put_list):
price_call = bs_price(S, K, r, T, vol_call, option='call')
implied_vol_call_list.append(implied_vol(S, K, r, T, price_call, option='call'))
price_put = bs_price(S, K, r, T, vol_put, option='put')
implied_vol_put_list.append(implied_vol(S, K, r, T, price_put, option='put'))
plt.plot(K_list, implied_vol_call_list, label='Call')
plt.plot(K_list, implied_vol_put_list, label='Put')
plt.xlabel('Strike')
plt.ylabel('Implied Volatility')
plt.title('Volatility Smile')
plt.legend()
plt.show()
```
在这个代码中,我们首先计算出每个期权的隐含波动率,然后分别将call和put的隐含波动率绘制在同一张图上。我们使用 `label` 和 `legend` 来标记每个曲线。注意,这里我们需要同时计算call和put的隐含波动率,因为它们的波动率微笑图形通常是不对称的。
阅读全文