{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "7d0d9eba-7944-49b4-8b6b-6e05a02dfd02", "metadata": {}, "source": [ "# Image Generation with Stable Diffusion and IP-Adapter\n", "\n", "[IP-Adapter](https://hf.co/papers/2308.06721) is an effective and lightweight adapter that adds image prompting capabilities to a diffusion model. This adapter works by decoupling the cross-attention layers of the image and text features. All the other model components are frozen and only the embedded image features in the UNet are trained. As a result, IP-Adapter files are typically only ~100MBs.\n", "![ip-adapter-pipe.png](https://huggingface.co/h94/IP-Adapter/resolve/main/fig1.png)\n", "\n", "In this tutorial, we will consider how to convert and run Stable Diffusion pipeline with loading IP-Adapter. We will use [stable-diffusion-v1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) as base model and apply official [IP-Adapter](https://huggingface.co/h94/IP-Adapter) weights. Also for speedup generation process we will use [LCM-LoRA](https://huggingface.co/latent-consistency/lcm-lora-sdv1-5)\n", "\n", "\n", "\n", "\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "a754d3e8", "metadata": {}, "source": [ "#### Table of contents:\n", "\n", "- [Prerequisites](#Prerequisites)\n", "- [Prepare Diffusers pipeline](#Prepare-Diffusers-pipeline)\n", "- [Convert PyTorch models](#Convert-PyTorch-models)\n", " - [Image Encoder](#Image-Encoder)\n", " - [U-net](#U-net)\n", " - [VAE Encoder and Decoder](#VAE-Encoder-and-Decoder)\n", " - [Text Encoder](#Text-Encoder)\n", "- [Prepare OpenVINO inference pipeline](#Prepare-OpenVINO-inference-pipeline)\n", "- [Run model inference](#Run-model-inference)\n", " - [Select inference device](#Select-inference-device)\n", " - [Generation image variation](#Generation-image-variation)\n", " - [Generation conditioned by image and text](#Generation-conditioned-by-image-and-text)\n", " - [Generation image blending](#Generation-image-blending)\n", "- [Interactive demo](#Interactive-demo)\n", "\n", "\n", "### Installation Instructions\n", "\n", "This is a self-contained example that relies solely on its own code.\n", "\n", "We recommend running the notebook in a virtual environment. You only need a Jupyter server to start.\n", "For details, please refer to [Installation Guide](https://github.com/openvinotoolkit/openvino_notebooks/blob/latest/README.md#-installation-guide)." ] }, { "attachments": {}, "cell_type": "markdown", "id": "37460b75-dea3-4d6b-a068-2b5c70af9298", "metadata": {}, "source": [ "## Prerequisites\n", "[back to top ⬆️](#Table-of-contents:)" ] }, { "cell_type": "code", "execution_count": 1, "id": "04547ab5-2262-44c6-ba11-fa9d21c61d0b", "metadata": {}, "outputs": [], "source": [ "%pip install -q \"torch>=2.1\" transformers accelerate \"diffusers>=0.24.0\" \"openvino>=2023.3.0\" \"gradio>=4.19\" opencv-python \"peft>=0.6.2\" \"protobuf>=3.20\" --extra-index-url https://download.pytorch.org/whl/cpu\n", "%pip install -q \"matplotlib>=3.4\"" ] }, { "attachments": {}, "cell_type": "markdown", "id": "c1293390-38a5-4ab5-8bec-7a7145989fb6", "metadata": {}, "source": [ "## Prepare Diffusers pipeline\n", "[back to top ⬆️](#Table-of-contents:)\n", "\n", "First of all, we should collect all components of our pipeline together. To work with Stable Diffusion, we will use HuggingFace [Diffusers](https://github.com/huggingface/diffusers) library. To experiment with Stable Diffusion models, Diffusers exposes the [`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/using-diffusers/conditional_image_generation) similar to the [other Diffusers pipelines](https://huggingface.co/docs/diffusers/api/pipelines/overview). Additionally, the pipeline supports load adapters that extend Stable Diffusion functionality such as [Low-Rank Adaptation (LoRA)](https://huggingface.co/papers/2106.09685), [PEFT](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference), [IP-Adapter](https://ip-adapter.github.io/), and [Textual Inversion](https://textual-inversion.github.io/). You can find more information about supported adapters in [diffusers documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters).\n", "\n", "In this tutorial, we will focus on ip-adapter. IP-Adapter can be integrated into diffusion pipeline using `load_ip_adapter` method. IP-Adapter allows you to use both image and text to condition the image generation process. For adjusting the text prompt and image prompt condition ratio, we can use `set_ip_adapter_scale()` method. If you only use the image prompt, you should set the scale to 1.0. You can lower the scale to get more generation diversity, but it’ll be less aligned with the prompt. scale=0.5 can achieve good results when you use both text and image prompts.\n", "\n", "As discussed before, we will also use LCM LoRA for speeding generation process. You can find more information about LCM LoRA in this [notebook](../latent-consistency-models-image-generation/lcm-lora-controlnet.ipynb). For applying LCM LoRA, we should use `load_lora_weights` method. Additionally, LCM requires using LCMScheduler for efficient generation." ] }, { "cell_type": "code", "execution_count": 2, "id": "1d51d006-1568-437e-a240-11eb86a6de36", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "243ccb21b5ef41029f425a821ad602d0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading pipeline components...: 0%| | 0/7 [00:00 1.0\n", " # get prompt text embeddings\n", " text_embeddings = self._encode_prompt(\n", " prompt,\n", " do_classifier_free_guidance=do_classifier_free_guidance,\n", " negative_prompt=negative_prompt,\n", " )\n", " # get ip-adapter image embeddings\n", " image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image)\n", " if do_classifier_free_guidance:\n", " image_embeds = np.concatenate([negative_image_embeds, image_embeds])\n", "\n", " # set timesteps\n", " accepts_offset = \"offset\" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())\n", " extra_set_kwargs = {}\n", " if accepts_offset:\n", " extra_set_kwargs[\"offset\"] = 1\n", "\n", " self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)\n", " timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength)\n", " latent_timestep = timesteps[:1]\n", "\n", " # get the initial random noise unless the user supplied it\n", " latents, meta = self.prepare_latents(\n", " 1,\n", " 4,\n", " height or self.height,\n", " width or self.width,\n", " generator=generator,\n", " latents=latents,\n", " image=image,\n", " latent_timestep=latent_timestep,\n", " )\n", "\n", " # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n", " # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n", " # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502\n", " # and should be between [0, 1]\n", " accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n", " extra_step_kwargs = {}\n", " if accepts_eta:\n", " extra_step_kwargs[\"eta\"] = eta\n", "\n", " for i, t in enumerate(self.progress_bar(timesteps)):\n", " # expand the latents if you are doing classifier free guidance\n", " latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents\n", " latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n", "\n", " # predict the noise residual\n", " noise_pred = self.unet([latent_model_input, t, text_embeddings, image_embeds])[0]\n", " # perform guidance\n", " if do_classifier_free_guidance:\n", " noise_pred_uncond, noise_pred_text = noise_pred[0], noise_pred[1]\n", " noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n", "\n", " # compute the previous noisy sample x_t -> x_t-1\n", " latents = self.scheduler.step(\n", " torch.from_numpy(noise_pred),\n", " t,\n", " torch.from_numpy(latents),\n", " **extra_step_kwargs,\n", " )[\"prev_sample\"].numpy()\n", "\n", " # scale and decode the image latents with vae\n", " image = self.vae_decoder(latents * (1 / 0.18215))[0]\n", "\n", " image = self.postprocess_image(image, meta, output_type)\n", " return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=False)\n", "\n", " def _encode_prompt(\n", " self,\n", " prompt: Union[str, List[str]],\n", " num_images_per_prompt: int = 1,\n", " do_classifier_free_guidance: bool = True,\n", " negative_prompt: Union[str, List[str]] = None,\n", " ):\n", " \"\"\"\n", " Encodes the prompt into text encoder hidden states.\n", "\n", " Parameters:\n", " prompt (str or list(str)): prompt to be encoded\n", " num_images_per_prompt (int): number of images that should be generated per prompt\n", " do_classifier_free_guidance (bool): whether to use classifier free guidance or not\n", " negative_prompt (str or list(str)): negative prompt to be encoded.\n", " Returns:\n", " text_embeddings (np.ndarray): text encoder hidden states\n", " \"\"\"\n", " batch_size = len(prompt) if isinstance(prompt, list) else 1\n", "\n", " # tokenize input prompts\n", " text_inputs = self.tokenizer(\n", " prompt,\n", " padding=\"max_length\",\n", " max_length=self.tokenizer.model_max_length,\n", " truncation=True,\n", " return_tensors=\"np\",\n", " )\n", " text_input_ids = text_inputs.input_ids\n", "\n", " text_embeddings = self.text_encoder(text_input_ids)[0]\n", "\n", " # duplicate text embeddings for each generation per prompt\n", " if num_images_per_prompt != 1:\n", " bs_embed, seq_len, _ = text_embeddings.shape\n", " text_embeddings = np.tile(text_embeddings, (1, num_images_per_prompt, 1))\n", " text_embeddings = np.reshape(text_embeddings, (bs_embed * num_images_per_prompt, seq_len, -1))\n", "\n", " # get unconditional embeddings for classifier free guidance\n", " if do_classifier_free_guidance:\n", " uncond_tokens: List[str]\n", " max_length = text_input_ids.shape[-1]\n", " if negative_prompt is None:\n", " uncond_tokens = [\"\"] * batch_size\n", " elif isinstance(negative_prompt, str):\n", " uncond_tokens = [negative_prompt]\n", " else:\n", " uncond_tokens = negative_prompt\n", " uncond_input = self.tokenizer(\n", " uncond_tokens,\n", " padding=\"max_length\",\n", " max_length=max_length,\n", " truncation=True,\n", " return_tensors=\"np\",\n", " )\n", "\n", " uncond_embeddings = self.text_encoder(uncond_input.input_ids)[0]\n", "\n", " # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n", " seq_len = uncond_embeddings.shape[1]\n", " uncond_embeddings = np.tile(uncond_embeddings, (1, num_images_per_prompt, 1))\n", " uncond_embeddings = np.reshape(uncond_embeddings, (batch_size * num_images_per_prompt, seq_len, -1))\n", "\n", " # For classifier-free guidance, we need to do two forward passes.\n", " # Here we concatenate the unconditional and text embeddings into a single batch\n", " # to avoid doing two forward passes\n", " text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])\n", "\n", " return text_embeddings\n", "\n", " def prepare_latents(\n", " self,\n", " batch_size,\n", " num_channels_latents,\n", " height,\n", " width,\n", " dtype=torch.float32,\n", " generator=None,\n", " latents=None,\n", " image=None,\n", " latent_timestep=None,\n", " ):\n", " shape = (\n", " batch_size,\n", " num_channels_latents,\n", " height // self.vae_scale_factor,\n", " width // self.vae_scale_factor,\n", " )\n", " if isinstance(generator, list) and len(generator) != batch_size:\n", " raise ValueError(\n", " f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n", " f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n", " )\n", "\n", " if latents is None:\n", " latents = randn_tensor(shape, generator=generator, dtype=dtype)\n", "\n", " if image is None:\n", " # scale the initial noise by the standard deviation required by the scheduler\n", " latents = latents * self.scheduler.init_noise_sigma\n", " return latents.numpy(), {}\n", " input_image, meta = preprocess(image, height, width)\n", " image_latents = self.vae_encoder(input_image)[0]\n", " image_latents = image_latents * 0.18215\n", " latents = self.scheduler.add_noise(torch.from_numpy(image_latents), latents, latent_timestep).numpy()\n", " return latents, meta\n", "\n", " def postprocess_image(self, image: np.ndarray, meta: Dict, output_type: str = \"pil\"):\n", " \"\"\"\n", " Postprocessing for decoded image. Takes generated image decoded by VAE decoder, unpad it to initial image size (if required),\n", " normalize and convert to [0, 255] pixels range. Optionally, converts it from np.ndarray to PIL.Image format\n", "\n", " Parameters:\n", " image (np.ndarray):\n", " Generated image\n", " meta (Dict):\n", " Metadata obtained on the latents preparing step can be empty\n", " output_type (str, *optional*, pil):\n", " Output format for result, can be pil or numpy\n", " Returns:\n", " image (List of np.ndarray or PIL.Image.Image):\n", " Post-processed images\n", " \"\"\"\n", " if \"padding\" in meta:\n", " pad = meta[\"padding\"]\n", " (_, end_h), (_, end_w) = pad[1:3]\n", " h, w = image.shape[2:]\n", " unpad_h = h - end_h\n", " unpad_w = w - end_w\n", " image = image[:, :, :unpad_h, :unpad_w]\n", " image = np.clip(image / 2 + 0.5, 0, 1)\n", " image = np.transpose(image, (0, 2, 3, 1))\n", " # 9. Convert to PIL\n", " if output_type == \"pil\":\n", " image = self.numpy_to_pil(image)\n", " if \"src_height\" in meta:\n", " orig_height, orig_width = meta[\"src_height\"], meta[\"src_width\"]\n", " image = [img.resize((orig_width, orig_height), PIL.Image.Resampling.LANCZOS) for img in image]\n", " else:\n", " if \"src_height\" in meta:\n", " orig_height, orig_width = meta[\"src_height\"], meta[\"src_width\"]\n", " image = [cv2.resize(img, (orig_width, orig_width)) for img in image]\n", " return image\n", "\n", " def encode_image(self, image, num_images_per_prompt=1):\n", " if not isinstance(image, torch.Tensor):\n", " image = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n", "\n", " image_embeds = self.image_encoder(image)[0]\n", " if num_images_per_prompt > 1:\n", " image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n", "\n", " uncond_image_embeds = np.zeros(image_embeds.shape)\n", " return image_embeds, uncond_image_embeds\n", "\n", " def get_timesteps(self, num_inference_steps: int, strength: float):\n", " \"\"\"\n", " Helper function for getting scheduler timesteps for generation\n", " In case of image-to-image generation, it updates number of steps according to strength\n", "\n", " Parameters:\n", " num_inference_steps (int):\n", " number of inference steps for generation\n", " strength (float):\n", " value between 0.0 and 1.0, that controls the amount of noise that is added to the input image.\n", " Values that approach 1.0 allow for lots of variations but will also produce images that are not semantically consistent with the input.\n", " \"\"\"\n", " # get the original timestep using init_timestep\n", " init_timestep = min(int(num_inference_steps * strength), num_inference_steps)\n", "\n", " t_start = max(num_inference_steps - init_timestep, 0)\n", " timesteps = self.scheduler.timesteps[t_start:]\n", "\n", " return timesteps, num_inference_steps - t_start" ] }, { "attachments": {}, "cell_type": "markdown", "id": "73369a60-990e-4dc7-b498-1db962398298", "metadata": {}, "source": [ "## Run model inference\n", "[back to top ⬆️](#Table-of-contents:)\n", "\n", "Now let's configure our pipeline and take a look on generation results.\n", "\n", "### Select inference device\n", "[back to top ⬆️](#Table-of-contents:)\n", "\n", "Select inference device from dropdown list." ] }, { "cell_type": "code", "execution_count": 8, "id": "d3a33e7a-1397-4170-bcfe-48809dd4d963", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "fa7731148d4c42468c9f518e3be708b3", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Dropdown(description='Device:', options=('CPU', 'GPU.0', 'GPU.1', 'AUTO'), value='CPU')" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import requests\n", "\n", "r = requests.get(\n", " url=\"https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/notebook_utils.py\",\n", ")\n", "open(\"notebook_utils.py\", \"w\").write(r.text)\n", "\n", "from notebook_utils import device_widget\n", "\n", "device = device_widget()\n", "\n", "device" ] }, { "cell_type": "code", "execution_count": 9, "id": "e3cb7120-9b5b-44b2-907c-f6e02df00371", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The config attributes {'skip_prk_steps': True} were passed to LCMScheduler, but are not expected and will be ignored. Please verify your scheduler_config.json configuration file.\n" ] } ], "source": [ "from transformers import AutoTokenizer\n", "\n", "core = ov.Core()\n", "\n", "ov_config = {\"INFERENCE_PRECISION_HINT\": \"f32\"} if device.value != \"CPU\" else {}\n", "vae_decoder = core.compile_model(VAE_DECODER_PATH, device.value, ov_config)\n", "vae_encoder = core.compile_model(VAE_ENCODER_PATH, device.value, ov_config)\n", "text_encoder = core.compile_model(TEXT_ENCODER_PATH, device.value)\n", "image_encoder = core.compile_model(IMAGE_ENCODER_PATH, device.value)\n", "unet = core.compile_model(UNET_PATH, device.value)\n", "\n", "scheduler = LCMScheduler.from_pretrained(models_dir / \"scheduler\")\n", "tokenizer = AutoTokenizer.from_pretrained(models_dir / \"tokenizer\")\n", "feature_extractor = CLIPImageProcessor.from_pretrained(models_dir / \"feature_extractor\")\n", "\n", "ov_pipe = OVStableDiffusionPipeline(\n", " vae_decoder,\n", " text_encoder,\n", " tokenizer,\n", " unet,\n", " scheduler,\n", " image_encoder,\n", " feature_extractor,\n", " vae_encoder,\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "7aaa5f01-dafd-463d-be36-13505d2f7007", "metadata": {}, "source": [ "### Generation image variation\n", "[back to top ⬆️](#Table-of-contents:)\n", "\n", "If we stay input text prompt empty and provide only ip-adapter image, we can get variation of the same image." ] }, { "cell_type": "code", "execution_count": 10, "id": "2556696d-9a33-4300-8199-dc2e15e393fa", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "\n", "def visualize_results(images, titles):\n", " \"\"\"\n", " Helper function for results visualization\n", "\n", " Parameters:\n", " orig_img (PIL.Image.Image): original image\n", " processed_img (PIL.Image.Image): processed image after editing\n", " img1_title (str): title for the image on the left\n", " img2_title (str): title for the image on the right\n", " Returns:\n", " fig (matplotlib.pyplot.Figure): matplotlib generated figure contains drawing result\n", " \"\"\"\n", " im_w, im_h = images[0].size\n", " is_horizontal = im_h <= im_w\n", " figsize = (10, 15 * len(images)) if is_horizontal else (15 * len(images), 10)\n", " fig, axs = plt.subplots(\n", " 1 if is_horizontal else len(images),\n", " len(images) if is_horizontal else 1,\n", " figsize=figsize,\n", " sharex=\"all\",\n", " sharey=\"all\",\n", " )\n", " fig.patch.set_facecolor(\"white\")\n", " list_axes = list(axs.flat)\n", " for a in list_axes:\n", " a.set_xticklabels([])\n", " a.set_yticklabels([])\n", " a.get_xaxis().set_visible(False)\n", " a.get_yaxis().set_visible(False)\n", " a.grid(False)\n", " for image, title, ax in zip(images, titles, list_axes):\n", " ax.imshow(np.array(image))\n", " ax.set_title(title, fontsize=20)\n", " fig.subplots_adjust(wspace=0.0 if is_horizontal else 0.01, hspace=0.01 if is_horizontal else 0.0)\n", " fig.tight_layout()\n", " return fig" ] }, { "cell_type": "code", "execution_count": 11, "id": "12fc99c7-684c-49cb-acca-753ef21950a3", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f1965ac03cee4f279be657d40dc88f3c", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/4 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "generator = torch.Generator(device=\"cpu\").manual_seed(576)\n", "\n", "image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png\")\n", "\n", "result = ov_pipe(\n", " prompt=\"\",\n", " ip_adapter_image=image,\n", " gaidance_scale=1,\n", " negative_prompt=\"\",\n", " num_inference_steps=4,\n", " generator=generator,\n", ")\n", "\n", "fig = visualize_results([image, result.images[0]], [\"input image\", \"result\"])" ] }, { "attachments": {}, "cell_type": "markdown", "id": "8a1405c3-fd46-4ff6-b3e0-c3d40bb72f0e", "metadata": {}, "source": [ "### Generation conditioned by image and text\n", "[back to top ⬆️](#Table-of-contents:)\n", "\n", "IP-Adapter allows you to use both image and text to condition the image generation process. Both IP-Adapter image and text prompt serve as extension for each other, for example we can use a text prompt to add “sunglasses” 😎 on previous image." ] }, { "cell_type": "code", "execution_count": 12, "id": "94a2679e-9801-49a3-895f-e8465a16fbc3", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f53693ca72c342d890e25c806232d8ca", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/4 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig = visualize_results([image, result.images[0]], [\"input image\", \"result\"])" ] }, { "attachments": {}, "cell_type": "markdown", "id": "fe0acc45-a646-4838-ab56-88bd3b726d59", "metadata": {}, "source": [ "### Generation image blending\n", "[back to top ⬆️](#Table-of-contents:)\n", "\n", "IP-Adapter also works great with Image-to-Image translation. It helps to achieve image blending effect." ] }, { "cell_type": "code", "execution_count": 14, "id": "90da4fbf-8368-4ef1-b917-e8853eeb03f9", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "89e1c79bb73c40a0beef4b1a8acff384", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/5 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig = visualize_results([image, ip_image, result.images[0]], [\"input image\", \"ip-adapter image\", \"result\"])" ] }, { "attachments": {}, "cell_type": "markdown", "id": "6eba1997-6c8c-45e0-bb5d-958ac3066bff", "metadata": {}, "source": [ "## Interactive demo\n", "[back to top ⬆️](#Table-of-contents:)\n", "\n", "Now, you can try model using own images and text prompts." ] }, { "cell_type": "code", "execution_count": null, "id": "8b81ba28-934e-46d2-b061-65888da9103a", "metadata": {}, "outputs": [], "source": [ "import requests\n", "from pathlib import Path\n", "\n", "if not Path(\"gradio_helper.py\").exists():\n", " r = requests.get(url=\"https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/notebooks/stable-diffusion-ip-adapter/gradio_helper.py\")\n", " open(\"gradio_helper.py\", \"w\").write(r.text)\n", "\n", "from gradio_helper import make_demo\n", "\n", "demo = make_demo(ov_pipe)\n", "\n", "try:\n", " demo.queue().launch(debug=True)\n", "except Exception:\n", " demo.queue().launch(share=True, debug=True)\n", "# if you are launching remotely, specify server_name and server_port\n", "# demo.launch(server_name='your server name', server_port='server port in int')\n", "# Read more in the docs: https://gradio.app/docs/" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" }, "openvino_notebooks": { "imageUrl": "https://github.com/openvinotoolkit/openvino_notebooks/blob/latest/notebooks/stable-diffusion-ip-adapter/stable-diffusion-ip-adapter.png?raw=true", "tags": { "categories": [ "Model Demos", "AI Trends" ], "libraries": [], "other": [ "Stable Diffusion" ], "tasks": [ "Image-to-Image", "Text-to-Image" ] } }, "widgets": { "application/vnd.jupyter.widget-state+json": { "state": {}, "version_major": 2, "version_minor": 0 } } }, "nbformat": 4, "nbformat_minor": 5 }