视觉生成模型原理及实现

本文简要记录Stable Diffusion为主的视觉生成大模型的原理和实现。

🔥Stable Diffusion

Stable Diffusion架构

  • 如图,输入prompt由CLIP模型进行编码,如果有输入图片则由VAE编码器进行编码映射到低维度的子空间(Latent Space),然后由UNet + Scheduler进行扩散,得到的特征给到VAE解码器。

CLIP

  • CLIP训练时结合了文本标签和图像,所以可以对文本进行编码,CLIP固定了输出token序列长度为77,维度为768。
  • VAE就是通过卷积将特征图进行降维/升维的模块,里面有激活函数、归一化之类的。

UNet

  • UNet利用了注意力和残差块,数据通过下采样、中间块、上采样后实现对噪声的预测,中间同尺寸有跳跃连接减少下采样损失。所以VAE部分其实不是必要的,UNet本身也能下/上采样。

  • 总的来说,扩散模型就是根据文本提示对一张原始带噪声的图预测其噪声,然后原图减去该噪声后得到目标图,简单来说就是一个去噪的过程。训练时就是将图片添加若干噪声,模型预测噪声后和实际加的噪声进行匹配计算损失。

🔥DiT

diffusion transformer,将diffusion架构中的UNet的纯卷积替换为transformer结构,后续成为Sora等视觉生成模型的基础架构。
DiT架构

  • 如图,最左边是模型架构,输入噪声原图和文本提示,分别编码后(图片要像ViT那样分成patch),输入到DiT块当中,输出的经过LN等层后就能输出结果。
  • 右边三个都是尝试过的DiT块的结构。
  • DiT块中大体和transformer的编码器差不多,但是分为两边,左边是图片,右边多出来的是输入的文本,经过自己的MLP后进行缩放不断地连接到图片这一路。

🔥实现

简要通过代码来看看如何实现整个模型。

参考

VAE

class VAE(nn.Module):
    def __init__(self, in_channels=3, latent_dim=4, image_size=512):
        super(VAE, self).__init__()
        self.in_channels = in_channels
        self.latent_dim = latent_dim
        self.image_size = image_size

        # Encoder
        # 3 x 512 x 512 -> 4 x 64 x 64
        self.encoder = nn.Sequential(
            self._conv_block(in_channels, 64),  # 64 x 256 x 256
            self._conv_block(64, 128),  # 128 x 128 x 128
            self._conv_block(128, 256),  # 256 x 64 x 64
        )

        # Encoder 的潜在空间输出
        self.fc_mu = nn.Conv2d(256, latent_dim, 1)  # 4 x 64 x 64 <- Latent Space
        self.fc_var = nn.Conv2d(256, latent_dim, 1)  # 4 x 64 x 64 <- Latent Space

        # Decoder
        # 4 x 64 x 64 -> 3 x 512 x 512
        self.decoder_input = nn.ConvTranspose2d(latent_dim, 256, 1)  # 256 x 64 x 64
        self.decoder = nn.Sequential(
            self._conv_transpose_block(256, 128),  # 128 x 128 x 128
            self._conv_transpose_block(128, 64),  # 64 x 256 x 256
            self._conv_transpose_block(64, in_channels),  # 3 x 512 x 512
        )

        self.sigmoid = nn.Sigmoid()  # [0, 1]
        self.tanh = nn.Tanh()  # [-1, 1]
            def _conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1),
            # nn.GroupNorm(num_groups=1, num_channels=out_channels),
            nn.BatchNorm2d(out_channels),
            # nn.LeakyReLU(),
            nn.LeakyReLU(0.2)
        )

    def _conv_transpose_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, 3, stride=2, padding=1, output_padding=1),
            # nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            # nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1),
            # nn.GroupNorm(num_groups=1, num_channels=out_channels),
            nn.BatchNorm2d(out_channels),
            # nn.LeakyReLU(),
            nn.LeakyReLU(0.2)
        )
  • VAE结构很简单,就是先通过卷积降采样(encoder),然后通过两个卷积层学习latent space输出(噪声的均值方差),最后通过反卷积升采样(decoder)恢复到原图。
  • nn.ConvTranspose2d是通过转置卷积(反卷积)恢复图像尺寸:通过对特征图插入0,然后使用卷积操作获得大尺寸输出特征。
  • 训练起来和普通CNN区别不大,模型的输入和标号都是原图,主要是多了噪声值(mu、var)用于计算损失。
  • 当然这是最简单VAE,实际上还能把卷积换成注意力等其他结构。

DDPM(UNet)

DDPM

  • 现在实现一下扩散模型(DDPM),也就是UNet这部分。

模型

class UNet_Transformer(nn.Module):
    def __init__(self, in_channels=3, time_dim=256, context_dim=512):
        super().__init__()

        self.time_dim = time_dim
        self.time_mlp = nn.Sequential(
            nn.Linear(time_dim, time_dim * 4),
            nn.SiLU(),
            nn.Linear(time_dim * 4, time_dim)
        )
        self.context_dim = context_dim

        # 初始卷积
        self.init_conv = nn.Sequential(
            nn.Conv2d(in_channels, 64, 3, padding=1),
            nn.SiLU(),
            nn.Conv2d(64, 64, 3, padding=1)
        )  # 64 x H x W

        # 下采样
        self.down1 = self._down_block(64, 128, time_dim)  # 128 x H/2 x W/2
        self.down2 = self._down_block(128, 256, time_dim, self_attn=True, cross_attn=False, num_heads=4, context_dim=context_dim)  # 256 x H/4 x W/4
        self.down3 = self._down_block(256, 512, time_dim, self_attn=True, cross_attn=False, num_heads=8, context_dim=context_dim)  # 512 x H/8 x W/8

        # 中间块
        self.middle_block = _down_block(512, time_dim, context_dim)  # 512 x H/8 x W/8

        # 上采样
        self.up1 = self._up_conv(512, 256, time_dim, self_attn=True, cross_attn=True, num_heads=8, context_dim=context_dim)  # 256 x H/4 x W/4
        self.up2 = self._up_conv(256+256, 128, time_dim, self_attn=True, cross_attn=True, num_heads=4, context_dim=context_dim)  # 128 x H/2 x W/2
        self.up3 = self._up_conv(128+128, 64, time_dim)  # 64 x H x W

        # 最终卷积
        self.final_conv = nn.Sequential(
            ResnetBlock(64 * 2, 64, time_dim),
            nn.Conv2d(64, 64, 3, stride=1, padding=1),
            nn.SiLU(),
            nn.Conv2d(64, in_channels, 3, stride=1, padding=1),
        )
  • UNet结构就是先下采样,然后通过中间块,最后上采样输出即可。
  • 和前面的VAE相比就是再上下采样之间用了一个中间块。并且_down_block_up_conv都不仅是卷积,而是融合resnet和transformer的结构,学习能力更强。

训练

  • 接下来通过部分代码简要地看看如何训练整个 CLIP + UNet(transformer)的diffusion模型:
for epoch in range(n_epochs):
    diffusion_model.train()
    progress_bar = tqdm(total=len(train_dataloader), desc=f"Epoch {epoch+1}/{n_epochs}")
    epoch_loss = 0.0

    # 训练模型
    for batch in train_dataloader:
        images = batch["images"].to(device)
        text = batch["text"]

        # 使用 CLIP 模型编码文本
        text_inputs = tokenizer(text, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
        text_embeddings = text_encoder(text_inputs.input_ids.to(device)).last_hidden_state

        timesteps = torch.randint(0, num_timesteps, (images.shape[0],), device=device).long() # 随机选择 timesteps
        noisy_images, noise = noise_scheduler.add_noise(images, timesteps) # 添加噪声
        noise_pred = diffusion_model(noisy_images, timesteps, text_embeddings) # 预测噪声
        loss = torch.nn.functional.mse_loss(noise_pred, noise) # 预测的噪声与真实噪声的均方误差

        optimizer.zero_grad()
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(diffusion_model.parameters(), 1.0)  # 梯度裁剪
        optimizer.step()
        # scheduler.step()  # OneCycleLR 在每个批次后调用

        epoch_loss += loss.item()
        progress_bar.update(1)
        progress_bar.set_postfix({"loss": loss.item()})
  • 数据集是图-文本对,对文本分词后,先使用网上预训练好的CLIP对文本进行编码,然后用噪声器添加随机噪声(噪声添加步数也随机)。
  • 然后把文本、图片、噪声步数一同输入到模型当中。计算损失就是把预测的噪声和实际噪声算均方误差,后面就是基本流程了。

测试(生成)

  • 训练好的模型可以根据噪声和文本生成图像了,这时候就不用输入原图了,直接随机初始化噪声图作为输入进行去噪:
def sample_cfg(model, noise_scheduler, n_samples, in_channels, text_embeddings, image_size=64, guidance_scale=3.0):
    """
    从噪声开始,逐渐减小噪声,直到最终的图像。
    :param model: UNet模型
    :param noise_scheduler: 噪声调度器
    :param n_samples: 生成的样本数量
    :param in_channels: 输入图像的通道数
    :param text_embeddings: 文本嵌入
    :param image_size: 图像的大小
    :param guidance_scale: 用于加权噪声预测的比例
    :return: 生成的图像
    """
    model.eval()
    device = next(model.parameters()).device

    x = torch.randn(n_samples, in_channels, image_size, image_size).to(device) # 随机初始化噪声图像
    null_embeddings = torch.zeros_like(text_embeddings) # 用于无条件生成

    # 逐步去噪
    for t in reversed(range(noise_scheduler.num_timesteps)):
        t_batch = torch.full((n_samples,), t, device=device, dtype=torch.long)

        noise_pred_uncond = model(x, t_batch, y=null_embeddings) # 生成一个无条件的噪声预测
        noise_pred_cond = model(x, t_batch, y=text_embeddings) # 生成一个有条件的噪声预测
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) # CFG:结果加权后的噪声预测

        # 采样器的去噪过程
        alpha_t = noise_scheduler.alphas[t]
        alpha_t_bar = noise_scheduler.alphas_cumprod[t]
        beta_t = noise_scheduler.betas[t]

        if t > 0:
            noise = torch.randn_like(x)
        else:
            noise = torch.zeros_like(x)

        # 去噪公式
        x = (1 / torch.sqrt(alpha_t)) * (x - ((1 - alpha_t) / (torch.sqrt(1 - alpha_t_bar))) * noise_pred) + torch.sqrt(beta_t) * noise

    model.train()
    return x
  • 生成,还是要先对文本分词、clip编码后,再生成图像,进行保存:
    if (epoch + 1) % save_checkpoint_interval == 0:
        diffusion_model.eval()
        with torch.no_grad():
            sample_text = ["a water type pokemon", "a red pokemon with a red fire tail"]
            text_input = tokenizer(sample_text, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
            text_embeddings = text_encoder(text_input.input_ids.to(device)).last_hidden_state
            sampled_images = sample_cfg(diffusion_model, noise_scheduler, len(sample_text), in_channels, text_embeddings, image_size=image_size, guidance_scale=3.0)
            # 保存生成的图像
            for i, img in enumerate(sampled_images):
                img = img * 0.5 + 0.5  # Rescale to [0, 1]
                img = img.detach().cpu().permute(1, 2, 0).numpy()
                img = (img * 255).astype(np.uint8)
                img_pil = Image.fromarray(img)
                img_pil.save(f'diffusion_results/generated_image_epoch_{epoch+1}_sample_{i}.png')

Stable Diffusion

模型

  • 接下来就是完整的CLIP + UNet + VAE的SD结构了,模型结构就是把VAE和UNet组合:
class StableDiffusion(nn.Module):
    def __init__(self, in_channels=3, latent_dim=4, image_size=512, diffusion_timesteps=1000, device="cuda"):
        super(StableDiffusion, self).__init__()
        # VAE
        self.vae = VAE(in_channels=in_channels, latent_dim=latent_dim, image_size=image_size)
        # Diffusion model (UNet)
        self.unet = UNet_Transformer(in_channels=latent_dim)
        # Noise scheduler
        self.noise_scheduler = NoiseScheduler(num_timesteps=diffusion_timesteps, device=device)

    def encode(self, x):
        return self.vae.encode(x)[0]

    def decode(self, z):
        return self.vae.decode(z)

    def diffuse(self, latents, t, context):
        return self.unet(latents, t, context)

    def forward(self, latents, t, context):
        noise_pred = self.diffuse(latents, t, context)
        return noise_pred
  • 这里VAE已经是预训练好的,所以训练时只要更新UNet即可。

训练

    # 训练模型
    for batch in train_dataloader:
        latents = batch["latents"].to(device)
        text = batch["text"]

        # 添加噪声
        timesteps = torch.randint(0, num_timesteps, (latents.shape[0],), device=device).long()
        noisy_latents, noise = model.noise_scheduler.add_noise(latents, timesteps)

        # 使用 CLIP 模型编码文本
        text_inputs = tokenizer(text, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
        text_embeddings = text_encoder(text_inputs.input_ids.to(device)).last_hidden_state

        # 预测噪声
        noise_pred = model(noisy_latents, timesteps, text_embeddings)
        mse_loss = F.mse_loss(noise_pred, noise)
        div_loss = diversity_loss(noisy_latents, use_cosine=True)

        # 计算去噪后的潜在表示
        alpha_t = model.noise_scheduler.alphas[timesteps][:, None, None, None]
        sqrt_alpha_t = torch.sqrt(alpha_t)
        sqrt_one_minus_alpha_t = torch.sqrt(1 - alpha_t)
        predicted_latents = (noisy_latents - sqrt_one_minus_alpha_t * noise_pred) / sqrt_alpha_t
        cons_loss = F.mse_loss(predicted_latents, latents)

        # 组合损失
        total_loss = mse_loss + diversity_weight * div_loss + cons_loss * current_lambda_cons
        epoch_loss += total_loss.item()
        num_batches += 1

        optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()  # OneCycleLR 学习率调度器

        # 动态调整多样性损失的权重
        if epoch % 10 == 0:
            diversity_weight = min(diversity_weight * 1.05, 0.1)  # 逐渐增加权重,但设置上限
  • 训练流程(部分代码)和之前DDPM部分区别不大,这里数据集的train_dataloader里的数据已经被VAE下采样过了,所以直接配合嵌入好的文本输入模型,UNet会利用网络和公式预测噪声用于损失计算。
  • 损失分为噪声损失多样性损失去噪后特征的损失

测试(生成)

  • 代码和之前一样,就不展示了,就多了一个,在UNet输出子空间表示时,用VAE解码成正常图片的信息:sampled_images = model.decode(sampled_latents)

🔥总结

  • 本文记录了stable diffusion(SD)的框架及其组件细节。并通过理解开源代码了解了网络结构搭建(VAE、UNet、SD)、训练过程(CLIP+UNet、SD)以及生成过程(CLIP+UNet、SD)。
  • 总的来说,SD的网络结构结合了CNN和transformer,本质是一个视觉模型。而模型生成过程就是:CLIP嵌入文本(可选)、生成噪声图像、VAE编码图文、UNet预测噪声、VAE解码出图。而训练不需要生成的图,可以去掉最后一步。模型当中的组件也都是可以换成更优秀的结构。
  • Copyrights © 2023-2025 LegendLeo Chen
  • 访问人数: | 浏览次数:

请我喝杯咖啡吧~

支付宝
微信