边缘大型AI模型:协作部署与物联网应用——论文解读

举报
DuHz 发表于 2025/10/02 22:02:01 2025/10/02
【摘要】 边缘大型AI模型:协作部署与物联网应用的深度解读Wang Z, Shi Y, Letaief K B. Edge Large AI Models: Collaborative Deployment and IoT Applications[J]. IEEE Internet of Things Magazine, 2025. 引言:边缘大模型的时代背景与核心矛盾大型人工智能模型展现出了跨越...

边缘大型AI模型:协作部署与物联网应用的深度解读

Wang Z, Shi Y, Letaief K B. Edge Large AI Models: Collaborative Deployment and IoT Applications[J]. IEEE Internet of Things Magazine, 2025.

引言:边缘大模型的时代背景与核心矛盾

大型人工智能模型展现出了跨越不同领域、模态和任务的类人问题解决能力。这种能力源于三个显著特征:海量的参数规模、大规模高质量数据集以及巨大的计算资源。结合先进的训练范式,这些特征使得LAM能够实现卓越的知识迁移,朝着通用人工智能(AGI)的方向发展。值得注意的是,对预训练LAM(例如DeepSeek、Gemini、ChatGPT)进行微调,可以用显著减少的数据和计算需求高效地适应各种下游任务,与从头训练相比大幅加速了杀手级应用的诞生。例如,TeleGPT利用电信特定数据集实现了下一代智能电信服务。

将LAM集成到具有多样化数据源的物联网(IoT)网络中,可以通过类人交互实现定制化任务执行,从而显著增强系统功能。例如,三星的SmartThings平台通过远程Galaxy AI增强,处理智能家居中的多模态传感器数据,支持实时语音和手势识别,为超过6200万用户提供个性化自动化服务。尽管取得了这些进展,但通过云服务器利用LAM的传统部署范式存在高延迟和隐私风险,限制了其满足物联网应用严格服务质量(QoS)要求的能力。

为了解决这些局限性,论文提出将LAM部署在网络边缘(称为边缘LAM),以降低延迟同时保护数据隐私,从而实现实时、可靠和个性化的智能物联网服务。与依赖单模态数据特定功能的黑盒映射的传统边缘AI方法不同,边缘LAM集成多模态输入并执行上下文建模,以实现个性化生成和多步推理。例如,在工业物联网应用中,边缘LAM分析实时多模态数据(包括设备振动、温度和电流波形),并综合历史维护记录和外部因素(如电网波动)来预测设备故障并生成动态维护策略。

然而,边缘LAM带来了与传统边缘AI不同的独特挑战,这些挑战源于其巨大的资源需求与边缘设备受限能力之间的不匹配。具体而言,现有的边缘AI框架采用联邦学习(Federated Learning, FL)作为分布式训练范式,需要统一的数据模态和目标来进行无线模型调整。然而,由于边缘LAM的海量参数规模、异构数据源和多样化的下游任务需求,这种策略可能不切实际。此外,虽然传统边缘AI推理仅限于单步特征到预测的映射,但边缘LAM通过其架构内模块化组件的协调交互支持上下文泛化和多步推理。应对这些挑战需要面向任务的设计以及对边缘LAM训练和推理中稀缺计算和通信资源的优化配置。

图1:框架、挑战、方法和优势总结

fig1.png

图1以表格形式系统地总结了论文提出的完整框架。该表格分为训练(Training)和推理(Inference)两大部分,每部分又细分为不同的框架类型。在训练部分,包括联邦微调(Federated Fine-tuning)和联邦遗忘(Federated Unlearning)两个框架。联邦微调面临的挑战包括边缘设备间异构的计算和数据模态、隐私保护的全参数微调中的高通信开销;采用的方法包括用于参数高效调优的LoRA、针对计算异构性的自适应截断和零填充、针对数据异构性的无线感知知识蒸馏;其优势是相比全模型调优显著降低通信成本、支持跨设备/模态适应、通过联邦训练保护隐私。

联邦遗忘框架的挑战在于在不完全重新训练的情况下移除特定客户端数据的影响、缓解无线衰落和动态边缘环境下的训练不稳定性;采用的方法包括正交梯度投影、衰落信道上的自适应噪声注入、涉及隐私、效用和传输效率的联合优化;其优势是确保退出客户端的隐私、对无线信道变化具有鲁棒性。

在推理部分,包括基于专家混合(MoE)的微服务和基于思维链(CoT)的微服务两个框架。基于MoE的微服务面临的挑战是MoE架构LAM中的冗余计算、并行专家调度中的低资源利用率和高延迟;采用的方法包括将MoE层虚拟化为微服务、用于延迟感知编排的Lyapunov优化、在线设备调度;其优势是减少冗余计算、通过自适应专家分配最小化推理延迟、提高边缘资源利用率。

基于CoT的微服务面临的挑战是多步CoT推理中的高延迟、动态任务需求与静态资源分配之间的不匹配;采用的方法包括将CoT分解为模块化微服务、用于部署优化的图扩散模型、用于多模态融合的信息瓶颈;其优势是通过分布式推理减少资源瓶颈、适应动态边缘环境、通过结构化推理流程提高准确性。

异构边缘网络上的协作训练:解构与重组的智慧

预训练的LAM可以通过在任务特定数据上进行微调来适应各种下游任务。然而,这个过程会产生巨大的计算和数据开销,对资源受限的边缘网络环境构成重大挑战。传统的基于云的中心化训练框架从边缘设备聚合大量数据,加剧了隐私问题,并由于频繁的数据传输导致高通信成本。这些局限性凸显了能够同时解决隐私和效率问题的替代方法的必要性。

边缘训练通过将数据保留在原位同时利用其分布式计算资源而成为一个有前景的解决方案。这种方法不仅增强了隐私保护,还降低了通信开销并最小化了数据传输延迟,从而实现更快的模型适应。尽管有这些优势,设备上训练范式的实际部署受到几个关键挑战的阻碍。具体而言,有限的通信资源、设备间异构的计算能力、多样化的数据模态以及各异的训练目标共同使得这类框架在现实场景中的直接实施变得复杂。为了应对这些挑战并充分利用LAM的表征能力,论文提出了一个为无线边缘网络量身定制的协作训练框架。该框架明确考虑了边缘设备间计算资源、数据模态和训练目标的异构性,确保高效有效的设备上训练。

异构联邦微调:低秩分解的精妙设计

为了满足跨空间分布数据集对预训练LAM进行专业任务微调的日益增长的需求,论文倡导采用联邦微调(FedFT),这是一种创新方法,能够在多个边缘设备上协作微调全局模型或其组件,而无需上传数据。然而,边缘设备有限的计算能力对FedFT的性能构成重大挑战。为了解决这些问题,论文将参数高效微调(PEFT)方法与联邦学习系统集成,使得能够微调最小集合的网络参数,以实现与全模型调优相当的性能,同时最小化通信开销。

具体而言,论文利用低秩适配(LoRA)将可训练参数分解为低秩矩阵,冻结原始参数以保持模型完整性。这种方法允许高效地并行处理接收到的嵌入,而不会在LAM推理期间产生额外的延迟。从数学角度理解LoRA的核心机制,假设预训练模型的权重矩阵为 W0Rd×kW_0 \in \mathbb{R}^{d \times k},在微调过程中,完整的权重更新可以表示为:

W=W0+ΔWW = W_0 + \Delta W

LoRA的关键创新在于假设 ΔW\Delta W 具有低秩结构,即存在秩 rmin(d,k)r \ll \min(d, k) 使得:

ΔW=BA\Delta W = B \cdot A

其中 BRd×rB \in \mathbb{R}^{d \times r}ARr×kA \in \mathbb{R}^{r \times k}。这样,原本需要更新 d×kd \times k 个参数的任务被简化为更新 (dr)+(rk)=r(d+k)(d \cdot r) + (r \cdot k) = r(d+k) 个参数。当 rr 很小时,参数量可以实现数百倍的压缩。例如,对于 d=4096d=4096k=4096k=4096r=8r=8 的情况,参数量从 4096×4096=16,777,2164096 \times 4096 = 16,777,216 减少到 8×(4096+4096)=65,5368 \times (4096+4096) = 65,536,压缩比达到约256倍。

尽管有这些进展,边缘设备固有的异构性(包括计算能力和数据模态)仍然对有效的FedFT实施构成障碍。为了解决这个问题,论文提出了一个为异构边缘网络量身定制的FedFT框架。该框架优化LoRA结构和受限的无线资源以最大化性能。

图2:异构无线网络上的联邦微调框架

fig2.png

图2详细展示了整个联邦微调的工作流程。图的左下方显示了三个具有异构计算能力的边缘设备,每个设备都部署了不同秩的本地LoRA模块(通过不同大小的矩阵图标表示)。这些设备通过上行链路(Uplink)将其截断的LoRA模块发送到边缘服务器。边缘服务器执行三个关键步骤:首先进行零填充(Zero-Padding)来标准化不同秩的矩阵;然后执行联邦平均(FedAvg)来聚合这些模块;最后执行截断(Truncation)操作。

图的右上方展示了预训练LAM和微调后LAM的对比。在架构中,灰色部分表示冻结参数(Frozen),橙色部分表示可调参数(Tunable)。通过利用稀疏性进行近似(Approximation by exploiting sparsity),可以得到更新后的权重。最后,边缘服务器通过下行链路(Downlink)将更新后的全局LoRA模块传输回边缘设备。

每个边缘设备独立地微调一组具有不同秩的低秩矩阵,而边缘服务器使用单播通信聚合这些多样化的低秩矩阵。此外,为了管理聚合矩阵的可变性,论文引入了一个投影框架,其中边缘服务器应用零填充来标准化传入的矩阵,并利用截断将它们重新分发回边缘设备。

从数学上形式化这个过程,假设有 KK 个边缘设备,第 kk 个设备选择的秩为 rkr_k,其本地LoRA矩阵为 BkRd×rkB_k \in \mathbb{R}^{d \times r_k}AkRrk×dA_k \in \mathbb{R}^{r_k \times d'}。设全局目标秩为 rmax=maxkrkr_{\text{max}} = \max_k r_k,则零填充操作可以表示为:

B~k=[Bk,0d×(rmaxrk)]Rd×rmax\tilde{B}_k = [B_k, \mathbf{0}_{d \times (r_{\text{max}} - r_k)}] \in \mathbb{R}^{d \times r_{\text{max}}}

A~k=[Ak0(rmaxrk)×d]Rrmax×d\tilde{A}_k = \begin{bmatrix} A_k \\ \mathbf{0}_{(r_{\text{max}} - r_k) \times d'} \end{bmatrix} \in \mathbb{R}^{r_{\text{max}} \times d'}

聚合后的全局矩阵通过加权平均获得:

Bˉ=k=1KwkB~k,Aˉ=k=1KwkA~k\bar{B} = \sum_{k=1}^{K} w_k \tilde{B}_k, \quad \bar{A} = \sum_{k=1}^{K} w_k \tilde{A}_k

其中 wkw_k 是第 kk 个设备的权重,通常与其数据量成正比,满足 k=1Kwk=1\sum_{k=1}^{K} w_k = 1

截断操作则是将全局矩阵 Bˉ\bar{B}Aˉ\bar{A} 裁剪到每个设备所需的秩:

Bknew=Bˉ[:,1:rk],Aknew=Aˉ[1:rk,:]B_k^{\text{new}} = \bar{B}[:, 1:r_k], \quad A_k^{\text{new}} = \bar{A}[1:r_k, :]

这种方法与传统的联邦学习框架有显著不同,因为它在下行链路通信中使用单播而非广播,并且聚合低秩矩阵而非全局模型梯度。为了进一步使本地LoRA模块的部署与边缘设备的计算和通信能力相一致,论文提出了联合设备选择和带宽分配优化策略。该策略在保持统一矩阵信息完整性的同时最小化计算和通信延迟,最终提高整体学习性能。

知识蒸馏应对数据异构:跨模态知识转移

除了计算异构性之外,边缘设备的本地数据集模态可能与其他设备不同。为了解决这个问题,论文开发了一个FedFT系统,通过在边缘设备上并行部署共享模块来适应数据异构性。边缘设备联合更新共享模块及其可训练参数。为了在捕获领域特定知识的同时确保客户端无关的持久性,论文提出了一个基于知识蒸馏(KD)的框架,将本地模型学到的知识转移到共享模块中。

知识蒸馏的核心思想是让学生模型(共享模块)学习教师模型(本地模型)的输出分布。具体地,通过最小化基于共享模块的预测与基于本地模型的预测之间的Kullback-Leibler散度,边缘服务器更新共享模块并将其广播给边缘设备。数学上,KL散度定义为:

LKD=KL(plocal(yx)pshared(yx))=yplocal(yx)logplocal(yx)pshared(yx)\mathcal{L}_{\text{KD}} = \text{KL}(p_{\text{local}}(y|x) \| p_{\text{shared}}(y|x)) = \sum_{y} p_{\text{local}}(y|x) \log \frac{p_{\text{local}}(y|x)}{p_{\text{shared}}(y|x)}

其中 xx 是输入数据,yy 是输出标签,plocal(yx)p_{\text{local}}(y|x)pshared(yx)p_{\text{shared}}(y|x) 分别表示本地模型和共享模块的预测分布。

为了使输出分布更加平滑并强调模型学到的"暗知识",通常引入温度参数 TT

pmodel(yx)=exp(zy/T)jexp(zj/T)p_{\text{model}}(y|x) = \frac{\exp(z_y/T)}{\sum_{j} \exp(z_j/T)}

其中 zyz_y 是模型对类别 yy 的logit输出。当 T>1T > 1 时,分布变得更加平滑,突出了类别之间的相对关系。

同时,边缘设备基于本地模块和最新共享模块的预测之间的KL散度更新其本地模块:

Llocal=KL(pshared(yx)plocal(yx))+λLtask\mathcal{L}_{\text{local}} = \text{KL}(p_{\text{shared}}(y|x) \| p_{\text{local}}(y|x)) + \lambda \mathcal{L}_{\text{task}}

其中 Ltask\mathcal{L}_{\text{task}} 是任务特定的损失函数(如交叉熵损失),λ\lambda 是平衡系数。

与传统的联邦蒸馏方法不同,传统方法要求边缘设备将本地预测上传到边缘服务器,而论文的方法促进了边缘服务器与边缘设备之间低秩矩阵的频繁交换。这种交换增加了训练延迟,需要高效的上行和下行链路联合设计。为此,论文旨在开发一个延迟感知的资源分配方案,以联合减少双向通信延迟。

假设上行链路和下行链路的通信时间分别为 TupT_{\text{up}}TdownT_{\text{down}},本地计算时间为 TcompT_{\text{comp}},则每轮训练的总延迟为:

Ttotal=Tup+Tcomp+TdownT_{\text{total}} = T_{\text{up}} + T_{\text{comp}} + T_{\text{down}}

给定带宽 BB 和信道容量 CC,传输大小为 SS 的数据所需时间为:

Tcomm=SC=SBlog2(1+SNR)T_{\text{comm}} = \frac{S}{C} = \frac{S}{B \log_2(1 + \text{SNR})}

其中SNR是信噪比。优化目标是在满足延迟约束的前提下,通过联合优化设备选择、带宽分配和功率控制来最大化学习性能。

联邦遗忘:正交梯度投影的优雅解法

尽管边缘LAM能够实现少样本和零样本适应,但其泛化能力呈现出一把双刃剑,模糊了模型退化和功能限制之间的界限。这使得在不降低整体性能的情况下对特定边缘设备强制执行"移除权"变得复杂,特别是在目标异构性下。传统的机器遗忘方法,如回退到历史模型状态或采用递减学习,由于巨大的资源开销、数据隐私问题以及中心化聚合的通信低效,对于LAM来说是不切实际的。

联邦遗忘(FU)将联邦学习与机器遗忘集成,以协作方式从训练模型中消除特定客户端数据的影响,而无需从头重新训练。然而,由于无线环境的不利影响可能严重破坏模型更新并降低整体系统性能,FU仍然面临重大挑战。

图3:边缘LAM的无线联邦遗忘框架

fig3.png

图3详细展示了联邦遗忘的三步流程。左下方显示了两类边缘设备:保留设备(Remaining devices)和退出设备(Opting-out device),每个设备都部署了LoRA模块。第一步是FedAvg(联邦平均),从保留设备收集梯度。第二步是构建子空间(Construct Subspace),通过对保留设备的梯度进行奇异值分解(SVD)来识别应该保留的知识子空间。第三步是对退出设备的投影(Projection for Opting-out devices),将退出设备的LoRA梯度投影到正交方向,即正交最陡下降方向(Orthogonal steepest descent direction)。

最后,边缘服务器通过下行链路向保留设备发送全局LoRA模块,向退出设备发送投影后的LoRA梯度。图的右下方还展示了设备架构,其中橙色部分表示可调参数(Tunable),灰色部分表示冻结参数(Frozen)。

为了解决前述挑战,论文提出了一个面向无线网络的边缘LAM联邦遗忘框架。该框架通过修改的梯度更新使协作模型训练的同时允许退出设备移除其数据影响,从而保持全局效用。关键创新在于正交投影机制,该机制通过在传播前将梯度投影到正交子空间,将遗忘更新与保留知识的更新解耦。

从数学角度详细阐述这个机制。假设有 KK 个边缘设备,其中 Kr\mathcal{K}_r 是保留设备的集合,Ko\mathcal{K}_o 是退出设备的集合,满足 KrKo={1,2,...,K}\mathcal{K}_r \cup \mathcal{K}_o = \{1, 2, ..., K\}KrKo=\mathcal{K}_r \cap \mathcal{K}_o = \emptyset

在第 tt 轮训练中,每个保留设备 kKrk \in \mathcal{K}_r 计算其梯度 gk(t)g_k^{(t)}。这些梯度张成一个子空间 S(t)\mathcal{S}^{(t)},可以通过奇异值分解构造:

Gr(t)=[g1(t),g2(t),...,gKr(t)]Rp×KrG_r^{(t)} = [g_1^{(t)}, g_2^{(t)}, ..., g_{|\mathcal{K}_r|}^{(t)}] \in \mathbb{R}^{p \times |\mathcal{K}_r|}

其中 pp 是模型参数的维度。对 Gr(t)G_r^{(t)} 进行SVD分解:

Gr(t)=UΣVTG_r^{(t)} = U\Sigma V^T

其中 URp×pU \in \mathbb{R}^{p \times p} 是左奇异向量矩阵,ΣRp×Kr\Sigma \in \mathbb{R}^{p \times |\mathcal{K}_r|} 是奇异值矩阵,VRKr×KrV \in \mathbb{R}^{|\mathcal{K}_r| \times |\mathcal{K}_r|} 是右奇异向量矩阵。

子空间 S(t)\mathcal{S}^{(t)}UU 的前 Kr|\mathcal{K}_r| 列张成。对于退出设备 kKok \in \mathcal{K}_o 的梯度 gk(t)g_k^{(t)},我们需要将其投影到 S(t)\mathcal{S}^{(t)} 的正交补空间 S(t)\mathcal{S}^{(t)\perp}

projS(gk(t))=UUTgk(t)\text{proj}_{\mathcal{S}}(g_k^{(t)}) = UU^T g_k^{(t)}

gk(t),proj=gk(t)projS(gk(t))=gk(t)UUTgk(t)=(IUUT)gk(t)g_k^{(t), \text{proj}} = g_k^{(t)} - \text{proj}_{\mathcal{S}}(g_k^{(t)}) = g_k^{(t)} - UU^T g_k^{(t)} = (I - UU^T)g_k^{(t)}

其中 II 是单位矩阵。投影矩阵 P=IUUTP_{\perp} = I - UU^T 将梯度投影到正交子空间。

更新规则变为:

θ(t+1)=θ(t)η(1KrkKrgk(t)+1KokKogk(t),proj)\theta^{(t+1)} = \theta^{(t)} - \eta \left(\frac{1}{|\mathcal{K}_r|}\sum_{k \in \mathcal{K}_r} g_k^{(t)} + \frac{1}{|\mathcal{K}_o|}\sum_{k \in \mathcal{K}_o} g_k^{(t), \text{proj}}\right)

其中 η\eta 是学习率,θ\theta 是模型参数。

与仅聚合退出设备梯度并广播单一全局模型的传统方法不同,这种方法利用所有参与设备的信息,并通过单播传输提供个性化更新,在异构边缘环境中增强了效率和适应性。

为了提高训练稳定性,论文引入了有界损失函数,通过修改交叉熵中的对数项来有效缓解梯度爆炸,同时对学习效率的影响可以忽略不计。标准交叉熵损失为:

LCE=i=1Nyilog(y^i)\mathcal{L}_{\text{CE}} = -\sum_{i=1}^{N} y_i \log(\hat{y}_i)

其中 yiy_i 是真实标签,y^i\hat{y}_i 是预测概率。当 y^i\hat{y}_i 接近0时,log(y^i)\log(\hat{y}_i) 趋向负无穷,导致梯度爆炸。有界损失函数引入一个下界 ϵ\epsilon

Lbounded=i=1Nyilog(max(y^i,ϵ))\mathcal{L}_{\text{bounded}} = -\sum_{i=1}^{N} y_i \log(\max(\hat{y}_i, \epsilon))

这样可以有效防止梯度爆炸,同时保持良好的学习性能。

此外,可以采用差分隐私技术来增强隐私保护级别,其中交换的梯度可以手动添加可调噪声。具体地,在梯度上添加高斯噪声:

g~k=gk+N(0,σ2I)\tilde{g}_k = g_k + \mathcal{N}(0, \sigma^2 I)

其中 σ2\sigma^2 是噪声方差,与所需的隐私级别 (ϵ,δ)(\epsilon, \delta) 相关。根据差分隐私理论,噪声规模需要满足:

σC2ln(1.25/δ)ϵ\sigma \geq \frac{C \sqrt{2\ln(1.25/\delta)}}{\epsilon}

其中 CC 是梯度的敏感度(灵敏度)。因此,需要建立理论框架来分析FU收敛性、无线衰落下的非线性投影梯度和隐私级别之间的关系,这与传统的基于线性梯度的分析不同。通过将这些见解整合到传输方案设计中,可以开发一种资源分配算法,在保持边缘LAM效用和数据隐私的同时改善FU收敛性。

微服务赋能的边缘LAM推理:模块化与动态调度

边缘LAM推理在网络边缘部署复杂的AI模型以提供实时智能服务。与依赖直接特征到预测映射的传统单步边缘AI范式不同,LAM需要顺序执行专门的计算模块,从潜在特征提取到提示驱动的推理。这种方法实现了复杂的物联网服务,但代价是神经元数量增加和计算强度提高。与按参数分区的传统模型不同,LAM需要与架构角色一致的功能分解,需要在资源受限的边缘设备之间进行协作推理。

然而,这带来了关键挑战:单体架构在推理阶段之间导致冗余计算,而异构设备能力在严格的延迟约束下使协调变得复杂。为了解决这些问题,论文提出了一种新颖的基于微服务的框架,该框架基于计算能力将边缘LAM的功能模块虚拟化为微服务。这是通过利用专家混合(MoE)架构和思维链(CoT)推理过程的特性来协调边缘设备实现的,从而提高资源利用率并降低推理延迟。

专家混合架构的微服务化:从集中到分布

最先进的LAM利用前馈网络的稀疏性支持条件计算,采用专家混合(MoE)框架提升推理效率。在MoE架构中,计算密集型解码器被拆分为由门控系统管理的轻量级专家模型。尽管有其优势,但这个框架需要并行激活多个专家和门控序列,导致计算需求大且延迟高。

现有解决方案通常专注于在单个压缩模型内减少计算和通信开销,往往忽略了不同下游任务中专家重复调度导致的响应延迟增加。此外,传统的基于MoE的边缘LAM推理依赖于同步MoE层,其中一个损坏的专家可能导致后续层的计算异常。

图4:基于微服务的边缘LAM推理架构(MoE)

fig4.png

图4展示了从传统MoE推理到基于微服务的MoE推理的范式转变。左侧显示了传统的MoE推理流程:输入经过堆叠 MM 次的门控网络(Gate Network),每次选择一个专家(Expert 1、2或3)进行计算,然后进行低效的激活调度(Inefficient Activation Scheduling),最后产生输出。整个LAM部署在单个设备上,导致可扩展性不足和资源利用率低。

右侧显示了基于微服务的MoE推理流程。边缘服务器负责广播嵌入(Broadcast Embedding)和反馈中间结果(Feedback Intermediate Result)。三个边缘设备(Edge Device 1、2、3)分别部署一个微服务化的专家。通过多时隙传输(Multi-slot transmission)实现基于微服务的推理流程。这种架构实现了高可扩展性和资源利用率。

为了解决这些问题,论文提出了基于微服务的边缘LAM推理框架。在该框架中,每个MoE层的专家被虚拟化为部署在边缘设备上的微服务,而边缘服务器处理注意力计算和门控调度。推理任务被转换为单向无环图(DAG)的微服务流程,其中序列MoE层中的门控函数调度可以形式化为在线微服务编排问题。

从数学角度详细描述MoE架构。对于一个MoE层,假设有 EE 个专家,第 ee 个专家的函数为 fe()f_e(\cdot)。门控网络 G()G(\cdot) 计算每个专家的权重:

w=G(x)=Softmax(Wgx)RE\mathbf{w} = G(x) = \text{Softmax}(W_g \cdot x) \in \mathbb{R}^E

其中 xx 是输入,WgW_g 是门控网络的权重矩阵。输出通过加权组合专家输出获得:

y=e=1Ewefe(x)y = \sum_{e=1}^{E} w_e \cdot f_e(x)

在稀疏MoE中,只激活top-kk 个专家(kEk \ll E):

y=eTop-k(w)wejTop-k(w)wjfe(x)y = \sum_{e \in \text{Top-k}(\mathbf{w})} \frac{w_e}{\sum_{j \in \text{Top-k}(\mathbf{w})} w_j} \cdot f_e(x)

对于包含 LL 个MoE层的模型,推理过程可以表示为:

x(0)=Inputx^{(0)} = \text{Input}

x()=MoE(x(1)),=1,2,...,Lx^{(\ell)} = \text{MoE}_\ell(x^{(\ell-1)}), \quad \ell = 1, 2, ..., L

Output=x(L)\text{Output} = x^{(L)}

目标是通过优化设备选择策略最小化长期系统成本(例如通信延迟、能量成本)。系统成本可以定义为:

Ctotal==1Le=1E1[expert e activated in layer ]ce,C_{\text{total}} = \sum_{\ell=1}^{L} \sum_{e=1}^{E} \mathbb{1}[\text{expert } e \text{ activated in layer } \ell] \cdot c_{e,\ell}

其中 ce,c_{e,\ell} 是在第 \ell 层激活专家 ee 的成本,包括计算成本和通信成本。

应用Lyapunov优化技术,可以将长期优化问题分解为顺序的单次设备调度问题。定义虚拟队列 Q(t)Q^{(t)} 来跟踪系统状态,Lyapunov函数为:

L(Q(t))=12(Q(t))2L(Q^{(t)}) = \frac{1}{2}(Q^{(t)})^2

Lyapunov漂移加惩罚(drift-plus-penalty)为:

Δ(Q(t))+VC(t)\Delta(Q^{(t)}) + V \cdot C^{(t)}

其中 Δ(Q(t))=E[L(Q(t+1))L(Q(t))Q(t)]\Delta(Q^{(t)}) = \mathbb{E}[L(Q^{(t+1)}) - L(Q^{(t)}) | Q^{(t)}] 是Lyapunov漂移,VV 是权衡参数,C(t)C^{(t)} 是时刻 tt 的系统成本。

通过最小化每个时隙的Lyapunov漂移加惩罚的上界,可以得到一个可处理的上界:

Δ(Q(t))+VC(t)B+E[Q(t)A(t)Q(t)S(t)+VC(t)Q(t)]\Delta(Q^{(t)}) + V \cdot C^{(t)} \leq B + \mathbb{E}[Q^{(t)} \cdot A^{(t)} - Q^{(t)} \cdot S^{(t)} + V \cdot C^{(t)} | Q^{(t)}]

其中 BB 是常数,A(t)A^{(t)} 是到达率,S(t)S^{(t)} 是服务率。在每个时隙,通过求解以下优化问题来选择设备:

mindD[Q(t)A(t)(d)Q(t)S(t)(d)+VC(t)(d)]\min_{d \in \mathcal{D}} \left[Q^{(t)} \cdot A^{(t)}(d) - Q^{(t)} \cdot S^{(t)}(d) + V \cdot C^{(t)}(d)\right]

其中 D\mathcal{D} 是可用设备集合,dd 是设备选择决策。

为了找到系统成本和有限边缘资源之间的最佳平衡,论文开发了一种在线优化方法,该方法在确保推理延迟的同时调度边缘设备。具体地,可以通过动态调整参数 VV 来平衡成本和延迟:VV 越大,系统越重视降低成本;VV 越小,系统越重视降低延迟。

思维链推理的图扩散优化:生成式部署策略

复杂推理是最先进边缘LAM的关键能力,对决策制定和多步推理至关重要。具体而言,思维链(CoT)提示通过将任务分解为结构化子过程来促进复杂推理,每个子过程与特定提示相关联。然而,相同提示的重复调用导致冗余计算和服务延迟增加。

通过将这些子过程转换为跨边缘设备部署的微服务,CoT推理可以被视为部署不同微服务的边缘设备之间的顺序推理。然而,边缘设备的异构计算能力可能导致推理效率与可用网络资源之间的不匹配。微服务部署的目标是在高效利用异构边缘设备有限资源的同时最小化总延迟。

因此,微服务部署问题可以形式化为NP难的组合优化问题,其中部署决策由二值部署指示器表示。设有 NN 个微服务和 MM 个边缘设备,定义部署矩阵 X{0,1}N×MX \in \{0,1\}^{N \times M},其中:

xi,j={1,如果微服务 i 部署在设备 j0,否则x_{i,j} = \begin{cases} 1, & \text{如果微服务 } i \text{ 部署在设备 } j \\ 0, & \text{否则} \end{cases}

约束条件包括:

  1. 每个微服务必须部署在至少一个设备上:j=1Mxi,j1,i\sum_{j=1}^{M} x_{i,j} \geq 1, \forall i
  2. 设备容量约束:i=1Nxi,jriCj,j\sum_{i=1}^{N} x_{i,j} \cdot r_i \leq C_j, \forall j,其中 rir_i 是微服务 ii 的资源需求,CjC_j 是设备 jj 的容量
  3. CoT推理流是路径图:部署必须形成有向无环路径

目标函数是最小化总延迟:

minXTtotal(X)=i=1N1(maxj:xi,j=1ticomp(j)+minj,j:xi,j=1,xi+1,j=1ti,i+1comm(j,j))\min_{X} T_{\text{total}}(X) = \sum_{i=1}^{N-1} \left(\max_{j: x_{i,j}=1} t_i^{\text{comp}}(j) + \min_{j,j': x_{i,j}=1, x_{i+1,j'}=1} t_{i,i+1}^{\text{comm}}(j,j')\right)

其中 ticomp(j)t_i^{\text{comp}}(j) 是微服务 ii 在设备 jj 上的计算时间,ti,i+1comm(j,j)t_{i,i+1}^{\text{comm}}(j,j') 是从设备 jj 到设备 jj' 传输数据的通信时间。

为了高效解决这个组合优化问题,训练一个生成似然模型来学习基于历史执行轨迹的近最优部署策略分布是一种有前景的方法。该模型允许以端到端方式实时生成部署决策,适应边缘环境的动态特性。

通过将边缘设备表示为节点、它们之间的通信链路表示为边,可以使用图扩散模型构造似然模型,同时捕获CoT推理期间的复杂交互。图扩散模型的核心是学习图结构的分布 p(G)p(G)

图可以表示为邻接矩阵 A{0,1}M×MA \in \{0,1\}^{M \times M},其中 Aj,j=1A_{j,j'} = 1 表示设备 jjjj' 之间有通信链路。图扩散过程定义为:

前向扩散过程(加噪):

q(GtG0)=N(Gt;αˉtG0,(1αˉt)I)q(G_t | G_0) = \mathcal{N}(G_t; \sqrt{\bar{\alpha}_t} G_0, (1-\bar{\alpha}_t)I)

其中 αˉt=s=1tαs\bar{\alpha}_t = \prod_{s=1}^{t} \alpha_sαt\alpha_t 是噪声调度参数。

逆向去噪过程(生成):

pθ(Gt1Gt)=N(Gt1;μθ(Gt,t),Σθ(Gt,t))p_\theta(G_{t-1} | G_t) = \mathcal{N}(G_{t-1}; \mu_\theta(G_t, t), \Sigma_\theta(G_t, t))

其中 μθ\mu_\thetaΣθ\Sigma_\theta 是神经网络参数化的均值和方差。

具体地,通信图可以通过图神经网络(GNN)编码为向量化表示:

hj(0)=Embedding(featuresj)h_j^{(0)} = \text{Embedding}(\text{features}_j)

hj(+1)=Aggregate({hj(),jN(j)})h_j^{(\ell+1)} = \text{Aggregate}\left(\left\{h_{j'}^{(\ell)}, \forall j' \in \mathcal{N}(j)\right\}\right)

zG=Readout({hj(L),j})z_G = \text{Readout}\left(\left\{h_j^{(L)}, \forall j\right\}\right)

其中 N(j)\mathcal{N}(j) 是节点 jj 的邻居集合,zGz_G 是图的嵌入表示。

为了确保CoT推理中的微服务流是路径图,采用节点度掩码变量。路径图要求每个中间节点的入度和出度都是1,起始节点入度为0、出度为1,终止节点入度为1、出度为0。因此,在生成过程中添加掩码约束:

Aj,j1[degree constraints violated]=0A_{j,j'} \cdot \mathbb{1}[\text{degree constraints violated}] = 0

以编码表示 zGz_G 为条件变量,图扩散模型直接采样两个相邻步骤的条件分布:

pθ(Gt2Gt,zG)p_\theta(G_{t-2} | G_t, z_G)

而不是传统的逐步去噪 pθ(Gt1Gt,zG)p_\theta(G_{t-1} | G_t, z_G)。这种跳跃采样可以从高噪声分布更快地逼近目标分布,同时减少计算开销。采样步数从 TT 减少到 T/2T/2

训练目标是最大化对数似然的变分下界(ELBO):

L=Eq(G0)[logpθ(G0)KL(q(GTG0)p(GT))t=1TKL(q(Gt1Gt,G0)pθ(Gt1Gt,zG))]\mathcal{L} = \mathbb{E}_{q(G_0)} \left[\log p_\theta(G_0) - \text{KL}(q(G_T|G_0) \| p(G_T)) - \sum_{t=1}^{T} \text{KL}(q(G_{t-1}|G_t, G_0) \| p_\theta(G_{t-1}|G_t, z_G))\right]

通过最小化重参数化的损失函数:

Lsimple=Et,G0,ϵ[ϵϵθ(Gt,t,zG)2]\mathcal{L}_{\text{simple}} = \mathbb{E}_{t, G_0, \epsilon}\left[\|\epsilon - \epsilon_\theta(G_t, t, z_G)\|^2\right]

其中 ϵN(0,I)\epsilon \sim \mathcal{N}(0, I) 是标准高斯噪声,ϵθ\epsilon_\theta 是神经网络预测的噪声。

案例研究:CoT推理性能评估

论文在GSM8K数据集上评估了基于CoT推理的微服务框架性能,该数据集包含253个逻辑推理问题。采用Qwen2.5-7B-Instruct作为基础模型,考虑包含最多10个边缘设备的设置进行基于微服务的CoT推理,其中每个设备被分配一个虚拟化为单独微服务的推理步骤。

图5:CoT推理性能比较

fig5.png

图5以柱状图形式展示了不同配置下的性能对比。横轴显示四种配置:单体基线(Monolithic based)、基于微服务(最大token 64)、基于微服务(最大token 128)、基于微服务(最大token 256)。纵轴左侧显示总内存消耗(MB),右侧显示计算延迟(秒)。蓝色柱状图表示总内存消耗,红色柱状图表示计算延迟。

评估结果显示,在正确解决推理问题时,论文研究了单设备的最大token分配与系统效用(即内存消耗和计算延迟)之间的权衡。由于注意力机制,更大的token数会导致单设备的内存消耗和计算开销呈二次增长,同时减少所需的边缘设备总数。

如图5所示,微服务架构(最大token=64/128/256)相比单体基线显著降低了总内存消耗和计算延迟。128-token配置达到了最优效率,相比基线减少了70.8%的总内存消耗和59.6%的计算延迟。这充分验证了微服务架构在分布式推理中的优势。

从数学角度分析,注意力机制的复杂度为 O(n2d)O(n^2d),其中 nn 是序列长度(token数),dd 是特征维度。对于单体架构,如果总共需要处理 NN 个token,则总复杂度为:

Cmono=O(N2d)C_{\text{mono}} = O(N^2 d)

对于微服务架构,假设有 KK 个设备,每个设备处理 nkn_k 个token(k=1Knk=N\sum_{k=1}^{K} n_k = N),则总复杂度为:

Cmicro=k=1KO(nk2d)=O(dk=1Knk2)C_{\text{micro}} = \sum_{k=1}^{K} O(n_k^2 d) = O\left(d \sum_{k=1}^{K} n_k^2\right)

根据柯西-施瓦茨不等式:

(k=1Knk)2Kk=1Knk2\left(\sum_{k=1}^{K} n_k\right)^2 \leq K \sum_{k=1}^{K} n_k^2

即:

N2Kk=1Knk2N^2 \leq K \sum_{k=1}^{K} n_k^2

nkn_k 均匀分布时(nk=N/Kn_k = N/K),有:

k=1Knk2=K(N/K)2=N2K\sum_{k=1}^{K} n_k^2 = K \cdot (N/K)^2 = \frac{N^2}{K}

因此:

Cmicro=O(N2dK)=CmonoKC_{\text{micro}} = O\left(\frac{N^2 d}{K}\right) = \frac{C_{\text{mono}}}{K}

这解释了为什么微服务架构能够显著降低计算成本。对于 K=10K=10 个设备,理论上可以实现约10倍的加速。


附录:数学推导与理论分析

A. LoRA的梯度计算与反向传播

对于LoRA参数化的权重更新 W=W0+BAW = W_0 + BA,我们需要计算损失函数 L\mathcal{L} 关于 BBAA 的梯度。

设前向传播为:

y=Wx=(W0+BA)x=W0x+BAxy = Wx = (W_0 + BA)x = W_0 x + BAx

定义中间变量:

z=AxRrz = Ax \in \mathbb{R}^r

y=W0x+Bzy = W_0 x + Bz

损失函数对输出的梯度为:

Ly=δyRd\frac{\partial \mathcal{L}}{\partial y} = \delta_y \in \mathbb{R}^d

利用链式法则,计算对 BB 的梯度:

LB=LyyB=δyzTRd×r\frac{\partial \mathcal{L}}{\partial B} = \frac{\partial \mathcal{L}}{\partial y} \frac{\partial y}{\partial B} = \delta_y z^T \in \mathbb{R}^{d \times r}

计算对 zz 的梯度:

Lz=BTδyRr\frac{\partial \mathcal{L}}{\partial z} = B^T \delta_y \in \mathbb{R}^r

计算对 AA 的梯度:

LA=LzzA=(BTδy)xTRr×k\frac{\partial \mathcal{L}}{\partial A} = \frac{\partial \mathcal{L}}{\partial z} \frac{\partial z}{\partial A} = (B^T \delta_y) x^T \in \mathbb{R}^{r \times k}

总结梯度公式:

BL=δy(Ax)T\nabla_B \mathcal{L} = \delta_y (Ax)^T

AL=(BTδy)xT\nabla_A \mathcal{L} = (B^T \delta_y) x^T

这些梯度可以高效计算,因为涉及的矩阵维度较小。

B. 正交投影的性质与最优性证明

给定子空间 SRp\mathcal{S} \subseteq \mathbb{R}^p 由向量集合 {v1,v2,...,vm}\{v_1, v_2, ..., v_m\} 张成,其中 viv_i 是标准正交的(即 viTvj=δijv_i^T v_j = \delta_{ij})。投影矩阵为:

P=VVTP = VV^T

其中 V=[v1,v2,...,vm]Rp×mV = [v_1, v_2, ..., v_m] \in \mathbb{R}^{p \times m}

性质1(幂等性)P2=PP^2 = P

证明

P2=(VVT)(VVT)=V(VTV)VT=VIVT=VVT=PP^2 = (VV^T)(VV^T) = V(V^TV)V^T = VIV^T = VV^T = P

性质2(对称性)PT=PP^T = P

证明

PT=(VVT)T=(VT)TVT=VVT=PP^T = (VV^T)^T = (V^T)^T V^T = VV^T = P

对于任意向量 gRpg \in \mathbb{R}^p,其在 S\mathcal{S} 上的投影为:

g=Pg=VVTgg_{\parallel} = Pg = VV^T g

其在 S\mathcal{S}^{\perp} 上的投影为:

g=gPg=(IVVT)gg_{\perp} = g - Pg = (I - VV^T)g

最优性:投影 gg_{\parallel}S\mathcal{S} 中距离 gg 最近的向量。

证明:对于任意 vSv \in \mathcal{S},可以写为 v=Vαv = V\alpha,其中 αRm\alpha \in \mathbb{R}^m。则:

gv2=gVα2=(gVVTg)+(VVTgVα)2\|g - v\|^2 = \|g - V\alpha\|^2 = \|(g - VV^Tg) + (VV^Tg - V\alpha)\|^2

由于 (gVVTg)(VVTgVα)(g - VV^Tg) \perp (VV^Tg - V\alpha)(因为前者在 S\mathcal{S}^{\perp} 中,后者在 S\mathcal{S} 中),有:

gv2=gVVTg2+VVTgVα2=g2+V(VTgα)2\|g - v\|^2 = \|g - VV^Tg\|^2 + \|VV^Tg - V\alpha\|^2 = \|g_{\perp}\|^2 + \|V(V^Tg - \alpha)\|^2

α=VTg\alpha = V^Tg 时,第二项为零,此时 v=VVTg=gv = VV^Tg = g_{\parallel},距离最小。

C. Lyapunov优化的收敛性分析

考虑一个离散时间系统,状态队列为 Q(t)Q(t),系统成本为 C(t)C(t)。Lyapunov函数定义为:

L(Q(t))=12Q(t)2L(Q(t)) = \frac{1}{2}Q(t)^2

Lyapunov漂移为:

Δ(Q(t))=E[L(Q(t+1))L(Q(t))Q(t)]\Delta(Q(t)) = \mathbb{E}[L(Q(t+1)) - L(Q(t)) | Q(t)]

展开:

L(Q(t+1))=12Q(t+1)2=12[Q(t)+A(t)S(t)]2L(Q(t+1)) = \frac{1}{2}Q(t+1)^2 = \frac{1}{2}[Q(t) + A(t) - S(t)]^2

=12[Q(t)2+A(t)2+S(t)2+2Q(t)A(t)2Q(t)S(t)2A(t)S(t)]= \frac{1}{2}[Q(t)^2 + A(t)^2 + S(t)^2 + 2Q(t)A(t) - 2Q(t)S(t) - 2A(t)S(t)]

因此:

L(Q(t+1))L(Q(t))=12[A(t)2+S(t)22A(t)S(t)]+Q(t)[A(t)S(t)]L(Q(t+1)) - L(Q(t)) = \frac{1}{2}[A(t)^2 + S(t)^2 - 2A(t)S(t)] + Q(t)[A(t) - S(t)]

假设 A(t)A(t)S(t)S(t) 有界,即 A(t)AmaxA(t) \leq A_{\max}S(t)SmaxS(t) \leq S_{\max},则:

12[A(t)2+S(t)22A(t)S(t)]12[Amax2+Smax2]=B\frac{1}{2}[A(t)^2 + S(t)^2 - 2A(t)S(t)] \leq \frac{1}{2}[A_{\max}^2 + S_{\max}^2] = B

因此:

Δ(Q(t))B+Q(t)E[A(t)S(t)Q(t)]\Delta(Q(t)) \leq B + Q(t)\mathbb{E}[A(t) - S(t) | Q(t)]

Lyapunov漂移加惩罚为:

Δ(Q(t))+VE[C(t)Q(t)]B+E[Q(t)(A(t)S(t))+VC(t)Q(t)]\Delta(Q(t)) + V\mathbb{E}[C(t) | Q(t)] \leq B + \mathbb{E}[Q(t)(A(t) - S(t)) + VC(t) | Q(t)]

在每个时隙,通过最小化右侧的期望值来选择控制决策,即求解:

mindD[Q(t)(A(t,d)S(t,d))+VC(t,d)]\min_{d \in \mathcal{D}} [Q(t)(A(t,d) - S(t,d)) + VC(t,d)]

队列稳定性:如果存在 ϵ>0\epsilon > 0 使得在所有 tt 和所有 Q(t)>QmaxQ(t) > Q_{\max} 时:

E[A(t)S(t)Q(t)]ϵ\mathbb{E}[A(t) - S(t) | Q(t)] \leq -\epsilon

则队列 Q(t)Q(t) 是均方稳定的,即 limT1Tt=0T1E[Q(t)2]<\lim_{T \to \infty} \frac{1}{T}\sum_{t=0}^{T-1} \mathbb{E}[Q(t)^2] < \infty

成本性能:设最优时间平均成本为 CC^*。Lyapunov优化算法达到的时间平均成本 Cˉ\bar{C} 满足:

CˉC+BV\bar{C} \leq C^* + \frac{B}{V}

证明:对Lyapunov漂移加惩罚不等式在 TT 个时隙上求和:

t=0T1E[Δ(Q(t))]+Vt=0T1E[C(t)]TB+t=0T1E[Q(t)(A(t)S(t))+VC(t)]\sum_{t=0}^{T-1}\mathbb{E}[\Delta(Q(t))] + V\sum_{t=0}^{T-1}\mathbb{E}[C(t)] \leq TB + \sum_{t=0}^{T-1}\mathbb{E}[Q(t)(A(t) - S(t)) + VC(t)]

由于 t=0T1E[Δ(Q(t))]=E[L(Q(T))]E[L(Q(0))]\sum_{t=0}^{T-1}\mathbb{E}[\Delta(Q(t))] = \mathbb{E}[L(Q(T))] - \mathbb{E}[L(Q(0))] 有界,左侧第一项可以忽略。

对于稳定系统,t=0T1E[Q(t)(A(t)S(t))]\sum_{t=0}^{T-1}\mathbb{E}[Q(t)(A(t) - S(t))] 有界。因此:

Vt=0T1E[C(t)]TB+O(T)V\sum_{t=0}^{T-1}\mathbb{E}[C(t)] \leq TB + O(T)

除以 VTVT 并取 TT \to \infty

Cˉ=limT1Tt=0T1E[C(t)]BV\bar{C} = \lim_{T \to \infty} \frac{1}{T}\sum_{t=0}^{T-1}\mathbb{E}[C(t)] \leq \frac{B}{V}

比较最优策略,可以得到 CˉC+BV\bar{C} \leq C^* + \frac{B}{V}

这表明通过增大 VV,可以使时间平均成本接近最优,但代价是增加队列长度(延迟)。

D. 信息瓶颈的变分推断

信息瓶颈的目标是找到压缩表示 ZZ,最大化拉格朗日函数:

LIB=I(Z;Y)βI(X;Z)\mathcal{L}_{\text{IB}} = I(Z; Y) - \beta I(X; Z)

其中 I(;)I(\cdot; \cdot) 是互信息,β\beta 是权衡参数。

互信息定义为:

I(X;Z)=Ep(x,z)[logp(zx)p(z)]=KL(p(zx)p(z))I(X; Z) = \mathbb{E}_{p(x,z)}\left[\log\frac{p(z|x)}{p(z)}\right] = \text{KL}(p(z|x) \| p(z))

I(Z;Y)=Ep(y,z)[logp(yz)p(y)]I(Z; Y) = \mathbb{E}_{p(y,z)}\left[\log\frac{p(y|z)}{p(y)}\right]

直接优化互信息是困难的,因为涉及真实分布 p(zx)p(z|x)p(yz)p(y|z)。使用变分推断引入变分分布 qϕ(zx)q_\phi(z|x)qψ(yz)q_\psi(y|z) 来近似。

对于 I(X;Z)I(X; Z)

I(X;Z)=Ep(x)[KL(p(zx)p(z))]I(X; Z) = \mathbb{E}_{p(x)}\left[\text{KL}(p(z|x) \| p(z))\right]

使用 qϕ(zx)q_\phi(z|x) 近似 p(zx)p(z|x),并假设 p(z)=N(0,I)p(z) = \mathcal{N}(0, I)

I(X;Z)Ep(x)[KL(qϕ(zx)N(0,I))]I(X; Z) \approx \mathbb{E}_{p(x)}\left[\text{KL}(q_\phi(z|x) \| \mathcal{N}(0,I))\right]

对于 I(Z;Y)I(Z; Y),使用下界:

I(Z;Y)=Ep(y,z)[logp(yz)p(y)]=Ep(y,z)[logp(yz)]Ep(y)[logp(y)]I(Z; Y) = \mathbb{E}_{p(y,z)}\left[\log\frac{p(y|z)}{p(y)}\right] = \mathbb{E}_{p(y,z)}[\log p(y|z)] - \mathbb{E}_{p(y)}[\log p(y)]

Ep(y,z)[logqψ(yz)]+H(Y)\geq \mathbb{E}_{p(y,z)}[\log q_\psi(y|z)] + H(Y)

其中 H(Y)H(Y)YY 的熵,与优化无关。

因此,信息瓶颈的优化目标变为:

maxϕ,ψEp(x,y),qϕ(zx)[logqψ(yz)]βEp(x)[KL(qϕ(zx)N(0,I))]\max_{\phi, \psi} \mathbb{E}_{p(x,y), q_\phi(z|x)}[\log q_\psi(y|z)] - \beta \mathbb{E}_{p(x)}\left[\text{KL}(q_\phi(z|x) \| \mathcal{N}(0,I))\right]

这等价于最小化:

LVIB=Ep(x,y),qϕ(zx)[logqψ(yz)]+βEp(x)[KL(qϕ(zx)N(0,I))]\mathcal{L}_{\text{VIB}} = -\mathbb{E}_{p(x,y), q_\phi(z|x)}[\log q_\psi(y|z)] + \beta \mathbb{E}_{p(x)}\left[\text{KL}(q_\phi(z|x) \| \mathcal{N}(0,I))\right]

第一项是重构损失(或分类损失),第二项是正则化项。

假设 qϕ(zx)=N(μϕ(x),Σϕ(x))q_\phi(z|x) = \mathcal{N}(\mu_\phi(x), \Sigma_\phi(x)),KL散度可以解析计算:

KL(qϕ(zx)N(0,I))=12i=1dz[μi2+σi2logσi21]\text{KL}(q_\phi(z|x) \| \mathcal{N}(0,I)) = \frac{1}{2}\sum_{i=1}^{d_z}\left[\mu_i^2 + \sigma_i^2 - \log\sigma_i^2 - 1\right]

对于分类任务,logqψ(yz)\log q_\psi(y|z) 可以用交叉熵损失实现:

logqψ(yz)=CrossEntropy(y,y^)-\log q_\psi(y|z) = \text{CrossEntropy}(y, \hat{y})

其中 y^=fψ(z)\hat{y} = f_\psi(z) 是分类器的输出。

通过重参数化技巧:

z=μϕ(x)+Σϕ1/2(x)ϵ,ϵN(0,I)z = \mu_\phi(x) + \Sigma_\phi^{1/2}(x) \cdot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)

可以进行端到端的梯度下降优化。


【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。