用DeepAR做股票价格预测

举报
darkpard 发表于 2022/09/12 09:39:04 2022/09/12
【摘要】 DeepAR是亚马逊提出的一种时间序列预测算法,通过在大量时间序列上训练自回归递归网络模型,从相关的时间序列中有效学习全局模型,进而对时间序列进行预测。它能应对季节性、周期性等问题。本文探索下该模型在股票预测中的应用。安装包!pip install baostock==0.8.8!pip install mxnet==1.7.0.post2!pip install gluonts==0.9....

DeepAR是亚马逊提出的一种时间序列预测算法,通过在大量时间序列上训练自回归递归网络模型,从相关的时间序列中有效学习全局模型,进而对时间序列进行预测。它能应对季节性、周期性等问题。本文探索下该模型在股票预测中的应用。

  1. 安装包

!pip install baostock==0.8.8
!pip install mxnet==1.7.0.post2
!pip install gluonts==0.9.5
!pip install matplotlib==3.3.1

2. 导入包

    import baostock as bs
    from gluonts.dataset import common
    from gluonts.model import deepar
    from gluonts.mx.trainer import Trainer
    from gluonts.evaluation.backtest import make_evaluation_predictions
    from tqdm.autonotebook import tqdm
    import os
    import matplotlib.pyplot as plt

    3. 数据准备

    3.1. 变量初始化

      stock_list = ['sh.603198', 'sh.600809', 'sh.600519'] #以三个白酒股为例
      train_list = []
      test_list = []
      prediction_length = 20 #预测长度

      3.2. 数据下载函数

        def get_stockdata(code):
            rs = bs.query_history_k_data_plus(code,
                                              'date,close,volume,turn',
                                              start_date='2020-09-01',
                                              end_date='2022-09-01',
                                              frequency='d', adjustflag='2') 
            return rs.get_data()

        3.3. 连上baostock

          lg = bs.login()

          3.4. 数据下载与划分

            for stock_id in stock_list:
                df = get_stockdata(stock_id)
                train_dic = {'start':df.date[0],
                             'target':df.close,
                             'cat':int(stock_id.split('.')[1]),
                             'dynamic_feat':[df.volume, df.turn]}
                    test_dic = {'start':df.date[0],
                            'target':df.close[:-prediction_length],
                            'cat':int(stock_id.split('.')[1]),
                            'dynamic_feat':[df.volume[:-prediction_length], df.turn[:-prediction_length]]}
                train_list.append(train_dic)
                test_list.append(test_dic)

            3.5. 数据封装

              train_data = common.ListDataset(train_list, freq='1d')
              test_data = common.ListDataset(test_list, freq='1d')

              4. 建模与训练

              4.1. 构建训练器

                estimator = deepar.DeepAREstimator(
                    prediction_length=prediction_length,
                    context_length=60,
                    freq='1d',
                    num_layers=2,
                    num_cells=64,
                    trainer=Trainer(epochs=20,
                                    learning_rate=1e-2,
                                    num_batches_per_epoch=32))

                4.2. 模型训练

                  predictor = estimator.train(train_data)

                  图片

                  5. 股价预测

                    forecast_it, ts_it = make_evaluation_predictions(
                        dataset=test_data,
                        predictor=predictor,
                        num_samples=100)

                    6. 模型评估与可视

                    6.1. 数据封装

                      tss = list(tqdm(ts_it, total=len(test_data)))
                      forecasts = list(tqdm(forecast_it, total=len(test_data)))

                      图片

                      6.2. 新建存放文件夹

                        plot_log_path = './plots/'
                        directory = os.path.dirname(plot_log_path)
                        if not os.path.exists(directory):
                            os.makedirs(directory)

                        6.3. 结果可视化

                          def plot_prob_forecasts(ts_entry, forecast_entry, path, sample_id):
                              plot_length = 150
                              prediction_intervals = (50, 80)
                              legend = ['observations', 'median prediction'] + [f'{k}% prediction interval' for k in prediction_intervals][::-1]
                              _, ax = plt.subplots(1, 1, figsize=(10, 7))
                              ts_entry[-plot_length:].plot(ax=ax)
                              forecast_entry.plot(prediction_intervals=prediction_intervals, color='g')
                              ax.axvline(ts_entry.index[-prediction_length], color='r')
                              plt.legend(legend, loc='upper left')
                              plt.savefig('{}forecast_{}.png'.format(path, sample_id))
                              plt.close()

                          6.4. 将结果存入刚新建的文件夹

                            for i in tqdm(range(len(stock_list))):
                                ts_entry = tss[i]
                                forecast_entry = forecasts[i]
                                plot_prob_forecasts(ts_entry, forecast_entry, plot_log_path, i)

                            图片

                            6.5. 预测结果

                            在刚刚新建的文件夹下可以看到

                            图片

                            三个股票的预测结果如下:

                            图片

                            图片

                            图片

                            可以看到,只有第三个票的结果勉强可以接受,想要将这种预测方法运用于实战也还需要深入地研究。

                            【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
                            • 点赞
                            • 收藏
                            • 关注作者

                            评论(0

                            0/1000
                            抱歉,系统识别当前为高风险访问,暂不支持该操作

                            全部回复

                            上滑加载中

                            设置昵称

                            在此一键设置昵称,即可参与社区互动!

                            *长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

                            *长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。