{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "rilo8nzE4L2F" }, "source": [ "If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers as well as some other libraries. Uncomment the following cell and run it." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "NWl3z9tR4Un1" }, "outputs": [], "source": [ "# Install\n", "!pip install -q biopython transformers datasets huggingface_hub accelerate" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "8XJthJun4mVw" }, "source": [ "If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries.\n", "\n", "To be able to share your model with the community and generate results like the one shown in the picture below via the inference API, there are a few more steps to follow.\n", "\n", "First you have to login to the huggingface hub" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import notebook_login\n", "\n", "notebook_login()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "AUY6jNR85SbG" }, "source": [ "Then you need to install Git-LFS. Uncomment the following instructions:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "439dGQnD5Uul", "outputId": "18c8291c-0af0-4f96-f360-7af4e49f21e2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Reading package lists... Done\n", "Building dependency tree \n", "Reading state information... Done\n", "git-lfs is already the newest version (2.9.2-1).\n", "0 upgraded, 0 newly installed, 0 to remove and 16 not upgraded.\n" ] } ], "source": [ "!apt install git-lfs" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "vYX1k_HIHqbE" }, "source": [ "# **Fine-Tuning the Nucleotide-transformer**" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "16RWKoAcPdLx" }, "source": [ "The **Nucleotide Transformer** paper [Dalla-torre et al, 2023](https://www.biorxiv.org/content/10.1101/2023.01.11.523679v2) introduces 4 genomics foundational models developed by **InstaDeep**. These transformers, of various sizes and trained on different datasets, allow powerful representations of DNA sequences that allow to tackle a very diverse set of problems such as chromatin accessibility, deleteriousness prediction, promoter and enhancer prediction etc... These representations can be extracted from the transformer and used as proxies of the DNA sequences (this is called probing) or the transformer can be trained further on a specific task (this is called finetuning)." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "_QJ13X_uPfSb" }, "source": [ "" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "AqGW_WWzn8TZ" }, "source": [ "This notebook allows you to fine-tune these models.\n", "\n", "The model we are going to use is the [500M Human Ref model](https://huggingface.co/InstaDeepAI/nucleotide-transformer-500m-1000g), which is a 500M parameters transformer pre-trained on the human reference genome, per the training methodology presented in the Nucleotide Transformer Paper. It is one of the 4 models introduced, all available on the [Instadeep HuggingFace page](https://huggingface.co/InstaDeepAI):\n", "\n", "```\n", "| Model name | Num layers | Num parameters | Training dataset |\n", "|---------------------|------------|----------------|------------------------|\n", "| `500M Human Ref` | 24 | 500M | Human reference genome |\n", "| `500M 1000G` | 24 | 500M | 1000G genomes |\n", "| `2.5B 1000G` | 32 | 2.5B | 1000G genomes |\n", "| `2.5B Multispecies` | 32 | 2.5B | Multi-species dataset |\n", "\n", "```\n", "\n", "Note that using the larger models will require more GPU memory and produce longer finetuning times\n", "\n", "In the following, we showcase the nucleotide transformer ability to classify genomic sequences as two of the most basic genomic motifs: **promoters** and **enhancers types**. Both of them are classification task, but the enhancers types task is much more challenging with its 3 classes.\n", "\n", "These two tasks are still very basic, but the nucleotide transformers have been shown to beat/match state of the art models on much more complex tasks such as [DeepSEA](https://www.nature.com/articles/nmeth.3547), which, given a DNA sequence, predicts 919 chromatin profiles from a diverse set of human cells and tissues from a single sequence or [DeepSTARR](https://www.nature.com/articles/s41588-022-01048-5), which predicts an enhancer's activity.\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "i2rtkspVfXFw" }, "source": [ "## **Importing required packages**" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "BKb8ZufeHccM" }, "source": [ "### **Import and install**" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "S_dMf_n-Epvp" }, "outputs": [], "source": [ "# Imports\n", "from transformers import AutoTokenizer, TrainingArguments, Trainer, AutoModelForSequenceClassification\n", "import torch\n", "from sklearn.metrics import matthews_corrcoef, f1_score\n", "from sklearn.model_selection import train_test_split\n", "import matplotlib.pyplot as plt\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "I7NSdbS-O6DI" }, "outputs": [], "source": [ "from accelerate.test_utils.testing import get_backend\n", "\n", "device, _, _ = get_backend()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "uuM3JB7SH45T" }, "source": [ "### **Prepare and create the model for fine-tuning**" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "ch7deTIvWrjF" }, "source": [ "The nucleotide transformer will be fine-tuned on two **classification tasks**: **promoter** and **enhancer types** classification.\n", "The `AutoModelForSequenceClassification` module automatically loads the model and adds a simple classification head on top of the final embeddings." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "k0shQr8_pu2s" }, "source": [ "## **First task : Promoter prediction**" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "ykhgl7B7-3EC" }, "source": [ "Promoter prediction is a **sequence classification** problem, in which the DNA sequence is predicted to be either a promoter or not.\n", "\n", "A promoter is a region of DNA where transcription of a gene is initiated. Promoters are a vital component of expression vectors because they control the binding of RNA polymerase to DNA. RNA polymerase transcribes DNA to mRNA which is ultimately translated into a functional protein\n", "\n", " This task was introduced in [DeePromoter](https://www.frontiersin.org/articles/10.3389/fgene.2019.00286/full), where a set of TATA and non-TATA promoters was gathered. A negative sequence was generated from each promoter, by randomly sampling subsets of the sequence, to guarantee that some obvious motifs were present both in the positive and negative dataset.\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QyaFCHqqQwpM" }, "outputs": [], "source": [ "num_labels_promoter = 2\n", "# Load the model\n", "model = AutoModelForSequenceClassification.from_pretrained(\"InstaDeepAI/nucleotide-transformer-500m-human-ref\", num_labels=num_labels_promoter)\n", "model = model.to(device)\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "HmIVmKByPk_T" }, "source": [ "### **Dataset loading and preparation**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "jGmNDcaKnoW6", "outputId": "cd2081f5-b51c-4694-a42d-7d96060bbe37" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:datasets.builder:Found cached dataset nucleotide_transformer_downstream_tasks_public (/root/.cache/huggingface/datasets/InstaDeepAI___nucleotide_transformer_downstream_tasks_public/promoter_all/0.0.0/d649d80b49e7b062da8a12a4d80a5d636571467e76a0a036d89078ffded1e5c9)\n", "WARNING:datasets.builder:Found cached dataset nucleotide_transformer_downstream_tasks_public (/root/.cache/huggingface/datasets/InstaDeepAI___nucleotide_transformer_downstream_tasks_public/promoter_all/0.0.0/d649d80b49e7b062da8a12a4d80a5d636571467e76a0a036d89078ffded1e5c9)\n" ] } ], "source": [ "from datasets import load_dataset, Dataset\n", "\n", "# Load the promoter dataset from the InstaDeep Hugging Face ressources\n", "dataset_name = \"promoter_all\"\n", "train_dataset_promoter = load_dataset(\n", " \"InstaDeepAI/nucleotide_transformer_downstream_tasks\",\n", " dataset_name,\n", " split=\"train\",\n", " streaming= False,\n", " )\n", "test_dataset_promoter = load_dataset(\n", " \"InstaDeepAI/nucleotide_transformer_downstream_tasks\",\n", " dataset_name,\n", " split=\"test\",\n", " streaming= False,\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G0N81WU4_r_l" }, "outputs": [], "source": [ "# Get training data\n", "train_sequences_promoter = train_dataset_promoter['sequence']\n", "train_labels_promoter = train_dataset_promoter['label']\n", "\n", "# Split the dataset into a training and a validation dataset\n", "train_sequences_promoter, validation_sequences_promoter, train_labels_promoter, validation_labels_promoter = train_test_split(train_sequences_promoter,\n", " train_labels_promoter, test_size=0.05, random_state=42)\n", "\n", "# Get test data\n", "test_sequences_promoter = test_dataset_promoter['sequence']\n", "test_labels_promoter = test_dataset_promoter['label']" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "bJYXsXPamVmq" }, "source": [ "Let us have a look at the data. If we extract the last sequence of the dataset, we see that it is indeed a promoter, as its label is 1. Furthermore, we can also see that it is a TATA promoter, as the TATA motif is present at the 221th nucleotide of the sequence!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "BRyL7nw6mDBF", "outputId": "9fc7c6f5-bc84-4fb0-cea0-2aec6f2dabf8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The DNA sequence is CACACCAGACAAAATTTGGTTAATTTGCGCCCAATATTCATTACTTTGACCTAACCTTTGTTCTGAAGGCCGTGTACAAGGACAAGGCCCTGAGATTATTGCAACAGTAACTTGAAAAACTTTCAGAAGTCTATTCTGTAGGATTAAAGGAATGCTGAGACTATTCAAGTTTGAAGTCCTGGGGGTGGGGAAAAATAAAAAACCTGTGCTAGAAAGCTTAGTATAGCATGTAACTTTAGAGTCCTGTGGAGTCCTGAGTCTCCCACAGACCAGAACAGTCATTTAAAAGTTTTCAGGAAA.\n", "Its associated label is label 1.\n", "This promoter is a TATA promoter, as the TATA motif is present at the 221th nucleotide.\n" ] } ], "source": [ "idx_sequence = -1\n", "sequence, label = train_sequences_promoter[idx_sequence], train_labels_promoter[idx_sequence]\n", "print(f\"The DNA sequence is {sequence}.\")\n", "print(f\"Its associated label is label {label}.\")\n", "\n", "idx_TATA = sequence.find(\"TATA\")\n", "\n", "print(f\"This promoter is a TATA promoter, as the TATA motif is present at the {idx_TATA}th nucleotide.\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "LvWLmFNjb2e0" }, "source": [ "### **Tokenizing the datasets**" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "CZWBIP8WUqLc" }, "source": [ "All inputs to neural nets must be numerical. The process of converting strings into numerical indices suitable for a neural net is called **tokenization**." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 113, "referenced_widgets": [ "41f0bd944b344306b2d733d7fd4f774b", "22d9dbe577b24d11aef695d50c758849", "f9d22785d89349f5b90c3fea031138ce", "7fd6f4b9269b4472b1816c4421fcc6de", "c1aada443f4c455796baaddf4a2601df", "da8d9fe59e354fd6b5275a5b1db2b042", "20a59fb6437d428ead7804d1d11b5a85", "e35890f77765447c9bb222eb6a102693", "0145b27c4df24e0db4197ab7c6e51929", "0c0d9a6536e54316a806e70f1f25b1b5", "7bc66105a9ca4dfa8a7f0bd35eece04b", "77d8294039da4b64b8e4621ec90ccc72", "190a8f17e3784b459de580f3b1a7fb88", "a3de9e2b6ee34a6aba03e3c0b57d748a", "663cc040b34743819d963e96516b4923", "fe0c86ad3f204244bee37243562f3727", "17bc5e5803434ac2b4f7e3a32fd607fe", "5f484378e87d4372ae3e29dec39da0fe", "7f5d7aa87cd7446fbca7a376cf08799d", "18896cac46ae4b17b5ed3f52290614dd", "e86fda4855ab45308f8de11c279e5bac", "adc0e65b9d394d9e8e019b25d0fd4ea9", "003615bc2ed445c0a356fef5866048aa", "dfdb0b9cf3464a8f9e80116f8b813df5", "ee0204452662487198fc8a75af7ae710", "ce5fe292799444cca3a0187a40960ff8", "bfbfe09ab54c4a7ba7c6bf972201e17c", "2f02dd06e29740e9ab9b80629bd3571a", "d6215f85005c44c2b37ae88825f48652", "47157cecfa5f465ab817172771fd32a0", "0e4f98525a154213a2b51f4590945a1c", "530bc19fd4914773b2dd8afbb768f984", "368ff5a5bc0c412080b939fa4bc8aa9a" ] }, "id": "tLebn4WTOUvO", "outputId": "011cc2bf-8f49-4e27-a07d-9c753b8f0179" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "41f0bd944b344306b2d733d7fd4f774b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading (…)okenizer_config.json: 0%| | 0.00/129 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "77d8294039da4b64b8e4621ec90ccc72", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading (…)solve/main/vocab.txt: 0.00B [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "003615bc2ed445c0a356fef5866048aa", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading (…)cial_tokens_map.json: 0%| | 0.00/101 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Load the tokenizer\n", "tokenizer = AutoTokenizer.from_pretrained(\"InstaDeepAI/nucleotide-transformer-500m-human-ref\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "X-pjPlthh2yv" }, "outputs": [], "source": [ "# Promoter dataset\n", "ds_train_promoter = Dataset.from_dict({\"data\": train_sequences_promoter,'labels':train_labels_promoter})\n", "ds_validation_promoter = Dataset.from_dict({\"data\": validation_sequences_promoter,'labels':validation_labels_promoter})\n", "ds_test_promoter = Dataset.from_dict({\"data\": test_sequences_promoter,'labels':test_labels_promoter})" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "wlaF29Xkn0qb" }, "outputs": [], "source": [ "def tokenize_function(examples):\n", " outputs = tokenizer(examples[\"data\"])\n", " return outputs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 17, "referenced_widgets": [ "551bf73f511845139a68c4954238e35a", "e94490ab2f6b499b9b3fd9a75ac34ecb", "3c50aef619844d3cb17b46112feef9e7", "6766349bcff54c6eba335b5e338fb3c2", "d72ffe2f537f4046aafe2071ad590986", "fac68609669e47a088920a9b49ac02ba", "de2b63f32abf428fb7207088e0345a41", "90e1d6ce95c045a18fcbbbfd5642cd33", "37810c305a4045e9bc7c77e75afb02e3", "be2ea4ec591f47069d89442348280d0e", "35c6085cd0b64bc79d3b40255eeb5e0c", "0e51ea01bc1a42279f60e8e948e4b698", "b668f41b31bf49c09a683b5dbb58f41e", "09970e98908a464aa39a2225c334e636", "0a4fd91da75546e9a6cb71f471868d2f", "ece0b154d7c64841bcbd10f62eb470a7", "05557b863f604c5b965da537d62675a4", "303e0aa79d6049b1b2c19394063102b3", "abad7fded3ac4d8a8f36a441b29f5743", "8f421ffe0f434e38aafe011bd75f8e3f", "900995fb7e1e4ebd9d1604108ef9b7f3", "96d4a4bdc23c4a49a40ff47be76ed986", "ca34f595b6e44f6cbae64b0e6709c619", "ef9ba1f0bed24ca4bb26b0f377e5e8c9", "bb439e793ae843d8bb7c3bf1ddd0e644", "c0eed79d70f4465eb5338ac22e7995be", "e1d8e2fa407644ae9398dd36004f19e7", "cf013941c28b4716b0401e98072a24f5", "c9300935f3c341c5b8fca9260ac3efc5", "5f8d3d55518643869060002f0a9dbea4", "8b31036099484077b74e21e226b6d4c3", "c5804024c20a4cc497d3fd123bf42204", "122df545ab684d67a3237bf27b2f429f" ] }, "id": "C3y1xdemnITN", "outputId": "c740abfe-2262-40f7-db96-ac36da2e8b26" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "551bf73f511845139a68c4954238e35a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/50612 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0e51ea01bc1a42279f60e8e948e4b698", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/2664 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ca34f595b6e44f6cbae64b0e6709c619", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/5920 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Creating tokenized promoter dataset\n", "tokenized_datasets_train_promoter = ds_train_promoter.map(\n", " tokenize_function,\n", " batched=True,\n", " remove_columns=[\"data\"],\n", ")\n", "tokenized_datasets_validation_promoter = ds_validation_promoter.map(\n", " tokenize_function,\n", " batched=True,\n", " remove_columns=[\"data\"],\n", ")\n", "tokenized_datasets_test_promoter = ds_test_promoter.map(\n", " tokenize_function,\n", " batched=True,\n", " remove_columns=[\"data\"],\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "HYlwYjn0N71G" }, "source": [ "### **Fine-tuning and evaluation**" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "CoSW_uHVaSVN" }, "source": [ "The hyper-parameters introduced here are different from the ones used in the paper since we are training the whole model. Further hyper-parameters search will surely improve the performance on the task!.\n", "We initialize our `TrainingArguments`. These control the various training hyperparameters, and will be passed to our `Trainer`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9ph4KeV1EyW3" }, "outputs": [], "source": [ "batch_size = 8\n", "model_name='nucleotide-transformer'\n", "args_promoter = TrainingArguments(\n", " f\"{model_name}-finetuned-NucleotideTransformer\",\n", " remove_unused_columns=False,\n", " eval_strategy=\"steps\",\n", " save_strategy=\"steps\",\n", " learning_rate=1e-5,\n", " per_device_train_batch_size=batch_size,\n", " gradient_accumulation_steps= 1,\n", " per_device_eval_batch_size= 64,\n", " num_train_epochs= 2,\n", " logging_steps= 100,\n", " load_best_model_at_end=True, # Keep the best model according to the evaluation\n", " metric_for_best_model=\"f1_score\",\n", " label_names=[\"labels\"],\n", " dataloader_drop_last=True,\n", " max_steps= 1000\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "eGxAeSt6YWal" }, "source": [ "Next, we define the metric we will use to evaluate our models and write a `compute_metrics` function. We can load this from the `scikit-learn` library." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OezXy__KTRyH" }, "outputs": [], "source": [ "# Define the metric for the evaluation using the f1 score\n", "def compute_metrics_f1_score(eval_pred):\n", " \"\"\"Computes F1 score for binary classification\"\"\"\n", " predictions = np.argmax(eval_pred.predictions, axis=-1)\n", " references = eval_pred.label_ids\n", " r={'f1_score': f1_score(references, predictions)}\n", " return r" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8uKdPOcTMDKG" }, "outputs": [], "source": [ "trainer = Trainer(\n", " model.to(device),\n", " args_promoter,\n", " train_dataset= tokenized_datasets_train_promoter,\n", " eval_dataset= tokenized_datasets_validation_promoter,\n", " tokenizer=tokenizer,\n", " compute_metrics=compute_metrics_f1_score,\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "HQi2sqLVXy1I" }, "source": [ "We can now finetune our model by just calling the `train` method:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 757 }, "id": "6DCpUK3Lfor3", "outputId": "eea494b8-6443-458a-e4c7-188776121b6a" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " warnings.warn(\n" ] }, { "data": { "text/html": [ "\n", "
| Step | \n", "Training Loss | \n", "Validation Loss | \n", "F1 Score | \n", "
|---|---|---|---|
| 50 | \n", "0.537400 | \n", "0.474810 | \n", "0.716025 | \n", "
| 100 | \n", "0.382600 | \n", "0.541949 | \n", "0.829284 | \n", "
| 150 | \n", "0.306800 | \n", "0.239439 | \n", "0.917620 | \n", "
| 200 | \n", "0.277000 | \n", "0.258893 | \n", "0.920110 | \n", "
| 250 | \n", "0.382500 | \n", "0.339054 | \n", "0.900000 | \n", "
| 300 | \n", "0.251800 | \n", "0.297043 | \n", "0.902489 | \n", "
| 350 | \n", "0.271700 | \n", "0.229136 | \n", "0.913751 | \n", "
| 400 | \n", "0.264300 | \n", "0.252098 | \n", "0.925468 | \n", "
| 450 | \n", "0.375900 | \n", "0.203039 | \n", "0.924782 | \n", "
| 500 | \n", "0.242500 | \n", "0.185888 | \n", "0.925326 | \n", "
| 550 | \n", "0.293500 | \n", "0.255502 | \n", "0.928889 | \n", "
| 600 | \n", "0.209800 | \n", "0.287625 | \n", "0.920796 | \n", "
| 650 | \n", "0.205400 | \n", "0.232000 | \n", "0.933126 | \n", "
| 700 | \n", "0.221300 | \n", "0.245803 | \n", "0.929134 | \n", "
| 750 | \n", "0.270400 | \n", "0.221323 | \n", "0.936412 | \n", "
| 800 | \n", "0.243600 | \n", "0.229891 | \n", "0.939006 | \n", "
| 850 | \n", "0.281000 | \n", "0.210025 | \n", "0.933492 | \n", "
| 900 | \n", "0.158600 | \n", "0.202798 | \n", "0.939675 | \n", "
| 950 | \n", "0.205100 | \n", "0.210799 | \n", "0.937712 | \n", "
| 1000 | \n", "0.159900 | \n", "0.202635 | \n", "0.940362 | \n", "
"
],
"text/plain": [
" "
],
"text/plain": [
"\n",
" \n",
"
\n",
" \n",
" \n",
" \n",
" Step \n",
" Training Loss \n",
" Validation Loss \n",
" Mcc Score \n",
" \n",
" \n",
" 100 \n",
" 0.894000 \n",
" 1.261910 \n",
" 0.154693 \n",
" \n",
" \n",
" 200 \n",
" 0.777800 \n",
" 0.519611 \n",
" 0.697207 \n",
" \n",
" \n",
" 300 \n",
" 0.658400 \n",
" 0.610950 \n",
" 0.632458 \n",
" \n",
" \n",
" 400 \n",
" 0.524100 \n",
" 0.394602 \n",
" 0.765819 \n",
" \n",
" \n",
" 500 \n",
" 0.517400 \n",
" 0.370607 \n",
" 0.781838 \n",
" \n",
" \n",
" 600 \n",
" 0.483000 \n",
" 0.382210 \n",
" 0.764866 \n",
" \n",
" \n",
" 700 \n",
" 0.384300 \n",
" 0.460822 \n",
" 0.783114 \n",
" \n",
" \n",
" 800 \n",
" 0.319300 \n",
" 0.331123 \n",
" 0.798741 \n",
" \n",
" \n",
" 900 \n",
" 0.368600 \n",
" 0.316699 \n",
" 0.802810 \n",
" \n",
" \n",
" \n",
"1000 \n",
" 0.304100 \n",
" 0.315886 \n",
" 0.802805 \n",
"