{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "colab": { "base_uri": "https://localhost:8080/" }, "id": "XU7NuMAA2drw", "outputId": "7eb9b063-664f-4a42-e960-728ec9608c42" }, "outputs": [], "source": [ "#@markdown Check type of GPU and VRAM available.\n", "#!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader\n", "!nvidia-smi" ] }, { "cell_type": "markdown", "metadata": { "id": "BzM7j0ZSc_9c" }, "source": [ "https://github.com/ShivamShrirao/diffusers/tree/main/examples/dreambooth\n", "with changes to make it work on SageMaker Studio Lab" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#%conda remove --name studiolab --all\n", "# Restore/Reset env again with Studio default please use the Yaml file from /opt/amazon/sagemaker/environments/\n", "#%conda env create -f /opt/amazon/sagemaker/environments/studiolab.yaml # or default.yaml" ] }, { "cell_type": "markdown", "metadata": { "id": "wnTMyW41cC1E" }, "source": [ "## Install Requirements" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Prepare for installation of xformers for SageMaker Studio lab\n", "%conda install -y -c pytorch -c conda-forge cudatoolkit=11.6 pytorch=1.12.1 torchvision==0.13.1" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aLWXPZqjsZVV" }, "outputs": [], "source": [ "!wget -q https://github.com/ShivamShrirao/diffusers/raw/main/examples/dreambooth/train_dreambooth.py\n", "!wget -q https://github.com/ShivamShrirao/diffusers/raw/main/scripts/convert_diffusers_to_original_stable_diffusion.py\n", "%pip install -qq git+https://github.com/ShivamShrirao/diffusers\n", "#!wget -q https://github.com/Miraculix200/diffusers/raw/main/examples/dreambooth/train_dreambooth.py\n", "#!wget -q https://github.com/Miraculix200/diffusers/raw/main/scripts/convert_diffusers_to_original_stable_diffusion.py\n", "#%pip install -qq git+https://github.com/Miraculix200/diffusers\n", "%pip install -q -U --pre triton\n", "%pip install -q accelerate==0.12.0 transformers ftfy bitsandbytes tensorboard natsort" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "bb33ec48088e401c89433c77d750f72b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Text(value='Enter Huggingface token here', description='Token:')" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Hugginface token, necessary to download the model\n", "# You only need to do this once, unless you delete the ~/.huggingface folder\n", "\n", "import ipywidgets as widgets\n", "\n", "token_textbox = widgets.Text(\n", " value='Enter Huggingface token here',\n", " description='Token:',\n", ")\n", "token_textbox" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Save Huggingface token to ~/.huggingface\n", "\n", "!mkdir -p ~/.huggingface\n", "!echo -n \"{token_textbox.value}\" > ~/.huggingface/token" ] }, { "cell_type": "markdown", "metadata": { "id": "XfTlc8Mqb8iH", "tags": [] }, "source": [ "### Install xformers" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Installation of xformers for SageMaker Studio lab\n", "#%conda install -y -c pytorch -c conda-forge cudatoolkit=11.6 pytorch=1.12.1 \n", "%conda install -y xformers -c xformers/label/dev" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "n6dcjPnnaiCn", "jupyter": { "source_hidden": true }, "outputId": "ac7dc3db-27a2-4dd4-963e-4d77dc313a5d", "tags": [] }, "outputs": [], "source": [ "#ignore this cell\n", "\n", "#%pip install -q https://github.com/metrolobo/xformers_wheels/releases/download/1d31a3ac_various_6/xformers-0.0.14.dev0-cp37-cp37m-linux_x86_64.whl\n", "# These were compiled on Tesla T4, should also work on P100, thanks to https://github.com/metrolobo\n", "\n", "# If precompiled wheels don't work, install it with the following command. It will take around 40 minutes to compile.\n", "#%pip install git+https://github.com/facebookresearch/xformers@1d31a3a#egg=xformers" ] }, { "cell_type": "markdown", "metadata": { "id": "G0NV324ZcL9L" }, "source": [ "## Settings and run" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "Rxg0y5MBudmd" }, "outputs": [], "source": [ "#@markdown If model weights should be saved directly in google drive (takes around 4-5 GB).\n", "\n", "# Saving to gdrive won't work on SageMaker Studio Lab, so don't change this variable\n", "save_to_gdrive = False #@param {type:\"boolean\"}\n", "#if save_to_gdrive:\n", "# from google.colab import drive\n", "# drive.mount('~/sagemaker-studiolab-notebooks/dreambooth/content/drive')\n", "\n", "#@markdown Name/Path of the initial model.\n", "MODEL_NAME = \"runwayml/stable-diffusion-v1-5\" #@param {type:\"string\"}\n", "\n", "#@markdown Enter the directory name to save model at.\n", "\n", "OUTPUT_DIR = \"stable_diffusion_weights/zwx\" #@param {type:\"string\"}\n", "if save_to_gdrive:\n", " OUTPUT_DIR = \"content/drive/MyDrive/\" + OUTPUT_DIR\n", "else:\n", " OUTPUT_DIR = \"content/\" + OUTPUT_DIR\n", "\n", "print(f\"[*] Weights will be saved at {OUTPUT_DIR}\")\n", "\n", "!mkdir -p $OUTPUT_DIR" ] }, { "cell_type": "markdown", "metadata": { "id": "qn5ILIyDJIcX" }, "source": [ "# Start Training\n", "\n", "Use the table below to choose the best flags based on your memory and speed requirements. Tested on Tesla T4 GPU.\n", "\n", "\n", "| `fp16` | `train_batch_size` | `gradient_accumulation_steps` | `gradient_checkpointing` | `use_8bit_adam` | GB VRAM usage | Speed (it/s) |\n", "| ---- | ------------------ | ----------------------------- | ----------------------- | --------------- | ---------- | ------------ |\n", "| fp16 | 1 | 1 | TRUE | TRUE | 9.92 | 0.93 |\n", "| no | 1 | 1 | TRUE | TRUE | 10.08 | 0.42 |\n", "| fp16 | 2 | 1 | TRUE | TRUE | 10.4 | 0.66 |\n", "| fp16 | 1 | 1 | FALSE | TRUE | 11.17 | 1.14 |\n", "| no | 1 | 1 | FALSE | TRUE | 11.17 | 0.49 |\n", "| fp16 | 1 | 2 | TRUE | TRUE | 11.56 | 1 |\n", "| fp16 | 2 | 1 | FALSE | TRUE | 13.67 | 0.82 |\n", "| fp16 | 1 | 2 | FALSE | TRUE | 13.7 | 0.83 |\n", "| fp16 | 1 | 1 | TRUE | FALSE | 15.79 | 0.77 |\n" ] }, { "cell_type": "markdown", "metadata": { "id": "-ioxxvHoicPs" }, "source": [ "Add `--gradient_checkpointing` flag for around 9.92 GB VRAM usage.\n", "\n", "remove `--use_8bit_adam` flag for full precision. Requires 15.79 GB with `--gradient_checkpointing` else 17.8 GB.\n", "\n", "remove `--train_text_encoder` flag to reduce memory usage further, degrades output quality." ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "id": "5vDpCxId1aCm" }, "outputs": [], "source": [ "# You can also add multiple concepts here. Try tweaking `--max_train_steps` accordingly.\n", "\n", "concepts_list = [\n", " {\n", " \"instance_prompt\": \"photo of zwx dog\",\n", " \"class_prompt\": \"photo of a dog\",\n", " \"instance_data_dir\": \"content/data/zwx\",\n", " \"class_data_dir\": \"content/data/dog\"\n", " },\n", "# {\n", "# \"instance_prompt\": \"photo of ukj person\",\n", "# \"class_prompt\": \"photo of a person\",\n", "# \"instance_data_dir\": \"/content/data/ukj\",\n", "# \"class_data_dir\": \"/content/data/person\"\n", "# }\n", "]\n", "\n", "# `class_data_dir` contains regularization images\n", "import json\n", "import os\n", "for c in concepts_list:\n", " os.makedirs(c[\"instance_data_dir\"], exist_ok=True)\n", "\n", "with open(\"concepts_list.json\", \"w\") as f:\n", " json.dump(concepts_list, f, indent=4)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2e8bdc274c2c444bab0507ead6041048", "version_major": 2, "version_minor": 0 }, "text/plain": [ "FileUpload(value={}, accept='.jpg', description='Upload', multiple=True)" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# SageMaker Studio Lab:\n", "# manually upload images to /content/data/zwx\n", "# or if you only train a single concept, run this cell to open a file upload dialog\n", "\n", "from ipywidgets import FileUpload\n", "from IPython.display import display\n", "upload = FileUpload(accept='.jpg', multiple=True)\n", "display(upload)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# if you only train a single concept, run this cell if you used the file upload dialog to upload images\n", "\n", "INSTANCE_DIR = concepts_list[0][\"instance_data_dir\"]\n", "print(\"Creating instance folder: \" + INSTANCE_DIR)\n", "!mkdir -p $INSTANCE_DIR\n", " \n", "for name, file_info in upload.value.items():\n", " with open(name, 'wb') as fp:\n", " print(\"Uploading: \" + name)\n", " fp.write(file_info['content'])\n", " !mv $name $INSTANCE_DIR\n", " \n", "print(\"### Upload complete ###\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jjcSXTp-u-Eg" }, "outputs": [], "source": [ "!accelerate launch train_dreambooth.py \\\n", " --pretrained_model_name_or_path=$MODEL_NAME \\\n", " --pretrained_vae_name_or_path=\"stabilityai/sd-vae-ft-mse\" \\\n", " --output_dir=$OUTPUT_DIR \\\n", " --with_prior_preservation --prior_loss_weight=1.0 \\\n", " --seed=1337 \\\n", " --resolution=512 \\\n", " --train_batch_size=1 \\\n", " --train_text_encoder \\\n", " --mixed_precision=\"fp16\" \\\n", " --use_8bit_adam \\\n", " --gradient_accumulation_steps=1 \\\n", " --learning_rate=1e-6 \\\n", " --lr_scheduler=\"constant\" \\\n", " --lr_warmup_steps=0 \\\n", " --num_class_images=50 \\\n", " --sample_batch_size=4 \\\n", " --max_train_steps=1500 \\\n", " --save_interval=10000 \\\n", " --save_sample_prompt=\"photo of zwx dog\" \\\n", " --concepts_list=\"concepts_list.json\"\n", "\n", "# Reduce the `--save_interval` to lower than `--max_train_steps` to save weights from intermediate steps.\n", "# `--save_sample_prompt` can be same as `--instance_prompt` to generate intermediate samples (saved along with weights in samples directory)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "89Az5NUxOWdy", "jupyter": { "source_hidden": true }, "tags": [] }, "outputs": [], "source": [ "#@markdown Run to generate a grid of preview images from the last saved weights.\n", "import os\n", "import matplotlib.pyplot as plt\n", "import matplotlib.image as mpimg\n", "\n", "weights_folder = OUTPUT_DIR\n", "folders = sorted([f for f in os.listdir(weights_folder) if f != \"0\"], key=lambda x: int(x))\n", "\n", "row = len(folders)\n", "col = len(os.listdir(os.path.join(weights_folder, folders[0], \"samples\")))\n", "scale = 4\n", "fig, axes = plt.subplots(row, col, figsize=(col*scale, row*scale), gridspec_kw={'hspace': 0, 'wspace': 0})\n", "\n", "for i, folder in enumerate(folders):\n", " folder_path = os.path.join(weights_folder, folder)\n", " image_folder = os.path.join(folder_path, \"samples\")\n", " images = [f for f in os.listdir(image_folder)]\n", " for j, image in enumerate(images):\n", " if row == 1:\n", " currAxes = axes[j]\n", " else:\n", " currAxes = axes[i, j]\n", " if i == 0:\n", " currAxes.set_title(f\"Image {j}\")\n", " if j == 0:\n", " currAxes.text(-0.1, 0.5, folder, rotation=0, va='center', ha='center', transform=currAxes.transAxes)\n", " image_path = os.path.join(image_folder, image)\n", " img = mpimg.imread(image_path)\n", " currAxes.imshow(img, cmap='gray')\n", " currAxes.axis('off')\n", " \n", "plt.tight_layout()\n", "plt.savefig('grid.png', dpi=72)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "89Az5NUxOWdy" }, "outputs": [], "source": [ "#@markdown Specify the weights directory to use (leave blank for latest)\n", "WEIGHTS_DIR = \"\" #@param {type:\"string\"}\n", "if WEIGHTS_DIR == \"\":\n", " from natsort import natsorted\n", " from glob import glob\n", " import os\n", " WEIGHTS_DIR = natsorted(glob(OUTPUT_DIR + os.sep + \"*\"))[-1]\n", "print(f\"[*] WEIGHTS_DIR={WEIGHTS_DIR}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "5V8wgU0HN-Kq" }, "source": [ "## Convert weights to ckpt to use in web UIs like AUTOMATIC1111." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "dcXzsUyG1aCy" }, "outputs": [], "source": [ "#@markdown Run conversion.\n", "ckpt_path = WEIGHTS_DIR + \"/model.ckpt\"\n", "\n", "half_arg = \"\"\n", "#@markdown Whether to convert to fp16, takes half the space (2GB).\n", "fp16 = False #@param {type: \"boolean\"}\n", "if fp16:\n", " half_arg = \"--half\"\n", "!python convert_diffusers_to_original_stable_diffusion.py --model_path $WEIGHTS_DIR --checkpoint_path $ckpt_path $half_arg\n", "print(f\"[*] Converted ckpt saved at {ckpt_path}\")" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "# Upload .ckpt to MEGA" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "# install mega.py \n", "%pip install mega.py" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f9ddf0f392494b3d83a88f13daf35f82", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Text(value='MEGA username', description='Username:')" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Prompt for username\n", "\n", "import ipywidgets as widgets\n", "\n", "username_textbox = widgets.Text(\n", " value='MEGA username',\n", " description='Username:',\n", ")\n", "username_textbox\n" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4e76c5c70a3d4511b1f82df5e08b8825", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Text(value='MEGA password', description='Password:')" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Prompt for password\n", "\n", "password_textbox = widgets.Text(\n", " value='MEGA password',\n", " description='Password:',\n", ")\n", "password_textbox" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# upload model.ckpt to mega.nz\n", "import time\n", "\n", "start = time.time()\n", "\n", "from mega import Mega\n", "mega = Mega()\n", "m = mega.login(username_textbox.value, password_textbox.value)\n", "\n", "print(\"Uploading to mega.nz. This can take 20-30 minutes...\")\n", "file = m.upload(ckpt_path)\n", "flink = m.get_upload_link(file)\n", "\n", "end = time.time()\n", "elapsed = int(time.time() - start)\n", "\n", "print(\"Upload complete after \" + str(elapsed) + \" seconds\")\n", "print(\"Public download link: \" + flink);\n", "print(\"The content folder can now be deleted to free up space\")\n", "print(\"Alternatively, if you want to use the downladed weights again, use the next cell to just delete model.ckpt\")\n", "\n", "#files = m.get_files()\n", "#for f in files:\n", "# print(f)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "# Delete model.ckpt to free space\n", "!rm $ckpt_path" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jXgi8HM4c-DA" }, "outputs": [], "source": [ "#@title Free runtime memory\n", "exit()" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "provenance": [] }, "kernelspec": { "display_name": "default:Python", "language": "python", "name": "conda-env-default-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" }, "vscode": { "interpreter": { "hash": "e7370f93d1d0cde622a1f8e1c04877d8463912d04d973331ad4851f04de6915a" } }, "widgets": { "application/vnd.jupyter.widget-state+json": {} } }, "nbformat": 4, "nbformat_minor": 4 }