{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Image generation with Sana and OpenVINO\n", "\n", "**Sana** is a text-to-image framework that can efficiently generate images up to 4096 × 4096 resolution developed by NVLabs. Sana can synthesize high-resolution, high-quality images with strong text-image alignment at a remarkably fast speed, deployable on laptop GPU. \n", "Core designs include: \n", "* **Deep compression autoencoder**: unlike traditional AEs, which compress images only 8×, we trained an AE that can compress images 32×, effectively reducing the number of latent tokens.\n", "* **Linear DiT**: authors replaced all vanilla attention in DiT with linear attention, which is more efficient at high resolutions without sacrificing quality.\n", "* **Decoder-only text encoder***: T5 replaced by modern decoder-only small LLM as the text encoder and designed complex human instruction with in-context learning to enhance the image-text alignment.\n", "* **Efficient training and sampling**: Proposed Flow-DPM-Solver to reduce sampling steps, with efficient caption labeling and selection to accelerate convergence.\n", "\n", "More details about model can be found in [paper](https://arxiv.org/abs/2410.10629), [model page](https://nvlabs.github.io/Sana/) and [original repo](https://github.com/NVlabs/Sana).\n", "\n", "**SANA-1.5** is a linear Diffusion Transformer for efficient scaling in text-to-image generation. SANA-1.5 is built on SANA-1.0 with introduction following improvements:\n", "* **Efficient Training Scaling**: A depth-growth paradigm that enables scaling from 1.6B to 4.8B parameters with significantly reduced computational resources, combined with a memory-efficient 8-bit optimizer.\n", "* **Model Depth Pruning**: A block importance analysis technique for efficient model compression to arbitrary sizes with minimal quality loss.\n", "* **Inference-time Scaling**: A repeated sampling strategy that trades computation for model capacity, enabling smaller models to match larger model quality at inference time.\n", "\n", "More details about model can be found in [paper](https://arxiv.org/abs/2501.18427), [model page](https://nvlabs.github.io/Sana/Sana-1.5/) and [original repo](https://github.com/NVlabs/Sana).\n", "\n", "**SANA-Sprint** is an efficient diffusion model for ultra-fast text-to-image. SANA-Sprint is built on a pre-trained foundation model and augmented with hybrid distillation, dramatically reducing inference steps from 20 to 1-4. \n", "Core innovations include:\n", "* **Training-Free Transformation to TrigFlow**: the paper proposes a training-free approach that transforms a pre-trained flow-matching model for continuous-time consistency distillation (sCM), eliminating costly training from scratch and achieving high training efficiency. The hybrid distillation strategy combines sCM with latent adversarial distillation (LADD): sCM ensures alignment with the teacher model, while LADD enhances single-step generation fidelity.\n", "* **Stabilizing Continuous-Time Distillation**: To stabilize continuous-time consistency distillation, we address two key challenges: training instabilities and excessively large gradient norms that occur when scaling up the model size and increasing resolution, leading to model collapse. We achieve this by refining dense time-embedding and integrating QK-Normalization into self- and cross-attention mechanisms. These modifications enable efficient training and improve stability, allowing for robust performance at higher resolutions and larger model sizes. \n", "* **Improving Continuous-Time CMs with GAN**: CTM analyzes that CMs distill teacher information in a local manner, where at each iteration, the student model learns from local time intervals. This leads the model to learn cross timestep information under the implicit extrapolation, which can slow the convergence speed. To address this limitation, we introduce an additional adversarial loss to provide direct global supervision across different timesteps, improving both the convergence speed and the output quality.\n", "\n", "More details about model can be found in [paper](https://arxiv.org/pdf/2503.09641), [model page](https://nvlabs.github.io/Sana/Sprint/) and [original repo](https://github.com/NVlabs/Sana).\n", "\n", "In this tutorial, we consider how to optimize and run models from Sana's family using OpenVINO.\n", "#### Table of contents:\n", "\n", "- [Prerequisites](#Prerequisites)\n", "- [Select model variant](#Select-model-variant)\n", "- [Convert and Optimize model with OpenVINO](#Convert-and-Optimize-model-with-OpenVINO)\n", " - [Convert model using Optimum Intel](#Convert-model-using-Optimum-Intel)\n", " - [Compress model weights](#compress-model-weights)\n", "- [Run OpenVINO model inference](#Run-OpenVINO-model-inference)\n", "- [Interactive demo](#Interactive-demo)\n", "\n", "\n", "### Installation Instructions\n", "\n", "This is a self-contained example that relies solely on its own code.\n", "\n", "We recommend running the notebook in a virtual environment. You only need a Jupyter server to start.\n", "For details, please refer to [Installation Guide](https://github.com/openvinotoolkit/openvino_notebooks/blob/latest/README.md#-installation-guide).\n", "\n", "\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Prerequisites\n", "[back to top ⬆️](#Table-of-contents:)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import platform\n", "\n", "%pip uninstall -q -y optimum optimum-intel optimum-onnx\n", "%pip install -q \"gradio>=4.19,<6\" \"torch==2.8\" \"transformers==4.53.3\" \"nncf>=2.14.0\" \"opencv-python\" \"pillow\" \"peft>=0.15.0\" --extra-index-url https://download.pytorch.org/whl/cpu\n", "%pip install -q \"sentencepiece\" \"protobuf\"\n", "%pip install -q \"git+https://github.com/huggingface/optimum-intel.git\" --extra-index-url https://download.pytorch.org/whl/cpu\n", "%pip install -qU \"openvino>=2025.1.0\"\n", "%pip install -q \"diffusers>=0.33.0\"\n", "\n", "if platform.system() == \"Darwin\":\n", " %pip install \"numpy<2.0\"" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "import requests\n", "\n", "helpers = [\"notebook_utils.py\", \"cmd_helper.py\"]\n", "base_url = \"https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils\"\n", "\n", "for helper in helpers:\n", " if not Path(helper).exists():\n", " r = requests.get(f\"{base_url}/{helper}\")\n", " with open(helper, \"w\") as f:\n", " f.write(r.text)\n", "\n", "if not Path(\"gradio_helper.py\").exists():\n", " r = requests.get(url=\"https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/notebooks/sana-image-generation/gradio_helper.py\")\n", " open(\"gradio_helper.py\", \"w\").write(r.text)\n", "\n", "# Read more about telemetry collection at https://github.com/openvinotoolkit/openvino_notebooks?tab=readme-ov-file#-telemetry\n", "from notebook_utils import collect_telemetry\n", "\n", "collect_telemetry(\"sana-image-generation.ipynb\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Select model variant\n", "[back to top ⬆️](#Table-of-contents:)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "test_replace": { "Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers": "katuni4ka/tiny-random-sana-sprint" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "adbe3f60ec9341a885a68cc65bbc1248", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Dropdown(description='Model:', options=('Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers', 'Efficient-…" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import ipywidgets as widgets\n", "\n", "model_ids = [\n", " \"Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers\",\n", " \"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers\",\n", " \"Efficient-Large-Model/SANA1.5_1.6B_1024px_diffusers\",\n", " \"Efficient-Large-Model/Sana_600M_512px_diffusers\",\n", " \"Efficient-Large-Model/Sana_600M_1024px_diffusers\",\n", " \"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers\",\n", " \"Efficient-Large-Model/Sana_1600M_2Kpx_BF16_diffusers\",\n", " \"Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers\",\n", "]\n", "\n", "model_selector = widgets.Dropdown(\n", " options=model_ids,\n", " default=model_ids[0],\n", " description=\"Model:\",\n", ")\n", "\n", "\n", "model_selector" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Convert and Optimize model with OpenVINO\n", "[back to top ⬆️](#Table-of-contents:)\n", "\n", "Starting from 2023.0 release, OpenVINO supports PyTorch models directly via Model Conversion API. `ov.convert_model` function accepts instance of PyTorch model and example inputs for tracing and returns object of `ov.Model` class, ready to use or save on disk using `ov.save_model` function. \n", "\n", "\n", "The pipeline consists of four important parts:\n", "\n", "* Gemma Text Encoder to create condition to generate an image from a text prompt.\n", "* Transformer for step-by-step denoising latent image representation.\n", "* Deep Compression Autoencoder (DCAE) for decoding latent space to image.\n", " \n", "### Convert model using Optimum Intel\n", "[back to top ⬆️](#Table-of-contents:)\n", "\n", "For convenience, we will use OpenVINO integration with HuggingFace Optimum. 🤗 [Optimum Intel](https://huggingface.co/docs/optimum/intel/index) is the interface between the 🤗 Transformers and Diffusers libraries and the different tools and libraries provided by Intel to accelerate end-to-end pipelines on Intel architectures.\n", "\n", "Among other use cases, Optimum Intel provides a simple interface to optimize your Transformers and Diffusers models, convert them to the OpenVINO Intermediate Representation (IR) format and run inference using OpenVINO Runtime. `optimum-cli` provides command line interface for model conversion and optimization. \n", "\n", "General command format:\n", "\n", "```bash\n", "optimum-cli export openvino --model --task \n", "```\n", "\n", "where task is task to export the model for, if not specified, the task will be auto-inferred based on the model (in case of image generation, **text-to-image** should be selected). You can find a mapping between tasks and model classes in Optimum TaskManager [documentation](https://huggingface.co/docs/optimum/exporters/task_manager). Additionally, you can specify weights compression using `--weight-format` argument with one of following options: `fp32`, `fp16`, `int8` and `int4`. Fro int8 and int4 [nncf](https://github.com/openvinotoolkit/nncf) will be used for weight compression. More details about model export provided in [Optimum Intel documentation](https://huggingface.co/docs/optimum/intel/openvino/export#export-your-model)." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "model_id = model_selector.value\n", "additional_args = {\"weight-format\": \"fp16\"}\n", "if \"sprint\" not in model_id.lower() and \"1.5\" not in model_id:\n", " variant = \"fp16\" if \"BF16\" not in model_id else \"bf16\"\n", " additional_args[\"variant\"] = variant\n", "\n", "model_dir = Path(model_id.split(\"/\")[-1])" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from cmd_helper import optimum_cli\n", "\n", "if not model_dir.exists():\n", " optimum_cli(model_id, model_dir, additional_args=additional_args)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Compress model weights\n", "[back to top ⬆️](#Table-of-contents:)\n", "\n", "For reducing model memory consumption we will use weights compression. The [Weights Compression](https://docs.openvino.ai/2024/openvino-workflow/model-optimization-guide/weight-compression.html) algorithm is aimed at compressing the weights of the models and can be used to optimize the model footprint and performance of large models where the size of weights is relatively larger than the size of activations, for example, Large Language Models (LLM). Compared to INT8 compression, INT4 compression improves performance even more, but introduces a minor drop in prediction quality. We will use [NNCF](https://github.com/openvinotoolkit/nncf) for transformer weight compression." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "825a93daac56468abcd497e6d5c7ab50", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Checkbox(value=True, description='Weight compression')" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "to_compress = widgets.Checkbox(\n", " value=True,\n", " description=\"Weight compression\",\n", " disabled=False,\n", ")\n", "\n", "to_compress" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "test_replace": { "group_size=64": "group_size=-1" } }, "outputs": [], "source": [ "import openvino as ov\n", "import nncf\n", "import gc\n", "\n", "compressed_transformer = Path(model_dir) / \"transformer/openvino_model_i4.xml\"\n", "\n", "if to_compress.value and not compressed_transformer.exists():\n", " core = ov.Core()\n", "\n", " ov_model = core.read_model(model_dir / \"transformer/openvino_model.xml\")\n", "\n", " compressed_model = nncf.compress_weights(ov_model, mode=nncf.CompressWeightsMode.INT4_SYM, group_size=64, ratio=1.0)\n", " ov.save_model(compressed_model, compressed_transformer)\n", " del compressed_model\n", " del ov_model\n", "\n", " gc.collect();" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Run OpenVINO model inference\n", "[back to top ⬆️](#Table-of-contents:)\n", "\n", "`OVDiffusionPipeline` from Optimum Intel provides ready-to-use interface for running Diffusers models using OpenVINO. It supports various models including Stable Diffusion, Stable Diffusion XL, LCM, Stable Diffusion v3 and Flux. Similar to original Diffusers pipeline, for initialization, we should use `from_preptrained` method providing model id from HuggingFace hub or local directory (both original PyTorch and OpenVINO models formats supported, in the first case model class additionally will trigger model conversion)." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ea/work/py311/lib/python3.11/site-packages/openvino/runtime/__init__.py:10: DeprecationWarning: The `openvino.runtime` module is deprecated and will be removed in the 2026.0 release. Please replace `openvino.runtime` with `openvino`.\n", " warnings.warn(\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "29b7f90069464b478ebc79f2e4a12b80", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Dropdown(description='Device:', options=('CPU', 'AUTO'), value='CPU')" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from notebook_utils import device_widget\n", "\n", "device = device_widget(default=\"CPU\", exclude=[\"NPU\"])\n", "device" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-04-16 20:55:37.264382: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", "2025-04-16 20:55:37.278500: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "E0000 00:00:1744822537.292992 1227763 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "E0000 00:00:1744822537.297771 1227763 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", "2025-04-16 20:55:37.314771: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" ] } ], "source": [ "from optimum.intel.openvino import OVDiffusionPipeline\n", "\n", "ov_pipe = OVDiffusionPipeline.from_pretrained(model_dir, device=device.value, transformer_file_name=compressed_transformer.name if to_compress.value else None)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Set timesteps: tensor([1.5708, 1.3000, 0.0000])\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f7ce615095aa43619cce7c37874614c7", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/2 [00:00" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "\n", "prompt = \"Cute 🐶 Wearing 🕶 flying on the 🌈\"\n", "\n", "image = ov_pipe(\n", " prompt,\n", " generator=torch.Generator(\"cpu\").manual_seed(1234563),\n", ").images[0]\n", "\n", "image" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Interactive demo\n", "[back to top ⬆️](#Table-of-contents:)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from gradio_helper import make_demo\n", "\n", "demo = make_demo(ov_pipe, sprint=\"sprint\" in model_id.lower())\n", "\n", "# if you are launching remotely, specify server_name and server_port\n", "# demo.launch(server_name='your server name', server_port='server port in int')\n", "# if you have any issue to launch on your platform, you can pass share=True to launch method:\n", "# demo.launch(share=True)\n", "# it creates a publicly shareable link for the interface. Read more in the docs: https://gradio.app/docs/\n", "try:\n", " demo.launch(debug=True)\n", "except Exception:\n", " demo.launch(debug=True, share=True)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.4" }, "openvino_notebooks": { "imageUrl": "https://github.com/user-attachments/assets/bacfcd2a-ac36-4421-9d1b-4e34aa0a9f62", "tags": { "categories": [ "Model Demos", "AI Trends" ], "libraries": [], "other": [ "Stable Diffusion" ], "tasks": [ "Text-to-Image" ] } }, "widgets": { "application/vnd.jupyter.widget-state+json": { "state": { "10cc00bb64de4e2e973d6fcc67015d0c": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": { "layout": "IPY_MODEL_fe82d8fb9a9f411eb7b59c476d4ff377", "style": "IPY_MODEL_5f4dfc9529b54883b4b4052900bc0565", "value": "100%" } }, "1456ae93e02b4cf78647c78dd68bb9ba": { "model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {} }, "18573adbbdcc40f7a01e31664a0d69f0": { "model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {} }, "1e8c24ac8e7149edb22bd1d4c8a27954": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "ProgressStyleModel", "state": { "description_width": "" } }, "1f2ae49c63d44020825c4786c892330d": { "model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {} }, "215992d572744b14997f0bc562c1c368": { "model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {} }, "25f9b6cc4df24ca68e4f3fc695a60764": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": { "layout": "IPY_MODEL_bf86142855a742158a52f66aa1441f13", "style": "IPY_MODEL_f72f72b42470443193c4bb6be984992f", "value": "100%" } }, "294241cc8206489ba60486dfcc27480e": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": { "description_width": "", "font_size": null, "text_color": null } }, "29b7f90069464b478ebc79f2e4a12b80": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "DropdownModel", "state": { "_options_labels": [ "CPU", "AUTO" ], "description": "Device:", "index": 0, "layout": "IPY_MODEL_2a99ea7dedf54b91a6f44f19fea9cbd3", "style": "IPY_MODEL_2e36640abbf7467eae8acf597fb98e7c" } }, "2a99ea7dedf54b91a6f44f19fea9cbd3": { "model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {} }, "2e36640abbf7467eae8acf597fb98e7c": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "DescriptionStyleModel", "state": { "description_width": "" } }, "31e5ba79edc0420c8e3dda9594ec5adc": { "model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {} }, "3b5312ad38534e2ca06ac1823444af30": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "CheckboxStyleModel", "state": { "description_width": "" } }, "445705c6be134f22b8de9287a3a87b8e": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "FloatProgressModel", "state": { "bar_style": "success", "layout": "IPY_MODEL_5e9a75ee7f9a4482955843732cf5eaaf", "max": 2, "style": "IPY_MODEL_e599a5783b43467196593527fb7edb44", "value": 2 } }, "449957bf178b450c9172927791ceac57": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "ProgressStyleModel", "state": { "description_width": "" } }, "4563a65764a14145a1162df470d070cb": { "model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {} }, "4ef4f0cf38fe4405a82d5dac21bee3d4": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": { "description_width": "", "font_size": null, "text_color": null } }, "56decb57197b40769dc57ad3751121e8": { "model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {} }, "5a06e1bb088543f3b21a641ecc0778af": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "FloatProgressModel", "state": { "bar_style": "success", "layout": "IPY_MODEL_1f2ae49c63d44020825c4786c892330d", "max": 2, "style": "IPY_MODEL_92ee19577aa745babec9968c5a8818c4", "value": 2 } }, "5e9a75ee7f9a4482955843732cf5eaaf": { "model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {} }, "5f4dfc9529b54883b4b4052900bc0565": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": { "description_width": "", "font_size": null, "text_color": null } }, "60677a1186f745a7bdb7899a80992869": { "model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {} }, "6aab6ab32fc34a7ea4ace689563d66d5": { "model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {} }, "6b8cc3dba13348e5beed3c8170bf035f": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": { "layout": "IPY_MODEL_b880cbd3ca934f75b686c71bda25e8c8", "style": "IPY_MODEL_f28fc004c13e4ba4a266e8a47a766165", "value": " 2/2 [00:01<00:00,  1.45steps/s]" } }, "6ec311b1b08b4b2585bf7d65f4c5c6c6": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "DescriptionStyleModel", "state": { "description_width": "" } }, "764104cfe4f946f6aa12576a22dd0f64": { "model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {} }, "79873f5c63524783954468667a9c6216": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": { "description_width": "", "font_size": null, "text_color": null } }, "7a2878e99b2446b9b491acfef64136dc": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": { "layout": "IPY_MODEL_1456ae93e02b4cf78647c78dd68bb9ba", "style": "IPY_MODEL_b53bee3bfece4b5fb6db98da3b0b521a", "value": "100%" } }, "825a93daac56468abcd497e6d5c7ab50": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "CheckboxModel", "state": { "description": "Weight compression", "disabled": false, "layout": "IPY_MODEL_d6ddf4502e3d4f7aafd4f89f7ae554ff", "style": "IPY_MODEL_3b5312ad38534e2ca06ac1823444af30", "value": true } }, "87dc8d8e536348d3b32b7819e12ac6e1": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HBoxModel", "state": { "children": [ "IPY_MODEL_10cc00bb64de4e2e973d6fcc67015d0c", "IPY_MODEL_445705c6be134f22b8de9287a3a87b8e", "IPY_MODEL_6b8cc3dba13348e5beed3c8170bf035f" ], "layout": "IPY_MODEL_e3fbbf28cfb54234995ddbcb6bb42d3a" } }, "92ee19577aa745babec9968c5a8818c4": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "ProgressStyleModel", "state": { "description_width": "" } }, "9902932f4e5647809db84f0a5ac86f06": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": { "layout": "IPY_MODEL_6aab6ab32fc34a7ea4ace689563d66d5", "style": "IPY_MODEL_4ef4f0cf38fe4405a82d5dac21bee3d4", "value": " 2/2 [00:01<00:00,  1.24it/s]" } }, "adbe3f60ec9341a885a68cc65bbc1248": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "DropdownModel", "state": { "_options_labels": [ "Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers", "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers", "Efficient-Large-Model/SANA1.5_1.6B_1024px_diffusers", "Efficient-Large-Model/Sana_600M_512px_diffusers", "Efficient-Large-Model/Sana_600M_1024px_diffusers", "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", "Efficient-Large-Model/Sana_1600M_2Kpx_BF16_diffusers", "Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers" ], "description": "Model:", "index": 0, "layout": "IPY_MODEL_4563a65764a14145a1162df470d070cb", "style": "IPY_MODEL_6ec311b1b08b4b2585bf7d65f4c5c6c6" } }, "ade5e04a8fbe4f4eb423b4d068bd1e91": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": { "layout": "IPY_MODEL_ecd185ce14044443b4487ef064737335", "style": "IPY_MODEL_294241cc8206489ba60486dfcc27480e", "value": "100%" } }, "aeb2c6c25b1045829c7f2c58aafba075": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "FloatProgressModel", "state": { "bar_style": "success", "layout": "IPY_MODEL_18573adbbdcc40f7a01e31664a0d69f0", "max": 2, "style": "IPY_MODEL_449957bf178b450c9172927791ceac57", "value": 2 } }, "b53bee3bfece4b5fb6db98da3b0b521a": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": { "description_width": "", "font_size": null, "text_color": null } }, "b880cbd3ca934f75b686c71bda25e8c8": { "model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {} }, "bf86142855a742158a52f66aa1441f13": { "model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {} }, "c5d0bb0ee54c4e708485b8b485171a68": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HBoxModel", "state": { "children": [ "IPY_MODEL_7a2878e99b2446b9b491acfef64136dc", "IPY_MODEL_aeb2c6c25b1045829c7f2c58aafba075", "IPY_MODEL_db3e826543b641fb927202bfcd43dd93" ], "layout": "IPY_MODEL_215992d572744b14997f0bc562c1c368" } }, "d6ddf4502e3d4f7aafd4f89f7ae554ff": { "model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {} }, "db3e826543b641fb927202bfcd43dd93": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": { "layout": "IPY_MODEL_764104cfe4f946f6aa12576a22dd0f64", "style": "IPY_MODEL_79873f5c63524783954468667a9c6216", "value": " 2/2 [00:01<00:00,  1.47steps/s]" } }, "dcf9dc892ce34e0095456cca9361b873": { "model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {} }, "e3fbbf28cfb54234995ddbcb6bb42d3a": { "model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {} }, "e599a5783b43467196593527fb7edb44": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "ProgressStyleModel", "state": { "description_width": "" } }, "ecd185ce14044443b4487ef064737335": { "model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {} }, "ef9831b22a614821b4e7e51583025aaa": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": { "layout": "IPY_MODEL_dcf9dc892ce34e0095456cca9361b873", "style": "IPY_MODEL_f838605564eb4754b5eacc099d26994b", "value": " 2/2 [00:01<00:00,  1.47steps/s]" } }, "f28fc004c13e4ba4a266e8a47a766165": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": { "description_width": "", "font_size": null, "text_color": null } }, "f411f35c75eb41348a5cf4f689664eaa": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HBoxModel", "state": { "children": [ "IPY_MODEL_ade5e04a8fbe4f4eb423b4d068bd1e91", "IPY_MODEL_5a06e1bb088543f3b21a641ecc0778af", "IPY_MODEL_ef9831b22a614821b4e7e51583025aaa" ], "layout": "IPY_MODEL_56decb57197b40769dc57ad3751121e8" } }, "f72f72b42470443193c4bb6be984992f": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": { "description_width": "", "font_size": null, "text_color": null } }, "f75cc4f9631d444a816241a1f638ed50": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "FloatProgressModel", "state": { "bar_style": "success", "layout": "IPY_MODEL_60677a1186f745a7bdb7899a80992869", "max": 2, "style": "IPY_MODEL_1e8c24ac8e7149edb22bd1d4c8a27954", "value": 2 } }, "f7ce615095aa43619cce7c37874614c7": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HBoxModel", "state": { "children": [ "IPY_MODEL_25f9b6cc4df24ca68e4f3fc695a60764", "IPY_MODEL_f75cc4f9631d444a816241a1f638ed50", "IPY_MODEL_9902932f4e5647809db84f0a5ac86f06" ], "layout": "IPY_MODEL_31e5ba79edc0420c8e3dda9594ec5adc" } }, "f838605564eb4754b5eacc099d26994b": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": { "description_width": "", "font_size": null, "text_color": null } }, "fe82d8fb9a9f411eb7b59c476d4ff377": { "model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {} } }, "version_major": 2, "version_minor": 0 } } }, "nbformat": 4, "nbformat_minor": 4 }