{ "cells": [ { "cell_type": "markdown", "id": "af5d6f2e", "metadata": {}, "source": [ "If you're opening this Notebook on colab, you will probably need to install 馃 Transformers as well as some other libraries. Uncomment the following cell and run it." ] }, { "cell_type": "code", "execution_count": null, "id": "4c5bf8d4", "metadata": {}, "outputs": [], "source": [ "#! pip install transformers evaluate datasets requests pandas sklearn" ] }, { "cell_type": "markdown", "id": "76e71a3f", "metadata": {}, "source": [ "If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries.\n", "\n", "To be able to share your model with the community and generate results like the one shown in the picture below via the inference API, there are a few more steps to follow.\n", "\n", "First you have to store your authentication token from the Hugging Face website (sign up here if you haven't already!) then execute the following cell and input your username and password:" ] }, { "cell_type": "code", "execution_count": null, "id": "25b8526a", "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import notebook_login\n", "\n", "notebook_login()" ] }, { "cell_type": "markdown", "id": "ab8b2712", "metadata": {}, "source": [ "Then you need to install Git-LFS. Uncomment the following instructions:" ] }, { "cell_type": "code", "execution_count": null, "id": "19e8f77c", "metadata": {}, "outputs": [], "source": [ "# !apt install git-lfs" ] }, { "cell_type": "markdown", "id": "80dbad4e", "metadata": {}, "source": [ "We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely." ] }, { "cell_type": "code", "execution_count": null, "id": "d107b8d9", "metadata": {}, "outputs": [], "source": [ "from transformers.utils import send_example_telemetry\n", "\n", "send_example_telemetry(\"protein_language_modeling_notebook\", framework=\"pytorch\")" ] }, { "cell_type": "markdown", "id": "5c0749e1", "metadata": {}, "source": [ "# Fine-Tuning Protein Language Models" ] }, { "cell_type": "markdown", "id": "1d81db83", "metadata": {}, "source": [ "In this notebook, we're going to do some transfer learning to fine-tune some large, pre-trained protein language models on tasks of interest. If that sentence feels a bit intimidating to you, don't panic - there's [a blog post](https://huggingface.co/blog/deep-learning-with-proteins) that explains the concepts here in much more detail.\n", "\n", "The specific model we're going to use is ESM-2, which is the state-of-the-art protein language model at the time of writing (November 2022). The citation for this model is [Lin et al, 2022](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v1).\n", "\n", "There are several ESM-2 checkpoints with differing model sizes. Larger models will generally have better accuracy, but they require more GPU memory and will take much longer to train. The available ESM-2 checkpoints (at time of writing) are:\n", "\n", "| Checkpoint name | Num layers | Num parameters |\n", "|------------------------------|----|----------|\n", "| `esm2_t48_15B_UR50D` | 48 | 15B |\n", "| `esm2_t36_3B_UR50D` | 36 | 3B | \n", "| `esm2_t33_650M_UR50D` | 33 | 650M | \n", "| `esm2_t30_150M_UR50D` | 30 | 150M | \n", "| `esm2_t12_35M_UR50D` | 12 | 35M | \n", "| `esm2_t6_8M_UR50D` | 6 | 8M | \n", "\n", "Note that the larger checkpoints may be very difficult to train without a large cloud GPU like an A100 or H100, and the largest 15B parameter checkpoint will probably be impossible to train on **any** single GPU! Also, note that memory usage for attention during training will scale as `O(batch_size * num_layers * seq_len^2)`, so larger models on long sequences will use quite a lot of memory! We will use the `esm2_t12_35M_UR50D` checkpoint for this notebook, which should train on any Colab instance or modern GPU." ] }, { "cell_type": "code", "execution_count": 1, "id": "32e605a2", "metadata": {}, "outputs": [], "source": [ "model_checkpoint = \"facebook/esm2_t12_35M_UR50D\"" ] }, { "cell_type": "markdown", "id": "a8e6ac19", "metadata": {}, "source": [ "# Sequence classification" ] }, { "cell_type": "markdown", "id": "c3eb400c", "metadata": {}, "source": [ "One of the most common tasks you can perform with a language model is **sequence classification**. In sequence classification, we classify an entire protein into a category, from a list of two or more possibilities. There's no limit on the number of categories you can use, or the specific problem you choose, as long as it's something the model could in theory infer from the raw protein sequence. To keep things simple for this example, though, let's try classifying proteins by their cellular localization - given their sequence, can we predict if they're going to be found in the cytosol (the fluid inside the cell) or embedded in the cell membrane?" ] }, { "cell_type": "markdown", "id": "c5bc122f", "metadata": {}, "source": [ "## Data preparation" ] }, { "cell_type": "markdown", "id": "4c91d394", "metadata": {}, "source": [ "In this section, we're going to gather some training data from UniProt. Our goal is to create a pair of lists: `sequences` and `labels`. `sequences` will be a list of protein sequences, which will just be strings like \"MNKL...\", where each letter represents a single amino acid in the complete protein. `labels` will be a list of the category for each sequence. The categories will just be integers, with 0 representing the first category, 1 representing the second and so on. In other words, if `sequences[i]` is a protein sequence then `labels[i]` should be its corresponding category. These will form the **training data** we're going to use to teach the model the task we want it to do.\n", "\n", "If you're adapting this notebook for your own use, this will probably be the main section you want to change! You can do whatever you want here, as long as you create those two lists by the end of it. If you want to follow along with this example, though, first we'll need to `import requests` and set up our query to UniProt." ] }, { "cell_type": "code", "execution_count": 2, "id": "c718ffbc", "metadata": {}, "outputs": [], "source": [ "import requests\n", "\n", "query_url =\"https://rest.uniprot.org/uniprotkb/stream?compressed=true&fields=accession%2Csequence%2Ccc_subcellular_location&format=tsv&query=%28%28organism_id%3A9606%29%20AND%20%28reviewed%3Atrue%29%20AND%20%28length%3A%5B80%20TO%20500%5D%29%29\"" ] }, { "cell_type": "markdown", "id": "3d2edc14", "metadata": {}, "source": [ "This query URL might seem mysterious, but it isn't! To get it, we searched for `(organism_id:9606) AND (reviewed:true) AND (length:[80 TO 500])` on UniProt to get a list of reasonably-sized human proteins,\n", "then selected 'Download', and set the format to TSV and the columns to `Sequence` and `Subcellular location [CC]`, since those contain the data we care about for this task.\n", "\n", "Once that's done, selecting `Generate URL for API` gives you a URL you can pass to Requests. Alternatively, if you're not on Colab you can just download the data through the web interface and open the file locally." ] }, { "cell_type": "code", "execution_count": 3, "id": "dd03ef98", "metadata": {}, "outputs": [], "source": [ "uniprot_request = requests.get(query_url)" ] }, { "cell_type": "markdown", "id": "b7217b77", "metadata": {}, "source": [ "To get this data into Pandas, we use a `BytesIO` object, which Pandas will treat like a file. If you downloaded the data as a file you can skip this bit and just pass the filepath directly to `read_csv`." ] }, { "cell_type": "code", "execution_count": 4, "id": "f2c05017", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EntrySequenceSubcellular location [CC]
0A0A0K2S4Q6MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE...SUBCELLULAR LOCATION: [Isoform 1]: Membrane {E...
1A0A5B9DLKNVFPPKVAVFEPSEAEISHTQKATLVCLATGFYPDHVELSWWV...SUBCELLULAR LOCATION: Cell membrane {ECO:00003...
2A0AVI4MDSPEVTFTLAYLVFAVCFVFTPNEFHAAGLTVQNLLSGWLGSEDA...SUBCELLULAR LOCATION: Endoplasmic reticulum me...
3A0JLT2MENFTALFGAQADPPPPPTALGFGPGKPPPPPPPPAGGGPGTAPPP...SUBCELLULAR LOCATION: Nucleus {ECO:0000305}.
4A0M8Q6GQPKAAPSVTLFPPSSEELQANKATLVCLVSDFNPGAVTVAWKADG...SUBCELLULAR LOCATION: Secreted {ECO:0000303|Pu...
............
11977Q9NZ38MAFPGQSDTKMQWPEVPALPLLSSLCMAMVRKSSALGKEVGRRSEG...NaN
11978Q9UFV3MAETYRRSRQHEQLPGQRHMDLLTGYSKLIQSRLKLLLHLGSQPPV...NaN
11979Q9Y6C7MAHHSLNTFYIWHNNVLHTHLVFFLPHLLNQPFSRGSFLIWLLLCW...NaN
11980X6R8D5MGRKEHESPSQPHMCGWEDSQKPSVPSHGPKTPSCKGVKAPHSSRP...NaN
11981X6R8R1MGVVLSPHPAPSRREPLAPLAPGTRPGWSPAVSGSSRSALRPSTAG...NaN
\n", "

11982 rows 脳 3 columns

\n", "
" ], "text/plain": [ " Entry Sequence \\\n", "0 A0A0K2S4Q6 MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE... \n", "1 A0A5B9 DLKNVFPPKVAVFEPSEAEISHTQKATLVCLATGFYPDHVELSWWV... \n", "2 A0AVI4 MDSPEVTFTLAYLVFAVCFVFTPNEFHAAGLTVQNLLSGWLGSEDA... \n", "3 A0JLT2 MENFTALFGAQADPPPPPTALGFGPGKPPPPPPPPAGGGPGTAPPP... \n", "4 A0M8Q6 GQPKAAPSVTLFPPSSEELQANKATLVCLVSDFNPGAVTVAWKADG... \n", "... ... ... \n", "11977 Q9NZ38 MAFPGQSDTKMQWPEVPALPLLSSLCMAMVRKSSALGKEVGRRSEG... \n", "11978 Q9UFV3 MAETYRRSRQHEQLPGQRHMDLLTGYSKLIQSRLKLLLHLGSQPPV... \n", "11979 Q9Y6C7 MAHHSLNTFYIWHNNVLHTHLVFFLPHLLNQPFSRGSFLIWLLLCW... \n", "11980 X6R8D5 MGRKEHESPSQPHMCGWEDSQKPSVPSHGPKTPSCKGVKAPHSSRP... \n", "11981 X6R8R1 MGVVLSPHPAPSRREPLAPLAPGTRPGWSPAVSGSSRSALRPSTAG... \n", "\n", " Subcellular location [CC] \n", "0 SUBCELLULAR LOCATION: [Isoform 1]: Membrane {E... \n", "1 SUBCELLULAR LOCATION: Cell membrane {ECO:00003... \n", "2 SUBCELLULAR LOCATION: Endoplasmic reticulum me... \n", "3 SUBCELLULAR LOCATION: Nucleus {ECO:0000305}. \n", "4 SUBCELLULAR LOCATION: Secreted {ECO:0000303|Pu... \n", "... ... \n", "11977 NaN \n", "11978 NaN \n", "11979 NaN \n", "11980 NaN \n", "11981 NaN \n", "\n", "[11982 rows x 3 columns]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from io import BytesIO\n", "import pandas\n", "\n", "bio = BytesIO(uniprot_request.content)\n", "\n", "df = pandas.read_csv(bio, compression='gzip', sep='\\t')\n", "df" ] }, { "cell_type": "markdown", "id": "0bcdf34b", "metadata": {}, "source": [ "Nice! Now we have some proteins and their subcellular locations. Let's start filtering this down. First, let's ditch the columns without subcellular location information. " ] }, { "cell_type": "code", "execution_count": 5, "id": "31d87663", "metadata": {}, "outputs": [], "source": [ "df = df.dropna() # Drop proteins with missing columns" ] }, { "cell_type": "markdown", "id": "10d1af5c", "metadata": {}, "source": [ "Now we'll make one dataframe of proteins that contain `cytosol` or `cytoplasm` in their subcellular localization column, and a second that mentions the `membrane` or `cell membrane`. To ensure we don't get overlap, we ensure each dataframe only contains proteins that don't match the other search term." ] }, { "cell_type": "code", "execution_count": 6, "id": "c831bb16", "metadata": {}, "outputs": [], "source": [ "cytosolic = df['Subcellular location [CC]'].str.contains(\"Cytosol\") | df['Subcellular location [CC]'].str.contains(\"Cytoplasm\")\n", "membrane = df['Subcellular location [CC]'].str.contains(\"Membrane\") | df['Subcellular location [CC]'].str.contains(\"Cell membrane\")" ] }, { "cell_type": "code", "execution_count": 7, "id": "f41139a2", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EntrySequenceSubcellular location [CC]
10A1E959MKIIILLGFLGATLSAPLIPQRLMSASNSNELLLNLNNGQLLPLQL...SUBCELLULAR LOCATION: Secreted {ECO:0000250|Un...
15A1XBS5MMRRTLENRNAQTKQLQTAVSNVEKHFGELCQIFAAYVRKTARLRD...SUBCELLULAR LOCATION: Cytoplasm {ECO:0000269|P...
19A2RU49MSSGNYQQSEALSKPTFSEEQASALVESVFGLKVSKVRPLPSYDDQ...SUBCELLULAR LOCATION: Cytoplasm {ECO:0000305}.
21A2RUH7MEAATAPEVAAGSKLKVKEASPADAEPPQASPGQGAGSPTPQLLPP...SUBCELLULAR LOCATION: Cytoplasm, myofibril, sa...
22A4D126MEAGPPGSARPAEPGPCLSGQRGADHTASASLQSVAGTEPGRHPQA...SUBCELLULAR LOCATION: Cytoplasm, cytosol {ECO:...
............
11555Q96L03MATLARLQARSSTVGNQYYFRNSVVDPFRKKENDAAVKIQSWFRGC...SUBCELLULAR LOCATION: Cytoplasm {ECO:0000250}.
11597Q9BYD9MNHCQLPVVIDNGSGMIKAGVAGCREPQFIYPNIIGRAKGQSRAAQ...SUBCELLULAR LOCATION: Cytoplasm, cytoskeleton ...
11639Q9NPB0MEQRLAEFRAARKRAGLAAQPPAASQGAQTPGEKAEAAATLKAAPG...SUBCELLULAR LOCATION: Cytoplasmic vesicle memb...
11652Q9NUJ7MGGQVSASNSFSRLHCRNANEDWMSALCPRLWDVPLHHLSIPGSHD...SUBCELLULAR LOCATION: Cytoplasm {ECO:0000269|P...
11662Q9P2W6MGRTWCGMWRRRRPGRRSAVPRWPHLSSQSGVEPPDRWTGTPGWPS...SUBCELLULAR LOCATION: Cytoplasm.
\n", "

2495 rows 脳 3 columns

\n", "
" ], "text/plain": [ " Entry Sequence \\\n", "10 A1E959 MKIIILLGFLGATLSAPLIPQRLMSASNSNELLLNLNNGQLLPLQL... \n", "15 A1XBS5 MMRRTLENRNAQTKQLQTAVSNVEKHFGELCQIFAAYVRKTARLRD... \n", "19 A2RU49 MSSGNYQQSEALSKPTFSEEQASALVESVFGLKVSKVRPLPSYDDQ... \n", "21 A2RUH7 MEAATAPEVAAGSKLKVKEASPADAEPPQASPGQGAGSPTPQLLPP... \n", "22 A4D126 MEAGPPGSARPAEPGPCLSGQRGADHTASASLQSVAGTEPGRHPQA... \n", "... ... ... \n", "11555 Q96L03 MATLARLQARSSTVGNQYYFRNSVVDPFRKKENDAAVKIQSWFRGC... \n", "11597 Q9BYD9 MNHCQLPVVIDNGSGMIKAGVAGCREPQFIYPNIIGRAKGQSRAAQ... \n", "11639 Q9NPB0 MEQRLAEFRAARKRAGLAAQPPAASQGAQTPGEKAEAAATLKAAPG... \n", "11652 Q9NUJ7 MGGQVSASNSFSRLHCRNANEDWMSALCPRLWDVPLHHLSIPGSHD... \n", "11662 Q9P2W6 MGRTWCGMWRRRRPGRRSAVPRWPHLSSQSGVEPPDRWTGTPGWPS... \n", "\n", " Subcellular location [CC] \n", "10 SUBCELLULAR LOCATION: Secreted {ECO:0000250|Un... \n", "15 SUBCELLULAR LOCATION: Cytoplasm {ECO:0000269|P... \n", "19 SUBCELLULAR LOCATION: Cytoplasm {ECO:0000305}. \n", "21 SUBCELLULAR LOCATION: Cytoplasm, myofibril, sa... \n", "22 SUBCELLULAR LOCATION: Cytoplasm, cytosol {ECO:... \n", "... ... \n", "11555 SUBCELLULAR LOCATION: Cytoplasm {ECO:0000250}. \n", "11597 SUBCELLULAR LOCATION: Cytoplasm, cytoskeleton ... \n", "11639 SUBCELLULAR LOCATION: Cytoplasmic vesicle memb... \n", "11652 SUBCELLULAR LOCATION: Cytoplasm {ECO:0000269|P... \n", "11662 SUBCELLULAR LOCATION: Cytoplasm. \n", "\n", "[2495 rows x 3 columns]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cytosolic_df = df[cytosolic & ~membrane]\n", "cytosolic_df" ] }, { "cell_type": "code", "execution_count": 8, "id": "be5c420e", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EntrySequenceSubcellular location [CC]
0A0A0K2S4Q6MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE...SUBCELLULAR LOCATION: [Isoform 1]: Membrane {E...
1A0A5B9DLKNVFPPKVAVFEPSEAEISHTQKATLVCLATGFYPDHVELSWWV...SUBCELLULAR LOCATION: Cell membrane {ECO:00003...
4A0M8Q6GQPKAAPSVTLFPPSSEELQANKATLVCLVSDFNPGAVTVAWKADG...SUBCELLULAR LOCATION: Secreted {ECO:0000303|Pu...
18A2RU14MAGTVLGVGAGVFILALLWVAVLLLCVLLSRASGAARFSVIFLFFG...SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
35A5X5Y0MEGSWFHRKRFSFYLLLGFLLQGRGVTFTINCSGFGQHGADPTALN...SUBCELLULAR LOCATION: Cell membrane {ECO:00002...
............
11843Q6UWF5MQIQNNLFFCCYTVMSAIFKWLLLYSLPALCFLLGTQESESFHSKA...SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
11917Q8N8V8MLLKVRRASLKPPATPHQGAFRAGNVIGQLIYLLTWSLFTAWLRPP...SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
11958Q96N68MQGQGALKESHIHLPTEQPEASLVLQGQLAESSALGPKGALRPQAQ...SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
11965Q9H0A3MMNNTDFLMLNNPWNKLCLVSMDFCFPLDFVSNLFWIFASKFIIVT...SUBCELLULAR LOCATION: Membrane {ECO:0000255}; ...
11968Q9H354MNKHNLRLVQLASELILIEIIPKLFLSQVTTISHIKREKIPPNHRK...SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
\n", "

2579 rows 脳 3 columns

\n", "
" ], "text/plain": [ " Entry Sequence \\\n", "0 A0A0K2S4Q6 MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE... \n", "1 A0A5B9 DLKNVFPPKVAVFEPSEAEISHTQKATLVCLATGFYPDHVELSWWV... \n", "4 A0M8Q6 GQPKAAPSVTLFPPSSEELQANKATLVCLVSDFNPGAVTVAWKADG... \n", "18 A2RU14 MAGTVLGVGAGVFILALLWVAVLLLCVLLSRASGAARFSVIFLFFG... \n", "35 A5X5Y0 MEGSWFHRKRFSFYLLLGFLLQGRGVTFTINCSGFGQHGADPTALN... \n", "... ... ... \n", "11843 Q6UWF5 MQIQNNLFFCCYTVMSAIFKWLLLYSLPALCFLLGTQESESFHSKA... \n", "11917 Q8N8V8 MLLKVRRASLKPPATPHQGAFRAGNVIGQLIYLLTWSLFTAWLRPP... \n", "11958 Q96N68 MQGQGALKESHIHLPTEQPEASLVLQGQLAESSALGPKGALRPQAQ... \n", "11965 Q9H0A3 MMNNTDFLMLNNPWNKLCLVSMDFCFPLDFVSNLFWIFASKFIIVT... \n", "11968 Q9H354 MNKHNLRLVQLASELILIEIIPKLFLSQVTTISHIKREKIPPNHRK... \n", "\n", " Subcellular location [CC] \n", "0 SUBCELLULAR LOCATION: [Isoform 1]: Membrane {E... \n", "1 SUBCELLULAR LOCATION: Cell membrane {ECO:00003... \n", "4 SUBCELLULAR LOCATION: Secreted {ECO:0000303|Pu... \n", "18 SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ... \n", "35 SUBCELLULAR LOCATION: Cell membrane {ECO:00002... \n", "... ... \n", "11843 SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ... \n", "11917 SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ... \n", "11958 SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ... \n", "11965 SUBCELLULAR LOCATION: Membrane {ECO:0000255}; ... \n", "11968 SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ... \n", "\n", "[2579 rows x 3 columns]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "membrane_df = df[membrane & ~cytosolic]\n", "membrane_df" ] }, { "cell_type": "markdown", "id": "77e8cea6", "metadata": {}, "source": [ "We're almost done! Now, let's make a list of sequences from each df and generate the associated labels. We'll use `0` as the label for cytosolic proteins and `1` as the label for membrane proteins." ] }, { "cell_type": "code", "execution_count": 9, "id": "023ec31b", "metadata": {}, "outputs": [], "source": [ "cytosolic_sequences = cytosolic_df[\"Sequence\"].tolist()\n", "cytosolic_labels = [0 for protein in cytosolic_sequences]" ] }, { "cell_type": "code", "execution_count": 10, "id": "d0e7318b", "metadata": {}, "outputs": [], "source": [ "membrane_sequences = membrane_df[\"Sequence\"].tolist()\n", "membrane_labels = [1 for protein in membrane_sequences]" ] }, { "cell_type": "markdown", "id": "5a4bbda2", "metadata": {}, "source": [ "Now we can concatenate these lists together to get the `sequences` and `labels` lists that will form our final training data. Don't worry - they'll get shuffled during training!" ] }, { "cell_type": "code", "execution_count": 11, "id": "7dec7a4a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sequences = cytosolic_sequences + membrane_sequences\n", "labels = cytosolic_labels + membrane_labels\n", "\n", "# Quick check to make sure we got it right\n", "len(sequences) == len(labels)" ] }, { "cell_type": "markdown", "id": "bc782dd0", "metadata": {}, "source": [ "Phew!" ] }, { "cell_type": "markdown", "id": "e0aac39c", "metadata": {}, "source": [ "## Splitting the data" ] }, { "cell_type": "markdown", "id": "a9099e7c", "metadata": {}, "source": [ "Since the data we're loading isn't prepared for us as a machine learning dataset, we'll have to split the data into train and test sets ourselves! We can use sklearn's function for that:" ] }, { "cell_type": "code", "execution_count": 12, "id": "366147ad", "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "\n", "train_sequences, test_sequences, train_labels, test_labels = train_test_split(sequences, labels, test_size=0.25, shuffle=True)" ] }, { "cell_type": "markdown", "id": "7d29b4ed", "metadata": {}, "source": [ "## Tokenizing the data" ] }, { "cell_type": "markdown", "id": "c02baaf7", "metadata": {}, "source": [ "All inputs to neural nets must be numerical. The process of converting strings into numerical indices suitable for a neural net is called **tokenization**. For natural language this can be quite complex, as usually the network's vocabulary will not contain every possible word, which means the tokenizer must handle splitting rarer words into pieces, as well as all the complexities of capitalization and unicode characters and so on.\n", "\n", "With proteins, however, things are very easy. In protein language models, each amino acid is converted to a single token. Every model on `transformers` comes with an associated `tokenizer` that handles tokenization for it, and protein language models are no different. Let's get our tokenizer!" ] }, { "cell_type": "code", "execution_count": 13, "id": "ddbe2b2d", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "cc61f599adc641da8d40eefac0179aa3", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading: 0%| | 0.00/40.0 [00:00\n", " \n", " \n", " [1428/1428 04:35, Epoch 3/3]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossAccuracy
1No log0.2109210.936958
20.2321000.2050990.944050
30.1457000.2000190.946414

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "***** Running Evaluation *****\n", " Num examples = 1269\n", " Batch size = 8\n", "Saving model checkpoint to esm2_t12_35M_UR50D-finetuned-localization/checkpoint-476\n", "Configuration saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-476/config.json\n", "Model weights saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-476/pytorch_model.bin\n", "tokenizer config file saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-476/tokenizer_config.json\n", "Special tokens file saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-476/special_tokens_map.json\n", "tokenizer config file saved in esm2_t12_35M_UR50D-finetuned-localization/tokenizer_config.json\n", "Special tokens file saved in esm2_t12_35M_UR50D-finetuned-localization/special_tokens_map.json\n", "***** Running Evaluation *****\n", " Num examples = 1269\n", " Batch size = 8\n", "Saving model checkpoint to esm2_t12_35M_UR50D-finetuned-localization/checkpoint-952\n", "Configuration saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-952/config.json\n", "Model weights saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-952/pytorch_model.bin\n", "tokenizer config file saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-952/tokenizer_config.json\n", "Special tokens file saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-952/special_tokens_map.json\n", "tokenizer config file saved in esm2_t12_35M_UR50D-finetuned-localization/tokenizer_config.json\n", "Special tokens file saved in esm2_t12_35M_UR50D-finetuned-localization/special_tokens_map.json\n", "***** Running Evaluation *****\n", " Num examples = 1269\n", " Batch size = 8\n", "Saving model checkpoint to esm2_t12_35M_UR50D-finetuned-localization/checkpoint-1428\n", "Configuration saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-1428/config.json\n", "Model weights saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-1428/pytorch_model.bin\n", "tokenizer config file saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-1428/tokenizer_config.json\n", "Special tokens file saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-1428/special_tokens_map.json\n", "tokenizer config file saved in esm2_t12_35M_UR50D-finetuned-localization/tokenizer_config.json\n", "Special tokens file saved in esm2_t12_35M_UR50D-finetuned-localization/special_tokens_map.json\n", "\n", "\n", "Training completed. Do not forget to share your model on huggingface.co/models =)\n", "\n", "\n", "Loading best model from esm2_t12_35M_UR50D-finetuned-localization/checkpoint-1428 (score: 0.946414499605989).\n" ] }, { "data": { "text/plain": [ "TrainOutput(global_step=1428, training_loss=0.1632746127473206, metrics={'train_runtime': 281.8102, 'train_samples_per_second': 40.506, 'train_steps_per_second': 5.067, 'total_flos': 1032423103475172.0, 'train_loss': 0.1632746127473206, 'epoch': 3.0})" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.train()" ] }, { "cell_type": "markdown", "id": "dfec59f4", "metadata": {}, "source": [ "Nice! After three epochs we have a model accuracy of ~94%. Note that we didn't do a lot of work to filter the training data or tune hyperparameters for this experiment, and also that we used one of the smallest ESM-2 models. With a larger starting model and more effort to ensure that the training data categories were cleanly separable, accuracy could almost certainly go a lot higher!" ] }, { "cell_type": "markdown", "id": "bc2ef458", "metadata": {}, "source": [ "***\n", "# Token classification" ] }, { "cell_type": "markdown", "id": "78d701ed", "metadata": {}, "source": [ "Another common language model task is **token classification**. In this task, instead of classifying the whole sequence into a single category, we categorize each token (amino acid, in this case!) into one or more categories. This kind of model could be useful for:\n", "\n", "- Predicting secondary structure\n", "- Predicting buried vs. exposed residues\n", "- Predicting residues that will receive post-translational modifications\n", "- Predicting residues involved in binding pockets or active sites\n", "- Probably several other things, it's been a while since I was a postdoc" ] }, { "cell_type": "markdown", "id": "20e00afe", "metadata": {}, "source": [ "## Data preparation" ] }, { "cell_type": "markdown", "id": "f1b9e75c", "metadata": {}, "source": [ "In this section, we're going to gather some training data from UniProt. As in the sequence classification example, we aim to create two lists: `sequences` and `labels`. Unlike in that example, however, the `labels` are more than just single integers. Instead, the label for each sample will be **one integer per token in the input**. This should make sense - when we do token classification, different tokens in the input may have different categories!\n", "\n", "To demonstrate token classification, we're going to go back to UniProt and get some data on protein secondary structures. As above, this will probably the main section you want to change when adapting this code to your own problems." ] }, { "cell_type": "code", "execution_count": 23, "id": "bf52cfb8", "metadata": {}, "outputs": [], "source": [ "import requests\n", "\n", "query_url =\"https://rest.uniprot.org/uniprotkb/stream?compressed=true&fields=accession%2Csequence%2Cft_strand%2Cft_helix&format=tsv&query=%28%28organism_id%3A9606%29%20AND%20%28reviewed%3Atrue%29%20AND%20%28length%3A%5B80%20TO%20500%5D%29%29\"" ] }, { "cell_type": "markdown", "id": "73c902be", "metadata": {}, "source": [ "This time, our UniProt search was `(organism_id:9606) AND (reviewed:true) AND (length:[100 TO 1000])` as it was in the first example, but instead of `Subcellular location [CC]` we take the `Helix` and `Beta strand` columns, as they contain the secondary structure information we want." ] }, { "cell_type": "code", "execution_count": 24, "id": "be65f529", "metadata": {}, "outputs": [], "source": [ "uniprot_request = requests.get(query_url)" ] }, { "cell_type": "markdown", "id": "3f683dd7", "metadata": {}, "source": [ "To get this data into Pandas, we use a `BytesIO` object, which Pandas will treat like a file. If you downloaded the data as a file you can skip this bit and just pass the filepath directly to `read_csv`." ] }, { "cell_type": "code", "execution_count": 25, "id": "f49439ab", "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/html": [ "

\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EntrySequenceBeta strandHelix
0A0A0K2S4Q6MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE...NaNNaN
1A0A5B9DLKNVFPPKVAVFEPSEAEISHTQKATLVCLATGFYPDHVELSWWV...STRAND 9..14; /evidence=\"ECO:0007829|PDB:4UDT\"...HELIX 2..4; /evidence=\"ECO:0007829|PDB:4UDT\"; ...
2A0AVI4MDSPEVTFTLAYLVFAVCFVFTPNEFHAAGLTVQNLLSGWLGSEDA...NaNNaN
3A0JLT2MENFTALFGAQADPPPPPTALGFGPGKPPPPPPPPAGGGPGTAPPP...STRAND 79..81; /evidence=\"ECO:0007829|PDB:7EMF\"HELIX 83..86; /evidence=\"ECO:0007829|PDB:7EMF\"...
4A0M8Q6GQPKAAPSVTLFPPSSEELQANKATLVCLVSDFNPGAVTVAWKADG...NaNNaN
...............
11977Q9NZ38MAFPGQSDTKMQWPEVPALPLLSSLCMAMVRKSSALGKEVGRRSEG...NaNNaN
11978Q9UFV3MAETYRRSRQHEQLPGQRHMDLLTGYSKLIQSRLKLLLHLGSQPPV...NaNNaN
11979Q9Y6C7MAHHSLNTFYIWHNNVLHTHLVFFLPHLLNQPFSRGSFLIWLLLCW...NaNNaN
11980X6R8D5MGRKEHESPSQPHMCGWEDSQKPSVPSHGPKTPSCKGVKAPHSSRP...NaNNaN
11981X6R8R1MGVVLSPHPAPSRREPLAPLAPGTRPGWSPAVSGSSRSALRPSTAG...NaNNaN
\n", "

11982 rows 脳 4 columns

\n", "
" ], "text/plain": [ " Entry Sequence \\\n", "0 A0A0K2S4Q6 MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE... \n", "1 A0A5B9 DLKNVFPPKVAVFEPSEAEISHTQKATLVCLATGFYPDHVELSWWV... \n", "2 A0AVI4 MDSPEVTFTLAYLVFAVCFVFTPNEFHAAGLTVQNLLSGWLGSEDA... \n", "3 A0JLT2 MENFTALFGAQADPPPPPTALGFGPGKPPPPPPPPAGGGPGTAPPP... \n", "4 A0M8Q6 GQPKAAPSVTLFPPSSEELQANKATLVCLVSDFNPGAVTVAWKADG... \n", "... ... ... \n", "11977 Q9NZ38 MAFPGQSDTKMQWPEVPALPLLSSLCMAMVRKSSALGKEVGRRSEG... \n", "11978 Q9UFV3 MAETYRRSRQHEQLPGQRHMDLLTGYSKLIQSRLKLLLHLGSQPPV... \n", "11979 Q9Y6C7 MAHHSLNTFYIWHNNVLHTHLVFFLPHLLNQPFSRGSFLIWLLLCW... \n", "11980 X6R8D5 MGRKEHESPSQPHMCGWEDSQKPSVPSHGPKTPSCKGVKAPHSSRP... \n", "11981 X6R8R1 MGVVLSPHPAPSRREPLAPLAPGTRPGWSPAVSGSSRSALRPSTAG... \n", "\n", " Beta strand \\\n", "0 NaN \n", "1 STRAND 9..14; /evidence=\"ECO:0007829|PDB:4UDT\"... \n", "2 NaN \n", "3 STRAND 79..81; /evidence=\"ECO:0007829|PDB:7EMF\" \n", "4 NaN \n", "... ... \n", "11977 NaN \n", "11978 NaN \n", "11979 NaN \n", "11980 NaN \n", "11981 NaN \n", "\n", " Helix \n", "0 NaN \n", "1 HELIX 2..4; /evidence=\"ECO:0007829|PDB:4UDT\"; ... \n", "2 NaN \n", "3 HELIX 83..86; /evidence=\"ECO:0007829|PDB:7EMF\"... \n", "4 NaN \n", "... ... \n", "11977 NaN \n", "11978 NaN \n", "11979 NaN \n", "11980 NaN \n", "11981 NaN \n", "\n", "[11982 rows x 4 columns]" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from io import BytesIO\n", "import pandas\n", "\n", "bio = BytesIO(uniprot_request.content)\n", "\n", "df = pandas.read_csv(bio, compression='gzip', sep='\\t')\n", "df" ] }, { "cell_type": "markdown", "id": "736010f0", "metadata": {}, "source": [ "Since not all proteins have this structural information, we discard proteins that have no annotated beta strands or alpha helices." ] }, { "cell_type": "code", "execution_count": 26, "id": "39ce9a5c", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EntrySequenceBeta strandHelix
1A0A5B9DLKNVFPPKVAVFEPSEAEISHTQKATLVCLATGFYPDHVELSWWV...STRAND 9..14; /evidence=\"ECO:0007829|PDB:4UDT\"...HELIX 2..4; /evidence=\"ECO:0007829|PDB:4UDT\"; ...
3A0JLT2MENFTALFGAQADPPPPPTALGFGPGKPPPPPPPPAGGGPGTAPPP...STRAND 79..81; /evidence=\"ECO:0007829|PDB:7EMF\"HELIX 83..86; /evidence=\"ECO:0007829|PDB:7EMF\"...
14A1L3X0MAFSDLTSRTVHLYDNWIKDADPRVEDWLLMSSPLPQTILLGFYVY...STRAND 97..99; /evidence=\"ECO:0007829|PDB:6Y7F\"HELIX 17..20; /evidence=\"ECO:0007829|PDB:6Y7F\"...
16A1Z1Q3MYPSNKKKKVWREEKERLLKMTLEERRKEYLRDYIPLNSILSWKEE...STRAND 71..77; /evidence=\"ECO:0007829|PDB:4IQY...HELIX 11..19; /evidence=\"ECO:0007829|PDB:4IQY\"...
20A2RUC4MAGQHLPVPRLEGVSREQFMQHLYPQRKPLVLEGIDLGPCTSKWTV...STRAND 10..13; /evidence=\"ECO:0007829|PDB:3AL5...HELIX 16..22; /evidence=\"ECO:0007829|PDB:3AL5\"...
...............
11551Q96I45MVNLGLSRVDDAVAAKHPGLGEYAACQSHAFMKGVFTFVTGTGMAF...STRAND 3..5; /evidence=\"ECO:0007829|PDB:2LOR\";...HELIX 6..16; /evidence=\"ECO:0007829|PDB:2LOR\";...
11614Q9H0W7MPTNCAAAGCATTYNKHINISFHRFPLDPKRRKEWVRLVRRKNFVP...STRAND 7..9; /evidence=\"ECO:0007829|PDB:2D8R\";...HELIX 29..38; /evidence=\"ECO:0007829|PDB:2D8R\"
11659Q9P1F3MNVDHEVNLLVEEIHRLGSKNADGKLSVKFGVLFRDDKCANLFEAL...STRAND 24..29; /evidence=\"ECO:0007829|PDB:2L2O...HELIX 3..17; /evidence=\"ECO:0007829|PDB:2L2O\";...
11661Q9P298MSANRRWWVPPDDEDCVSEKLLRKTRESPLVPIGLGGCLVVAAYRI...STRAND 11..14; /evidence=\"ECO:0007829|PDB:2LON...HELIX 18..24; /evidence=\"ECO:0007829|PDB:2LON\"...
11668Q9UIY3MSASVKESLQLQLLEMEMLFSMFPNQGEVKLEDVNALTNIKRYLEG...STRAND 28..32; /evidence=\"ECO:0007829|PDB:2DAW...HELIX 5..22; /evidence=\"ECO:0007829|PDB:2DAW\";...
\n", "

3911 rows 脳 4 columns

\n", "
" ], "text/plain": [ " Entry Sequence \\\n", "1 A0A5B9 DLKNVFPPKVAVFEPSEAEISHTQKATLVCLATGFYPDHVELSWWV... \n", "3 A0JLT2 MENFTALFGAQADPPPPPTALGFGPGKPPPPPPPPAGGGPGTAPPP... \n", "14 A1L3X0 MAFSDLTSRTVHLYDNWIKDADPRVEDWLLMSSPLPQTILLGFYVY... \n", "16 A1Z1Q3 MYPSNKKKKVWREEKERLLKMTLEERRKEYLRDYIPLNSILSWKEE... \n", "20 A2RUC4 MAGQHLPVPRLEGVSREQFMQHLYPQRKPLVLEGIDLGPCTSKWTV... \n", "... ... ... \n", "11551 Q96I45 MVNLGLSRVDDAVAAKHPGLGEYAACQSHAFMKGVFTFVTGTGMAF... \n", "11614 Q9H0W7 MPTNCAAAGCATTYNKHINISFHRFPLDPKRRKEWVRLVRRKNFVP... \n", "11659 Q9P1F3 MNVDHEVNLLVEEIHRLGSKNADGKLSVKFGVLFRDDKCANLFEAL... \n", "11661 Q9P298 MSANRRWWVPPDDEDCVSEKLLRKTRESPLVPIGLGGCLVVAAYRI... \n", "11668 Q9UIY3 MSASVKESLQLQLLEMEMLFSMFPNQGEVKLEDVNALTNIKRYLEG... \n", "\n", " Beta strand \\\n", "1 STRAND 9..14; /evidence=\"ECO:0007829|PDB:4UDT\"... \n", "3 STRAND 79..81; /evidence=\"ECO:0007829|PDB:7EMF\" \n", "14 STRAND 97..99; /evidence=\"ECO:0007829|PDB:6Y7F\" \n", "16 STRAND 71..77; /evidence=\"ECO:0007829|PDB:4IQY... \n", "20 STRAND 10..13; /evidence=\"ECO:0007829|PDB:3AL5... \n", "... ... \n", "11551 STRAND 3..5; /evidence=\"ECO:0007829|PDB:2LOR\";... \n", "11614 STRAND 7..9; /evidence=\"ECO:0007829|PDB:2D8R\";... \n", "11659 STRAND 24..29; /evidence=\"ECO:0007829|PDB:2L2O... \n", "11661 STRAND 11..14; /evidence=\"ECO:0007829|PDB:2LON... \n", "11668 STRAND 28..32; /evidence=\"ECO:0007829|PDB:2DAW... \n", "\n", " Helix \n", "1 HELIX 2..4; /evidence=\"ECO:0007829|PDB:4UDT\"; ... \n", "3 HELIX 83..86; /evidence=\"ECO:0007829|PDB:7EMF\"... \n", "14 HELIX 17..20; /evidence=\"ECO:0007829|PDB:6Y7F\"... \n", "16 HELIX 11..19; /evidence=\"ECO:0007829|PDB:4IQY\"... \n", "20 HELIX 16..22; /evidence=\"ECO:0007829|PDB:3AL5\"... \n", "... ... \n", "11551 HELIX 6..16; /evidence=\"ECO:0007829|PDB:2LOR\";... \n", "11614 HELIX 29..38; /evidence=\"ECO:0007829|PDB:2D8R\" \n", "11659 HELIX 3..17; /evidence=\"ECO:0007829|PDB:2L2O\";... \n", "11661 HELIX 18..24; /evidence=\"ECO:0007829|PDB:2LON\"... \n", "11668 HELIX 5..22; /evidence=\"ECO:0007829|PDB:2DAW\";... \n", "\n", "[3911 rows x 4 columns]" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "no_structure_rows = df[\"Beta strand\"].isna() & df[\"Helix\"].isna()\n", "df = df[~no_structure_rows]\n", "df" ] }, { "cell_type": "markdown", "id": "f43e372c", "metadata": {}, "source": [ "Well, this works, but that data still isn't in a clean format that we can use to build our labels. Let's take a look at one sample to see what exactly we're dealing with:" ] }, { "cell_type": "code", "execution_count": 27, "id": "73e99d1b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'HELIX 2..4; /evidence=\"ECO:0007829|PDB:4UDT\"; HELIX 17..23; /evidence=\"ECO:0007829|PDB:4UDT\"; HELIX 83..86; /evidence=\"ECO:0007829|PDB:4UDT\"'" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.iloc[0][\"Helix\"]" ] }, { "cell_type": "markdown", "id": "6cd5160a", "metadata": {}, "source": [ "We'll need to use a [regex](https://docs.python.org/3/howto/regex.html) to pull out each segment that's marked as being a STRAND or HELIX. What we're asking for is a list of everywhere we see the word STRAND or HELIX followed by two numbers separated by two dots. In each case where this pattern is found, we tell the regex to extract the two numbers as a tuple for us." ] }, { "cell_type": "code", "execution_count": 28, "id": "7540949e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('2', '4'), ('17', '23'), ('83', '86')]" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import re\n", "\n", "strand_re = r\"STRAND\\s(\\d+)\\.\\.(\\d+)\\;\"\n", "helix_re = r\"HELIX\\s(\\d+)\\.\\.(\\d+)\\;\"\n", "\n", "re.findall(helix_re, df.iloc[0][\"Helix\"])" ] }, { "cell_type": "markdown", "id": "4457b1a0", "metadata": {}, "source": [ "Looks good! We can use this to build our training data. Recall that the **labels** need to be a list or array of integers that's the same length as the input sequence. We're going to use 0 to indicate residues without any annotated structure, 1 for residues in an alpha helix, and 2 for residues in a beta strand. To build that, we'll start with an array of all 0s, and then fill in values based on the positions that our regex pulls out of the UniProt results.\n", "\n", "We'll use NumPy arrays rather than lists here, since these allow [slice assignment](https://numpy.org/doc/stable/user/basics.indexing.html#assigning-values-to-indexed-arrays), which will be a lot simpler than editing a list of integers. Note also that UniProt annotates residues starting from 1 (unlike Python, which starts from 0), and region annotations are inclusive (so 1..3 means residues 1, 2 and 3). To turn these into Python slices, we subtract 1 from the start of each annotation, but not the end." ] }, { "cell_type": "code", "execution_count": 29, "id": "a4c97179", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "def build_labels(sequence, strands, helices):\n", " # Start with all 0s\n", " labels = np.zeros(len(sequence), dtype=np.int64)\n", " \n", " if isinstance(helices, float): # Indicates missing (NaN)\n", " found_helices = []\n", " else:\n", " found_helices = re.findall(helix_re, helices)\n", " for helix_start, helix_end in found_helices:\n", " helix_start = int(helix_start) - 1\n", " helix_end = int(helix_end)\n", " assert helix_end <= len(sequence)\n", " labels[helix_start: helix_end] = 1 # Helix category\n", " \n", " if isinstance(strands, float): # Indicates missing (NaN)\n", " found_strands = []\n", " else:\n", " found_strands = re.findall(strand_re, strands)\n", " for strand_start, strand_end in found_strands:\n", " strand_start = int(strand_start) - 1\n", " strand_end = int(strand_end)\n", " assert strand_end <= len(sequence)\n", " labels[strand_start: strand_end] = 2 # Strand category\n", " return labels" ] }, { "cell_type": "markdown", "id": "5ad7e7fd", "metadata": {}, "source": [ "Now we've defined a helper function, let's build our lists of sequences and labels:" ] }, { "cell_type": "code", "execution_count": 30, "id": "313811fe", "metadata": {}, "outputs": [], "source": [ "sequences = []\n", "labels = []\n", "\n", "for row_idx, row in df.iterrows():\n", " row_labels = build_labels(row[\"Sequence\"], row[\"Beta strand\"], row[\"Helix\"])\n", " sequences.append(row[\"Sequence\"])\n", " labels.append(row_labels)" ] }, { "cell_type": "markdown", "id": "8e8b3ba8", "metadata": {}, "source": [ "## Creating our dataset" ] }, { "cell_type": "markdown", "id": "e619d9ae", "metadata": {}, "source": [ "Nice! Now we'll split and tokenize the data, and then create datasets - I'll go through this quite quickly here, since it's identical to how we did it in the sequence classification example above." ] }, { "cell_type": "code", "execution_count": 31, "id": "3c208c30", "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "\n", "train_sequences, test_sequences, train_labels, test_labels = train_test_split(sequences, labels, test_size=0.25, shuffle=True)" ] }, { "cell_type": "code", "execution_count": 32, "id": "2182fae2", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "loading file vocab.txt from cache at /home/matt/.cache/huggingface/hub/models--facebook--esm2_t12_35M_UR50D/snapshots/dbb5b2b74bf5bd9cd0ab5c2b95ef3994f69879a3/vocab.txt\n", "loading file added_tokens.json from cache at None\n", "loading file special_tokens_map.json from cache at /home/matt/.cache/huggingface/hub/models--facebook--esm2_t12_35M_UR50D/snapshots/dbb5b2b74bf5bd9cd0ab5c2b95ef3994f69879a3/special_tokens_map.json\n", "loading file tokenizer_config.json from cache at /home/matt/.cache/huggingface/hub/models--facebook--esm2_t12_35M_UR50D/snapshots/dbb5b2b74bf5bd9cd0ab5c2b95ef3994f69879a3/tokenizer_config.json\n" ] } ], "source": [ "from transformers import AutoTokenizer\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n", "\n", "train_tokenized = tokenizer(train_sequences)\n", "test_tokenized = tokenizer(test_sequences)" ] }, { "cell_type": "code", "execution_count": 33, "id": "3939f13a", "metadata": {}, "outputs": [], "source": [ "from datasets import Dataset\n", "\n", "train_dataset = Dataset.from_dict(train_tokenized)\n", "test_dataset = Dataset.from_dict(test_tokenized)\n", "\n", "train_dataset = train_dataset.add_column(\"labels\", train_labels)\n", "test_dataset = test_dataset.add_column(\"labels\", test_labels)" ] }, { "cell_type": "markdown", "id": "4766fe4b", "metadata": {}, "source": [ "## Model loading" ] }, { "cell_type": "markdown", "id": "de8419b5", "metadata": {}, "source": [ "The key difference here with the above example is that we use `AutoModelForTokenClassification` instead of `AutoModelForSequenceClassification`. We will also need a `data_collator` this time, as we're in the slightly more complex case where both inputs and labels must be padded in each batch." ] }, { "cell_type": "code", "execution_count": 34, "id": "4b26b828", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "loading configuration file config.json from cache at /home/matt/.cache/huggingface/hub/models--facebook--esm2_t12_35M_UR50D/snapshots/dbb5b2b74bf5bd9cd0ab5c2b95ef3994f69879a3/config.json\n", "Model config EsmConfig {\n", " \"_name_or_path\": \"facebook/esm2_t12_35M_UR50D\",\n", " \"architectures\": [\n", " \"EsmForMaskedLM\"\n", " ],\n", " \"attention_probs_dropout_prob\": 0.0,\n", " \"classifier_dropout\": null,\n", " \"emb_layer_norm_before\": false,\n", " \"esmfold_config\": null,\n", " \"hidden_act\": \"gelu\",\n", " \"hidden_dropout_prob\": 0.0,\n", " \"hidden_size\": 480,\n", " \"id2label\": {\n", " \"0\": \"LABEL_0\",\n", " \"1\": \"LABEL_1\",\n", " \"2\": \"LABEL_2\"\n", " },\n", " \"initializer_range\": 0.02,\n", " \"intermediate_size\": 1920,\n", " \"is_folding_model\": false,\n", " \"label2id\": {\n", " \"LABEL_0\": 0,\n", " \"LABEL_1\": 1,\n", " \"LABEL_2\": 2\n", " },\n", " \"layer_norm_eps\": 1e-05,\n", " \"mask_token_id\": 32,\n", " \"max_position_embeddings\": 1026,\n", " \"model_type\": \"esm\",\n", " \"num_attention_heads\": 20,\n", " \"num_hidden_layers\": 12,\n", " \"pad_token_id\": 1,\n", " \"position_embedding_type\": \"rotary\",\n", " \"token_dropout\": true,\n", " \"torch_dtype\": \"float32\",\n", " \"transformers_version\": \"4.25.0.dev0\",\n", " \"use_cache\": true,\n", " \"vocab_list\": null,\n", " \"vocab_size\": 33\n", "}\n", "\n", "loading weights file pytorch_model.bin from cache at /home/matt/.cache/huggingface/hub/models--facebook--esm2_t12_35M_UR50D/snapshots/dbb5b2b74bf5bd9cd0ab5c2b95ef3994f69879a3/pytorch_model.bin\n", "Some weights of the model checkpoint at facebook/esm2_t12_35M_UR50D were not used when initializing EsmForTokenClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.decoder.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias']\n", "- This IS expected if you are initializing EsmForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing EsmForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of EsmForTokenClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.weight', 'classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer\n", "\n", "num_labels = 3\n", "model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=num_labels)" ] }, { "cell_type": "code", "execution_count": 35, "id": "eec0005a", "metadata": {}, "outputs": [], "source": [ "from transformers import DataCollatorForTokenClassification\n", "\n", "data_collator = DataCollatorForTokenClassification(tokenizer)" ] }, { "cell_type": "markdown", "id": "bd3c7305", "metadata": {}, "source": [ "Now we set up our `TrainingArguments` as before." ] }, { "cell_type": "code", "execution_count": 36, "id": "e7724323", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "PyTorch: setting up devices\n", "The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).\n" ] } ], "source": [ "model_name = model_checkpoint.split(\"/\")[-1]\n", "batch_size = 8\n", "\n", "args = TrainingArguments(\n", " f\"{model_name}-finetuned-secondary-structure\",\n", " evaluation_strategy = \"epoch\",\n", " save_strategy = \"epoch\",\n", " learning_rate=1e-4,\n", " per_device_train_batch_size=batch_size,\n", " per_device_eval_batch_size=batch_size,\n", " num_train_epochs=3,\n", " weight_decay=0.001,\n", " load_best_model_at_end=True,\n", " metric_for_best_model=\"accuracy\",\n", " push_to_hub=True,\n", ")" ] }, { "cell_type": "markdown", "id": "fb5fba9a", "metadata": {}, "source": [ "Our `compute_metrics` function is a bit more complex than in the sequence classification task, as we need to ignore padding tokens (those where the label is `-100`)." ] }, { "cell_type": "code", "execution_count": 37, "id": "736886a0", "metadata": {}, "outputs": [], "source": [ "from evaluate import load\n", "import numpy as np\n", "\n", "metric = load(\"accuracy\")\n", "\n", "def compute_metrics(eval_pred):\n", " predictions, labels = eval_pred\n", " labels = labels.reshape((-1,))\n", " predictions = np.argmax(predictions, axis=2)\n", " predictions = predictions.reshape((-1,))\n", " predictions = predictions[labels!=-100]\n", " labels = labels[labels!=-100]\n", " return metric.compute(predictions=predictions, references=labels)" ] }, { "cell_type": "markdown", "id": "37491af5", "metadata": {}, "source": [ "And now we're ready to train our model! " ] }, { "cell_type": "code", "execution_count": 38, "id": "4c97836c", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/matt/PycharmProjects/notebooks/examples/esm2_t12_35M_UR50D-finetuned-secondary-structure is already a clone of https://huggingface.co/Rocketknight1/esm2_t12_35M_UR50D-finetuned-secondary-structure. Make sure you pull the latest changes with `repo.git_pull()`.\n", "/home/matt/PycharmProjects/transformers/src/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " warnings.warn(\n", "***** Running training *****\n", " Num examples = 2933\n", " Num Epochs = 3\n", " Instantaneous batch size per device = 8\n", " Total train batch size (w. parallel, distributed & accumulation) = 8\n", " Gradient Accumulation steps = 1\n", " Total optimization steps = 1101\n", " Number of trainable parameters = 33763203\n", "Automatic Weights & Biases logging enabled, to disable set os.environ[\"WANDB_DISABLED\"] = \"true\"\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [1101/1101 03:52, Epoch 3/3]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossAccuracy
1No log0.4654080.809475
20.4962000.4439260.818526
30.3711000.4491090.821522

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "***** Running Evaluation *****\n", " Num examples = 978\n", " Batch size = 8\n", "Saving model checkpoint to esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-367\n", "Configuration saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-367/config.json\n", "Model weights saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-367/pytorch_model.bin\n", "tokenizer config file saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-367/tokenizer_config.json\n", "Special tokens file saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-367/special_tokens_map.json\n", "tokenizer config file saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/tokenizer_config.json\n", "Special tokens file saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/special_tokens_map.json\n", "***** Running Evaluation *****\n", " Num examples = 978\n", " Batch size = 8\n", "Saving model checkpoint to esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-734\n", "Configuration saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-734/config.json\n", "Model weights saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-734/pytorch_model.bin\n", "tokenizer config file saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-734/tokenizer_config.json\n", "Special tokens file saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-734/special_tokens_map.json\n", "tokenizer config file saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/tokenizer_config.json\n", "Special tokens file saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/special_tokens_map.json\n", "***** Running Evaluation *****\n", " Num examples = 978\n", " Batch size = 8\n", "Saving model checkpoint to esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-1101\n", "Configuration saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-1101/config.json\n", "Model weights saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-1101/pytorch_model.bin\n", "tokenizer config file saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-1101/tokenizer_config.json\n", "Special tokens file saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-1101/special_tokens_map.json\n", "tokenizer config file saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/tokenizer_config.json\n", "Special tokens file saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/special_tokens_map.json\n", "\n", "\n", "Training completed. Do not forget to share your model on huggingface.co/models =)\n", "\n", "\n", "Loading best model from esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-1101 (score: 0.8215224822508546).\n" ] }, { "data": { "text/plain": [ "TrainOutput(global_step=1101, training_loss=0.42545173083728927, metrics={'train_runtime': 232.9156, 'train_samples_per_second': 37.778, 'train_steps_per_second': 4.727, 'total_flos': 794586720601188.0, 'train_loss': 0.42545173083728927, 'epoch': 3.0})" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer = Trainer(\n", " model,\n", " args,\n", " train_dataset=train_dataset,\n", " eval_dataset=test_dataset,\n", " tokenizer=tokenizer,\n", " compute_metrics=compute_metrics,\n", " data_collator=data_collator,\n", ")\n", "\n", "trainer.train()" ] }, { "cell_type": "markdown", "id": "f503fc00", "metadata": {}, "source": [ "This definitely seems harder than the first task, but we still attain a very respectable accuracy. Remember that to keep this demo lightweight, we used one of the smallest ESM models, focused on human proteins only and didn't put a lot of work into making sure we only included completely-annotated proteins in our training set. With a bigger model and a cleaner, broader training set, accuracy on this task could definitely go a lot higher!" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" } }, "nbformat": 4, "nbformat_minor": 5 }