【图神经网络DGL】消息传递范式

举报
野猪佩奇996 发表于 2022/01/23 01:24:37 2022/01/23
【摘要】 学习总结 文章目录 学习总结一、消息传递范式1.1 内置函数和消息传递API(1)API属性介绍(2)内置&自定义函数 1.2 编写高效的消息传递代码1.3 在图的一部分上进行消息...

学习总结

一、消息传递范式

聚合函数和更新函数。
在这里插入图片描述

1.1 内置函数和消息传递API

(1)API属性介绍

  • 消息函数:接受一个参数 edges,这是一个 dgl.EdgeBatch 的实例, 在消息传递时,它被DGL在内部生成以表示一批边。edges有三个成员属性:srcdstdata,分别用于访问源节点、目标节点和边的特征。

  • 聚合函数:接受一个参数 nodes,这是一个 NodeBatch 的实例, 在消息传递时,它被DGL在内部生成以表示一批节点。 nodes 的成员属性 mailbox 可以用来访问节点收到的消息。 一些最常见的聚合操作包括 summaxminmean等。聚合函数一般有2个参数,它们的类型都是字符串:

    • 一个用于指定mailbox中的字段名;
    • 一个用于指示目标节点特征的字段名,例如dgl.function.sum('m', 'h')等价于如下所示的对接收到消息求和的用户定义函数:
import torch
def reduce_func(nodes):
     return {'h': torch.sum(nodes.mailbox['m'], dim=1)}

  
 
  • 1
  • 2
  • 3
  • 更新函数:接受一个如上所述的参数 nodes。此函数对 聚合函数 的聚合结果进行操作, 通常在消息传递的最后一步将其与节点的特征相结合,并将输出作为节点的新特征

(2)内置&自定义函数

(1)命名空间dfl.function中实现了常用的(内置的)消息函数和聚合函数(能够自动处理维度广播),当然也可以自定义函数。

(2)(自定义)内置消息函数:可以是一元函数(dgl支持copy函数),也支持二元函数(dgl支持add、sub、mul、div、dot函数):

  • 消息的内置函数的命名约定是u表示源节点,v表示目标节点,e表示边。
  • 这些函数的参数是字符串,指示相应节点和边的输入和输出特征字段名。
    • ex:要对源节点的hu特征和目标节点的hv特征求和,然后将结果保存在边的he特征上:dgl.function.u_add_v('hu', 'hv', 'he')
    • 如下自定义消息函数和内置函数相同:
def message_func(edges):
     return {'he': edges.src['hu'] + edges.dst['hv']}

  
 
  • 1
  • 2

(3)在DGL中,也可以在不涉及消息传递的情况下,通过 apply_edges() 单独调用逐边计算。

  • apply_edges() 的参数是一个消息函数。
  • 在默认情况下,这个接口将更新所有的边。
  • 例如:
import dgl.function as fn
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))

  
 
  • 1
  • 2

(4)消息传递的高级APIupdate_all()

  • 它在单个API调用里合并了消息生成、消息聚合和节点特征更新,为从整体上进行系统优化提供了空间。
  • update_all()的参数:一个消息、聚合、更新函数(可选,也可以在外面操作,dgl不推荐在update_all中指定更新函数)。
    • 更新函数是一个可选择的参数,可以不使用,而是在update_all执行完后直接对节点特征进行操作;
    • 因为更新函数通常可用纯张量操作实现,所以DGL不推荐在update_all中指定更新函数,如函数:

 final  f t i = 2 ∗ ∑ j ∈ N ( i ) ( f t j ∗ a i j ) \text { final } f t_{i}=2 * \sum_{j \in \mathcal{N}(i)}\left(f t_{j} * a_{i j}\right)  final fti=2jN(i)(ftjaij)

def updata_all_example(graph):
    # 在graph.ndata['ft']中存储结果
    graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
                     fn.sum('m', 'ft'))
    # 在update_all外调用更新函数
    final_ft = graph.ndata['ft'] * 2
    return final_ft

  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

ex:graph.update_all(fn.u_mul_e('ft', 'a', 'm')将源节点特征tf和边特征a相乘生成消息mfn.sum('m', 'ft')再对所有消息求和来更新节点特征ft,再乘2后得到最终结果final_ft。调用后,中间消息m会被清除。

1.2 编写高效的消息传递代码

关于dgl内置函数是如何优化消息传递的内存消耗和计算速度的, 详见文字描述: DGL官方文档 ; 总结来说主要是合并内核, 并行逐边运算, 减少点边拷贝等; 如update_all()函数就是一个效率很高的接口; 如果确实需要使用apply_edges()函数在边上保存消息, 则内存占用会非常大;

(1)一个通过对节点特征降维来减少消息维度的示例:

  • 拼接源节点与目标节点特征, 然后应用一个线性层: W × ( u ∣ ∣ v ) W\times (u||v) W×(uv)
  • 这样源节点与目标节点特征维数较高, 而线性层输出维数较低;
  • 代码示例:
import torch
import torch.nn as nn

linear = nn.Parameter(torch.FloatTensor(size=(1, node_feat_dim * 2)))
def concat_message_function(edges):
	 return {'cat_feat': torch.cat([edges.src.ndata['feat'], edges.dst.ndata['feat']])}
g.apply_edges(concat_message_function)
g.edata['out'] = g.edata['cat_feat'] * linear	

  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

也可以将先行操作分成两部分, 即分别对源节点特征和目标节点特征进行线性变换后再相加, 即 W l × u + W r × v W_{l} \times u+W_{r} \times v Wl×u+Wr×v,其中 W = ( W l ∥ W r ) W = \left(W_{l} \| W_{r}\right) W=(WlWr),这样可能会更加优化。代码实例:

import dgl.function as fn

linear_src = nn.Parameter(torch.FloatTensor(size=(1, node_feat_dim)))
linear_dst = nn.Parameter(torch.FloatTensor(size=(1, node_feat_dim)))
out_src = g.ndata['feat'] * linear_src
out_dst = g.ndata['feat'] * linear_dst
g.srcdata.update({'out_src': out_src})
g.dstdata.update({'out_dst': out_dst})
g.apply_edges(fn.u_add_v('out_src', 'out_dst', 'out'))

  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

这两种方法数学上等价, 但后一种方法更加高效, 因为无需再边上保存feat_srcfeat_dst, 空间占用小, 另外加法可以直接用内置函数u_add_v进行优化, 内置函数的效率一般比自定义函数要高。

1.3 在图的一部分上进行消息传递

如果用户只想更新图中部分节点,先将想处理的节点编号创建一个子图,然后对其调用update_all()(这也是小批量处理中的常见用法)。

nid = [0, 2, 3, 6, 7, 9]
sg = g.subgraph(nid)
sg.update_all(message_func, reduce_func, apply_node_func)

  
 
  • 1
  • 2
  • 3

1.4 在消息传递中使用边的权重

常见的GNN建模做法:在消息聚合前使用边的权重,如GAT和一些GCN的变种。dgl的处理:

  • 将权重存为边的特征;
  • 在消息函数中用边的特征和源节点的特征相乘。

ex:假定下面的权重eweight是一个形状为(E, *)的张量,E是边的数量。权重存为边的特征,即eweight被用作边的权重(通常是一个标量)。

import dgl.function as fn

# 假定eweight是一个形状为(E, *)的张量,E是边的数量。
graph.edata['a'] = eweight
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
                 fn.sum('m', 'ft'))

  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

1.5 在异构图上进行消息传递

本质上异构图的消息传递与同构图并没有太大区别,异构图上的消息传递可以分为两个部分:

  • 对每个关系计算和聚合消息。

  • 对每个结点聚合来自不同关系的消息。

  • DGLGraph.multi_update_all(etype_dict, cross_reducer, apply_node_func=None)

    • etype_dict: dict类型, 键为一种关系, 值为这种关系对应的update_all()的参数;
    • cross_reducer: str类型, 表示跨类型整合函数, 来指定整合不同关系聚合结果的方式, 可以是sum, min, max, mean, stack中之一;

在DGL中,对异构图进行消息传递的接口是 multi_update_all()multi_update_all() 接受一个字典。这个字典的每一个键值对里,键是一种关系, 值是这种关系对应 update_all() 的参数。 multi_update_all() 还接受一个字符串来表示跨类型整合函数,来指定整合不同关系聚合结果的方式。 这个整合方式可以是 summinmaxmeanstack 中的一个。

import dgl.function as fn

for c_etype in G.canonical_etypes:
    srctype, etype, dsttype = c_etype
    Wh = self.weight[etype](feat_dict[srctype])
    # 把它存在图中用来做消息传递
    G.nodes[srctype].data['Wh_%s' % etype] = Wh
    # 指定每个关系的消息传递函数:(message_func, reduce_func).
    # 注意结果保存在同一个目标特征“h”,说明聚合是逐类进行的。
    funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.mean('m', 'h'))
# 将每个类型消息聚合的结果相加。
G.multi_update_all(funcs, 'sum')
# 返回更新过的节点特征字典
return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes}

  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

小结消息传递的流程:
在这里插入图片描述

  • 消息函数(message function):传递消息的目的是将节点计算时需要的信息传递给它,因此对每条边来说,每个源节点将会将自身的Embedding(e.src.data)和边的Embedding(edge.data)传递到目的节点;对于每个目的节点来说,它可能会受到多个源节点传过来的消息,它会将这些消息存储在"邮箱"中。
  • 汇聚函数(reduce function):汇聚函数的目的是根据邻居传过来的消息更新跟新自身节点Embedding,对每个节点来说,它先从邮箱(v.mailbox[‘m’])中汇聚消息函数所传递过来的消息(message),并清空邮箱(v.mailbox[‘m’])内消息;然后该节点结合汇聚后的结果和该节点原Embedding,更新节点Embedding。
  • 更新: message函数的参数是边,包括源节点,目标节点的特征信息,处理完的数据放置到节点的mailbox中。 聚合函数reduce_function或 apply_node函数作用于节点本身,即传入的参数是节点信息和节点的邮箱信息。

二、不带边权重的例子

2.1 消息传递框架

DGL遵循Gilmer等人提出的消息8传递框架,很多GNN模型能符合如下框架:
m u → v ( l ) = M ( l ) ( h v ( l − 1 ) , h u ( l − 1 ) , e u → v ( l − 1 ) ) m v ( l ) = ∑ u ∈ N ( v ) m u → v ( l ) h v ( l ) = U ( l ) ( h v ( l − 1 ) , m v ( l ) )

m(l)uv=M(l)(h(l1)v,h(l1)u,e(l1)uv)m(l)v=uN(v)m(l)uvh(l)v=U(l)(h(l1)v,m(l)v) m u v ( l ) = M ( l ) ( h v ( l 1 ) , h u ( l 1 ) , e u v ( l 1 ) ) m v ( l ) = u N ( v ) m u v ( l ) h v ( l ) = U ( l ) ( h v ( l 1 ) , m v ( l ) )
muv(l)=M(l)(hv(l1),hu(l1),euv(l1))mv(l)=uN(v)muv(l)hv(l)=U(l)(hv(l1),mv(l))

  • M ( l ) M^{(l)} M(l)是消息message函数;
  • ∑ \sum 是聚合函数(reduce function),不一定是求和;
  • U ( l ) U^{(l)} U(l)是更新函数(update function)。

2.2 GraphSAGE的消息传递

如GraphSAGE可表示为:
h N ( v ) k ←  Average  { h u k − 1 , ∀ u ∈ N ( v ) } h v k ← ReLU ⁡ ( W k ⋅ CONCAT ⁡ ( h v k − 1 , h N ( v ) k ) )

hkN(v) Average {hk1u,uN(v)}hkvReLU(WkCONCAT(hk1v,hkN(v))) h N ( v ) k  Average  { h u k 1 , u N ( v ) } h v k ReLU ( W k CONCAT ( h v k 1 , h N ( v ) k ) )
hN(v)k Average {huk1,uN(v)}hvkReLU(WkCONCAT(hvk1,hN(v)k))

我们可以看到消息传递是定向(有方向)的:从一个节点u发送到另一个节点v的消息不一定与从节点v发送到相反方向的节点u的消息相同。
DGL提供了GraphSAGE的实现dgl.nn.SAGEConv

import dgl.function as fn

class SAGEConv(nn.Module):
    """Graph convolution module used by the GraphSAGE model.

    Parameters
    ----------
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    """
    def __init__(self, in_feat, out_feat):
        super(SAGEConv, self).__init__()
        # A linear submodule for projecting the input and neighbor feature to the output.
        self.linear = nn.Linear(in_feat * 2, out_feat)

    def forward(self, g, h):
        """Forward computation

        Parameters
        ----------
        g : Graph
            The input graph.
        h : Tensor
            The input node feature.
        """
        with g.local_scope():
            g.ndata['h'] = h
            # update_all is a message passing API.
            g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h_N'))
            h_N = g.ndata['h_N']
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)

  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

上述代码中的核心部分是g.update_all函数,该函数收集并平均相邻特征。这里有三个概念:

  • 消息函数fn.copy_u('h','m'),它将名为“h”的节点特征复制为发送给邻居的消息
  • 聚合函数fn.mean('m', 'h_N'),该函数对所有接收到的消息中名为’m’的信息进行平均,并将结果保存为新的节点特征’h_N’
  • update_all让DGL触发所有节点和边的消息函数和聚合函数

2.3 堆叠网络

然后我们可以堆叠自己的GraphSAGE卷积层以构成多层GraphSAGE网络:

import torch.nn.functional as F

class Model(nn.Module):
    
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = GraphSAGE(in_feats, h_feats)
        self.conv2 = GraphSAGE(h_feats, num_classes)
    
    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h

  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

2.4 训练网络

import dgl.data

dataset = dgl.data.CoraGraphDataset()
g = dataset[0]

def train(g, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    all_logits = []
    best_val_acc = 0
    best_test_acc = 0

    features = g.ndata['feat']
    labels = g.ndata['label']
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
    test_mask = g.ndata['test_mask']
    for e in range(200):
        # Forward
        logits = model(g, features)

        # Compute prediction
        pred = logits.argmax(1)

        # Compute loss
        # Note that we should only compute the losses of the nodes in the training set,
        # i.e. with train_mask 1.
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])

        # Compute accuracy on training/validation/test
        train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
        val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
        test_acc = (pred[test_mask] == labels[test_mask]).float().mean()

        # Save the best validation accuracy and the corresponding test accuracy.
        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        all_logits.append(logits.detach())

        if e % 5 == 0:
            print('In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})'.format(
                e, loss, val_acc, best_val_acc, test_acc, best_test_acc))

model = Model(g.ndata['feat'].shape[1], 16, dataset.num_classes)
train(g, model)

  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50

2.5 结果

Using backend: pytorch
  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.
In epoch 0, loss: 1.951, val acc: 0.114 (best 0.114), test acc: 0.103 (best 0.103)
In epoch 5, loss: 1.900, val acc: 0.290 (best 0.292), test acc: 0.278 (best 0.277)
In epoch 10, loss: 1.790, val acc: 0.462 (best 0.462), test acc: 0.435 (best 0.435)
In epoch 15, loss: 1.614, val acc: 0.502 (best 0.502), test acc: 0.489 (best 0.489)
In epoch 20, loss: 1.372, val acc: 0.548 (best 0.548), test acc: 0.529 (best 0.529)
In epoch 25, loss: 1.087, val acc: 0.592 (best 0.592), test acc: 0.591 (best 0.591)
In epoch 30, loss: 0.798, val acc: 0.650 (best 0.650), test acc: 0.639 (best 0.639)
In epoch 35, loss: 0.547, val acc: 0.690 (best 0.690), test acc: 0.682 (best 0.682)
In epoch 40, loss: 0.358, val acc: 0.710 (best 0.710), test acc: 0.721 (best 0.721)
In epoch 45, loss: 0.230, val acc: 0.736 (best 0.736), test acc: 0.734 (best 0.734)
In epoch 50, loss: 0.149, val acc: 0.738 (best 0.738), test acc: 0.743 (best 0.744)
In epoch 55, loss: 0.099, val acc: 0.740 (best 0.740), test acc: 0.744 (best 0.743)
In epoch 60, loss: 0.068, val acc: 0.742 (best 0.742), test acc: 0.743 (best 0.745)
In epoch 65, loss: 0.048, val acc: 0.734 (best 0.742), test acc: 0.749 (best 0.745)
In epoch 70, loss: 0.036, val acc: 0.736 (best 0.742), test acc: 0.753 (best 0.745)
In epoch 75, loss: 0.028, val acc: 0.734 (best 0.742), test acc: 0.755 (best 0.745)
In epoch 80, loss: 0.023, val acc: 0.738 (best 0.742), test acc: 0.757 (best 0.745)
In epoch 85, loss: 0.019, val acc: 0.738 (best 0.742), test acc: 0.758 (best 0.745)
In epoch 90, loss: 0.017, val acc: 0.742 (best 0.742), test acc: 0.756 (best 0.745)
In epoch 95, loss: 0.015, val acc: 0.742 (best 0.742), test acc: 0.755 (best 0.745)
In epoch 100, loss: 0.013, val acc: 0.742 (best 0.742), test acc: 0.755 (best 0.745)
In epoch 105, loss: 0.012, val acc: 0.742 (best 0.742), test acc: 0.755 (best 0.745)
In epoch 110, loss: 0.011, val acc: 0.742 (best 0.742), test acc: 0.753 (best 0.745)
In epoch 115, loss: 0.010, val acc: 0.742 (best 0.742), test acc: 0.753 (best 0.745)
In epoch 120, loss: 0.009, val acc: 0.742 (best 0.742), test acc: 0.754 (best 0.745)
In epoch 125, loss: 0.008, val acc: 0.742 (best 0.742), test acc: 0.754 (best 0.745)
In epoch 130, loss: 0.008, val acc: 0.742 (best 0.742), test acc: 0.752 (best 0.745)
In epoch 135, loss: 0.007, val acc: 0.742 (best 0.742), test acc: 0.752 (best 0.745)
In epoch 140, loss: 0.007, val acc: 0.744 (best 0.744), test acc: 0.751 (best 0.751)
In epoch 145, loss: 0.006, val acc: 0.744 (best 0.744), test acc: 0.751 (best 0.751)
In epoch 150, loss: 0.006, val acc: 0.744 (best 0.744), test acc: 0.749 (best 0.751)
In epoch 155, loss: 0.006, val acc: 0.744 (best 0.744), test acc: 0.750 (best 0.751)
In epoch 160, loss: 0.005, val acc: 0.742 (best 0.744), test acc: 0.750 (best 0.751)
In epoch 165, loss: 0.005, val acc: 0.742 (best 0.744), test acc: 0.753 (best 0.751)
In epoch 170, loss: 0.005, val acc: 0.742 (best 0.744), test acc: 0.753 (best 0.751)
In epoch 175, loss: 0.005, val acc: 0.742 (best 0.744), test acc: 0.752 (best 0.751)
In epoch 180, loss: 0.004, val acc: 0.742 (best 0.744), test acc: 0.753 (best 0.751)
In epoch 185, loss: 0.004, val acc: 0.742 (best 0.744), test acc: 0.753 (best 0.751)
In epoch 190, loss: 0.004, val acc: 0.742 (best 0.744), test acc: 0.754 (best 0.751)
In epoch 195, loss: 0.004, val acc: 0.742 (best 0.744), test acc: 0.754 (best 0.751)

  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49

三、带边权重的例子

这里需要改的两个位置,g.update_all的2个参数,以及在Model中要设置边权重的传参。其他就没啥变化了。即使用带权平均聚合邻居表示,edata成员可以保存边权重(特征),这些特征也可以参与消息传递。

# data可以包含边特征信息,同时传递
class WeightedSAGEConv(nn.Module):
    """
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    """
    def __init__(self, in_feat, out_feat):
        super(WeightedSAGEConv, self).__init__()
        # 将input和邻近节点特征映射到outpu线性子模块
        self.linear = nn.Linear(in_feat * 2, out_feat)
    
    def forward(self, g, h, w):
        """
        g : Graph
            The input graph.
        h : Tensor
            The input node feature.
        w : Tensor
            The edge weight.
        """
        with g.local_scope():
            g.ndata['h'] = h
            # 加入边的权重,进行消息传递和更新
            g.edata['w'] = w
            g.update_all(message_func=fn.u_mul_e('h', 'w', 'm'), reduce_func=fn.mean('m', 'h_N'))
            h_N = g.ndata['h_N']
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)


# 因为这个数据集中的图没有边的权值,
# 所以我们在模型的 forward 函数中手动将所有边的权值赋给1。
class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = WeightedSAGEConv(in_feats, h_feats)
        self.conv2 = WeightedSAGEConv(h_feats, num_classes)
    
    def forward(self, g, in_feat):
        # 3个参数,(g, h, w)即图,点特征,边权重
        h = self.conv1(g, in_feat, torch.ones(g.num_edges(), 1).to(g.device))
        h = F.relu(h)
        # 设置所有边的权重为1
        h = self.conv2(g, h, torch.ones(g.num_edges(), 1).to(g.device))
        return h

  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47

四、DGL按照优先级高低排序推荐做法

  • 直接调用dgl.nn模块;
  • 使用dgl.nn.functional内置方法,适合一些简单操作,如为每个节点计算softmax;
  • 使用update_all,内置的消息函数和聚合函数;
  • 使用用户自定义的消息(message)函数和聚合(reduce)函数。

五、用户自定义函数

DGL允许用户自定义消息函数和聚合函数以获得最大的表达能力。以下是一个用户定义的消息函数,它等价于fn.u_mul_e('h', 'w', 'm')

def u_mul_e_udf(edges):
    return {"m": edges.src["h"] * edges.data["w"]}

  
 
  • 1
  • 2

参数edges共有三个成员:src,data和dst,分别代表所有边的源节点特征,边特征和目标节点特征。

也可以编写自己的聚合函数。例如,下面的函数相当于内置的fn.sum(‘m’, ‘h’)函数,它对传入的消息求和:

def sum_udf(nodes):
    return {"h": nodes.mailbox["m"].sum(dim=1)} 
    # dim=1,按行求和

  
 
  • 1
  • 2
  • 3

总之,DGL将按节点的度数对节点进行分组,对于每个组DGL将传入的消息沿着第2维度(按行)进行堆叠,然后沿第2个维度执行缩减(reduce)以聚合消息。

Reference

(1)NYU、AWS联合推出:全新图神经网络框架DGL正式发布
(2)https://www.dgl.ai/
(3)Write your own GNN module

文章来源: andyguo.blog.csdn.net,作者:山顶夕景,版权归原作者所有,如需转载,请联系作者。

原文链接:andyguo.blog.csdn.net/article/details/122351629

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。