{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Neural network training example" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import sys\n", "!{sys.executable} -m pip install \"torch>=1.10\" --index-url https://download.pytorch.org/whl/cu118\n", "!{sys.executable} -m pip install cesnet-datazoo cesnet-models tqdm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Prepare data transformations for the model." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from cesnet_models.transforms import ClipAndScaleFlowstats, ClipAndScalePPI, NormalizeHistograms, ScalerEnum\n", "\n", "ppi_transform = ClipAndScalePPI(psizes_scaler_enum=ScalerEnum.STANDARD,\n", " ipt_scaler_enum=ScalerEnum.STANDARD,)\n", "flowstats_transform = ClipAndScaleFlowstats(flowstats_scaler_enum=ScalerEnum.ROBUST, quantile_clip=0.99)\n", "packet_histograms_transform = NormalizeHistograms()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Initialize the dataset class and prepare its configuration.\n", "\n", "* Define train and test periods from which the train and test sets will be built\n", "* Split the train set - use 20% of its samples as the validation set\n", "* We use all available applications for a closed-world classification task\n", "* Set data transforms" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[2024-04-08 17:40:19,224][cesnet_datazoo.pytables_data.indices_setup][INFO] - Processing train indices\n", "[2024-04-08 17:40:19,774][cesnet_datazoo.pytables_data.pytables_dataset][INFO] - Reading app column for table /flows/D20221114 took 0.51 seconds\n", "[2024-04-08 17:40:20,281][cesnet_datazoo.pytables_data.pytables_dataset][INFO] - Reading app column for table /flows/D20221115 took 0.51 seconds\n", "[2024-04-08 17:40:20,696][cesnet_datazoo.pytables_data.pytables_dataset][INFO] - Reading app column for table /flows/D20221116 took 0.42 seconds\n", "[2024-04-08 17:40:20,870][cesnet_datazoo.pytables_data.pytables_dataset][INFO] - Reading app column for table /flows/D20221117 took 0.17 seconds\n", "[2024-04-08 17:40:21,101][cesnet_datazoo.pytables_data.pytables_dataset][INFO] - Reading app column for table /flows/D20221118 took 0.23 seconds\n", "[2024-04-08 17:40:21,236][cesnet_datazoo.pytables_data.pytables_dataset][INFO] - Reading app column for table /flows/D20221119 took 0.13 seconds\n", "[2024-04-08 17:40:21,413][cesnet_datazoo.pytables_data.pytables_dataset][INFO] - Reading app column for table /flows/D20221120 took 0.18 seconds\n", "[2024-04-08 17:40:21,431][cesnet_datazoo.pytables_data.pytables_dataset][INFO] - Found applications with less than 100 train samples: ['livescore']. Disabling these applications\n", "[2024-04-08 17:40:21,442][cesnet_datazoo.pytables_data.pytables_dataset][INFO] - Selected 101 known applications and 0 unknown applications\n", "[2024-04-08 17:40:23,261][cesnet_datazoo.pytables_data.pytables_dataset][INFO] - Processing indices took 1.85 seconds\n", "[2024-04-08 17:40:27,834][cesnet_datazoo.pytables_data.data_scalers][INFO] - Reading data and fitting scalers took 3.68 seconds\n" ] } ], "source": [ "import logging\n", "from cesnet_datazoo.config import AppSelection, DatasetConfig, ValidationApproach\n", "from cesnet_datazoo.datasets import CESNET_QUIC22\n", "\n", "logging.basicConfig(\n", " level=logging.INFO,\n", " format=\"[%(asctime)s][%(name)s][%(levelname)s] - %(message)s\")\n", "\n", "DATASET_SIZE = \"XS\"\n", "dataset = CESNET_QUIC22(data_root=\"data/CESNET-QUIC22\", size=DATASET_SIZE)\n", "\n", "dataset_config = DatasetConfig(\n", " dataset=dataset,\n", " train_period_name=\"W-2022-46\",\n", " test_period_name=\"W-2022-47\",\n", " # train_size=500_000, # Uncomment to limit the number of training samples to speed up this example\n", " val_approach=ValidationApproach.SPLIT_FROM_TRAIN,\n", " train_val_split_fraction=0.2,\n", " apps_selection=AppSelection.ALL_KNOWN,\n", " return_tensors=True,\n", " use_packet_histograms=True,\n", " ppi_transform=ppi_transform,\n", " flowstats_transform=flowstats_transform,\n", " flowstats_phist_transform=packet_histograms_transform,)\n", "\n", "dataset.set_dataset_config_and_initialize(dataset_config)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Show dataset classes in the current configuration, together with train, validation, and test counts." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ " Train Validation Test\n", "google-www 121836 30459 205010\n", "google-ads 116419 29105 195979\n", "google-services 109998 27499 177295\n", "google-play 97905 24476 161546\n", "google-gstatic 92789 23197 150633\n", "... ... ... ...\n", "toggl 150 37 247\n", "ebay-kleinanzeigen 150 38 176\n", "alza-identity 130 32 215\n", "bitdefender-nimbus 118 29 204\n", "uber 87 22 118\n", "\n", "[101 rows x 3 columns]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset.known_app_counts.sort_values(by=\"Train\", ascending=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Reuse neural network architecture from the `cesnet-models` package without using pre-trained weights, i.e., start the training from scratch." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Multimodal_CESNET(\n", " (cnn_ppi): Sequential(\n", " (0): Conv1d(3, 200, kernel_size=(7,), stride=(1,), padding=(3,))\n", " (1): ReLU()\n", " (2): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (3): Sequential(\n", " (0): Conv1d(200, 200, kernel_size=(5,), stride=(1,), padding=(2,))\n", " (1): ReLU()\n", " (2): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (4): Sequential(\n", " (0): Conv1d(200, 200, kernel_size=(5,), stride=(1,), padding=(2,))\n", " (1): ReLU()\n", " (2): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (5): Sequential(\n", " (0): Conv1d(200, 200, kernel_size=(5,), stride=(1,), padding=(2,))\n", " (1): ReLU()\n", " (2): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (6): Conv1d(200, 300, kernel_size=(5,), stride=(1,))\n", " (7): ReLU()\n", " (8): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (9): Conv1d(300, 300, kernel_size=(5,), stride=(1,))\n", " (10): ReLU()\n", " (11): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (12): Conv1d(300, 300, kernel_size=(4,), stride=(2,))\n", " (13): ReLU()\n", " )\n", " (cnn_global_pooling): Sequential(\n", " (0): GeM(kernel_size=10, p=3.0000, eps=1e-06)\n", " (1): Flatten(start_dim=1, end_dim=-1)\n", " (2): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (3): Dropout(p=0.1, inplace=False)\n", " )\n", " (mlp_flowstats): Sequential(\n", " (0): Linear(in_features=43, out_features=225, bias=True)\n", " (1): ReLU()\n", " (2): BatchNorm1d(225, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (3): Sequential(\n", " (0): Linear(in_features=225, out_features=225, bias=True)\n", " (1): ReLU()\n", " (2): BatchNorm1d(225, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (4): Sequential(\n", " (0): Linear(in_features=225, out_features=225, bias=True)\n", " (1): ReLU()\n", " (2): BatchNorm1d(225, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (5): Linear(in_features=225, out_features=225, bias=True)\n", " (6): ReLU()\n", " (7): BatchNorm1d(225, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (8): Dropout(p=0.1, inplace=False)\n", " )\n", " (mlp_shared): Sequential(\n", " (0): Linear(in_features=525, out_features=600, bias=True)\n", " (1): ReLU()\n", " (2): BatchNorm1d(600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (3): Dropout(p=0.2, inplace=False)\n", " )\n", " (classifier): Linear(in_features=600, out_features=101, bias=True)\n", ")\n" ] } ], "source": [ "from cesnet_models.models import mm_cesnet_v2\n", "\n", "model = mm_cesnet_v2(weights=None, num_classes=dataset.get_num_classes(), ppi_input_channels=len(dataset_config.get_ppi_channels()), flowstats_input_size=dataset_config.get_flowstats_features_len())\n", "print(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training loop\n", "Train the model with a standard training loop using the cross-entropy loss, the AdamW optimizer, and the OneCycleLR learning scheduler.\n", "\n", "The number of epochs is set to five, and the model is validated after each epoch." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 1, validation accuracy: 0.8280 \n", "Epoch 2, validation accuracy: 0.8897 \n", "Epoch 3, validation accuracy: 0.8751 \n", "Epoch 4, validation accuracy: 0.9321 \n", "Epoch 5, validation accuracy: 0.9413 \n", "100%|██████████| 5/5 [09:41<00:00, 116.35s/it]\n" ] } ], "source": [ "import torch\n", "from torch import nn, optim\n", "from torch.utils.data import DataLoader\n", "from tqdm import tqdm\n", "\n", "def train_one_epoch(model: nn.Module, train_dataloader: DataLoader, optimizer: optim.Optimizer, scheduler, loss_fn, device) -> None:\n", " model.train()\n", " for _, batch_ppi, batch_flowstats, batch_labels in train_dataloader:\n", " batch_ppi, batch_flowstats, batch_labels = batch_ppi.to(device), batch_flowstats.to(device), batch_labels.to(device)\n", " optimizer.zero_grad()\n", " out = model((batch_ppi, batch_flowstats))\n", " loss = loss_fn(out, batch_labels)\n", " loss.backward()\n", " optimizer.step()\n", " scheduler.step()\n", "\n", "def test(model: nn.Module, dataloader: DataLoader, device, progress: bool = False) -> float:\n", " model.eval()\n", " true_labels = []\n", " preds = []\n", " with torch.no_grad():\n", " for __, batch_ppi, batch_flowstats, batch_labels in tqdm(dataloader, total=len(dataloader), disable=not progress):\n", " batch_ppi, batch_flowstats, batch_labels = batch_ppi.to(device), batch_flowstats.to(device), batch_labels.to(device)\n", " out = model((batch_ppi, batch_flowstats))\n", " batch_preds = out.argmax(dim=1)\n", " true_labels.append(batch_labels)\n", " preds.append(batch_preds)\n", " true_labels, preds = torch.cat(true_labels).cpu().numpy(), torch.cat(preds).cpu().numpy()\n", " return (true_labels == preds).mean()\n", "\n", "EPOCHS = 5\n", "train_dataloader = dataset.get_train_dataloader()\n", "val_dataloader = dataset.get_val_dataloader()\n", "optimizer = optim.AdamW(model.parameters())\n", "scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, steps_per_epoch=len(train_dataloader), epochs=EPOCHS)\n", "loss_fn = nn.CrossEntropyLoss()\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "model = model.to(device)\n", "\n", "for i in tqdm(range(1, EPOCHS + 1), total=EPOCHS, file=sys.stdout):\n", " train_one_epoch(model, train_dataloader, optimizer, scheduler, loss_fn, device)\n", " validation_accuracy = test(model, val_dataloader, device)\n", " tqdm.write(f\"Epoch {i}, validation accuracy: {validation_accuracy:.4f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Evaluate the trained model on the test set." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Computing model predictions on the test set.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 1289/1289 [00:38<00:00, 33.07it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "The trained model achieved an accuracy of 0.93962 on the test period W-2022-47 of the CESNET-QUIC22-XS dataset.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "test_dataloader = dataset.get_test_dataloader()\n", "print(\"Computing model predictions on the test set.\")\n", "test_accuracy = test(model, test_dataloader, device, progress=True)\n", "print(f\"The trained model achieved an accuracy of {test_accuracy:.5f} on the test period {dataset_config.test_period_name} of the {dataset.name} dataset.\")" ] } ], "metadata": { "kernelspec": { "display_name": "venv", "language": "python", "name": 