多模态实践--扩散模型的逆向扩散代码实践

举报
剑指南天 发表于 2026/06/05 17:46:11 2026/06/05
【摘要】 反向扩散过程是从一张完全高斯噪声图片中,逐步去除噪声,来生成一张图片。但是噪声的数据是很难直接获得的,所以噪声使用神经网络来模拟。

1.概述

反向扩散过程是从一张完全的高斯噪声图片中,逐步去除噪声,来生成一张图片。但是噪声的数据是很难直接获得的,所以噪声使用神经网络来模拟。

2. U-Net 神经网络

2.1 U-Net 神经网络的结构

U-Net的神经网络结构由卷积块、池化层、上采样层和残差连接构成。在代码中简化了这一过程,但是思想和 U-Net 网络一致。

# 按照时间步进行位置编码
def _pos_encoding(time_idx, output_dim, device='cpu'):
    t, D = time_idx, output_dim
    v = torch.zeros(D, device=device)
    _2i = torch.arange(0, D, step=2, device=device)
    div_term = torch.pow(10000, _2i / D)
    v[0::2] = torch.sin(t / div_term)
    if D % 2 == 1:
        div_term = div_term[:-1]
    v[1::2] = torch.cos(t / div_term)
    return v


# 批量位置编码
def pos_encoding(time_steps, output_dim, device='cpu'):
    batch_size = len(time_steps)
    v = torch.zeros(batch_size, output_dim, device=device)
    for i in range(batch_size):
        v[i] = _pos_encoding(time_steps[i], output_dim, device)
    return v


# 卷积块
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_embed_dim):
        super().__init__()
        self.convs = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU()
        )
        # 将时间步嵌入转换成和图像相同的形状
        self.mlp = nn.Sequential(
            nn.Linear(time_embed_dim, in_ch),
            nn.ReLU(),
            nn.Linear(in_ch, in_ch)
        )
    # v是时间步嵌入
    def forward(self, x, v):
        N, C, _, _ = x.shape  # shape: (批次大小N,通道数量C,H,W)
        v = self.mlp(v)
        v = v.view(N, C, 1, 1)
        y = self.convs(x + v)
        return y


# U-Net 网络结构
class UNet(nn.Module):
    def __init__(self, in_ch=3, time_embed_dim=100):
        super().__init__()
        self.time_embed_dim = time_embed_dim

        self.down1 = ConvBlock(in_ch, 64, time_embed_dim)
        self.down2 = ConvBlock(64, 128, time_embed_dim)
        self.bot1 = ConvBlock(128, 256, time_embed_dim)
        self.up2 = ConvBlock(128 + 256, 128, time_embed_dim)
        self.up1 = ConvBlock(128 + 64, 64, time_embed_dim)
        self.out = nn.Conv2d(64, in_ch, 1)

        self.max_pool = nn.MaxPool2d(2)
        self.up_sample = nn.Upsample(scale_factor=2, mode='bilinear')

    def forward(self, x, time_steps):
        v = pos_encoding(time_steps, self.time_embed_dim, x.device)
        # x的形状是(B,C,H,W)
        x1 = self.down1(x, v)  # x1的形状是(B,C*2,H,W)
        x = self.max_pool(x1)  # x的形状是(B,C*2,H/2,W/2)
        x2 = self.down2(x, v)  # x2的形状是(B,C*2*2,H/2,W/2)
        x = self.max_pool(x2)  # x的形状是(B,C*2*2,H/4,W/4)
        x = self.bot1(x, v)  # x的形状是(B,C*2*2*2,H/4,W/4)
        x = self.up_sample(x)  # x的形状是(B,C*2*2*2,H/2,W/2)
        x = torch.cat([x, x2], dim=1)  # x的形状是(B,C*2*2*2+C*2*2,H/2,W/2)
        x = self.up2(x, v)  # x的形状是(B,C*2*2+C*2,H/2,W/2)
        x = self.up_sample(x)  # x的形状是(B,C*2*2+C*2,H,W)
        x = torch.cat([x, x1], dim=1)  # x的形状是(B,C*2*2*2,H,W)
        x = self.up1(x, v)  # x的形状是(B,C*2*2,H,W)
        x = self.out(x)  # x的形状是(B,C,H,W)
        return x

2.2 模型训练

class ImagesDatasets(Dataset):
    def __init__(self, img_size):
        self.preprocess = transforms.Compose([
            transforms.Resize((img_size, img_size)),  # Resize the input image
            transforms.ToTensor(),  # Convert to torch tensor (scales data into [0,1])
            transforms.Lambda(lambda t: (t * 2) - 1),  # Scale data between [-1, 1]
        ])
        image_trump = preprocess(Image.open("trump.jpeg"))
        image_biden = preprocess(Image.open("biden.jpeg"))
        # 将两张图片各复制128次,然后变成一个批次
        self.images = torch.stack([image_trump] * (batch_size * 2) + [image_biden] * (batch_size * 2)).to(device=device)
        self.T = torch.randint(1, num_timesteps + 1, (len(self.images),)).to(device=device)
        # 方差调度计划[β_1, β_2, β_3,...]
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps).to(device=device)
        self.alphas = 1 - self.betas
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, item):
        t_idx = self.T - 1  # alpha_bars[0] is for t=1
        alpha_bar = self.alpha_bars[t_idx]  # (N,)
        N = alpha_bar.size(0)
        alpha_bar = alpha_bar.view(N, 1, 1, 1)  # (N, 1, 1, 1)
        noise = torch.randn_like(self.images, device=self.device)
        x_t = torch.sqrt(alpha_bar) * self.images + torch.sqrt(1 - alpha_bar) * noise
        return x_t[item], noise[item], self.T[item]
# 基础配置
ROOT_DIR = Path(__file__).parent.parent
device = 'cuda' if torch.cuda.is_available() else 'cpu'
log_dir = ROOT_DIR / 'logs' / 'U-Net'

# 超参数
img_size = 32
batch_size = 128
num_timesteps = 1000
epochs = 1000
lr = 5e-4

images_datasets = ImagesDatasets(img_size)
dataloader = DataLoader(images_datasets, batch_size=batch_size, shuffle=True)
model = UNet().to(device)
optimizer = Adam(model.parameters(), lr=lr)
loss_f = nn.MSELoss()
with SummaryWriter(log_dir=str(log_dir / time.strftime('%Y-%m-%d_%H-%M-%S'))) as writer:
    for epoch in range(epochs):
        loss_sum = 0.0
        cnt = 0
        loss_avg = 0.0
        for x_t, noise, t in tqdm(dataloader, "开始训练: "):
            x_t, noise, t = x_t.to(device=device), noise.to(device=device), t.to(device=device)
            noise_p = model(x_t, t)
            optimizer.zero_grad()
            loss = loss_f(noise_p, noise)
            loss.backward()
            optimizer.step()
            loss_sum += loss.item()
            cnt += 1
            loss_avg = loss_sum / cnt
        writer.add_scalar('loss', loss_avg, epoch + 1)
        print(f'Epoch {epoch + 1} | Loss: {loss_avg}')

3. Denoising Diffusion Probabilistic Models,去噪扩散概率模型


# 像素点恢复为图像
def reverse_to_img(x):
    x = (x + 1) / 2
    x = x * 255
    x = x.clamp(0, 255)
    x = x.to(torch.uint8)
    x = x.cpu()
    to_pil = transforms.ToPILImage()
    return to_pil(x)

# 使用两行十列的方式显示20张图片
def show_images(images, rows=2, cols=10):
    fig = plt.figure(figsize=(cols, rows))
    i = 0
    for r in range(rows):
        for c in range(cols):
            fig.add_subplot(rows, cols, i + 1)
            plt.imshow(images[i], cmap='gray')
            plt.axis('off')
            i += 1
    plt.show()
alphas = images_datasets.alphas
alpha_bars = images_datasets.alpha_bars
preprocess = images_datasets.preprocess
image_trump = preprocess(Image.open("trump.jpeg"))
images_1 = []
images_2 = []
# 按照训练数据的形状,采样一张白噪声图片x_1000出来
image_noise_pure = torch.randn_like(image_trump).to(device=device)
# 一张添加噪声的图片
image_noise_add = torch.sqrt(alpha_bars[999]) * image_trump + torch.sqrt(1-alpha_bars[999]) * image_noise_pure
x = torch.stack([image_noise_pure,image_noise_add])
batch_size = 2
# for t = T, T-1, ..., 0
for i in tqdm(range(num_timesteps, 0, -1)):
    t = torch.tensor([i] * batch_size, dtype=torch.long).to(device=device)
    # 一步去噪,x_t --> x_{t-1}
    t_idx = t - 1  # alphas[0] is for t=1
    alpha = alphas[t_idx]
    alpha_bar = alpha_bars[t_idx]
    alpha_bar_prev = alpha_bars[t_idx - 1]

    N = alpha.size(0)
    alpha = alpha.view(N, 1, 1, 1)
    alpha_bar = alpha_bar.view(N, 1, 1, 1)
    alpha_bar_prev = alpha_bar_prev.view(N, 1, 1, 1)

    model.eval()
    with torch.no_grad():
        eps = model(x, t)
    noise = torch.randn_like(x).to(device=device)
    noise[t == 1] = 0  # no noise at t=1
    mu = (x - ((1 - alpha) / torch.sqrt(1 - alpha_bar)) * eps) / torch.sqrt(alpha)
    std = torch.sqrt((1 - alpha) * (1 - alpha_bar_prev) / (1 - alpha_bar))
    # x_{t-1}
    x = mu + noise * std
    if (i-1) % 50 == 0:
        images_1.append(reverse_to_img(x[0]))
        images_2.append(reverse_to_img(x[1]))

show_images(images_1)
show_images(images_2)

【版权声明】本文为华为云社区用户原创内容,未经允许不得转载,如需转载请自行联系原作者进行授权。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。