用DeepAR做股票价格预测
【摘要】 DeepAR是亚马逊提出的一种时间序列预测算法,通过在大量时间序列上训练自回归递归网络模型,从相关的时间序列中有效学习全局模型,进而对时间序列进行预测。它能应对季节性、周期性等问题。本文探索下该模型在股票预测中的应用。安装包!pip install baostock==0.8.8!pip install mxnet==1.7.0.post2!pip install gluonts==0.9....
DeepAR是亚马逊提出的一种时间序列预测算法,通过在大量时间序列上训练自回归递归网络模型,从相关的时间序列中有效学习全局模型,进而对时间序列进行预测。它能应对季节性、周期性等问题。本文探索下该模型在股票预测中的应用。
-
安装包
!pip install baostock==0.8.8
!pip install mxnet==1.7.0.post2
!pip install gluonts==0.9.5
!pip install matplotlib==3.3.1
2. 导入包
import baostock as bs
from gluonts.dataset import common
from gluonts.model import deepar
from gluonts.mx.trainer import Trainer
from gluonts.evaluation.backtest import make_evaluation_predictions
from tqdm.autonotebook import tqdm
import os
import matplotlib.pyplot as plt
3. 数据准备
3.1. 变量初始化
stock_list = ['sh.603198', 'sh.600809', 'sh.600519'] #以三个白酒股为例
train_list = []
test_list = []
prediction_length = 20 #预测长度
3.2. 数据下载函数
def get_stockdata(code):
rs = bs.query_history_k_data_plus(code,
'date,close,volume,turn',
start_date='2020-09-01',
end_date='2022-09-01',
frequency='d', adjustflag='2')
return rs.get_data()
3.3. 连上baostock
lg = bs.login()
3.4. 数据下载与划分
for stock_id in stock_list:
df = get_stockdata(stock_id)
train_dic = {'start':df.date[0],
'target':df.close,
'cat':int(stock_id.split('.')[1]),
'dynamic_feat':[df.volume, df.turn]}
test_dic = {'start':df.date[0],
'target':df.close[:-prediction_length],
'cat':int(stock_id.split('.')[1]),
'dynamic_feat':[df.volume[:-prediction_length], df.turn[:-prediction_length]]}
train_list.append(train_dic)
test_list.append(test_dic)
3.5. 数据封装
train_data = common.ListDataset(train_list, freq='1d')
test_data = common.ListDataset(test_list, freq='1d')
4. 建模与训练
4.1. 构建训练器
estimator = deepar.DeepAREstimator(
prediction_length=prediction_length,
context_length=60,
freq='1d',
num_layers=2,
num_cells=64,
trainer=Trainer(epochs=20,
learning_rate=1e-2,
num_batches_per_epoch=32))
4.2. 模型训练
predictor = estimator.train(train_data)
5. 股价预测
forecast_it, ts_it = make_evaluation_predictions(
dataset=test_data,
predictor=predictor,
num_samples=100)
6. 模型评估与可视
6.1. 数据封装
tss = list(tqdm(ts_it, total=len(test_data)))
forecasts = list(tqdm(forecast_it, total=len(test_data)))
6.2. 新建存放文件夹
plot_log_path = './plots/'
directory = os.path.dirname(plot_log_path)
if not os.path.exists(directory):
os.makedirs(directory)
6.3. 结果可视化
def plot_prob_forecasts(ts_entry, forecast_entry, path, sample_id):
plot_length = 150
prediction_intervals = (50, 80)
legend = ['observations', 'median prediction'] + [f'{k}% prediction interval' for k in prediction_intervals][::-1]
_, ax = plt.subplots(1, 1, figsize=(10, 7))
ts_entry[-plot_length:].plot(ax=ax)
forecast_entry.plot(prediction_intervals=prediction_intervals, color='g')
ax.axvline(ts_entry.index[-prediction_length], color='r')
plt.legend(legend, loc='upper left')
plt.savefig('{}forecast_{}.png'.format(path, sample_id))
plt.close()
6.4. 将结果存入刚新建的文件夹
for i in tqdm(range(len(stock_list))):
ts_entry = tss[i]
forecast_entry = forecasts[i]
plot_prob_forecasts(ts_entry, forecast_entry, plot_log_path, i)
6.5. 预测结果
在刚刚新建的文件夹下可以看到
三个股票的预测结果如下:
可以看到,只有第三个票的结果勉强可以接受,想要将这种预测方法运用于实战也还需要深入地研究。
【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)