DGL & RDKit | 基于Attentive FP可视化训练模型原子权重

举报
DrugAI 发表于 2021/07/15 03:28:24 2021/07/15
【摘要】 DGL具有许多用于化学信息学、药物与生物信息学任务的函数。 DGL开发人员提供了用于可视化训练模型原子权重的代码。使用Attentive FP构建模型后,可以可视化给定分子的原子权重,意味着每个原子对目标值的贡献量。 基于Attentive FP可视化训练模型原子权重 环境准备 PyTorch:深度学习框架DGL:基于PyTorch的库,支持深度学习以处理图形RDK...

DGL具有许多用于化学信息学、药物与生物信息学任务的函数。

DGL开发人员提供了用于可视化训练模型原子权重的代码。使用Attentive FP构建模型后,可以可视化给定分子的原子权重,意味着每个原子对目标值的贡献量。


基于Attentive FP可视化训练模型原子权重

环境准备

  • PyTorch:深度学习框架
  • DGL:基于PyTorch的库,支持深度学习以处理图形
  • RDKit:用于构建分子图并从字符串表示形式绘制结构式
  • MDTraj:用于分子动力学轨迹分析的开源库

导入库


  
  1. %matplotlib inline
  2. import matplotlib.pyplot as plt
  3. import os
  4. from rdkit import Chem
  5. from rdkit import RDPaths
  6. import dgl
  7. import numpy as np
  8. import random
  9. import torch
  10. import torch.nn as nn
  11. import torch.nn.functional as F
  12. from torch.utils.data import DataLoader
  13. from torch.utils.data import Dataset
  14. from dgl import model_zoo
  15. from dgl.data.chem.utils import mol_to_complete_graph, mol_to_bigraph
  16. from dgl.data.chem.utils import atom_type_one_hot
  17. from dgl.data.chem.utils import atom_degree_one_hot
  18. from dgl.data.chem.utils import atom_formal_charge
  19. from dgl.data.chem.utils import atom_num_radical_electrons
  20. from dgl.data.chem.utils import atom_hybridization_one_hot
  21. from dgl.data.chem.utils import atom_total_num_H_one_hot
  22. from dgl.data.chem.utils import one_hot_encoding
  23. from dgl.data.chem import CanonicalAtomFeaturizer
  24. from dgl.data.chem import CanonicalBondFeaturizer
  25. from dgl.data.chem import ConcatFeaturizer
  26. from dgl.data.chem import BaseAtomFeaturizer
  27. from dgl.data.chem import BaseBondFeaturizer
  28. from dgl.data.chem import one_hot_encoding
  29. from dgl.data.utils import split_dataset
  30. from functools import partial
  31. from sklearn.metrics import roc_auc_score

代码来源于dgl/example

DGL开发人员提供了用于可视化训练模型原子权重的代码。

使用Attentive FP构建模型后,可以可视化给定分子的原子权重,意味着每个原子对目标值的贡献量。

 


  
  1. def chirality(atom):
  2. try:
  3. return one_hot_encoding(atom.GetProp('_CIPCode'), ['R', 'S']) + \
  4. [atom.HasProp('_ChiralityPossible')]
  5. except:
  6. return [False, False] + [atom.HasProp('_ChiralityPossible')]
  7. def collate_molgraphs(data):
  8. """Batching a list of datapoints for dataloader.
  9. Parameters
  10. ----------
  11. data : list of 3-tuples or 4-tuples.
  12. Each tuple is for a single datapoint, consisting of
  13. a SMILES, a DGLGraph, all-task labels and optionally
  14. a binary mask indicating the existence of labels.
  15. Returns
  16. -------
  17. smiles : list
  18. List of smiles
  19. bg : BatchedDGLGraph
  20. Batched DGLGraphs
  21. labels : Tensor of dtype float32 and shape (B, T)
  22. Batched datapoint labels. B is len(data) and
  23. T is the number of total tasks.
  24. masks : Tensor of dtype float32 and shape (B, T)
  25. Batched datapoint binary mask, indicating the
  26. existence of labels. If binary masks are not
  27. provided, return a tensor with ones.
  28. """
  29. assert len(data[0]) in [3, 4], \
  30. 'Expect the tuple to be of length 3 or 4, got {:d}'.format(len(data[0]))
  31. if len(data[0]) == 3:
  32. smiles, graphs, labels = map(list, zip(*data))
  33. masks = None
  34. else:
  35. smiles, graphs, labels, masks = map(list, zip(*data))
  36. bg = dgl.batch(graphs)
  37. bg.set_n_initializer(dgl.init.zero_initializer)
  38. bg.set_e_initializer(dgl.init.zero_initializer)
  39. labels = torch.stack(labels, dim=0)
  40. if masks is None:
  41. masks = torch.ones(labels.shape)
  42. else:
  43. masks = torch.stack(masks, dim=0)
  44. return smiles, bg, labels, masks
  45. atom_featurizer = BaseAtomFeaturizer(
  46. {'hv': ConcatFeaturizer([
  47. partial(atom_type_one_hot, allowable_set=[
  48. 'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br', 'Te', 'I', 'At'],
  49. encode_unknown=True),
  50. partial(atom_degree_one_hot, allowable_set=list(range(6))),
  51. atom_formal_charge, atom_num_radical_electrons,
  52. partial(atom_hybridization_one_hot, encode_unknown=True),
  53. lambda atom: [0], # A placeholder for aromatic information,
  54. atom_total_num_H_one_hot, chirality
  55. ],
  56. )})
  57. bond_featurizer = BaseBondFeaturizer({
  58. 'he': lambda bond: [0 for _ in range(10)]
  59. })
  60. train_mols = Chem.SDMolSupplier('solubility.train.sdf')
  61. train_smi =[Chem.MolToSmiles(m) for m in train_mols]
  62. train_sol = torch.tensor([float(mol.GetProp('SOL')) for mol in train_mols]).reshape(-1,1)
  63. test_mols = Chem.SDMolSupplier('solubility.test.sdf')
  64. test_smi = [Chem.MolToSmiles(m) for m in test_mols]
  65. test_sol = torch.tensor([float(mol.GetProp('SOL')) for mol in test_mols]).reshape(-1,1)
  66. train_graph =[mol_to_bigraph(mol,
  67. node_featurizer=atom_featurizer,
  68. edge_featurizer=bond_featurizer) for mol in train_mols]
  69. test_graph =[mol_to_bigraph(mol,
  70. node_featurizer=atom_featurizer,
  71. edge_featurizer=bond_featurizer) for mol in test_mols]
  72. def run_a_train_epoch(n_epochs, epoch, model, data_loader,loss_criterion, optimizer):
  73. model.train()
  74. total_loss = 0
  75. losses = []
  76. for batch_id, batch_data in enumerate(data_loader):
  77. batch_data
  78. smiles, bg, labels, masks = batch_data
  79. if torch.cuda.is_available():
  80. bg.to(torch.device('cuda:0'))
  81. labels = labels.to('cuda:0')
  82. masks = masks.to('cuda:0')
  83. prediction = model(bg, bg.ndata['hv'], bg.edata['he'])
  84. loss = (loss_criterion(prediction, labels)*(masks != 0).float()).mean()
  85. #loss = loss_criterion(prediction, labels)
  86. #print(loss.shape)
  87. optimizer.zero_grad()
  88. loss.backward()
  89. optimizer.step()
  90. losses.append(loss.data.item())
  91. #total_score = np.mean(train_meter.compute_metric('rmse'))
  92. total_score = np.mean(losses)
  93. print('epoch {:d}/{:d}, training {:.4f}'.format( epoch + 1, n_epochs, total_score))
  94. return total_score
  95. model = model_zoo.chem.AttentiveFP(node_feat_size=39,
  96. edge_feat_size=10,
  97. num_layers=2,
  98. num_timesteps=2,
  99. graph_feat_size=200,
  100. output_size=1,
  101. dropout=0.2)
  102. train_loader = DataLoader(dataset=list(zip(train_smi, train_graph, train_sol)), batch_size=128, collate_fn=collate_molgraphs)
  103. test_loader = DataLoader(dataset=list(zip(test_smi, test_graph, test_sol)), batch_size=128, collate_fn=collate_molgraphs)
  104. loss_fn = nn.MSELoss(reduction='none')
  105. optimizer = torch.optim.Adam(model.parameters(), lr=10 ** (-2.5), weight_decay=10 ** (-5.0),)
  106. n_epochs = 100
  107. epochs = []
  108. scores = []
  109. for e in range(n_epochs):
  110. score = run_a_train_epoch(n_epochs, e, model, train_loader, loss_fn, optimizer)
  111. epochs.append(e)
  112. scores.append(score)
  113. model.eval()

导入用于分子可视化依赖库


  
  1. import copy
  2. from rdkit.Chem import rdDepictor
  3. from rdkit.Chem.Draw import rdMolDraw2D
  4. from IPython.display import SVG
  5. from IPython.display import display
  6. import matplotlib
  7. import matplotlib.cm as cm

定义可视化函数

  • 代码来源于DGL库。
  • DGL模型具有get_node_weight选项,该选项返回图形的node_weight。该模型具有两层GRU,因此以下代码我将0用作时间步长,因此时间步长必须为0或1。

  
  1. def drawmol(idx, dataset, timestep):
  2. smiles, graph, _ = dataset[idx]
  3. print(smiles)
  4. bg = dgl.batch([graph])
  5. atom_feats, bond_feats = bg.ndata['hv'], bg.edata['he']
  6. if torch.cuda.is_available():
  7. print('use cuda')
  8. bg.to(torch.device('cuda:0'))
  9. atom_feats = atom_feats.to('cuda:0')
  10. bond_feats = bond_feats.to('cuda:0')
  11. _, atom_weights = model(bg, atom_feats, bond_feats, get_node_weight=True)
  12. assert timestep < len(atom_weights), 'Unexpected id for the readout round'
  13. atom_weights = atom_weights[timestep]
  14. min_value = torch.min(atom_weights)
  15. max_value = torch.max(atom_weights)
  16. atom_weights = (atom_weights - min_value) / (max_value - min_value)
  17. norm = matplotlib.colors.Normalize(vmin=0, vmax=1.28)
  18. cmap = cm.get_cmap('bwr')
  19. plt_colors = cm.ScalarMappable(norm=norm, cmap=cmap)
  20. atom_colors = {i: plt_colors.to_rgba(atom_weights[i].data.item()) for i in range(bg.number_of_nodes())}
  21. mol = Chem.MolFromSmiles(smiles)
  22. rdDepictor.Compute2DCoords(mol)
  23. drawer = rdMolDraw2D.MolDraw2DSVG(280, 280)
  24. drawer.SetFontSize(1)
  25. op = drawer.drawOptions()
  26. mol = rdMolDraw2D.PrepareMolForDrawing(mol)
  27. drawer.DrawMolecule(mol, highlightAtoms=range(bg.number_of_nodes()),
  28. highlightBonds=[],
  29. highlightAtomColors=atom_colors)
  30. drawer.FinishDrawing()
  31. svg = drawer.GetDrawingText()
  32. svg = svg.replace('svg:', '')
  33. if torch.cuda.is_available():
  34. atom_weights = atom_weights.to('cpu')
  35. return (Chem.MolFromSmiles(smiles), atom_weights.data.numpy(), svg)

绘制测试数据集分子

该模型预测溶解度,颜色表示红色是溶解度的积极影响,蓝色是负面影响。


  
  1. target = test_loader.dataset
  2. for i in range(len(target)):
  3. mol, aw, svg = drawmol(i, target, 0)
  4. display(SVG(svg))

。。。。。 


参考资料

1. https://github.com/dmlc/dgl/tree/master/apps/life_sci

2. https://github.com/dmlc/dgl/blob/master/python/dgl/model_zoo/chem/attentive_fp.py

3. https://pubs.acs.org/doi/full/10.1021/acs.jcim.9b00387

 

文章来源: drugai.blog.csdn.net,作者:DrugAI,版权归原作者所有,如需转载,请联系作者。

原文链接:drugai.blog.csdn.net/article/details/104868996

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。