Geneformer:基于Transformer的基因表达预测深度学习模型

举报
AI4S_NPU 发表于 2025/06/14 22:50:30 2025/06/14
【摘要】 摘要Geneformer被广泛应用于疾病建模、治疗靶点发掘、基因网络预测与调控分析、基因功能预测与剂量敏感性分析、单细胞转录组数据集成与标准化、遗传变异解释与GWAS靶点优先排序。该案例既有算法原理,也有手把手的昇腾部署教学,包含细胞分类、基因分类、提取细胞嵌入图、细胞多分类的微调任务1    Geneformer介绍GeneFormer是一种基于 Transformer 架构的深度学习模型...

摘要

Geneformer被广泛应用于疾病建模、治疗靶点发掘、基因网络预测与调控分析、基因功能预测与剂量敏感性分析、单细胞转录组数据集成与标准化、遗传变异解释与GWAS靶点优先排序。该案例既有算法原理,也有手把手的昇腾部署教学,包含细胞分类、基因分类、提取细胞嵌入图、细胞多分类的微调任务

1    Geneformer介绍

GeneFormer是一种基于 Transformer 架构的深度学习模型,专为基因表达数据分析而设计。它将基因视为“词汇”,将整个基因组的表达谱视为“句子”,通过自监督学习捕捉基因间的复杂调控关系和生物学背景,在医学研究中展现出强大的应用潜力。借助GeneFormer,研究人员能够更有效地处理和理解大量的基因组数据,从而加速新药开发、疾病治疗等领域的研究进展。在基因序列分析、蛋白质结构预测疾病机制解析和药物发现等领域也具有突出的应用价值。

image001.png

1:自监督大规模预训练迁移学习策略示意图

初始自监督大规模预训练,将预训练权重复制到每个微调任务的模型中,添加微调层,并使用有限的特定任务数据对每个下游任务进行微调。通过对可泛化的学习目标进行单次初始自监督大规模预训练,模型获得学习领域的基础知识,然后将其推广到众多不同于预训练学习目标的下游应用中,将知识迁移到新任务。

2    网络结构

image002.png

2Geneformer预训练架构图

预训练的 Geneformer 架构。每个单细胞转录组被编码为秩值编码,然后经过六层 Transformer 编码器单元,参数如下:输入大小为 2,048(完全代表 Geneformer-30M 93% 的秩值编码),嵌入维度为 256,每层四个注意力头,前馈大小为 512Geneformer 2,048 的输入大小上使用完全密集的自注意力机制。可提取的输出包括上下文基因和细胞嵌入、上下文注意力权重以及上下文预测。

2.1  输入层

输入层针对基因表达数据的特性在数据预处理、嵌入表示(Embedding)和位置编码(Positional Encoding)等进行了专门优化。

数据预处理:

  1. 基因嵌入:对基因表达值进行归一化处理,消除不同基因表达水平之间的差异,并对缺失值进行合理填充或插值处理,以确保数据的完整性。
  2. 输入数据:通常包括基因表达矩阵(如单细胞RNA测序数据)和基因序列(如DNA序列)。基因表达矩阵是一个二维矩阵,其中行代表样本,列代表基因,每个元素代表对应基因在该样本中的表达值。基因序列则是由碱基ATCG组成的字符串序列。

嵌入层:将基因表达值或基因序列映射到高维向量空间,以捕捉基因间的复杂关系,便于后续模型处理序列结构。维度设置需要根据具体任务和计算资源进行权衡,过低的维度可能导致信息丢失,而过高会增加计算复杂度。此外,嵌入层通常通过反向传播进行训练,使模型能够自动学习最优的基因嵌入表示,从而更好地适应任务需求。

位置编码:用于提供基因序列中各碱基的位置信息,帮助模型理解基因序列中碱基的顺序关系和位置依赖性,对于分析基因序列的功能和结构至关重要。

2.2   Transformer

GeneFormer的核心由多个 Transformer 层堆叠而来。通过多头自注意力、残差连接和前馈神经网络,从高维基因表达数据中提取复杂的调控模式。在保持标准的Transformer 结构的同时,针对基因表达数据的特性(高维度、稀疏性、基因共表达模式)进行了优化,使模型能够有效捕捉基因间的功能关联,为下游任务(微调)提供强有力的表征。

多头注意力:并行使用多个注意力头,每个头学习不同的交互模式,同时计算多组注意力权重,捕捉基因间的全局依赖关系(如协同表达的基因网络)。通过计算查询(Query)、键(Key)和值(Value)之间的点积来确定权重,并通过 Softmax 函数进行归一化,且总和为1

image003.png

image004.png

将输入拆分为h个头,每个头单独计算后拼接。

前馈神经网络:由两层全连接层和激活函数组成,每个多头注意力层后接一个前馈神经网络层,对注意力层的输出进行非线性映射增强非线性表达能力,用于学习并保存知识。

image005.png

稀疏注意力:基因表达数据中,大部分基因表达值为0,可能采用局部稀疏注意力以降低计算开销。

image006.png

相对位置编码:由于基因在序列中的物理位置可能无关紧要,Geneformer 采用相对位置编码,仅编码基因间的相对顺序或距离,增强对基因序列位置的敏感性。

image007.png

i,j为基因在序列中的位置,k为最大相对距离。

层归一化与残差连接:层归一化稳定单细胞数据的高变异表达分布,残差连接保留原始基因表达信息,缓解梯度消失,加速收敛。

image010.png

μσ为样本内均值和方差,γβ为可学习的缩放和平移参数

image011.png

2.3   输出层

经过transformer层之后,张量被传入输出层,但Geneformer输出层的设计根据具体任务(如基因表达预测、分类或自监督预训练)有所不同,主要操作通常包括以下几个关键步骤:

线性变换:使用全连接层,将Transformer最后一层输出的隐藏状态映射到目标维度(如基因数量或类别数)。

image012.png

激活函数:根据任务需求不同调整使用的激活函数,回归任务可能使用ReLUSoftplus确保输出非负,分类任务使用Softmax(多分类)或Sigmoid(二分类)输出概率分布,对于线性输出,则没有激活函数。

image016.png

损失计算:对于回归任务,使用均方误差(MSE)或负对数似然。分类任务,交叉熵损失。自监督任务(掩码基因预测),使用对比损失或遮蔽语言建模(MLM)类似的损失。

image017.png

细胞分类任务中的损失计算,交叉熵损失

image019.png

基因扰动预测,对比损失

image020.png

3    微调介绍

GeneFormer 先在大规模单细胞数据上预训练,结合特定任务的需求和数据特点,灵活选择冻结策略、调整输出头、引入适配器或领域特定模块。通过平衡预训练知识的保留与任务适配,高效实现模型优化。

网络结构的微调操作:

  1. 根据具体的下游任务,确定输入输出格式。即指定数据集。在输入层将数据预处理为与 GeneFormer 兼容的格式,加载预训练的 GeneFormer 权重。
  2. 选择冻结一定数目的transformer层,但不会全部冻结,会保留几层用于保留预训练模型的底层知识(如基因共现模式、 基础序列特征),防止小数据过拟合。
  3. 在预训练模型的基础上额外增加一个transformer层,用于学习新的知识。并在每一层插入小型适配器模块,保持预训练权重冻结,仅训练适配器参数,用于减少参数更新量,适用于小样本微调。
  4. 在输出层,也会根据具体的下游任务进行调整,仅训练最后一层transformer层及输出头。对于分类任务:替换最后的全局平均池化层 + 全连接层。回归任务:调整输出层为线性回归头。生成任务:添加解码器。

4    实验准备

4.1  设备&组件

机器:

Atlas 800T A2

组件:

hdk:24.1.rc3

image022.png

cann:8.0.RC3

image023.png

python:3.10.16

image024.png

torch:2.1.0

torch:2.1.0.post8

image025.png

4.2  安装LFS

git lfs install

4.3  下载源码

未标题-2.png

4.4  下载数据集

未标题-3.png

image026.png

4.5  安装环境

requirements.txt里面torch的版本>=2.0.1即可,这里选用2.1.0版本的torch

cd Geneformer

vi requirements.txt,将torch>=2.0.1修改为torch==2.1.0。再:wq保存退出。

pip install .

4.6  安装torch-npu

4.6.1  下载

未标题-1.png

4.6.2  安装

pip3 install torch_npu-2.1.0.post8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl

4.6.3  验证npu是否可用

numpy报错,需降低至1.x版本

image027.png

image028.png

更换完numpy版本之后,再次验证

image029.png

5    微调

5.1  微调1:细胞分类

5.1.1  数据集&权重任务

任务:cell_classification

数据集:human_dcm_hcm_nf.dataset

预训练权重:gf-6L-30M-i2048

5.1.2  新建微调脚本

cd /Geneformer/examples

mkdir cell_classification.py

vi cell_classification.py

cell_classification.ipynb的代码复制过来。需注意修改权重路径和数据集路径。导入os包,将第10行的!mkdir $output_dir修改为os.makedirs(output_dir, exist_ok=True)

image030.png

5.1.3  修改评估模型脚本

vi /root/miniconda3/envs/Genecorpus_py310/lib/python3.10/site-packages/geneformer/

evaluation_utils.py

导入torch_npu包,替换相关cudaapi

image031.png

5.1.4  微调前source cann

source /usr/local/Ascend/ascend-toolkit/set_env.sh

5.1.5  开始微调

image032.png

再开一个窗口,命令行输入 npu-smi info查看显存占用率

image033.png

image034.png

5.1.6  评估模型时报错

image035.png

vi /root/miniconda3/envs/Genecorpus_py310/lib/python3.10/site-packages/geneformer/evaluation_utils.py

在第86classifier_predict函数内添加

device = torch.device('npu' if torch_npu.npu.is_available() else 'cpu')

并将119120121三行中的.to(cuda)修改为.to(device)

image036.png

再重新运行

image037.png

输出精度0.9542330129066371

image038.png

输出文件

image039.png

混淆矩阵

image040.png

评估微调模型的预测结果

image041.png

5.2  微调2:基因分类

5.2.1  数据集&权重文件

任务:gene_classification

数据集:gc-30M_sample50k.dataset

预训练权重:gf-6L-30M-i2048

5.2.2  新建微调脚本

cd /Geneformer/examples

touch gene_classification.py

vi gene_classification.py

gene_classification.ipynb的代码复制过来。需注意修改权重路径和数据集路径

image042.png

5.2.3  开始微调

image043.png

image044.png

5.2.4  输出文件

image045.png

5.3  微调3:绘制细胞嵌入图

5.3.1  数据集&权重文件

任务:extract_and_plot_cell_embeddings

数据集:human_dcm_hcm_nf.dataset

预训练权重:gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224

5.3.2  新建微调脚本

cd /Geneformer/examples

touch extract_and_plot_cell_embeddings.py

vi extract_and_plot_cell_embeddings.py

extract_and_plot_cell_embeddings.ipynb的代码复制过来。需注意修改权重路径和数据集路径

image046.png

5.3.3  开始微调

image047.png

5.3.4   输出文件

image048.png

5.3.5  细胞嵌入UMAP

image049.png


5.3.6  细胞嵌入heapmap

image050.png


5.4  微调4:多任务细胞分类

5.4.1  数据集&权重文件

任务:multitask_cell_classification

数据集:human_dcm_hcm_nf.dataset

预训练权重:gf-6L-30M-i2048

5.4.2   新建微调脚本

cd /Geneformer/examples

touch multitask_cell_classification.py

vi multitask_cell_classification.py

multitask_cell_classification.ipynb的代码复制过来。需注意修改权重路径、数据集路径以及token_dictionary路径。

image051.png

5.4.3  微调过程

image052.png

5.4.4  输出

image053.png

6    参考文献

Theodoris, C. V., Xiao, L., Chopra, A., Chaffin, M. D., Al Sayed, Z. R., Hill, M. C., ... & Ellinor, P. T. (2023). Transfer learning enables predictions in network biology. Nature, 618(7965), 616-624. https://doi.org/10.1038/s41586-023-06139-9

【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。