来自华为诺亚实验室的ICLR 2020满分论文:基于强化学习的因果发现算法

举报
Noah's Ark LAB 发表于 2020/08/26 17:47:09 2020/08/26
【摘要】 华为诺亚方舟实验室因果研究团队将强化学习应用到打分法的因果发现算法中,最终得到因果图结构。该工作获得了ICLR 2020满分评价,并做口头报告。

因果研究作为下一个潜在的热点,已经吸引了机器学习/深度学习领域的的广泛关注,例如Youshua Bengio和Fei-Fei Li近期都有相关的工作。因果研究中一个经典的问题是“因果发现”问题——从被动可观测的数据中发现潜在的因果图结构。华为诺亚方舟实验室因果研究团队将强化学习应用到打分法的因果发现算法中,通过基于自注意力机制的encoder-decoder神经网络模型探索数据之间的关系,结合因果结构的条件,并使用策略梯度的强化学习算法对神经网络参数进行训练,最终得到因果图结构。在学术界常用的一些数据模型中,该方法在中等规模的图上的表现优于其他方法,包括传统的因果发现算法和近期的基于梯度的算法。同时该方法非常灵活,可以和任意的打分函数结合使用。


该工作获得了ICLR 2020满分评价,并做口头报告。论文地址为:https://arxiv.org/abs/1906.04477  


模型定义和问题

我们假设以下常用的数据生成模型:给定一个有向无环图(DAG),每个节点对应一个随机变量,每个变量的观测值是图中父亲变量的函数加上一个独立的噪声,即

image.png

这里噪声n_i是联合独立的。如果所有的函数都是线性的且噪声是高斯的,则上述模型为标准的线性高斯模型。当函数为线性但噪声为非高斯函数时,上述模型为线性非高斯加性模型(LiNGAM),在一定的条件下是可以识别出真实的DAG。

我们目前考虑所有的变量都是一维的实变量;给定一个合适的打分函数则可以直接扩展到多维变量的情形。在固定的函数和噪声分布下,我们的观测数据是根据上述模型在某个未知的DAG上独立采样得到。因果发现的目的就是使用这些观测的数据来推断真实的因果DAG。


背景介绍

打分法是因果发现算法中一类常用的方法:给每个有向图打分(通常基于观测数据计算得到),然后在所有的DAG中进行搜索取得最好分数的DAG:

image.png

尽管有很多已经深入研究的打分函数,例如基于线性高斯模型的BIC/MDL和BGe分数,但上述问题通常是NP-hard的,因为DAG条件是一个组合问题,并且可能的DAG数量的随着图节点的个数增加而超指数增加。为了解决这个问题,大多数已有方法都依赖于局部启发式算法。例如,贪婪等价搜索(GES)在添加一条边时显式检查DAG约束是否满足。GES在适当的假设和极限数据量的情况下可以找到具全局最优值,但在有限样本的情况下无法得到保证。

最近,也有工作在线性数据模型上对上述的无环条件提出了一个等价的可微分函数,再选择适当的损失函数(例如最小二乘损失),上述问题可以转换为关于带权值的邻接矩阵的连续优化问题。后续的工作也采用ELBO和negative log-likelihood作为损失函数,并使用神经网络对因果关系进行建模。但是很多已有的得分函数没有显式的表示或者是非常复杂的等价损失函数,这样和上述连续的方法结合会比较困难。


基于强化学习的因果发现算法

我们提出一种基于RL的方法来搜索DAG,整体框架图如下所示。基于随机策略的RL可以在给定策略的不确定性信息的情况下自动确定要搜索的位置,同时可以通过奖励信号来及时更新。在合成数据集和真实数据集上的实验表明,基于强化学习的方法大大提高了搜索能力,并且不会影响打分函数的选择。

 


3.1 基于自注意力机制的Encoder-Decoder模型

如上图所示,我们采用Transfomer中基于自注意机制的encoder, 而decoder则是通过建立成对的encoder输出之间的关系来生成图的邻接矩阵。为了得到0-1的邻接矩阵,我们将每个decoder的输出通过logistic-sigmoid函数,然后使用Bernoulli分布进行采样。

我们也尝试了其他的decoder,例如bilinear model以及Transformer中的decoder。我们实验发现上图中decoder的效果最好,可能是因为它的参数量比较少、更容易训练来找到更好的DAG,而基于自注意力机制的encoder已经提供了足够的交互来探索数据之间的因果关系。


3.2 Reward

传统的GES会在每次添加一条边时显式的检查图是否有环,我们使用打分函数和基于有环性质的惩罚项来设计reward,并允许生成的图在每次迭代中变化多条边。具体的形式如下: 

image.png

其中第一项是得分函数,用于衡量给定有向图和观测数据的匹配程度,其他两个正项则衡量某些“DAGness”(给定的有向图距无环的某种度量,例如所有环上的长度之和),lambda_1和lamba_2是惩罚项的权重。通过选择适当的惩罚权重,最大化reward等价于之前打分法的问题的形式。但是两个问题等价并不意味着使用RL来最大化reward就可以直接取得很好的结果:实际中,我们发现较大的惩罚权重可能会妨碍RL的探索,得到的因果图的得分通常比较差,而较小的惩罚值将导致有环的图。同时,不同的打分函数可能具有非常不同的范围,而两个惩罚项的值与打分函数是没有关系的。因此,我们将所有的打分函数调整到一定范围,并为惩罚权重设计一种在线更新策略。详细内容可以参见论文的第5章。


3.3 Actor-Critic优化参数

我们采用策略梯度和随机优化的方法来优化以下目标

image.png

其中A中有向图对应的0-1邻接矩阵。我们使用Actor-Critic来进行训练,同时还加了熵正则项来鼓励探索。尽管策略梯度方法仅在一定条件下能保证局部收敛,但是通过惩罚项系数的设计,在我们的实验中RL算法得到的图都是无环的。


3.4 最终输出

由于我们关注的是寻找得分最好的DAG,而不是policy,因此我们记录了训练过程中生成的所有的有向图,并选择具有最佳reward的图作为输出结果。实际上由于有限的数据,图中会包含一些真图里边不存在的边,因此需要进一步的减枝处理。

我们可以根据损失函数或者打分函数,使用贪婪方法来进行减枝操作。我们删除一个父亲变量并计算相应的结果,如果损失函数或者打分函数效果没有变差或者是在预先设定的范围内,就接受减枝的操作并继续下去。对于线性模型,可以通过和阈值比较的方法来进行减枝。


实验结果

在此工作中,我们使用BIC打分函数,并假设附加性的高斯噪声(实际中噪声可能是非高斯的)。考虑两种情况:不同的噪声方差,等价于negative log-likelihood加上一个对边的个数的惩罚项作为打分函数;以及相等的噪声方差,将得到最小平方损失加上边的个数的惩罚项。它们分别表示为RL-BIC和RL-BIC2。

我们的方法与传统方法(PC,GES,ICA-LiNGAM和CAM)以及最近基于梯度的方法(NOTEARS,DAG-GNN和GraN-DAG)在学术界常用的一些数据集上进行了比较。我们使用三个指标评估学到的图结构:错误发现率(FDR),正确率(TPR)和结构汉明距离(SHD)。SHD是将得到的图转换为真实DAG的边添加,删除和反转操作的最少个数。


4.1 高斯和非高斯噪声的线性数据模型

我们首先考虑12个节点的有向图。图2显示了在一个线性高斯数据集上RL-BIC2的训练过程。我们采用NOTEARS和DAG-GNN在同样的数据集上使用的阈值来做减枝。在这个例子中,RL-BIC2在训练过程中生成683,784个不同的图,远低于12个节点DAG的总数(约5.22 * 10^26)。经过减枝的DAG和真实的图结构完全相同。

表1是我们在LiNGAM和线性高斯数据模型的实验结果。在该实验中,RL-BIC2在两个数据模型上恢复了所有真实的因果图,而RL-BIC的表现稍差。尽管如此,在相同的BIC分数下,RL-BIC在两个数据集上的表现均远好于GES。


4.2 具有高斯过程的非线性模型

我们考虑一种非线性的数据模型,每个因果关系函数是从高斯过程中采样的一个函数。该问题被证明是可识别的,即可以从联合概率分布中识别出真实的图。我们使用和GraN-DAG一样的实验条件:10个节点,40条边的DAG,并考虑1000个观测样本。

实验结果如下表3所示。对于我们的方法,我们将高斯过程回归(GPR)与RBF 核一起使用来建立因果关系模型。虽然观察到的数据是来自于高斯过程采样得到的函数,但这并不能保证具有相同核的GPR可以达到很好的结果。实际上,使用固定的核参数将导致严重的过度拟合,从而导致许多错误的边,这样训练结束最好reward对应的有向图通常不是DAG。为此我们将数据归一化处理,并使用median heuristics来选择核参数。我们两种方法的表现都不错,其中RL-BIC的结果优于其他所有方法。

 



4.3 真实数据集

我们最后考虑Sachs数据集,通过蛋白质和磷脂的表达程度来发现蛋白质信号网络。我们将带有RBF内核的GPR应用于因果关系建模,对数据做归一化并使用基于median heuristics的核参数。我们使用和CAM及Gran-DAG中同样的减枝方法。实验结果见下表。与其他方法相比,RL-BIC和RL-BIC2均取得了不错的结果。


结语

我们使用强化学习来搜索具有最佳分数的DAG,其中actor是基于自注意力机制的encoder-decoder模型,而reward结合了预先给定的得分函数和两个惩罚项来得到无环图。在合成和真实数据集上,该方法均取得了很好的结果。在论文里,我们还展示了该方法在30节点的图上的效果,但是处理大规模的图(超过50个节点)仍然具有挑战性。尽管如此,许多实际的应用(例如Sachs数据集)的变量数都相对较少。此外,有可能将大的因果发现问题分解为较小的问题分别处理,基于先验知识或基于约束的方法也可以用来减少搜索空间。

当前的工作有几个未来改进的方向。在目前的实现中,打分函数的计算比训练神经网络会花费更多的时间,一个更有效率的打分函数将会大大提升目前算法的表现。其他RL算法也可以用来加速训练,例如A3C。此外,我们观察到实验中使用的总迭代次数通常超过了需要的次数,我们也会研究如何进行early stopping。

除了利用深度学习来帮助因果发现,我们的团队还致力于使用因果来增强机器学习、深度学习。我们相信这个方向会有很好的前景,但是目前的挑战也非常大。如果您对我们的研究内容有兴趣(实习或者全职),请发送您的简历到 zhushengyu@huawei.com 。


【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。