昇腾AI4S图机器学习:DGL消息传递接口的PyG替换

举报
AI4S_NPU 发表于 2025/06/16 00:42:31 2025/06/16
【摘要】 背景介绍DGL (Deep Graph Learning) 和 PyG (Pytorch Geometric) 是两个主流的图神经网络库,它们在API设计和底层实现上有一定差异,在不同场景下,研究人员会使用不同的依赖库,昇腾NPU对PyG图机器学习库的支持亲和度更高,因此有些时候需要做DGL接口的PyG替换。SE3Transformer在RFdiffusion蛋白质设计模型中(GitHub ...

背景介

DGL (Deep Graph Learning) PyG (Pytorch Geometric) 是两个主流的络库,它API设计和底层实现上有一定差异,在不同景下,研究人会使用不同的依赖库,昇NPUPyG机器学习库的支持和度更高,因此有些候需要做DGL接口的PyG

image001.png

SE3TransformerRFdiffusion蛋白质设计模型中(GitHub - RosettaCommons/RFdiffusion: Code for running RFdiffusion)作核心件,负责处理蛋白质结构的几何信息。其架构基于,通SE(3)实现对和平移的不性特征提取。本系列以RFDiffusion模型中的SE3Transformer例,解如何将DGL中的接口替换为PyG实现

image003.png

在本文中,主要展示消息传递接口的PyG

消息传递接口

一、-点消息传递 (EdgeSoftmax + Aggregation)

位置:

rfdiffusion/modules/equivariant_attention/modules.py 中的 TransformerLayer

:

  • 节点特征: x , 形状为(N, F)
  • 边特征: edge_attr , 形状为(E, F')
  • 图结构: graph

:

  • 更新的节点特征: 形状为(N, F_out)

DGL函数:

  • dgl.nn.EdgeSoftmax:对边特征进行归一化
  • dgl.function.copy_edge:复制边特征
  • dgl.function.sum:聚合消息

数学逻辑:

  1. 算注意力分数: a_{ij}=\mathrm{softmax}_j(e_{ij})
  2. 消息聚合: h_i^{\prime}=\sum_{j\in\mathcal{N}(i)}a_{ij}\cdot h_j

PyG实现:

def edge_softmax_aggregation(x, edge_index, edge_attr):

        # 算源点和目标节点索引

        src, dst = edge_index

 

        # softmax

        exp_edge_attr = torch.exp(edge_attr)

 

        # 按目标节一化

        node_degree = scatter_add(exp_edge_attr, dst, dim=0, dim_size=x.size(0)) norm = node_degree[dst].clamp(min=1e-6)

        norm_edge_attr = exp_edge_attr / norm

 

        # 消息传递

        message = norm_edge_attr * x[src]

 

        # 聚合

        out = scatter_add(message, dst, dim=0, dim_size=x.size(0))

 

        return out

二、矢量特征消息传递

位置:

rfdiffusion/modules/equivariant_attention/modules.py 中的 AttentionBlockSE3

:

  • 标量特征: feat_scalar , 形状为(N, F_s)
  • 矢量特征: feat_vector , 形状为(N, F_v, 3)
  • 图结构: graph

:

  • 更新的标量和矢量特征

DGL函数:

  • dgl.nn.EdgeSoftmax:边特征softmax
  • g.send_and_recv:消息传递与聚合

数学逻辑:

  1. m_{ij}=f_\mathrm{att}(h_i^s,h_j^s,h_i^v,h_j^v)
  2. 矢量特征旋转: h_j^v\cdot R_{ij},其中R_{ij}是相对方向

PyG实现:

  • 需要自定义消息传递函数
  • 实现等变性旋转操作
  • 处理批处理边索引

 

【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。