{ "cells": [ { "cell_type": "markdown", "id": "af5d6f2e", "metadata": {}, "source": [ "If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers as well as some other libraries. Uncomment the following cell and run it." ] }, { "cell_type": "code", "execution_count": null, "id": "4c5bf8d4", "metadata": {}, "outputs": [], "source": [ "#! pip install transformers evaluate datasets requests pandas sklearn" ] }, { "cell_type": "markdown", "id": "76e71a3f", "metadata": {}, "source": [ "If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries.\n", "\n", "To be able to share your model with the community and generate results like the one shown in the picture below via the inference API, there are a few more steps to follow.\n", "\n", "First you have to store your authentication token from the Hugging Face website (sign up here if you haven't already!) then execute the following cell and input your username and password:" ] }, { "cell_type": "code", "execution_count": null, "id": "25b8526a", "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import notebook_login\n", "\n", "notebook_login()" ] }, { "cell_type": "markdown", "id": "ab8b2712", "metadata": {}, "source": [ "Then you need to install Git-LFS. Uncomment the following instructions:" ] }, { "cell_type": "code", "execution_count": null, "id": "19e8f77c", "metadata": {}, "outputs": [], "source": [ "# !apt install git-lfs" ] }, { "cell_type": "markdown", "id": "22c7b6be", "metadata": {}, "source": [ "We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely." ] }, { "cell_type": "code", "execution_count": null, "id": "7f28d9b7", "metadata": {}, "outputs": [], "source": [ "from transformers.utils import send_example_telemetry\n", "\n", "send_example_telemetry(\"protein_language_modeling_notebook\", framework=\"tensorflow\")" ] }, { "cell_type": "markdown", "id": "5c0749e1", "metadata": {}, "source": [ "# Fine-Tuning Protein Language Models" ] }, { "cell_type": "markdown", "id": "1d81db83", "metadata": {}, "source": [ "In this notebook, we're going to do some transfer learning to fine-tune some large, pre-trained protein language models on tasks of interest. If that sentence feels a bit intimidating to you, don't panic - there's [a blog post](https://huggingface.co/blog/deep-learning-with-proteins) that explains the concepts here in much more detail.\n", "\n", "The specific model we're going to use is ESM-2, which is the state-of-the-art protein language model at the time of writing (November 2022). The citation for this model is [Lin et al, 2022](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v1).\n", "\n", "There are several ESM-2 checkpoints with differing model sizes. Larger models will generally have better accuracy, but they require more GPU memory and will take much longer to train. The available ESM-2 checkpoints (at time of writing) are:\n", "\n", "| Checkpoint name | Num layers | Num parameters |\n", "|------------------------------|----|----------|\n", "| `esm2_t48_15B_UR50D` | 48 | 15B |\n", "| `esm2_t36_3B_UR50D` | 36 | 3B | \n", "| `esm2_t33_650M_UR50D` | 33 | 650M | \n", "| `esm2_t30_150M_UR50D` | 30 | 150M | \n", "| `esm2_t12_35M_UR50D` | 12 | 35M | \n", "| `esm2_t6_8M_UR50D` | 6 | 8M | \n", "\n", "Note that the larger checkpoints may be very difficult to train without a large cloud GPU like an A100 or H100, and the largest 15B parameter checkpoint will probably be impossible to train on **any** single GPU! Also, note that memory usage for attention during training will scale as `O(batch_size * num_layers * seq_len^2)`, so larger models on long sequences will use quite a lot of memory! We will use the `esm2_t12_35M_UR50D` checkpoint for this notebook, which should train on any Colab instance or modern GPU." ] }, { "cell_type": "code", "execution_count": 1, "id": "32e605a2", "metadata": {}, "outputs": [], "source": [ "model_checkpoint = \"facebook/esm2_t12_35M_UR50D\"" ] }, { "cell_type": "markdown", "id": "a8e6ac19", "metadata": {}, "source": [ "# Sequence classification" ] }, { "cell_type": "markdown", "id": "c3eb400c", "metadata": {}, "source": [ "One of the most common tasks you can perform with a language model is **sequence classification**. In sequence classification, we classify an entire protein into a category, from a list of two or more possibilities. There's no limit on the number of categories you can use, or the specific problem you choose, as long as it's something the model could in theory infer from the raw protein sequence. To keep things simple for this example, though, let's try classifying proteins by their cellular localization - given their sequence, can we predict if they're going to be found in the cytosol (the fluid inside the cell) or embedded in the cell membrane?" ] }, { "cell_type": "markdown", "id": "c5bc122f", "metadata": {}, "source": [ "## Data preparation" ] }, { "cell_type": "markdown", "id": "4c91d394", "metadata": {}, "source": [ "In this section, we're going to gather some training data from UniProt. Our goal is to create a pair of lists: `sequences` and `labels`. `sequences` will be a list of protein sequences, which will just be strings like \"MNKL...\", where each letter represents a single amino acid in the complete protein. `labels` will be a list of the category for each sequence. The categories will just be integers, with 0 representing the first category, 1 representing the second and so on. In other words, if `sequences[i]` is a protein sequence then `labels[i]` should be its corresponding category. These will form the **training data** we're going to use to teach the model the task we want it to do.\n", "\n", "If you're adapting this notebook for your own use, this will probably be the main section you want to change! You can do whatever you want here, as long as you create those two lists by the end of it. If you want to follow along with this example, though, first we'll need to `import requests` and set up our query to UniProt." ] }, { "cell_type": "code", "execution_count": 2, "id": "c718ffbc", "metadata": {}, "outputs": [], "source": [ "import requests\n", "\n", "query_url =\"https://rest.uniprot.org/uniprotkb/stream?compressed=true&fields=accession%2Csequence%2Ccc_subcellular_location&format=tsv&query=%28%28organism_id%3A9606%29%20AND%20%28reviewed%3Atrue%29%20AND%20%28length%3A%5B80%20TO%20500%5D%29%29\"" ] }, { "cell_type": "markdown", "id": "3d2edc14", "metadata": {}, "source": [ "This query URL might seem mysterious, but it isn't! To get it, we searched for `(organism_id:9606) AND (reviewed:true) AND (length:[80 TO 500])` on UniProt to get a list of reasonably-sized human proteins,\n", "then selected 'Download', and set the format to TSV and the columns to `Sequence` and `Subcellular location [CC]`, since those contain the data we care about for this task.\n", "\n", "Once that's done, selecting `Generate URL for API` gives you a URL you can pass to Requests. Alternatively, if you're not on Colab you can just download the data through the web interface and open the file locally." ] }, { "cell_type": "code", "execution_count": 3, "id": "dd03ef98", "metadata": {}, "outputs": [], "source": [ "uniprot_request = requests.get(query_url)" ] }, { "cell_type": "markdown", "id": "b7217b77", "metadata": {}, "source": [ "To get this data into Pandas, we use a `BytesIO` object, which Pandas will treat like a file. If you downloaded the data as a file you can skip this bit and just pass the filepath directly to `read_csv`." ] }, { "cell_type": "code", "execution_count": 4, "id": "f2c05017", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | Entry | \n", "Sequence | \n", "Subcellular location [CC] | \n", "
|---|---|---|---|
| 0 | \n", "A0A0K2S4Q6 | \n", "MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE... | \n", "SUBCELLULAR LOCATION: [Isoform 1]: Membrane {E... | \n", "
| 1 | \n", "A0A5B9 | \n", "DLKNVFPPKVAVFEPSEAEISHTQKATLVCLATGFYPDHVELSWWV... | \n", "SUBCELLULAR LOCATION: Cell membrane {ECO:00003... | \n", "
| 2 | \n", "A0AVI4 | \n", "MDSPEVTFTLAYLVFAVCFVFTPNEFHAAGLTVQNLLSGWLGSEDA... | \n", "SUBCELLULAR LOCATION: Endoplasmic reticulum me... | \n", "
| 3 | \n", "A0JLT2 | \n", "MENFTALFGAQADPPPPPTALGFGPGKPPPPPPPPAGGGPGTAPPP... | \n", "SUBCELLULAR LOCATION: Nucleus {ECO:0000305}. | \n", "
| 4 | \n", "A0M8Q6 | \n", "GQPKAAPSVTLFPPSSEELQANKATLVCLVSDFNPGAVTVAWKADG... | \n", "SUBCELLULAR LOCATION: Secreted {ECO:0000303|Pu... | \n", "
| ... | \n", "... | \n", "... | \n", "... | \n", "
| 11977 | \n", "Q9NZ38 | \n", "MAFPGQSDTKMQWPEVPALPLLSSLCMAMVRKSSALGKEVGRRSEG... | \n", "NaN | \n", "
| 11978 | \n", "Q9UFV3 | \n", "MAETYRRSRQHEQLPGQRHMDLLTGYSKLIQSRLKLLLHLGSQPPV... | \n", "NaN | \n", "
| 11979 | \n", "Q9Y6C7 | \n", "MAHHSLNTFYIWHNNVLHTHLVFFLPHLLNQPFSRGSFLIWLLLCW... | \n", "NaN | \n", "
| 11980 | \n", "X6R8D5 | \n", "MGRKEHESPSQPHMCGWEDSQKPSVPSHGPKTPSCKGVKAPHSSRP... | \n", "NaN | \n", "
| 11981 | \n", "X6R8R1 | \n", "MGVVLSPHPAPSRREPLAPLAPGTRPGWSPAVSGSSRSALRPSTAG... | \n", "NaN | \n", "
11982 rows × 3 columns
\n", "| \n", " | Entry | \n", "Sequence | \n", "Subcellular location [CC] | \n", "
|---|---|---|---|
| 10 | \n", "A1E959 | \n", "MKIIILLGFLGATLSAPLIPQRLMSASNSNELLLNLNNGQLLPLQL... | \n", "SUBCELLULAR LOCATION: Secreted {ECO:0000250|Un... | \n", "
| 15 | \n", "A1XBS5 | \n", "MMRRTLENRNAQTKQLQTAVSNVEKHFGELCQIFAAYVRKTARLRD... | \n", "SUBCELLULAR LOCATION: Cytoplasm {ECO:0000269|P... | \n", "
| 19 | \n", "A2RU49 | \n", "MSSGNYQQSEALSKPTFSEEQASALVESVFGLKVSKVRPLPSYDDQ... | \n", "SUBCELLULAR LOCATION: Cytoplasm {ECO:0000305}. | \n", "
| 21 | \n", "A2RUH7 | \n", "MEAATAPEVAAGSKLKVKEASPADAEPPQASPGQGAGSPTPQLLPP... | \n", "SUBCELLULAR LOCATION: Cytoplasm, myofibril, sa... | \n", "
| 22 | \n", "A4D126 | \n", "MEAGPPGSARPAEPGPCLSGQRGADHTASASLQSVAGTEPGRHPQA... | \n", "SUBCELLULAR LOCATION: Cytoplasm, cytosol {ECO:... | \n", "
| ... | \n", "... | \n", "... | \n", "... | \n", "
| 11555 | \n", "Q96L03 | \n", "MATLARLQARSSTVGNQYYFRNSVVDPFRKKENDAAVKIQSWFRGC... | \n", "SUBCELLULAR LOCATION: Cytoplasm {ECO:0000250}. | \n", "
| 11597 | \n", "Q9BYD9 | \n", "MNHCQLPVVIDNGSGMIKAGVAGCREPQFIYPNIIGRAKGQSRAAQ... | \n", "SUBCELLULAR LOCATION: Cytoplasm, cytoskeleton ... | \n", "
| 11639 | \n", "Q9NPB0 | \n", "MEQRLAEFRAARKRAGLAAQPPAASQGAQTPGEKAEAAATLKAAPG... | \n", "SUBCELLULAR LOCATION: Cytoplasmic vesicle memb... | \n", "
| 11652 | \n", "Q9NUJ7 | \n", "MGGQVSASNSFSRLHCRNANEDWMSALCPRLWDVPLHHLSIPGSHD... | \n", "SUBCELLULAR LOCATION: Cytoplasm {ECO:0000269|P... | \n", "
| 11662 | \n", "Q9P2W6 | \n", "MGRTWCGMWRRRRPGRRSAVPRWPHLSSQSGVEPPDRWTGTPGWPS... | \n", "SUBCELLULAR LOCATION: Cytoplasm. | \n", "
2495 rows × 3 columns
\n", "| \n", " | Entry | \n", "Sequence | \n", "Subcellular location [CC] | \n", "
|---|---|---|---|
| 0 | \n", "A0A0K2S4Q6 | \n", "MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE... | \n", "SUBCELLULAR LOCATION: [Isoform 1]: Membrane {E... | \n", "
| 1 | \n", "A0A5B9 | \n", "DLKNVFPPKVAVFEPSEAEISHTQKATLVCLATGFYPDHVELSWWV... | \n", "SUBCELLULAR LOCATION: Cell membrane {ECO:00003... | \n", "
| 4 | \n", "A0M8Q6 | \n", "GQPKAAPSVTLFPPSSEELQANKATLVCLVSDFNPGAVTVAWKADG... | \n", "SUBCELLULAR LOCATION: Secreted {ECO:0000303|Pu... | \n", "
| 18 | \n", "A2RU14 | \n", "MAGTVLGVGAGVFILALLWVAVLLLCVLLSRASGAARFSVIFLFFG... | \n", "SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ... | \n", "
| 35 | \n", "A5X5Y0 | \n", "MEGSWFHRKRFSFYLLLGFLLQGRGVTFTINCSGFGQHGADPTALN... | \n", "SUBCELLULAR LOCATION: Cell membrane {ECO:00002... | \n", "
| ... | \n", "... | \n", "... | \n", "... | \n", "
| 11843 | \n", "Q6UWF5 | \n", "MQIQNNLFFCCYTVMSAIFKWLLLYSLPALCFLLGTQESESFHSKA... | \n", "SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ... | \n", "
| 11917 | \n", "Q8N8V8 | \n", "MLLKVRRASLKPPATPHQGAFRAGNVIGQLIYLLTWSLFTAWLRPP... | \n", "SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ... | \n", "
| 11958 | \n", "Q96N68 | \n", "MQGQGALKESHIHLPTEQPEASLVLQGQLAESSALGPKGALRPQAQ... | \n", "SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ... | \n", "
| 11965 | \n", "Q9H0A3 | \n", "MMNNTDFLMLNNPWNKLCLVSMDFCFPLDFVSNLFWIFASKFIIVT... | \n", "SUBCELLULAR LOCATION: Membrane {ECO:0000255}; ... | \n", "
| 11968 | \n", "Q9H354 | \n", "MNKHNLRLVQLASELILIEIIPKLFLSQVTTISHIKREKIPPNHRK... | \n", "SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ... | \n", "
2579 rows × 3 columns
\n", "| \n", " | Entry | \n", "Sequence | \n", "Beta strand | \n", "Helix | \n", "
|---|---|---|---|---|
| 0 | \n", "A0A0K2S4Q6 | \n", "MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE... | \n", "NaN | \n", "NaN | \n", "
| 1 | \n", "A0A5B9 | \n", "DLKNVFPPKVAVFEPSEAEISHTQKATLVCLATGFYPDHVELSWWV... | \n", "STRAND 9..14; /evidence=\"ECO:0007829|PDB:4UDT\"... | \n", "HELIX 2..4; /evidence=\"ECO:0007829|PDB:4UDT\"; ... | \n", "
| 2 | \n", "A0AVI4 | \n", "MDSPEVTFTLAYLVFAVCFVFTPNEFHAAGLTVQNLLSGWLGSEDA... | \n", "NaN | \n", "NaN | \n", "
| 3 | \n", "A0JLT2 | \n", "MENFTALFGAQADPPPPPTALGFGPGKPPPPPPPPAGGGPGTAPPP... | \n", "STRAND 79..81; /evidence=\"ECO:0007829|PDB:7EMF\" | \n", "HELIX 83..86; /evidence=\"ECO:0007829|PDB:7EMF\"... | \n", "
| 4 | \n", "A0M8Q6 | \n", "GQPKAAPSVTLFPPSSEELQANKATLVCLVSDFNPGAVTVAWKADG... | \n", "NaN | \n", "NaN | \n", "
| ... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
| 11977 | \n", "Q9NZ38 | \n", "MAFPGQSDTKMQWPEVPALPLLSSLCMAMVRKSSALGKEVGRRSEG... | \n", "NaN | \n", "NaN | \n", "
| 11978 | \n", "Q9UFV3 | \n", "MAETYRRSRQHEQLPGQRHMDLLTGYSKLIQSRLKLLLHLGSQPPV... | \n", "NaN | \n", "NaN | \n", "
| 11979 | \n", "Q9Y6C7 | \n", "MAHHSLNTFYIWHNNVLHTHLVFFLPHLLNQPFSRGSFLIWLLLCW... | \n", "NaN | \n", "NaN | \n", "
| 11980 | \n", "X6R8D5 | \n", "MGRKEHESPSQPHMCGWEDSQKPSVPSHGPKTPSCKGVKAPHSSRP... | \n", "NaN | \n", "NaN | \n", "
| 11981 | \n", "X6R8R1 | \n", "MGVVLSPHPAPSRREPLAPLAPGTRPGWSPAVSGSSRSALRPSTAG... | \n", "NaN | \n", "NaN | \n", "
11982 rows × 4 columns
\n", "| \n", " | Entry | \n", "Sequence | \n", "Beta strand | \n", "Helix | \n", "
|---|---|---|---|---|
| 1 | \n", "A0A5B9 | \n", "DLKNVFPPKVAVFEPSEAEISHTQKATLVCLATGFYPDHVELSWWV... | \n", "STRAND 9..14; /evidence=\"ECO:0007829|PDB:4UDT\"... | \n", "HELIX 2..4; /evidence=\"ECO:0007829|PDB:4UDT\"; ... | \n", "
| 3 | \n", "A0JLT2 | \n", "MENFTALFGAQADPPPPPTALGFGPGKPPPPPPPPAGGGPGTAPPP... | \n", "STRAND 79..81; /evidence=\"ECO:0007829|PDB:7EMF\" | \n", "HELIX 83..86; /evidence=\"ECO:0007829|PDB:7EMF\"... | \n", "
| 14 | \n", "A1L3X0 | \n", "MAFSDLTSRTVHLYDNWIKDADPRVEDWLLMSSPLPQTILLGFYVY... | \n", "STRAND 97..99; /evidence=\"ECO:0007829|PDB:6Y7F\" | \n", "HELIX 17..20; /evidence=\"ECO:0007829|PDB:6Y7F\"... | \n", "
| 16 | \n", "A1Z1Q3 | \n", "MYPSNKKKKVWREEKERLLKMTLEERRKEYLRDYIPLNSILSWKEE... | \n", "STRAND 71..77; /evidence=\"ECO:0007829|PDB:4IQY... | \n", "HELIX 11..19; /evidence=\"ECO:0007829|PDB:4IQY\"... | \n", "
| 20 | \n", "A2RUC4 | \n", "MAGQHLPVPRLEGVSREQFMQHLYPQRKPLVLEGIDLGPCTSKWTV... | \n", "STRAND 10..13; /evidence=\"ECO:0007829|PDB:3AL5... | \n", "HELIX 16..22; /evidence=\"ECO:0007829|PDB:3AL5\"... | \n", "
| ... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
| 11551 | \n", "Q96I45 | \n", "MVNLGLSRVDDAVAAKHPGLGEYAACQSHAFMKGVFTFVTGTGMAF... | \n", "STRAND 3..5; /evidence=\"ECO:0007829|PDB:2LOR\";... | \n", "HELIX 6..16; /evidence=\"ECO:0007829|PDB:2LOR\";... | \n", "
| 11614 | \n", "Q9H0W7 | \n", "MPTNCAAAGCATTYNKHINISFHRFPLDPKRRKEWVRLVRRKNFVP... | \n", "STRAND 7..9; /evidence=\"ECO:0007829|PDB:2D8R\";... | \n", "HELIX 29..38; /evidence=\"ECO:0007829|PDB:2D8R\" | \n", "
| 11659 | \n", "Q9P1F3 | \n", "MNVDHEVNLLVEEIHRLGSKNADGKLSVKFGVLFRDDKCANLFEAL... | \n", "STRAND 24..29; /evidence=\"ECO:0007829|PDB:2L2O... | \n", "HELIX 3..17; /evidence=\"ECO:0007829|PDB:2L2O\";... | \n", "
| 11661 | \n", "Q9P298 | \n", "MSANRRWWVPPDDEDCVSEKLLRKTRESPLVPIGLGGCLVVAAYRI... | \n", "STRAND 11..14; /evidence=\"ECO:0007829|PDB:2LON... | \n", "HELIX 18..24; /evidence=\"ECO:0007829|PDB:2LON\"... | \n", "
| 11668 | \n", "Q9UIY3 | \n", "MSASVKESLQLQLLEMEMLFSMFPNQGEVKLEDVNALTNIKRYLEG... | \n", "STRAND 28..32; /evidence=\"ECO:0007829|PDB:2DAW... | \n", "HELIX 5..22; /evidence=\"ECO:0007829|PDB:2DAW\";... | \n", "
3911 rows × 4 columns
\n", "