【王喆-推荐系统】前沿篇-(task2)GNN在Pinterest的应用
学习总结
(1)GraphSAGE 的主要步骤是三步“采样 - 聚合 - 预测”:
- 采样是指在整体图数据上随机确定中心节点,采样 k 阶子图样本。
- 聚合是指利用 GNN 把 k 阶子图样本聚合成中心节点 Embedding。
- 预测是指利用 GNN 做有监督的标签预测或者直接生成节点 Embedding。
(2)3步中的重点在于聚合的 GNN 结构,它使用 CONVOLVE 操作把邻接点 Embedding 聚合起来,跟中心节点上一轮的 Embedding 连接后,利用全连接层生成新的 Embedding。
(3)决定什么任务的是GraphSAGE的输出层:
- 有监督学习:输出层是LR这种二分类模型,input是之前通过GNN学习到的中心节点embedding,output是预测概率的label。
- 无监督学习:可以用像word2vec一样用softmax作为输出层,预测每个点的ID,这样每个点ID的softmax的输出层向量就是该点的embedding,原理和word2vec一致。和上一节YouTube架构中的召回层的视频向量的生成也是一致的:
一、以往方法
deepwalk、Node2vec等Graph embedding方法并没直接处理图结构数据,而是先把图结构数据通过随机游走采样,转为序列数据,然后再用word2vec这类序列数据embedding方法生成Graph embedding。
但是这种“搭桥”的方法对图数据进行采样的时候,破坏了信息原始的结构。
二、GraphSAGE 的主要步骤
Graph Sample and Aggregate,翻译过来叫“图采样和聚集方法”。
GraphSAGE 的过程如上图所示,主要可以分为 3 步:
- 在整体的图数据上,从某一个中心节点开始采样,得到一个 k 阶的子图,示意图中给出的示例是一个二阶子图;
- 有了这个二阶子图,我们可以先利用 GNN 把二阶的邻接点聚合成一阶的邻接点(图 1-2 中绿色的部分),再把一阶的邻接点聚合成这个中心节点(图 1-2 中蓝色的部分);
- 有了聚合好的这个中心节点的 Embedding,我们就可以去完成一个预测任务,比如这个中心节点的标签是被点击的电影,那我们就可以让这个 GNN 完成一个点击率预估任务。
GNN 既可以预测中心节点的标签,比如点击或未点击,也可以单纯训练中心节点的 Embedding。主要步骤就是三个“抽样 - 聚合 - 预测”。
三、GraphSAGE 的模型结构
GraphSAGE 的模型结构到底怎么样?它到底是怎么把一个 k 阶的子图放到 GNN 中去训练,然后生成中心节点的 Embedding 的呢?
上图中处理的样本是一个以点 A 为中心节点的二阶子图,从左到右我们可以看到,点 A 的一阶邻接点包括点 B、点 C 和点 D,从点 B、C、D 再扩散一阶,可以看到点 B 的邻接点是点 A 和点 C,点 C 的邻接点是 A、B、E、F,而点 D 的邻接点是点 A。
3.1 GraphSAGE 的训练过程:
这个 GNN 的输入是二阶邻接点的 Embedding,二阶邻接点的 Embedding 通过一个叫 CONVOLVE 的操作生成了一阶邻接点的 Embedding,然后一阶邻接点的 Embedding 再通过这个 CONVOLVE 的操作生成了目标中心节点的 Embedding,至此完成了整个训练。
3.2 CONVOLVE
CONVOLVE 的中文名是卷积,但这里的卷积并不是严格意义上的数学卷积运算,而是一个由 Aggregate 操作和 Concat 操作组成的复杂操作。
CONVOLVE 操作是由两个步骤组成的:
- 第一步叫 Aggregate 操作,就是图 4 中 gamma 符号代表的操作,它把点 A 的三个邻接点 Embedding 进行了聚合,生成了一个 Embedding hN(A);
- 第二步,我们再把 hN(A) 与点 A 上一轮训练中的 Embedding hA 连接起来,然后通过一个全联接层生成点 A 新的 Embedding。
3.3 第一步的Aggregate操作
就是把多个 Embedding 聚合成一个 Embedding 的操作。比如,我们最开始使用的 Average Pooling,在 DIN 中使用过的 Attention 机制,在序列模型中讲过的基于 GRU 的方法,以及可以把这些 Embedding 聚合起来的 MLP 等等。
四、GraphSAGE 的预测目标
预测节点的标签(如点击or未点击)是一个有监督学习任务;
生成节点的embedding是一个无监督学习任务。
决定什么任务的是GraphSAGE的输出层:
(1)有监督学习:输出层是LR这种二分类模型,input是之前通过GNN学习到的中心节点embedding,output是预测概率的label。
(2)无监督学习:可以用像word2vec一样用softmax作为输出层,预测每个点的ID,这样每个点ID的softmax的输出层向量就是该点的embedding,原理和word2vec一致。和上一节YouTube架构中的召回层的视频向量的生成也是一致的:
五、GraphSAGE 在 Pinterest 推荐系统中的应用
在 PinSAGE 应用的构成中,它没有直接分析图片内容,而只是把图片当作一个节点,利用节点和周围节点的关系生成的图片 Embedding。因此,这个例子可以说明,PinSAGE 某种程度上理解了图片的语义信息,而这些语义信息正是埋藏在 Pinterest 的商品关系图中。
六、作业
使用 GraphSAGE 是为了生成每个节点的 Embedding,那我们有没有办法在 GraphSAGE 中加入物品的其他特征,如物品的价格、种类等等特征,让最终生成的物品 Embedding 中包含这些物品特征的信息呢?
【答】可以在k阶聚合完成后,像wide&deep钟一样,将节点的embedding和物品其他特征拼接后接入全连接层和softmax层得到embedding。
七、课后答疑
(1)在实际公司推荐场景中如果要应用这个算法,数据是通过图数据库来存储吗?能否推荐一个生产环境适合的图数据库?
【答】最近大家提neo4j比较多,https://neo4j.com/,可以研究一下。也可以用spark xgraph直接处理原始数据。
Reference
(1)https://github.com/wzhe06/Reco-papers
(2)《深度学习推荐系统实战》,王喆
文章来源: andyguo.blog.csdn.net,作者:山顶夕景,版权归原作者所有,如需转载,请联系作者。
原文链接:andyguo.blog.csdn.net/article/details/121399789
- 点赞
- 收藏
- 关注作者
评论(0)