从VAE到Diffusion:生成模型演进背后的概率图视角

举报
江南清风起 发表于 2025/11/09 16:52:11 2025/11/09
【摘要】 从VAE到Diffusion:生成模型演进背后的概率图视角 引言生成模型是人工智能领域最令人兴奋的方向之一,其目标是从训练数据中学习潜在分布,并生成新的样本。从变分自编码器(VAE)到扩散模型(Diffusion Models),这一演进历程不仅带来了生成质量的飞跃,更体现了概率图模型思想的深化发展。本文将从概率图的视角,深入剖析这一演进路径的内在逻辑,并提供详细的代码实现。 概率图模型基...

从VAE到Diffusion:生成模型演进背后的概率图视角

引言

生成模型是人工智能领域最令人兴奋的方向之一,其目标是从训练数据中学习潜在分布,并生成新的样本。从变分自编码器(VAE)到扩散模型(Diffusion Models),这一演进历程不仅带来了生成质量的飞跃,更体现了概率图模型思想的深化发展。本文将从概率图的视角,深入剖析这一演进路径的内在逻辑,并提供详细的代码实现。

概率图模型基础

概率图模型的核心思想

概率图模型是概率论与图论的结合,它使用图结构来表示随机变量之间的条件依赖关系。在生成模型中,我们主要关注两类概率图:

  • 有向图模型(贝叶斯网络):表示变量间的因果关系
  • 无向图模型(马尔可夫随机场):表示变量间的相关关系
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.distributions import Normal, Bernoulli

# 设置随机种子保证可重复性
torch.manual_seed(42)
np.random.seed(42)

# 概率图模型基础示例
class SimpleBayesianNetwork:
    def __init__(self):
        # 先验概率 P(Rain) = 0.2
        self.p_rain = 0.2
        # 条件概率 P(Sprinkler|Rain)
        self.p_sprinkler_given_rain = {True: 0.01, False: 0.4}
        # 条件概率 P(Wet|Rain, Sprinkler)
        self.p_wet = {
            (True, True): 0.99,
            (True, False): 0.8,
            (False, True): 0.9,
            (False, False): 0.001
        }
    
    def sample(self):
        rain = np.random.random() < self.p_rain
        sprinkler = np.random.random() < self.p_sprinkler_given_rain[rain]
        wet = np.random.random() < self.p_wet[(rain, sprinkler)]
        return rain, sprinkler, wet

# 测试简单贝叶斯网络
bn = SimpleBayesianNetwork()
samples = [bn.sample() for _ in range(1000)]
rain_count = sum(1 for r, _, _ in samples if r)
print(f"模拟的雨天概率: {rain_count/1000:.3f}")

生成模型与概率图

生成模型本质上是在学习一个联合概率分布 p(x,z)p(\mathbf{x}, \mathbf{z}),其中 x\mathbf{x} 是观测变量(如图像),z\mathbf{z} 是潜在变量。概率图为我们提供了表示和推断这一联合分布的框架。

变分自编码器(VAE)

VAE的概率图视角

VAE的概率图模型是一个简单的有向图:

zx\mathbf{z} \rightarrow \mathbf{x}

其中潜在变量 z\mathbf{z} 服从先验分布 p(z)p(\mathbf{z}),观测变量 x\mathbf{x} 服从条件分布 pθ(xz)p_\theta(\mathbf{x}|\mathbf{z})。VAE的核心思想是通过变分推断来近似后验分布 p(zx)p(\mathbf{z}|\mathbf{x})

class VAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super(VAE, self).__init__()
        
        # 编码器
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # 潜在空间的均值和方差
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        
        # 解码器
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
        )
    
    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        return recon_x, mu, logvar

# VAE损失函数
def vae_loss(recon_x, x, mu, logvar):
    # 重构损失
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    
    # KL散度损失
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return BCE + KLD

# 训练示例
def train_vae(model, dataloader, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        train_loss = 0
        for batch_idx, (data, _) in enumerate(dataloader):
            optimizer.zero_grad()
            recon_batch, mu, logvar = model(data)
            loss = vae_loss(recon_batch, data, mu, logvar)
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
        
        print(f'Epoch {epoch+1}, Loss: {train_loss/len(dataloader.dataset):.4f}')

# 生成新样本
def generate_vae_samples(model, num_samples=16):
    model.eval()
    with torch.no_grad():
        z = torch.randn(num_samples, 20)
        samples = model.decode(z)
    return samples

VAE的局限性

尽管VAE在理论上有很好的概率解释,但其生成质量往往不如后来的模型。主要原因包括:

  1. 后验分布假设过于简单:通常假设为高斯分布
  2. ELBO目标的限制性:变分下界可能过于宽松
  3. 生成过程的简单性:单步生成难以捕捉复杂分布

标准化流(Normalizing Flows)

从VAE到标准化流

标准化流通过一系列可逆变换来构建复杂的分布,解决了VAE中后验分布过于简单的问题。其核心思想是通过变量变换公式:

pX(x)=pZ(f1(x))detf1(x)xp_X(\mathbf{x}) = p_Z(f^{-1}(\mathbf{x})) \left| \det \frac{\partial f^{-1}(\mathbf{x})}{\partial \mathbf{x}} \right|

class PlanarFlow(nn.Module):
    """平面流:一种简单的标准化流"""
    def __init__(self, dim):
        super(PlanarFlow, self).__init__()
        self.dim = dim
        self.weight = nn.Parameter(torch.randn(1, dim))
        self.bias = nn.Parameter(torch.randn(1))
        self.scale = nn.Parameter(torch.randn(1, dim))
        
    def forward(self, z):
        # 前向变换
        activation = torch.tanh(F.linear(z, self.weight, self.bias))
        return z + self.scale * activation
    
    def log_det_jacobian(self, z):
        # 计算对数雅可比行列式
        hidden = F.linear(z, self.weight, self.bias)
        derivative = 1 - torch.tanh(hidden)**2
        psi = derivative * self.weight
        det_grad = 1 + torch.mm(psi, self.scale.t())
        return torch.log(det_grad.abs() + 1e-8)

class NormalizingFlowVAE(nn.Module):
    """结合标准化流的VAE"""
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20, flow_length=4):
        super(NormalizingFlowVAE, self).__init__()
        
        # 编码器部分与VAE相同
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        
        # 标准化流
        self.flows = nn.ModuleList([PlanarFlow(latent_dim) for _ in range(flow_length)])
        
        # 解码器
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
        )
    
    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z0 = mu + eps * std
        
        # 应用流变换
        z_k = z0
        log_det_jacobian = 0
        for flow in self.flows:
            log_det_jacobian += flow.log_det_jacobian(z_k)
            z_k = flow(z_k)
        
        return z_k, log_det_jacobian
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z, log_det_jacobian = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        return recon_x, mu, logvar, log_det_jacobian

# 带流的VAE损失函数
def flow_vae_loss(recon_x, x, mu, logvar, log_det_jacobian):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    
    # 改进的KL散度,考虑流的变换
    log_q_z0 = -0.5 * torch.sum(logvar + (mu**2 + logvar.exp()) - 1, dim=1)
    log_q_z = log_q_z0 - log_det_jacobian.sum(dim=1)
    
    KLD = -torch.sum(log_q_z)
    
    return BCE + KLD

扩散模型(Diffusion Models)

扩散模型的概率图框架

扩散模型可以看作是一个层次化的马尔可夫链,包含两个过程:

  1. 前向过程(扩散过程):逐步向数据添加噪声

    q(x1:Tx0)=t=1Tq(xtxt1)q(\mathbf{x}_{1:T}|\mathbf{x}_0) = \prod_{t=1}^T q(\mathbf{x}_t|\mathbf{x}_{t-1})

  2. 反向过程(生成过程):从噪声中逐步重建数据

    pθ(x0:T)=p(xT)t=1Tpθ(xt1xt)p_\theta(\mathbf{x}_{0:T}) = p(\mathbf{x}_T) \prod_{t=1}^T p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)

class DiffusionModel(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=512, timesteps=1000):
        super(DiffusionModel, self).__init__()
        
        self.timesteps = timesteps
        
        # 定义噪声调度
        self.betas = self._linear_beta_schedule(timesteps)
        self.alphas = 1. - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
        
        # 噪声预测网络
        self.denoise_net = nn.Sequential(
            nn.Linear(input_dim + 1, hidden_dim),  # +1 for timestep embedding
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, input_dim)
        )
    
    def _linear_beta_schedule(self, timesteps, start=0.0001, end=0.02):
        return torch.linspace(start, end, timesteps)
    
    def _extract(self, arr, timesteps, broadcast_shape):
        # 从数组中提取对应时间步的值
        res = arr.to(timesteps.device)[timesteps].float()
        while len(res.shape) < len(broadcast_shape):
            res = res[..., None]
        return res.expand(broadcast_shape)
    
    def q_sample(self, x_start, t, noise=None):
        # 前向扩散过程
        if noise is None:
            noise = torch.randn_like(x_start)
        
        sqrt_alphas_cumprod_t = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape)
        sqrt_one_minus_alphas_cumprod_t = self._extract(
            self.sqrt_one_minus_alphas_cumprod, t, x_start.shape
        )
        
        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
    
    def p_loss(self, x_start, t, noise=None):
        # 计算损失 - 预测噪声
        if noise is None:
            noise = torch.randn_like(x_start)
        
        x_noisy = self.q_sample(x_start, t, noise)
        
        # 将时间步嵌入
        t_embed = t.float() / self.timesteps
        t_embed = t_embed.view(-1, 1).expand(-1, x_start.shape[1])
        
        # 预测噪声
        predicted_noise = self.denoise_net(torch.cat([x_noisy, t_embed], dim=1))
        
        # 简单的MSE损失
        loss = F.mse_loss(predicted_noise, noise)
        return loss
    
    def p_sample(self, x, t, t_index):
        # 单步反向采样
        t_embed = torch.tensor([t_index] * x.shape[0], device=x.device).float() / self.timesteps
        t_embed = t_embed.view(-1, 1).expand(-1, x.shape[1])
        
        # 预测噪声
        predicted_noise = self.denoise_net(torch.cat([x, t_embed], dim=1))
        
        # 计算系数
        beta_t = self._extract(self.betas, torch.tensor([t_index]), x.shape)
        sqrt_one_minus_alphas_cumprod_t = self._extract(
            self.sqrt_one_minus_alphas_cumprod, torch.tensor([t_index]), x.shape
        )
        sqrt_recip_alphas_t = self._extract(torch.sqrt(1.0 / self.alphas), torch.tensor([t_index]), x.shape)
        
        # 计算均值
        model_mean = sqrt_recip_alphas_t * (
            x - beta_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t
        )
        
        if t_index == 0:
            return model_mean
        else:
            # 添加噪声
            posterior_variance_t = self._extract(self.betas, torch.tensor([t_index]), x.shape)
            noise = torch.randn_like(x)
            return model_mean + torch.sqrt(posterior_variance_t) * noise
    
    def sample(self, num_samples, image_size=784):
        # 完整生成过程
        shape = (num_samples, image_size)
        device = next(self.denoise_net.parameters()).device
        
        # 从纯噪声开始
        x = torch.randn(shape, device=device)
        
        # 逐步去噪
        for i in reversed(range(0, self.timesteps)):
            x = self.p_sample(x, torch.tensor([i] * num_samples, device=device), i)
            
        return x

# 训练扩散模型
def train_diffusion(model, dataloader, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, _) in enumerate(dataloader):
            optimizer.zero_grad()
            
            # 随机选择时间步
            t = torch.randint(0, model.timesteps, (data.shape[0],), device=data.device)
            
            # 计算损失
            loss = model.p_loss(data.view(data.shape[0], -1), t)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f'Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}')

扩散模型的概率图解释

从概率图的角度看,扩散模型定义了一个复杂的马尔可夫链:

  • 前向链q(x0:T)q(\mathbf{x}_{0:T}) 是固定的推理过程
  • 反向链pθ(x0:T)p_\theta(\mathbf{x}_{0:T}) 是学习的生成过程

这种框架允许模型通过变分下界进行训练:

Eq[logpθ(x0)]Eq[logp(xT)+t=1Tlogpθ(xt1xt)q(xtxt1)]\mathbb{E}_q[\log p_\theta(\mathbf{x}_0)] \geq \mathbb{E}_q[\log p(\mathbf{x}_T) + \sum_{t=1}^T \log \frac{p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)}{q(\mathbf{x}_t|\mathbf{x}_{t-1})}]

概率图视角的统一理解

生成模型的演进路径

从VAE到扩散模型,我们可以看到一个清晰的演进路径:

  1. VAE:单步潜在变量模型,强调潜在空间的建模
  2. 标准化流:多步可逆变换,强调分布的精确建模
  3. 扩散模型:多步马尔可夫链,强调渐进式生成过程

概率图结构的演变

import networkx as nx
import matplotlib.pyplot as plt

def plot_probabilistic_graphs():
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # VAE的概率图
    G_vae = nx.DiGraph()
    G_vae.add_edges_from([('z', 'x')])
    pos_vae = {'z': (0, 0), 'x': (1, 0)}
    nx.draw(G_vae, pos_vae, with_labels=True, node_size=2000, 
            node_color='lightblue', arrows=True, ax=axes[0])
    axes[0].set_title('VAE: z → x')
    
    # 标准化流的概率图
    G_flow = nx.DiGraph()
    G_flow.add_edges_from([('z_0', 'z_1'), ('z_1', 'z_2'), ('z_2', 'x')])
    pos_flow = {'z_0': (0, 0), 'z_1': (1, 0), 'z_2': (2, 0), 'x': (3, 0)}
    nx.draw(G_flow, pos_flow, with_labels=True, node_size=2000, 
            node_color='lightgreen', arrows=True, ax=axes[1])
    axes[1].set_title('Normalizing Flows: z₀ → z₁ → z₂ → x')
    
    # 扩散模型的概率图
    G_diff = nx.DiGraph()
    G_diff.add_edges_from([('x_0', 'x_1'), ('x_1', 'x_2'), ('x_2', 'x_3')])
    pos_diff = {'x_0': (0, 0), 'x_1': (1, 0), 'x_2': (2, 0), 'x_3': (3, 0)}
    nx.draw(G_diff, pos_diff, with_labels=True, node_size=2000, 
            node_color='lightcoral', arrows=True, ax=axes[2])
    axes[2].set_title('Diffusion Models: x₀ → x₁ → x₂ → x₃')
    
    plt.tight_layout()
    plt.show()

# 绘制概率图比较
plot_probabilistic_graphs()

理论联系与差异

模型 概率图结构 潜在变量 训练目标 生成方式
VAE 单层有向图 连续 ELBO 单步解码
标准化流 多层可逆变换 连续 精确似然 可逆变换
扩散模型 马尔可夫链 离散时间 变分下界 多步去噪

代码实例:完整的扩散模型实现

下面我们实现一个完整的扩散模型,用于MNIST数据集的生成:

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

class ImprovedDiffusionModel(nn.Module):
    def __init__(self, image_size=28, channels=1, hidden_dim=256, timesteps=1000):
        super(ImprovedDiffusionModel, self).__init__()
        
        self.timesteps = timesteps
        self.image_size = image_size
        self.channels = channels
        
        # 噪声调度
        self.betas = self._cosine_beta_schedule(timesteps)
        self.alphas = 1. - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
        
        # 计算反向过程的方差
        self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
        
        # 改进的UNet风格去噪网络
        self.denoise_net = DenoiseNet(image_size, channels, hidden_dim)
    
    def _cosine_beta_schedule(self, timesteps, s=0.008):
        """余弦调度,来自Improved DDPM论文"""
        steps = timesteps + 1
        x = torch.linspace(0, timesteps, steps)
        alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
        return torch.clip(betas, 0, 0.999)
    
    def _extract(self, arr, timesteps, broadcast_shape):
        res = arr.to(timesteps.device)[timesteps].float()
        while len(res.shape) < len(broadcast_shape):
            res = res[..., None]
        return res.expand(broadcast_shape)
    
    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)
        
        sqrt_alphas_cumprod_t = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape)
        sqrt_one_minus_alphas_cumprod_t = self._extract(
            self.sqrt_one_minus_alphas_cumprod, t, x_start.shape
        )
        
        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
    
    def p_loss(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)
        
        x_noisy = self.q_sample(x_start, t, noise)
        predicted_noise = self.denoise_net(x_noisy, t)
        
        loss = F.mse_loss(predicted_noise, noise)
        return loss
    
    @torch.no_grad()
    def p_sample(self, x, t, t_index):
        betas_t = self._extract(self.betas, t, x.shape)
        sqrt_one_minus_alphas_cumprod_t = self._extract(
            self.sqrt_one_minus_alphas_cumprod, t, x.shape
        )
        sqrt_recip_alphas_t = self._extract(torch.sqrt(1.0 / self.alphas), t, x.shape)
        
        # 使用网络预测噪声
        model_mean = sqrt_recip_alphas_t * (
            x - betas_t * self.denoise_net(x, t) / sqrt_one_minus_alphas_cumprod_t
        )
        
        if t_index == 0:
            return model_mean
        else:
            posterior_variance_t = self._extract(self.posterior_variance, t, x.shape)
            noise = torch.randn_like(x)
            return model_mean + torch.sqrt(posterior_variance_t) * noise
    
    @torch.no_grad()
    def sample(self, batch_size=16):
        device = next(self.denoise_net.parameters()).device
        shape = (batch_size, self.channels, self.image_size, self.image_size)
        
        # 从噪声开始
        x = torch.randn(shape, device=device)
        
        # 逐步去噪
        for i in reversed(range(0, self.timesteps)):
            t = torch.full((batch_size,), i, device=device, dtype=torch.long)
            x = self.p_sample(x, t, i)
            
        return x

class DenoiseNet(nn.Module):
    """简化的UNet风格去噪网络"""
    def __init__(self, image_size=28, channels=1, hidden_dim=256):
        super(DenoiseNet, self).__init__()
        
        self.image_size = image_size
        self.channels = channels
        
        # 时间步嵌入
        self.time_embed = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # 编码器
        self.encoder = nn.Sequential(
            nn.Conv2d(channels, hidden_dim//4, 3, padding=1),
            nn.SiLU(),
            nn.Conv2d(hidden_dim//4, hidden_dim//2, 3, padding=1),
            nn.SiLU(),
            nn.Conv2d(hidden_dim//2, hidden_dim, 3, padding=1),
            nn.SiLU()
        )
        
        # 中间层
        self.middle = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
            nn.SiLU(),
            nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
            nn.SiLU()
        )
        
        # 解码器
        self.decoder = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim//2, 3, padding=1),
            nn.SiLU(),
            nn.Conv2d(hidden_dim//2, hidden_dim//4, 3, padding=1),
            nn.SiLU(),
            nn.Conv2d(hidden_dim//4, channels, 3, padding=1)
        )
    
    def forward(self, x, t):
        # 时间步嵌入
        t_embed = self._timestep_embedding(t, x.shape[1])
        t_embed = self.time_embed(t_embed)
        
        # 添加时间嵌入到特征中
        B, C, H, W = x.shape
        t_embed = t_embed.view(B, -1, 1, 1).expand(B, -1, H, W)
        
        # 编码器
        h = self.encoder(x)
        h = h + t_embed  # 注入时间信息
        
        # 中间层
        h = self.middle(h)
        
        # 解码器
        h = self.decoder(h)
        
        return h
    
    def _timestep_embedding(self, timesteps, dim, max_period=10000):
        """创建正弦时间步嵌入"""
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=timesteps.device)
        args = timesteps[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding

# 训练和生成示例
def demo_diffusion_model():
    # 数据加载
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    
    dataset = torchvision.datasets.MNIST(root='./data', train=True, 
                                        download=True, transform=transform)
    dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
    
    # 模型和优化器
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ImprovedDiffusionModel().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    
    # 训练循环
    model.train()
    for epoch in range(5):  # 简化的训练轮数
        total_loss = 0
        for i, (images, _) in enumerate(dataloader):
            images = images.to(device)
            optimizer.zero_grad()
            
            # 随机时间步
            t = torch.randint(0, model.timesteps, (images.shape[0],), device=device)
            
            # 计算损失
            loss = model.p_loss(images, t)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            if i % 100 == 0:
                print(f'Epoch {epoch}, Batch {i}, Loss: {loss.item():.4f}')
        
        print(f'Epoch {epoch} completed, Average Loss: {total_loss/len(dataloader):.4f}')
    
    # 生成样本
    model.eval()
    with torch.no_grad():
        generated = model.sample(batch_size=16)
    
    # 显示生成的图像
    fig, axes = plt.subplots(4, 4, figsize=(8, 8))
    for i, ax in enumerate(axes.flat):
        ax.imshow(generated[i].cpu().squeeze(), cmap='gray')
        ax.axis('off')
    plt.tight_layout()
    plt.show()

# 运行演示
# demo_diffusion_model()  # 取消注释以运行完整训练

结论与展望

从VAE到扩散模型的演进,展现了概率图模型在生成式AI中的强大表达能力。这一演进路径的核心在于:

  1. 概率图结构的复杂化:从简单的单层潜在变量到复杂的马尔可夫链
  2. 训练目标的精细化:从近似的ELBO到更紧的变分下界
  3. 生成过程的多步化:从单步生成到渐进式生成

未来,生成模型的发展可能会继续沿着这个方向前进,结合不同模型的优点,如扩散模型的生成质量与VAE的紧凑潜在空间。概率图模型将继续为我们提供理解和改进这些强大模型的理论基础。

生成模型的演进不仅是技术上的进步,更是我们对数据生成过程理解深化的体现。通过概率图的视角,我们能够更好地把握这一领域的内在逻辑和发展方向。

【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。