{ "cells": [ { "cell_type": "markdown", "id": "949f360e", "metadata": {}, "source": [ "## Protein Folding with ESMFold and 🤗`transformers`" ] }, { "cell_type": "markdown", "id": "ab50e270", "metadata": {}, "source": [ "ESMFold ([paper link](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v2)) is a recently released protein folding model from FAIR. Unlike other protein folding models, it does not require external databases or search tools to predict structures, and is up to 60X faster as a result.\n", "\n", "The port to the HuggingFace `transformers` library is even easier to use, as we've removed the dependency on tools like `openfold` - once you `pip install transformers`, you're ready to use this model! \n", "\n", "Note that all the code that follows will be running the model **locally**, rather than calling an external API. This means that no rate limiting applies here - you can predict as many structures as your computer can handle. \n", "\n", "In testing, we found that ESMFold needs about 16-24GB of GPU memory to run well, depending on protein length. This may be too much for the smaller free GPUs on Colab." ] }, { "cell_type": "markdown", "id": "a2f53405", "metadata": {}, "source": [ "First step, make sure you're up to date - you'll need the most recent release of `transformers` and `accelerate`! If you want to visualize your predicted protein structure in the notebook, you should also install py3Dmol. " ] }, { "cell_type": "code", "execution_count": null, "id": "eb29483f", "metadata": { "scrolled": false }, "outputs": [], "source": [ "! pip install --upgrade transformers py3Dmol accelerate" ] }, { "cell_type": "markdown", "id": "eb4bb6a8", "metadata": {}, "source": [ "We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely." ] }, { "cell_type": "code", "execution_count": null, "id": "889b852f", "metadata": {}, "outputs": [], "source": [ "from transformers.utils import send_example_telemetry\n", "\n", "send_example_telemetry(\"protein_folding_notebook\", framework=\"pytorch\")" ] }, { "cell_type": "markdown", "id": "1dca9819", "metadata": {}, "source": [ "## Preparing your model and tokenizer" ] }, { "cell_type": "markdown", "id": "c418e286", "metadata": {}, "source": [ "Now we load our model and tokenizer. If using GPU, use `model.cuda()` to transfer the model to GPU." ] }, { "cell_type": "code", "execution_count": 1, "id": "c200c170", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoTokenizer, EsmForProteinFolding\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"facebook/esmfold_v1\")\n", "model = EsmForProteinFolding.from_pretrained(\"facebook/esmfold_v1\", low_cpu_mem_usage=True)\n", "\n", "model = model.cuda()" ] }, { "cell_type": "markdown", "id": "a6f43d78", "metadata": {}, "source": [ "## Performance optimizations" ] }, { "cell_type": "markdown", "id": "cc0f0186", "metadata": {}, "source": [ "Since ESMFold is quite a large model, there are some considerations regarding memory usage and performance.\n", "\n", "Firstly, we can optionally convert the language model stem to float16 to improve performance and memory usage when running on a modern GPU. This was used during model training, and so should not make the outputs from the rest of the model invalid." ] }, { "cell_type": "code", "execution_count": 2, "id": "90ee986d", "metadata": {}, "outputs": [], "source": [ "# Uncomment to switch the stem to float16\n", "model.esm = model.esm.half()" ] }, { "cell_type": "markdown", "id": "7c647c77", "metadata": {}, "source": [ "Secondly, you can enable TensorFloat32 computation for a general speedup if your hardware supports it. This line has no effect if your hardware doesn't support it." ] }, { "cell_type": "code", "execution_count": 4, "id": "c2bd9e11", "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "torch.backends.cuda.matmul.allow_tf32 = True" ] }, { "cell_type": "markdown", "id": "a8eefe1b", "metadata": {}, "source": [ "Finally, we can reduce the 'chunk_size' used in the folding trunk. Smaller chunk sizes use less memory, but have slightly worse performance." ] }, { "cell_type": "code", "execution_count": 5, "id": "6f8ba985", "metadata": {}, "outputs": [], "source": [ "# Uncomment this line if your GPU memory is 16GB or less, or if you're folding longer (over 600 or so) sequences\n", "model.trunk.set_chunk_size(64)" ] }, { "cell_type": "markdown", "id": "c9a26e91", "metadata": {}, "source": [ "## Folding a single chain" ] }, { "cell_type": "markdown", "id": "8752706a", "metadata": {}, "source": [ "First, we tokenize our input. If you've used `transformers` before, proteins are processed like any other input string. Make sure **not** to add special tokens - ESM was trained with them, but ESMFold was trained without them. " ] }, { "cell_type": "code", "execution_count": 6, "id": "dde34627", "metadata": {}, "outputs": [], "source": [ "# This is the sequence for human GNAT1, because I worked on it when\n", "# I was a postdoc and so everyone else has to learn to appreciate it too.\n", "# Feel free to substitute your own peptides of interest\n", "# Depending on memory constraints you may wish to use shorter sequences.\n", "test_protein = \"MGAGASAEEKHSRELEKKLKEDAEKDARTVKLLLLGAGESGKSTIVKQMKIIHQDGYSLEECLEFIAIIYGNTLQSILAIVRAMTTLNIQYGDSARQDDARKLMHMADTIEEGTMPKEMSDIIQRLWKDSGIQACFERASEYQLNDSAGYYLSDLERLVTPGYVPTEQDVLRSRVKTTGIIETQFSFKDLNFRMFDVGGQRSERKKWIHCFEGVTCIIFIAALSAYDMVLVEDDEVNRMHESLHLFNSICNHRYFATTSIVLFLNKKDVFFEKIKKAHLSICFPDYDGPNTYEDAGNYIKVQFLELNMRRDVKEIYSHMTCATDTQNVKFVFDAVTDIIIKENLKDCGLF\"\n", "\n", "tokenized_input = tokenizer([test_protein], return_tensors=\"pt\", add_special_tokens=False)['input_ids']\n" ] }, { "cell_type": "markdown", "id": "18c00d09", "metadata": {}, "source": [ "If you're using a GPU, you'll need to move the tokenized data to the GPU now." ] }, { "cell_type": "code", "execution_count": 7, "id": "e0520279", "metadata": {}, "outputs": [], "source": [ "tokenized_input = tokenized_input.cuda()" ] }, { "cell_type": "markdown", "id": "89029c8b", "metadata": {}, "source": [ "With our preparations out of the way, getting your model outputs is as simple as..." ] }, { "cell_type": "code", "execution_count": 8, "id": "707fad0d", "metadata": { "scrolled": true }, "outputs": [], "source": [ "import torch\n", "\n", "with torch.no_grad():\n", " output = model(tokenized_input)" ] }, { "cell_type": "markdown", "id": "b34de2f6", "metadata": {}, "source": [ "Now here's the tricky bit - we convert the model outputs to a PDB file. This will likely be moved to a function in `transformers` in the future, but everything's still quite new, so it lives here for now! This code comes from the original ESMFold repo, and uses some functions from `openfold` that have been ported to `transformers`." ] }, { "cell_type": "code", "execution_count": 9, "id": "1c9de19e", "metadata": {}, "outputs": [], "source": [ "from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein\n", "from transformers.models.esm.openfold_utils.feats import atom14_to_atom37\n", "\n", "def convert_outputs_to_pdb(outputs):\n", " final_atom_positions = atom14_to_atom37(outputs[\"positions\"][-1], outputs)\n", " outputs = {k: v.to(\"cpu\").numpy() for k, v in outputs.items()}\n", " final_atom_positions = final_atom_positions.cpu().numpy()\n", " final_atom_mask = outputs[\"atom37_atom_exists\"]\n", " pdbs = []\n", " for i in range(outputs[\"aatype\"].shape[0]):\n", " aa = outputs[\"aatype\"][i]\n", " pred_pos = final_atom_positions[i]\n", " mask = final_atom_mask[i]\n", " resid = outputs[\"residue_index\"][i] + 1\n", " pred = OFProtein(\n", " aatype=aa,\n", " atom_positions=pred_pos,\n", " atom_mask=mask,\n", " residue_index=resid,\n", " b_factors=outputs[\"plddt\"][i],\n", " chain_index=outputs[\"chain_index\"][i] if \"chain_index\" in outputs else None,\n", " )\n", " pdbs.append(to_pdb(pred))\n", " return pdbs" ] }, { "cell_type": "code", "execution_count": 10, "id": "24613dbb", "metadata": {}, "outputs": [], "source": [ "pdb = convert_outputs_to_pdb(output)" ] }, { "cell_type": "markdown", "id": "f23adc4e", "metadata": {}, "source": [ "Now we have our pdb - can we visualize it?" ] }, { "cell_type": "code", "execution_count": 11, "id": "e094b965", "metadata": { "scrolled": true }, "outputs": [ { "data": { "application/3dmoljs_load.v0": "
You appear to be running in JupyterLab (or JavaScript failed to load for some other reason). You need to install the 3dmol extension:
\n jupyter labextension install jupyterlab_3dmol
You appear to be running in JupyterLab (or JavaScript failed to load for some other reason). You need to install the 3dmol extension:
\n",
" jupyter labextension install jupyterlab_3dmol
You appear to be running in JupyterLab (or JavaScript failed to load for some other reason). You need to install the 3dmol extension:
\n jupyter labextension install jupyterlab_3dmol
You appear to be running in JupyterLab (or JavaScript failed to load for some other reason). You need to install the 3dmol extension:
\n",
" jupyter labextension install jupyterlab_3dmol
You appear to be running in JupyterLab (or JavaScript failed to load for some other reason). You need to install the 3dmol extension:
\n jupyter labextension install jupyterlab_3dmol
You appear to be running in JupyterLab (or JavaScript failed to load for some other reason). You need to install the 3dmol extension:
\n",
" jupyter labextension install jupyterlab_3dmol
| \n", " | Entry | \n", "Sequence | \n", "
|---|---|---|
| 0 | \n", "P00393 | \n", "MTTPLKKIVIVGGGAGGLEMATQLGHKLGRKKKAKITLVDRNHSHL... | \n", "
| 1 | \n", "P00811 | \n", "MFKTTLCALLITASCSTFAAPQQINDIVHRTITPLIEQQKIPGMAV... | \n", "
| 2 | \n", "P00903 | \n", "MILLIDNYDSFTWNLYQYFCELGADVLVKRNDALTLADIDALKPQK... | \n", "
| 3 | \n", "P00914 | \n", "MTTHLVWFRQDLRLHDNLALAAACRNSSARVLALYIATPRQWATHN... | \n", "
| 4 | \n", "P00926 | \n", "MENAKMNSLIAQYPLVKDLVALKETTWFNPGTTSLAEGLPYVGLTE... | \n", "
| ... | \n", "... | \n", "... | \n", "
| 291 | \n", "C5A132 | \n", "MSHPALTQLRALRYCKEIPALDPQLLDWLLLEDSMTKRFEQQGKTV... | \n", "
| 292 | \n", "P27862 | \n", "MESWLIPAAPVTVVEEIKKSRFITMLAHTDGVEAAKAFVESVRAEH... | \n", "
| 293 | \n", "P34209 | \n", "MNITPFPTLSPATIDAINVIGQWLAQDDFSGEVPYQADCVILAGNA... | \n", "
| 294 | \n", "P76116 | \n", "MHLRHLFSSRLRGSLLLGSLLVVSSFSTQAAEEMLRKAVGKGAYEM... | \n", "
| 295 | \n", "P76483 | \n", "MGMIGYFAEIDSEKINQLLESTEKPLMDNIHDTLSGLRRLDIDKRW... | \n", "
296 rows × 2 columns
\n", "You appear to be running in JupyterLab (or JavaScript failed to load for some other reason). You need to install the 3dmol extension:
\n jupyter labextension install jupyterlab_3dmol
You appear to be running in JupyterLab (or JavaScript failed to load for some other reason). You need to install the 3dmol extension:
\n",
" jupyter labextension install jupyterlab_3dmol