import logging from collections.abc import Iterator import torch from ltx_core.components.noisers import GaussianNoiser from ltx_core.loader import LoraPathStrengthAndSDOps from ltx_core.loader.registry import Registry from ltx_core.model.transformer.compiling import CompilationConfig from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number from ltx_core.quantization import QuantizationPolicy from ltx_core.types import Audio from ltx_pipelines.utils.args import ( ImageConditioningInput, default_2_stage_distilled_arg_parser, detect_checkpoint_path, ) from ltx_pipelines.utils.blocks import ( AudioDecoder, DiffusionStage, ImageConditioner, PromptEncoder, VideoDecoder, VideoUpsampler, ) from ltx_pipelines.utils.constants import ( DISTILLED_SIGMAS, STAGE_2_DISTILLED_SIGMAS, detect_params, ) from ltx_pipelines.utils.denoisers import SimpleDenoiser from ltx_pipelines.utils.helpers import ( assert_resolution, combined_image_conditionings, get_device, ) from ltx_pipelines.utils.media_io import encode_video from ltx_pipelines.utils.types import ModalitySpec, OffloadMode class DistilledPipeline: """ Two-stage distilled video generation pipeline. Stage 1 generates video at half of the target resolution, then Stage 2 upsamples by 2x and refines with additional denoising steps for higher quality output. """ def __init__( self, distilled_checkpoint_path: str, gemma_root: str, spatial_upsampler_path: str, loras: list[LoraPathStrengthAndSDOps], device: torch.device | None = None, quantization: QuantizationPolicy | None = None, registry: Registry | None = None, compilation_config: CompilationConfig | None = None, offload_mode: OffloadMode = OffloadMode.NONE, ): self.device = device or get_device() self.dtype = torch.bfloat16 self.prompt_encoder = PromptEncoder( distilled_checkpoint_path, gemma_root, self.dtype, self.device, registry=registry, offload_mode=offload_mode, ) self.image_conditioner = ImageConditioner(distilled_checkpoint_path, self.dtype, self.device, registry=registry) self.stage = DiffusionStage( distilled_checkpoint_path, self.dtype, self.device, loras=tuple(loras), quantization=quantization, registry=registry, compilation_config=compilation_config, offload_mode=offload_mode, ) self.upsampler = VideoUpsampler( distilled_checkpoint_path, spatial_upsampler_path, self.dtype, self.device, registry=registry ) self.video_decoder = VideoDecoder(distilled_checkpoint_path, self.dtype, self.device, registry=registry) self.audio_decoder = AudioDecoder(distilled_checkpoint_path, self.dtype, self.device, registry=registry) def __call__( # noqa: PLR0913 self, prompt: str, seed: int, height: int, width: int, num_frames: int, frame_rate: float, images: list[ImageConditioningInput], tiling_config: TilingConfig | None = None, enhance_prompt: bool = False, stage_1_sigmas: torch.Tensor = DISTILLED_SIGMAS, stage_2_sigmas: torch.Tensor = STAGE_2_DISTILLED_SIGMAS, ) -> tuple[Iterator[torch.Tensor], Audio]: assert_resolution(height=height, width=width, is_two_stage=True) generator = torch.Generator(device=self.device).manual_seed(seed) noiser = GaussianNoiser(generator=generator) dtype = torch.bfloat16 (ctx_p,) = self.prompt_encoder( [prompt], enhance_first_prompt=enhance_prompt, enhance_prompt_image=images[0][0] if len(images) > 0 else None, ) video_context, audio_context = ctx_p.video_encoding, ctx_p.audio_encoding # Stage 1: Initial low resolution video generation. stage_1_sigmas = stage_1_sigmas.to(dtype=torch.float32, device=self.device) stage_1_w, stage_1_h = width // 2, height // 2 stage_1_conditionings = self.image_conditioner( lambda enc: combined_image_conditionings( images=images, height=stage_1_h, width=stage_1_w, video_encoder=enc, dtype=dtype, device=self.device, ) ) video_state, audio_state = self.stage( denoiser=SimpleDenoiser(video_context, audio_context), sigmas=stage_1_sigmas, noiser=noiser, width=stage_1_w, height=stage_1_h, frames=num_frames, fps=frame_rate, video=ModalitySpec(context=video_context, conditionings=stage_1_conditionings), audio=ModalitySpec(context=audio_context), ) # Stage 2: Upsample and refine the video at higher resolution with distilled LORA. upscaled_video_latent = self.upsampler(video_state.latent[:1]) stage_2_sigmas = stage_2_sigmas.to(dtype=torch.float32, device=self.device) stage_2_conditionings = self.image_conditioner( lambda enc: combined_image_conditionings( images=images, height=height, width=width, video_encoder=enc, dtype=dtype, device=self.device, ) ) video_state, audio_state = self.stage( denoiser=SimpleDenoiser(video_context, audio_context), sigmas=stage_2_sigmas, noiser=noiser, width=width, height=height, frames=num_frames, fps=frame_rate, video=ModalitySpec( context=video_context, conditionings=stage_2_conditionings, noise_scale=stage_2_sigmas[0].item(), initial_latent=upscaled_video_latent, ), audio=ModalitySpec( context=audio_context, noise_scale=stage_2_sigmas[0].item(), initial_latent=audio_state.latent, ), ) decoded_video = self.video_decoder(video_state.latent, tiling_config, generator) decoded_audio = self.audio_decoder(audio_state.latent) return decoded_video, decoded_audio @torch.inference_mode() def main() -> None: logging.basicConfig(level=logging.INFO) checkpoint_path = detect_checkpoint_path(distilled=True) params = detect_params(checkpoint_path) parser = default_2_stage_distilled_arg_parser(params=params) args = parser.parse_args() pipeline = DistilledPipeline( distilled_checkpoint_path=args.distilled_checkpoint_path, spatial_upsampler_path=args.spatial_upsampler_path, gemma_root=args.gemma_root, loras=tuple(args.lora) if args.lora else (), quantization=args.quantization, compilation_config=args.compile, offload_mode=args.offload_mode, ) tiling_config = TilingConfig.default() video_chunks_number = get_video_chunks_number(args.num_frames, tiling_config) video, audio = pipeline( prompt=args.prompt, seed=args.seed, height=args.height, width=args.width, num_frames=args.num_frames, frame_rate=args.frame_rate, images=args.images, tiling_config=tiling_config, enhance_prompt=args.enhance_prompt, ) encode_video( video=video, fps=args.frame_rate, audio=audio, output_path=args.output_path, video_chunks_number=video_chunks_number, ) if __name__ == "__main__": main()