1.概述
由快手、北京大学和北京邮电大学的研究团队联合开发的Pyramid-Flow,是一项在AI视频生成领域取得重大突破的开源项目。这一创新模型以其“特征金字塔+流匹配”技术,通过高效的空间和时间金字塔表示,显著提升了视频生成的训练效率和输出质量。Pyramid-Flow能够依据简单的文本指令,生成长达10秒、分辨率高达1280x768、帧率为24fps的高清视频,不仅在光影效果、动作流畅度、画面质量上表现出色,还在文本语义的准确还原和色彩搭配上展现了卓越的能力,生成的视频效果令人印象深刻。
这一技术的核心优势在于其高效生成过程,相较于其他开源视频生成模型,Pyramid-Flow在生成效率方面具有明显优势。这种优化的计算方法使其特别适合资源有限的中小企业和个人开发者,为他们提供了一个既经济又高效的视频生成解决方案。此外,Pyramid-Flow采用的流模型生成方法,赋予了模型灵活生成不同长度视频序列的能力,进一步拓宽了其应用范围。实验结果证明,即使在公开数据集上进行训练,Pyramid-Flow也能生成与竞争对手相媲美的高质量视频,同时大幅度降低了训练成本。这一成果不仅为视频内容创作带来了新的可能性,也为AI技术的普及和应用开辟了新的道路。
这篇博客将带你详细了解论文的大概内容及模型的整体架构。
相关链接
github地址:
GitHub - jy0205/Pyramid-Flow: Code of Pyramidal Flow Matching for Efficient Video Generative Modeling
论文地址:https://arxiv.org/abs/2410.05954
hugging face权重:https://huggingface.co/rain1011/pyramid-flow-sd3
2.论文
flow-based模型基础
个人认为大家不要把flow-based模型看的太玄乎,实际上这个模型就是扩散模型的翻版,原理都是学习数据分布,生成新样本。
flow-based模型和diffusion模型的主要区别:
- 流模型的过程通常是可逆的,这意味着可以从生成的数据反向推导出模型的参数。这也是为什么在代码里会频繁看到插值运算,因为可逆意味着数据经历上下插值之后的结果和原结果一样。
- 生成样本通常较快,因为它们通过直接的函数映射来生成数据,但训练过程可能更复杂,需要优化整个流的路径。
GPT生成的flow-based模型通俗易懂版,仅供参考:
流生成模型就像是一个魔法画师,它的工作是把一张白纸(随机噪声)变成一幅精美的画作(数据样本,比如图片)。这个魔法画师的工作过程是这样的:
- 开始作画:画师从一张白纸开始,这张白纸上有一些随机的涂鸦,我们可以想象成是一些随机的点和线。
- 逐步细化:画师会慢慢地、一步步地在白纸上添加细节,让涂鸦开始变得有意义。每一步,画师都会根据他的记忆(模型学到的数据分布知识)来决定下一步怎么画。
- 方向指引:为了让画作最终变成一幅精美的图,画师需要一个指引,告诉他每一步该怎么走。这个指引就像是一张地图,告诉画师从涂鸦到画作的最短路径。
- 不断调整:画师在画的过程中,会不断地检查自己的作品,看看是否按照地图的指引在正确的道路上。如果有偏差,他就调整自己的画笔,确保每一步都更接近最终的目标。
- 完成作品:经过一系列的步骤,画师最终会完成画作,这时候白纸上的涂鸦就变成了一幅漂亮的画。
在流生成模型中,这个“画师”是一个数学模型,它通过一系列的计算来逐步从随机噪声生成数据。每一步的计算都基于模型之前学到的知识,最终目的是生成新的、看起来自然的数据样本。
空间金字塔的图解
图像金字塔的功效(Adelson等人,1984)已经被广泛地验证并用于CV领域的各大模型中,图像金字塔是一种强大的视觉概念,它通过构建一系列不同分辨率的图像层级来捕捉场景的多尺度特征。在这个层级结构中,每一层都是下一层的高分辨率版本,就像金字塔一样,从宽广的基础逐渐过渡到尖锐的顶端。这种结构不仅有助于我们从宏观到微观地理解图像内容,还使得图像处理算法能够更加灵活地处理不同尺寸的图像。无论是在图像压缩、边缘检测还是对象识别等任务中,图像金字塔都扮演着至关重要的角色,它通过提供不同层次的细节信息,增强了计算机视觉系统的性能和适应性。
在这篇论文中,视频被转换到三个金字塔层级分别运算,不同层级、时间的数据会进行交流运算。简单来说,的数据会经过一次上采样,转换为尺寸更大的图片,然后加上一段随机噪声(增加多样性),变成然后进行流模型的生成,生成出。需要注意的是不仅考虑了当前帧的数据,还考虑了以前帧的数据,这点在论文的图中没有出现,需要查看代码才能理解。
3.代码
环境安装
git clone https://github.com/jy0205/Pyramid-Flow
cd Pyramid-Flow
# create env using conda
conda create -n pyramid python==3.8.10
conda activate pyramid
pip install -r requirements.txt
个人使用上次清华模型的虚拟环境,稍微安装几个库就能正常运行了,建议跑了清华CogvideoX的朋友继续使用同款环境,不要多花时间了。
inference
官方给的是ipynb版本的代码,我这里把里面的内容复制到新的文件中使用。
import torch
from PIL import Image
from pyramid_dit import PyramidDiTForVideoGeneration
from diffusers.utils import load_image, export_to_video
torch.cuda.set_device(0)
model_dtype, torch_dtype = 'bf16', torch.bfloat16 # Use bf16 (not support fp16 yet)
model = PyramidDiTForVideoGeneration(
'pyramid-flow-sd3', # The downloaded checkpoint dir
model_dtype,
model_variant='diffusion_transformer_768p', # 'diffusion_transformer_384p'
)
model.vae.enable_tiling()
model.vae.to("cuda") # 使用GPU推理要把这三行注释消掉
model.dit.to("cuda")
model.text_encoder.to("cuda")
# if you're not using sequential offloading bellow uncomment the lines above ^
# model.enable_sequential_cpu_offload() # 如果GPU显存不足,打开这里的显存。
prompt = "The Glenfinnan Viaduct is a historic railway bridge.It is a stunning sight as a steam train leaves the bridge, traveling over the arch-covered viaduct. The landscape is dotted with lush greenery and rocky mountains"
with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
frames = model.generate(
prompt=prompt,
num_inference_steps=[20, 20, 20],
video_num_inference_steps=[10, 10, 10],
height=768,
width=1280,
temp=16, # 设置时长的,f+1->8f+1,即16相当于121帧视频,temp=16: 5s, temp=31: 10s
guidance_scale=9.0, # The guidance for the first frame, set it to 7 for 384p variant
video_guidance_scale=5.0, # The guidance for the other video latent
output_type="pil",
save_memory=True, # If you have enough GPU memory, set it to `False` to improve vae decoding speed
)
export_to_video(frames, "./text_to_video_sample1.mp4", fps=24)
简单来说,大家需要注意这几点
- 如果有GPU,记得把model.vae.to("cuda")这几行取消注释,这样模型会放到GPU上跑,比CPU快很多。
- 如果GPU显存不足,把model.enable_sequential_cpu_offload()的注释取消掉,这个部分会在推理时把暂时不需要的模块放回CPU,节省内存。如果显存充足,要注释掉,因为CPU和GPU的切换也要浪费时间。
- num_inference_steps、video_num_inference_steps分别代表金字塔三层各自的推理时间,可适当放大。目前测试下来,不一定越大越好。
- temp代表时长,这个数值是压缩的隐空间大小,即f+1,然后输出时会放大回RGB空间,即8f+1,如果是16,输出结果是121,以下面的fps=24为例,会变成5秒的视频。
generate
模型进入model.generate后,会进入pyramid_dit_for_video_gen_pipeline.py的generate函数,即531行。
整个模型大体上分为5个阶段,接下来我们一一查看。
1.初始化
第554至588行是第一个阶段,这个阶段主要是做一些初始化的工作,故不多赘述。这部分主要需要关注两点:
首先是559行,这里是用于节省显存的操作,如果用户在inference.py里面设置了model.enable_sequential_cpu_offload(),这里会将vae放到CPU上,因为暂时用不到vae,直到最后一步才将vae放回GPU。
if cpu_offloading:
# skip caring about the text encoder here as its about to be used anyways.
if not self.sequential_offload_enabled:
if str(self.dit.device) != "cpu":
print("(dit) Warning: Do not preload pipeline components (i.e. to cuda) with cpu offloading enabled! Otherwise, a second transfer will occur needlessly taking up time.")
self.dit.to("cpu")
torch.cuda.empty_cache()
if str(self.vae.device) != "cpu":
print("(vae) Warning: Do not preload pipeline components (i.e. to cuda) with cpu offloading enabled! Otherwise, a second transfer will occur needlessly taking up time.")
self.vae.to("cpu")
torch.cuda.empty_cache()
接着是574行第prompt,这里默认会给prompt加上高质量和高清的提示词,以帮助模型提高质量。
if isinstance(prompt, str):
batch_size = 1
prompt = prompt + ", hyper quality, Ultra HD, 8K" # adding this prompt to improve aesthetics
2.文本处理
第2个部分的主要目的是将文本提示词prompt和negative prompt编码成向量,然后拼接在一起。
以下是代码的核心部分,具体来说,如果开启了cpu_offloading,代码会讲text_encoder暂时放到GPU上,用完了再放回CPU上。而text_encoder是基于transformer的编码器架构,它将提示词prompt和negative_prompt编码为prompt_embeds, prompt_attention_mask, pooled_prompt_embeds。
prompt_embeds
:- 这是一个包含prompt的嵌入向量,然后每个token被映射到一个高维空间中的嵌入向量。这些嵌入向量捕获了token的语义信息和上下文信息。- 形状是[batch_size, sequence_length, embedding_dim]
,其中: -batch_size
是批处理大小,表示同时处理的文本数量。-sequence_length
是序列长度,表示每个文本中的token数量。-embedding_dim
是嵌入向量的维度,表示每个token嵌入的高维表示的大小。- 在这段代码中,prompt_embeds
的形状是[1, 128, 4096]
,表示有一个文本序列,长度为128个token,每个token的嵌入维度为4096。prompt_attention_mask
:- 这是一个用于指示模型应该关注哪些token的掩码(mask)。在处理变长序列时,一些token可能是填充的(padding),这些填充的token不应该被模型考虑。- 形状是[batch_size, sequence_length]
,其中的每个元素通常是0或1,1表示对应的token是有效的,0表示对应的token是填充的。- 在这段代码中,prompt_attention_mask
的形状是[1, 128]
,与prompt_embeds
的序列长度相对应。pooled_prompt_embeds
:- 这是一个聚合(pooled)的嵌入向量,通常用于表示整个输入序列的全局信息。在一些模型中,如BERT,会使用一个特殊的token(例如[CLS]
)来聚合序列的信息,或者通过其他方式(如平均池化)来获得整个序列的表示。- 形状通常是[batch_size, embedding_dim]
,其中embedding_dim
与prompt_embeds
中的相同。- 在这段代码中,pooled_prompt_embeds
的形状是[1, 2048]
,表示有一个聚合的嵌入向量,维度为2048。
# 2.Get the text embeddings
if cpu_offloading and not self.sequential_offload_enabled: # 文本需要text_encoder,故放到GPU上
self.text_encoder.to("cuda")
prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.text_encoder(prompt, device) # [b,seq_len,c]=[1,128,4096],[1,128],[1,2048]
negative_prompt_embeds, negative_prompt_attention_mask, negative_pooled_prompt_embeds = self.text_encoder(negative_prompt, device)
if cpu_offloading: # 用完了,放回CPU
if not self.sequential_offload_enabled:
self.text_encoder.to("cpu")
self.dit.to("cuda")
torch.cuda.empty_cache()
text_encoder架构如下所示,可以看到其实都是Transformer架构,和其他文本编码器没有太大区别。
SD3TextEncoderWithMask(
(text_encoder): CLIPTextModelWithProjection(
(text_model): CLIPTextTransformer(
(embeddings): CLIPTextEmbeddings(
(token_embedding): Embedding(49408, 768)
(position_embedding): Embedding(77, 768)
)
(encoder): CLIPEncoder(
(layers): ModuleList(
(0-11): 12 x CLIPEncoderLayer(
(self_attn): CLIPSdpaAttention(
(k_proj): Linear(in_features=768, out_features=768, bias=True)
(v_proj): Linear(in_features=768, out_features=768, bias=True)
(q_proj): Linear(in_features=768, out_features=768, bias=True)
(out_proj): Linear(in_features=768, out_features=768, bias=True)
)
(layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP(
(activation_fn): QuickGELUActivation()
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(fc2): Linear(in_features=3072, out_features=768, bias=True)
)
(layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
)
)
(final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(text_projection): Linear(in_features=768, out_features=768, bias=False)
)
(text_encoder_2): CLIPTextModelWithProjection(
(text_model): CLIPTextTransformer(
(embeddings): CLIPTextEmbeddings(
(token_embedding): Embedding(49408, 1280)
(position_embedding): Embedding(77, 1280)
)
(encoder): CLIPEncoder(
(layers): ModuleList(
(0-31): 32 x CLIPEncoderLayer(
(self_attn): CLIPSdpaAttention(
(k_proj): Linear(in_features=1280, out_features=1280, bias=True)
(v_proj): Linear(in_features=1280, out_features=1280, bias=True)
(q_proj): Linear(in_features=1280, out_features=1280, bias=True)
(out_proj): Linear(in_features=1280, out_features=1280, bias=True)
)
(layer_norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP(
(activation_fn): GELUActivation()
(fc1): Linear(in_features=1280, out_features=5120, bias=True)
(fc2): Linear(in_features=5120, out_features=1280, bias=True)
)
(layer_norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
)
)
)
(final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
)
(text_projection): Linear(in_features=1280, out_features=1280, bias=False)
)
(text_encoder_3): T5EncoderModel(
(shared): Embedding(32128, 4096)
(encoder): T5Stack(
(embed_tokens): Embedding(32128, 4096)
(block): ModuleList(
(0): T5Block(
(layer): ModuleList(
(0): T5LayerSelfAttention(
(SelfAttention): T5Attention(
(q): Linear(in_features=4096, out_features=4096, bias=False)
(k): Linear(in_features=4096, out_features=4096, bias=False)
(v): Linear(in_features=4096, out_features=4096, bias=False)
(o): Linear(in_features=4096, out_features=4096, bias=False)
(relative_attention_bias): Embedding(32, 64)
)
(layer_norm): T5LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(1): T5LayerFF(
(DenseReluDense): T5DenseGatedActDense(
(wi_0): Linear(in_features=4096, out_features=10240, bias=False)
(wi_1): Linear(in_features=4096, out_features=10240, bias=False)
(wo): Linear(in_features=10240, out_features=4096, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
(act): NewGELUActivation()
)
(layer_norm): T5LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
(1-23): 23 x T5Block(
(layer): ModuleList(
(0): T5LayerSelfAttention(
(SelfAttention): T5Attention(
(q): Linear(in_features=4096, out_features=4096, bias=False)
(k): Linear(in_features=4096, out_features=4096, bias=False)
(v): Linear(in_features=4096, out_features=4096, bias=False)
(o): Linear(in_features=4096, out_features=4096, bias=False)
)
(layer_norm): T5LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(1): T5LayerFF(
(DenseReluDense): T5DenseGatedActDense(
(wi_0): Linear(in_features=4096, out_features=10240, bias=False)
(wi_1): Linear(in_features=4096, out_features=10240, bias=False)
(wo): Linear(in_features=10240, out_features=4096, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
(act): NewGELUActivation()
)
(layer_norm): T5LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(final_layer_norm): T5LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
接着模型会将正负两个提示词以dim=0拼接在一起,即变成[2,128,4096]、[2,128]、[2,2048]的三个向量。
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) # [2,128,4096]
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
3.生成随机噪声
具体来说,代码会在prepare_latents()生成一个大小为[b,c,t,h,w]=[1,16,16,96,160]的随机噪声,然后利用双线性插值的方法生成特征金字塔最底层的向量,即[1,16,16,24,40]
# 3.Create the initial random noise
num_channels_latents = self.dit.config.in_channels # 16
latents = self.prepare_latents( # [b,c,t,h,w]=[1,16,16,96,160]
batch_size * num_images_per_prompt,
num_channels_latents,
temp,
height,
width,
prompt_embeds.dtype,
device,
generator,
)
temp, height, width = latents.shape[-3], latents.shape[-2], latents.shape[-1]
latents = rearrange(latents, 'b c t h w -> (b t) c h w') # [1,16,16,24,40]->[16,16,24,40]
# by default, we needs to start from the block noise
for _ in range(len(self.stages)-1): # 双线性插值(bilinear interpolation)方法对latents张量进行上采样
height //= 2;width //= 2
latents = F.interpolate(latents, size=(height, width), mode='bilinear') * 2 # 结果变化过程:[16,16,96,160]->[16,16,48,80]->[16,16,24,40]
latents = rearrange(latents, '(b t) c h w -> b c t h w', t=temp) # [16,16,24,40]->[1,16,16,24,40]
num_units = 1 + (temp - 1) // self.frame_per_unit # 训练步长 16
stages = self.stages # 特征金字塔,三层[1,2,4]
generated_latents_list = [] # The generated results
last_generated_latents = None
prepare_latents()的核心代码:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
为什么不直接生成低分辨率的,而要先生成高分辨率的,再降到低分辨率的?
- 细节保留:高分辨率的潜在表示可以捕获更多的细节信息。通过先在高分辨率下生成这些细节,然后逐步降低分辨率,可以更好地保留图像的重要特征。
- 多尺度处理:在不同的分辨率层次上处理数据可以让模型在每个阶段专注于不同尺度的特征。例如,在高分辨率阶段,模型可以专注于细节和纹理的生成,而在低分辨率阶段,模型可以专注于整体结构和布局。
虽然直接生成低分辨率图像在某些情况下可能看起来更直接,但逐步从高分辨率到低分辨率的生成方法在实践中往往能产生更好的结果,特别是在需要生成高质量图像的应用中。
4.flow-based生成过程
这个部分是整个代码的核心,整体流程如下所示:
其中,当unit_index=0(第一帧)时,会执行if,而其他帧会执行else。这一点和之前的CogvideoX是一样的,即第一帧单独编码,剩余帧沿时间t维度扩大8倍,如我设置的是16(15+1),最后会放大到121帧(15*8+1)。这种操作是从图像生成领域照搬过来的。
pastpast_condition_latents代表以前时间步的三层金字塔特征。整个代码的核心是self.generate_one_unit,其生成的intermed_latents包含当前时间步的三层金字塔特征。
# 4.flow
for unit_index in tqdm(range(num_units)):
gc.collect()
torch.cuda.empty_cache()
if callback:
callback(unit_index, num_units)
if use_linear_guidance:
self._guidance_scale = guidance_scale_list[unit_index]
self._video_guidance_scale = guidance_scale_list[unit_index]
if unit_index == 0: # 三个层级,每个层级迭代num_inference_steps步
past_condition_latents = [[] for _ in range(len(stages))]
intermed_latents = self.generate_one_unit( # 3个,分别是[1,16,1,24,40],[1,16,1,48,80],[1,16,1,96,160]
latents[:,:,:1],
past_condition_latents,
prompt_embeds,
prompt_attention_mask,
pooled_prompt_embeds,
num_inference_steps,
height,
width,
1,
device,
dtype,
generator,
is_first_frame=True,
)
else:
# prepare the condition latents
past_condition_latents = [] # 上一步三个层级的特征
clean_latents_list = self.get_pyramid_latent(torch.cat(generated_latents_list, dim=2), len(stages) - 1) # 每个层级拼起来[1,16,f,h,w]
for i_s in range(len(stages)):
last_cond_latent = clean_latents_list[i_s][:,:,-(self.frame_per_unit):]
stage_input = [torch.cat([last_cond_latent] * 2) if self.do_classifier_free_guidance else last_cond_latent] # [2,16,1,,]
# pad the past clean latents
cur_unit_num = unit_index
cur_stage = i_s
cur_unit_ptx = 1
while cur_unit_ptx < cur_unit_num:
cur_stage = max(cur_stage - 1, 0)
if cur_stage == 0:
break
cur_unit_ptx += 1
cond_latents = clean_latents_list[cur_stage][:, :, -(cur_unit_ptx * self.frame_per_unit) : -((cur_unit_ptx - 1) * self.frame_per_unit)]
stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
if cur_stage == 0 and cur_unit_ptx < cur_unit_num:
cond_latents = clean_latents_list[0][:, :, :-(cur_unit_ptx * self.frame_per_unit)]
stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
stage_input = list(reversed(stage_input))
past_condition_latents.append(stage_input)
intermed_latents = self.generate_one_unit(
latents[:,:, 1 + (unit_index - 1) * self.frame_per_unit:1 + unit_index * self.frame_per_unit],
past_condition_latents,
prompt_embeds,
prompt_attention_mask,
pooled_prompt_embeds,
video_num_inference_steps,
height,
width,
self.frame_per_unit,
device,
dtype,
generator,
is_first_frame=False,
)
generated_latents_list.append(intermed_latents[-1]) # 取最后一层,即[1,16,1,96,160]
last_generated_latents = intermed_latents
接下来我们详细看看self.generate_one_unit()函数,这个函数包括三个部分:预处理、预测噪声、生成下一时间步的状态
- 预处理:如果是最小层(i_s=0)的patch(24*40),跳过;如果是其他两层,会将patch沿h和w维度插值放大两倍,同时还会引入额外的噪声,这些噪声的目的是在模型的不同阶段引入随机性,帮助模型更好地学习数据的分布。
- 预测噪声:首先先讲隐空间的数据沿着batch维度扩展2倍,变成[2,16,1,h,w],这里的2包含无条件控制的和有条件控制的数据。然后引入上一帧预测的结果past_conditions,然后进入DiT模型,预测噪声。下文为大家详细拆解DiT流程
- 生成下一个时间步的状态:noise_pred输出包括无条件和有条件的噪声,使用self.guidance_scale将二者融合,这里同时考虑无条件和有条件的噪声,使得生成模型能够灵活地适应不同的应用场景,同时保持生成图像的质量和多样性。最后进入self.scheduler.step,预测下一时间步的状态。下文为大家详细拆解schedular流程。
- 循环生成每个金字塔层级,返回。
for i_s in range(len(stages)):
self.scheduler.set_timesteps(num_inference_steps[i_s], i_s, device=device)
timesteps = self.scheduler.timesteps
if i_s > 0:
height *= 2; width *= 2
latents = rearrange(latents, 'b c t h w -> (b t) c h w')
latents = F.interpolate(latents, size=(height, width), mode='nearest')
latents = rearrange(latents, '(b t) c h w -> b c t h w', t=temp)
# Fix the stage
ori_sigma = 1 - self.scheduler.ori_start_sigmas[i_s] # the original coeff of signal
gamma = self.scheduler.config.gamma
alpha = 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)
beta = alpha * (1 - ori_sigma) / math.sqrt(gamma)
bs, ch, temp, height, width = latents.shape
noise = self.sample_block_noise(bs, ch, temp, height, width) # 在模型的不同阶段引入随机性,帮助模型更好地学习数据的分布
noise = noise.to(device=device, dtype=dtype)
latents = alpha * latents + beta * noise # To fix the block artifact
for idx, t in enumerate(timesteps): # num_inference_steps=20
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
if is_sequence_parallel_initialized():
# sync the input latent
sp_group_rank = get_sequence_parallel_group_rank()
global_src_rank = sp_group_rank * get_sequence_parallel_world_size()
torch.distributed.broadcast(latent_model_input, global_src_rank, group=get_sequence_parallel_group())
latent_model_input = past_conditions[i_s] + [latent_model_input] # 将前一帧的状态和当前阶段的状态合在一起
noise_pred = self.dit( # [2,16,1,h,w]
sample=[latent_model_input],
timestep_ratio=timestep,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
pooled_projections=pooled_prompt_embeds,
)
noise_pred = noise_pred[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) # [1,16,1,24,40] noise_pred_uncond表示无条件的噪声预测,而noise_pred_text可能表示与文本条件相关的噪声预测
if is_first_frame: # 同时考虑无条件和有条件的噪声,使得生成模型能够灵活地适应不同的应用场景,同时保持生成图像的质量和多样性。
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
else:
noise_pred = noise_pred_uncond + self.video_guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
model_output=noise_pred,
timestep=timestep,
sample=latents,
generator=generator,
).prev_sample
intermed_latents.append(latents) # [3个,分别是[1,16,1,24,40],[1,16,1,48,80],[1,16,1,96,160]]
return intermed_latents
为了方便大家理解,我这里为283行代码画了一张图,帮助大家搞清楚每个时间t,每个层级的交流过程。
- i_s,来自stages:当前金字塔层级
- unit_index,来自unit-indexs:当前在推理第几帧,超参设置为16
- t,来自timesteps:当前推理时间步,也就是设置的超参,默认20
#283
latent_model_input = past_conditions[i_s] + [latent_model_input] # 将前一帧的状态和当前阶段的状态合在一起
DiT
接下来,我们来看一下DiT部分。
首先,模型会通过线性层将时间和文本的池化信息融合在一起,具体步骤如下:
- 时间t:[2,256]->[2,1536]->[2,1536]
- text:[2,2048]->[2,1536]->[2,1536]
- 二者相加,[2,1536]
除此之外,还会把文本信息进行embedding
# Get the timestep embedding
temb = self.time_text_embed(timestep_ratio, pooled_projections) # 控制条件 t:[2,256]->[2,1536]->[2,1536];text:[2,2048]->[2,1536]->[2,1536] 两者加起来=[2,1536]
encoder_hidden_states = self.context_embedder(encoder_hidden_states) # [2,128,4096]->[2,128,1536]
encoder_hidden_length = encoder_hidden_states.shape[1] # 128
这一步把图像噪声的信息embedding为[2,240,1536]的尺寸,注意,240随着金字塔的层级变化而变化。
# Get the input sequence
hidden_states, hidden_length, temps, heights, widths, trainable_token_list, encoder_attention_mask, \
attention_mask, image_rotary_emb = self.merge_input(sample, encoder_hidden_length, encoder_attention_mask)
接下来,模型进入Transformer模块进行推理:
encoder_hidden_states, hidden_states = block( # [2,128,1536],[2,1920,1536]
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
temb=temb,
attention_mask=attention_mask,
hidden_length=hidden_length,
image_rotary_emb=image_rotary_emb,
)
模型的架构如下,可以看到,这个部分就是基于Transformer模块去做的,和CogVideoX类似,里面也添加了门控机制、偏移缩放。故不多赘述。
(transformer_blocks): ModuleList(
(0-22): 23 x JointTransformerBlock(
(norm1): AdaLayerNormZero(
(silu): SiLU()
(linear): Linear(in_features=1536, out_features=9216, bias=True)
(norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=False)
)
(norm1_context): AdaLayerNormZero(
(silu): SiLU()
(linear): Linear(in_features=1536, out_features=9216, bias=True)
(norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=False)
)
(attn): JointAttention(
(norm_q): RMSNorm()
(norm_k): RMSNorm()
(to_q): Linear(in_features=1536, out_features=1536, bias=True)
(to_k): Linear(in_features=1536, out_features=1536, bias=True)
(to_v): Linear(in_features=1536, out_features=1536, bias=True)
(add_k_proj): Linear(in_features=1536, out_features=1536, bias=True)
(add_v_proj): Linear(in_features=1536, out_features=1536, bias=True)
(add_q_proj): Linear(in_features=1536, out_features=1536, bias=True)
(norm_add_q): RMSNorm()
(norm_add_k): RMSNorm()
(to_out): ModuleList(
(0): Linear(in_features=1536, out_features=1536, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
(to_add_out): Linear(in_features=1536, out_features=1536, bias=True)
)
(norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=False)
(ff): FeedForward(
(net): ModuleList(
(0): GELU(
(proj): Linear(in_features=1536, out_features=6144, bias=True)
)
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=6144, out_features=1536, bias=True)
)
)
(norm2_context): LayerNorm((1536,), eps=1e-06, elementwise_affine=False)
(ff_context): FeedForward(
(net): ModuleList(
(0): GELU(
(proj): Linear(in_features=1536, out_features=6144, bias=True)
)
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=6144, out_features=1536, bias=True)
)
)
)
具体来说,每个block会先对视频帧和文本分别进行normlization,并生成门控、偏移量和缩放量。
接着模型会进入attention的计算,我们主要看var_len_attn()(也就是真正计算的部分)怎么计算的,可以看到,这一部分实际上是把文本和图像的编码concat到一起,然后拆分出qkv,这里的qkv既有文本的信息,又有图像的信息,然后计算注意力得分,这一操作和CogvideoX一模一样!
encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) # 文本的qkv[bs, sub_seq, 3, head, head_dim]
qkv = torch.stack([query, key, value], dim=2) # 图像的qkv[bs, sub_seq, 3, head, head_dim]
i_sum = 0
output_encoder_hidden_list = []
output_hidden_list = []
for i_p, length in enumerate(hidden_length):
encoder_qkv_tokens = encoder_qkv[i_p::num_stages]
qkv_tokens = qkv[:, i_sum:i_sum+length]
concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1) #文本和图像拼起来 [bs, tot_seq, 3, nhead, dim]
if image_rotary_emb is not None:
concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = self.apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])
query, key, value = concat_qkv_tokens.unbind(2) # [bs, tot_seq, nhead, dim]
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=False, enable_mem_efficient=True):
stage_hidden_states = F.scaled_dot_product_attention( # 计算attention
query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask[i_p],
)
stage_hidden_states = stage_hidden_states.transpose(1, 2).flatten(2, 3) # [bs, tot_seq, dim]
output_encoder_hidden_list.append(stage_hidden_states[:, :encoder_length])
output_hidden_list.append(stage_hidden_states[:, encoder_length:])
i_sum += length
最后经过全连接层和残差链接,再经过输出层,得到最终的输出。这一部分和其他transformer完全一样,故不多赘述。
接下来,我们来看一下如何把无条件和有条件的噪声预测结果融合起来。可以看到这里其实是无条件预测结果加上有无条件噪声的插值。这里同时考虑无条件和有条件的噪声,使得生成模型能够灵活地适应不同的应用场景,同时保持生成图像的质量和多样性。
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) # [1,16,1,24,40] noise_pred_uncond表示无条件的噪声预测,而noise_pred_text可能表示与文本条件相关的噪声预测
if is_first_frame: # 同时考虑无条件和有条件的噪声,使得生成模型能够灵活地适应不同的应用场景,同时保持生成图像的质量和多样性。
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
else:
noise_pred = noise_pred_uncond + self.video_guidance_scale * (noise_pred_text - noise_pred_uncond)
schedular.step()
接下来,我们来看一下schedular.step()部分
核心代码如下,这一步就是利用当前的sigma和下一步的sigma计算预测的结果。
sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]
prev_sample = sample + (sigma_next - sigma) * model_output # 预测下一步的数据 [b,c,t,h,w]=[1,16,1,h,w]
5.decoder输出
这个部分的核心代码如下
image = self.decode_latent(generated_latents, save_memory=save_memory, inference_multigpu=inference_multigpu)
decode_latent的内容如下:
实际上这里就是先调整特征的数值范围,使其更适合后续的处理步骤。然后进行decode,再还原回rgb通道的正常空间。核心部分在vae.decode
def decode_latent(self, latents, save_memory=True, inference_multigpu=False):
# only the main process needs vae decoding
if inference_multigpu and get_rank() != 0:
return None
if latents.shape[2] == 1:
latents = (latents / self.vae_scale_factor) + self.vae_shift_factor
else:
latents[:, :, :1] = (latents[:, :, :1] / self.vae_scale_factor) + self.vae_shift_factor # 在处理编码或嵌入表示时,用于调整特征的数值范围,使其更适合后续的处理步骤。
latents[:, :, 1:] = (latents[:, :, 1:] / self.vae_video_scale_factor) + self.vae_video_shift_factor
if save_memory:
# reducing the tile size and temporal chunk window size
image = self.vae.decode(latents, temporal_chunk=True, window_size=1, tile_sample_min_size=256).sample # [b,c,f,w,h]=[1,3,121,768,1280]
else:
image = self.vae.decode(latents, temporal_chunk=True, window_size=2, tile_sample_min_size=512).sample
image = image.mul(127.5).add(127.5).clamp(0, 255).byte() # 转换回[0-255]的区间
image = rearrange(image, "B C T H W -> (B T) H W C") # [121,768,1280,3]
image = image.cpu().numpy()
image = self.numpy_to_pil(image) # 转换为PIL格式
return image
vae.decode部分的代码:
def decode(self, z: torch.FloatTensor, is_init_image=True, temporal_chunk=False,
return_dict: bool = True, window_size: int = 2, tile_sample_min_size: int = 256,) -> Union[DecoderOutput, torch.FloatTensor]:
self.tile_sample_min_size = tile_sample_min_size
self.tile_latent_min_size = int(tile_sample_min_size / self.downsample_scale)
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
return self.tiled_decode(z, is_init_image=is_init_image,
temporal_chunk=temporal_chunk, window_size=window_size, return_dict=return_dict)
if temporal_chunk:
dec = self.chunk_decode(z, window_size=window_size)
else:
z = self.post_quant_conv(z, is_init_image=is_init_image, temporal_chunk=False)
dec = self.decoder(z, is_init_image=is_init_image, temporal_chunk=False)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
tiled_decode()
这个函数实际上是隐空间的向量切成[1,16,16,32,32]的patch,每个patch经过chunk_decode()函数后变成[1,3,121,256,256],放入列表rows中。然后进行尺寸限制,变成[1,3,121,224,224],最后将所有patch concat起来,变成[1,3,121,768,1280]。这样视频的数据就生成好了。
overlap_size = int(self.tile_latent_min_size * (1 - self.decode_tile_overlap_factor)) # 28
blend_extent = int(self.tile_sample_min_size * self.decode_tile_overlap_factor) # 32
row_limit = self.tile_sample_min_size - blend_extent # 224
# Split z into overlapping 64x64 tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, z.shape[3], overlap_size): # 按列
row = []
for j in range(0, z.shape[4], overlap_size): # 按行
tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] # 切成一个个小方格(patch)
if temporal_chunk:
decoded = self.chunk_decode(tile, window_size=window_size) # 对每个patch进行处理[1,16,16,32,32]->[1,3,121,256,256]
else:
tile = self.post_quant_conv(tile, is_init_image=True, temporal_chunk=False) # [1,16,16,32,32]
decoded = self.decoder(tile, is_init_image=True, temporal_chunk=False) # [1,3,121,256,256]
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit]) # 尺寸限制 [1,3,121,224,224]
result_rows.append(torch.cat(result_row, dim=4)) # 按行拼起来 [1,3,121,224,1280]
dec = torch.cat(result_rows, dim=3) # 按列拼起来[1,3,121,768,1280]
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
chunk_decode()
整个decoder部分一层套一层,其实这个这个函数才是整个部分的核心。这个代码实际上就是把16帧的视频分成两部分:前两帧和后n帧(16的话,n=14),前两帧经过Vae的decoder后变成1,3,9,256,256,后面n帧都是将t上采样到了8。一共生成121帧图像。
注意,这里的图像实际上是patch,也就是256*256的方格,在经过tiled_decode()的多轮循环后才是完整的图像。
def chunk_decode(self, z: torch.FloatTensor, window_size=2):
num_frames = z.shape[2]
init_window_size = window_size + 1
frame_list = [z[:,:,:init_window_size]] # 前2帧,[1,16,2,h,w]
# To chunk the long video
full_chunk_size = (num_frames - init_window_size) // window_size
fid = init_window_size
for idx in range(full_chunk_size):
frame_list.append(z[:, :, fid:fid+window_size]) # 后面的14帧[1,16,1,h,w]
fid += window_size
if fid < num_frames:
frame_list.append(z[:, :, fid:])
dec_list = []
for idx, frames in enumerate(frame_list):
if idx == 0: # 包含初始图像
z_h = self.post_quant_conv(frames, is_init_image=True, temporal_chunk=True)
dec = self.decoder(z_h, is_init_image=True, temporal_chunk=True) # [1,3,f+1=9,256,256]
else: # 不包含原始图像
z_h = self.post_quant_conv(frames, is_init_image=False, temporal_chunk=True)
dec = self.decoder(z_h, is_init_image=False, temporal_chunk=True) # [1,3,8,256,256]
dec_list.append(dec)
dec = torch.cat(dec_list, dim=2)
return dec
输出视频
inference.py最后输出视频。
export_to_video(frames, "./text_to_video_sample1.mp4", fps=24)
这个代码的核心内容如下,实际上就是将模型输出的结果写入文件。
with imageio.get_writer(output_video_path, fps=fps) as writer:
for frame in video_frames:
writer.append_data(frame)
至此模型的代码分析就结束了。
4.评测
在观看第一段视频时,可以明显感受到整体的制作水平是相当高的,画面流畅,色彩饱满,给人以视觉上的享受。然而,在对人物手部细节的呈现上,我们注意到了一些不尽如人意的地方。具体来说,手部的渲染出现了较为明显的问题,让人在观看时产生了一种错觉,仿佛人物手中正握着一个塑料袋。这种视觉上的误差无疑分散了观众的注意力,降低了观看体验的质量。
第二段视频在初次观看时,似乎并没有明显的瑕疵,画面连贯,内容也颇为吸引人。但仔细观察后,我们发现了一个技术上的疏忽:视频中出现了两辆火车在同一条轨道上并行的奇异现象。这种情形,类似于我们在电子游戏中偶尔遇到的“穿模”现象,即两个本不应重叠的物体在视觉上发生了重叠。
5.总结
Pyramid-Flow模型的核心优势在于其创新的金字塔流匹配算法,这种算法通过将视频生成过程分解为多个不同分辨率的金字塔阶段,有效降低了计算复杂度,同时保持了生成视频的高分辨率和帧率。Pyramid-Flow模型支持端到端优化,使用单一的统一扩散变换器(DiT)进行训练,简化了模型的实现,使其在生成长达10秒、分辨率高达1280x768、帧率为24fps的高清视频方面表现出色。
然而,Pyramid-Flow模型也存在一些局限性。尽管它在处理静态图像和简单动态场景时表现出色,但在生成复杂动态场景,尤其是涉及快速移动物体的视频时,可能会出现细节处理不足的情况,同时对物体细节的捕捉和处理也需要进一步优化。此外,尽管模型在训练效率上取得了显著提升,但在生成极高分辨率或极长视频时,仍可能面临计算资源和时间的挑战。
尽管如此,Pyramid-Flow模型在开源社区中提供了一个高效、灵活的视频生成解决方案,特别适合资源有限的中小企业和个人开发者使用。
探索AI视频生成的无限可能,与Pyramid-flow一起开启创意之旅!如果你对这篇分享充满热情,别忘了点赞和关注,让我们的创新故事持续发光发热。每一次互动都是我们前进的动力,感谢你的支持,让我们共同见证科技与艺术的完美融合!
版权归原作者 Sherlock Ma 所有, 如有侵权,请联系我们删除。