{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Requirements for running this notebook:\n", "- Python =>3.6 (using f-Strings)\n", "- PyTorch => 1.0\n", "- scikit-learn\n", "- NumPy\n", "- PyTables\n", "- mt_data.h5 file: You can generate by yourself with this other notebook or [download it](http://ftp.ebi.ac.uk/pub/databases/chembl/blog/pytorch_mtl/mt_data.h5)\n", "\n", "## This notebook trains and test a multi-task neural network on ChEMBL data\n", "- It uses a simple shuffled 80/20 train/test split\n", "- Automatically configures the output layer no matter the number of targets in the training data.\n", "- Tries to use GPU if available\n", "- Saves and loads a model to/from a file\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import torch\n", "from torch import nn\n", "import torch.nn.functional as F\n", "import torch.utils.data as D\n", "import tables as tb\n", "from sklearn.metrics import (matthews_corrcoef, \n", " confusion_matrix, \n", " f1_score, \n", " roc_auc_score,\n", " accuracy_score,\n", " roc_auc_score)\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# set the device to GPU if available\n", "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Set some config values" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "MAIN_PATH = '.'\n", "DATA_FILE = 'mt_data.h5'\n", "MODEL_FILE = 'chembl_mt.model'\n", "N_WORKERS = 8 # Dataloader workers, prefetch data in parallel to have it ready for the model after each batch train\n", "BATCH_SIZE = 32 # https://twitter.com/ylecun/status/989610208497360896?lang=es\n", "LR = 2 # Learning rate. Big value because of the way we are weighting the targets\n", "N_EPOCHS = 2 # You should train longer!!!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Set the dataset loaders\n", "\n", "Simple 80/20 train/test split for the example" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class ChEMBLDataset(D.Dataset):\n", " \n", " def __init__(self, file_path):\n", " self.file_path = file_path\n", " with tb.open_file(self.file_path, mode='r') as t_file:\n", " self.length = t_file.root.fps.shape[0]\n", " self.n_targets = t_file.root.labels.shape[1]\n", " \n", " def __len__(self):\n", " return self.length\n", " \n", " def __getitem__(self, index):\n", " with tb.open_file(self.file_path, mode='r') as t_file:\n", " structure = t_file.root.fps[index]\n", " labels = t_file.root.labels[index]\n", " return structure, labels\n", "\n", "\n", "dataset = ChEMBLDataset(f\"{MAIN_PATH}/{DATA_FILE}\")\n", "validation_split = .2\n", "random_seed= 42\n", "\n", "dataset_size = len(dataset)\n", "indices = list(range(dataset_size))\n", "split = int(np.floor(validation_split * dataset_size))\n", "\n", "np.random.seed(random_seed)\n", "np.random.shuffle(indices)\n", "train_indices, test_indices = indices[split:], indices[:split]\n", "\n", "train_sampler = D.sampler.SubsetRandomSampler(train_indices)\n", "test_sampler = D.sampler.SubsetRandomSampler(test_indices)\n", "\n", "# dataloaders can prefetch the next batch if using n workers while\n", "# the model is tranining\n", "train_loader = torch.utils.data.DataLoader(dataset,\n", " batch_size=BATCH_SIZE,\n", " num_workers=N_WORKERS,\n", " sampler=train_sampler)\n", "\n", "test_loader = torch.utils.data.DataLoader(dataset, \n", " batch_size=BATCH_SIZE,\n", " num_workers=N_WORKERS,\n", " sampler=test_sampler)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Define the model, the optimizer and the loss criterion" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class ChEMBLMultiTask(nn.Module):\n", " \"\"\"\n", " Architecture borrowed from: https://arxiv.org/abs/1502.02072\n", " \"\"\"\n", " def __init__(self, n_tasks):\n", " super(ChEMBLMultiTask, self).__init__()\n", " self.n_tasks = n_tasks\n", " self.fc1 = nn.Linear(1024, 2000)\n", " self.fc2 = nn.Linear(2000, 100)\n", " self.dropout = nn.Dropout(0.25)\n", "\n", " # add an independet output for each task int the output laer\n", " for n_m in range(self.n_tasks):\n", " self.add_module(f\"y{n_m}o\", nn.Linear(100, 1))\n", " \n", " def forward(self, x):\n", " h1 = self.dropout(F.relu(self.fc1(x)))\n", " h2 = F.relu(self.fc2(h1))\n", " out = [torch.sigmoid(getattr(self, f\"y{n_m}o\")(h2)) for n_m in range(self.n_tasks)]\n", " return out\n", " \n", "# create the model, to GPU if available\n", "model = ChEMBLMultiTask(dataset.n_targets).to(device)\n", "\n", "# binary cross entropy\n", "# each task loss is weighted inversely proportional to its number of datapoints, borrowed from:\n", "# http://www.bioinf.at/publications/2014/NIPS2014a.pdf\n", "with tb.open_file(f\"{MAIN_PATH}/{DATA_FILE}\", mode='r') as t_file:\n", " weights = torch.tensor(t_file.root.weights[:])\n", " weights = weights.to(device)\n", "\n", "criterion = [nn.BCELoss(weight=w) for x, w in zip(range(dataset.n_targets), weights.float())]\n", "\n", "# stochastic gradient descend as an optimiser\n", "optimizer = torch.optim.SGD(model.parameters(), LR)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Train the model\n", "Given the extremely sparse nature of the dataset is difficult to clearly see how the loss is improving after every batch. It looks clearer after several epochs and much more clear when testing :)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: [1/2], Step: [500/17789], Loss: 0.01780553348362446\n", "Epoch: [1/2], Step: [1000/17789], Loss: 0.01136045902967453\n", "Epoch: [1/2], Step: [1500/17789], Loss: 0.018664617091417313\n", "Epoch: [1/2], Step: [2000/17789], Loss: 0.013626799918711185\n", "Epoch: [1/2], Step: [2500/17789], Loss: 0.012855792418122292\n", "Epoch: [1/2], Step: [3000/17789], Loss: 0.013796127401292324\n", "Epoch: [1/2], Step: [3500/17789], Loss: 0.021601887419819832\n", "Epoch: [1/2], Step: [4000/17789], Loss: 0.00950919184833765\n", "Epoch: [1/2], Step: [4500/17789], Loss: 0.02028888650238514\n", "Epoch: [1/2], Step: [5000/17789], Loss: 0.013251284137368202\n", "Epoch: [1/2], Step: [5500/17789], Loss: 0.008788244798779488\n", "Epoch: [1/2], Step: [6000/17789], Loss: 0.012066680938005447\n", "Epoch: [1/2], Step: [6500/17789], Loss: 0.013928443193435669\n", "Epoch: [1/2], Step: [7000/17789], Loss: 0.011484757997095585\n", "Epoch: [1/2], Step: [7500/17789], Loss: 0.0071386718191206455\n", "Epoch: [1/2], Step: [8000/17789], Loss: 0.014712771400809288\n", "Epoch: [1/2], Step: [8500/17789], Loss: 0.010457032360136509\n", "Epoch: [1/2], Step: [9000/17789], Loss: 0.00854165107011795\n", "Epoch: [1/2], Step: [9500/17789], Loss: 0.009312299080193043\n", "Epoch: [1/2], Step: [10000/17789], Loss: 0.010153095237910748\n", "Epoch: [1/2], Step: [10500/17789], Loss: 0.006983090192079544\n", "Epoch: [1/2], Step: [11000/17789], Loss: 0.010238541290163994\n", "Epoch: [1/2], Step: [11500/17789], Loss: 0.012679124251008034\n", "Epoch: [1/2], Step: [12000/17789], Loss: 0.01116170920431614\n", "Epoch: [1/2], Step: [12500/17789], Loss: 0.011749005876481533\n", "Epoch: [1/2], Step: [13000/17789], Loss: 0.015176426619291306\n", "Epoch: [1/2], Step: [13500/17789], Loss: 0.013586488552391529\n", "Epoch: [1/2], Step: [14000/17789], Loss: 0.012365413829684258\n", "Epoch: [1/2], Step: [14500/17789], Loss: 0.009591283276677132\n", "Epoch: [1/2], Step: [15000/17789], Loss: 0.01857740990817547\n", "Epoch: [1/2], Step: [15500/17789], Loss: 0.009823130443692207\n", "Epoch: [1/2], Step: [16000/17789], Loss: 0.01805167831480503\n", "Epoch: [1/2], Step: [16500/17789], Loss: 0.011896809563040733\n", "Epoch: [1/2], Step: [17000/17789], Loss: 0.008349821902811527\n", "Epoch: [1/2], Step: [17500/17789], Loss: 0.013517800718545914\n", "Epoch: [2/2], Step: [500/17789], Loss: 0.007128629367798567\n", "Epoch: [2/2], Step: [1000/17789], Loss: 0.01153416559100151\n", "Epoch: [2/2], Step: [1500/17789], Loss: 0.02041609212756157\n", "Epoch: [2/2], Step: [2000/17789], Loss: 0.0165218748152256\n", "Epoch: [2/2], Step: [2500/17789], Loss: 0.011772445403039455\n", "Epoch: [2/2], Step: [3000/17789], Loss: 0.011200090870261192\n", "Epoch: [2/2], Step: [3500/17789], Loss: 0.012209323234856129\n", "Epoch: [2/2], Step: [4000/17789], Loss: 0.007769708056002855\n", "Epoch: [2/2], Step: [4500/17789], Loss: 0.012243629433214664\n", "Epoch: [2/2], Step: [5000/17789], Loss: 0.018942933529615402\n", "Epoch: [2/2], Step: [5500/17789], Loss: 0.013197326101362705\n", "Epoch: [2/2], Step: [6000/17789], Loss: 0.011520257219672203\n", "Epoch: [2/2], Step: [6500/17789], Loss: 0.020596494898200035\n", "Epoch: [2/2], Step: [7000/17789], Loss: 0.018161792308092117\n", "Epoch: [2/2], Step: [7500/17789], Loss: 0.01610906422138214\n", "Epoch: [2/2], Step: [8000/17789], Loss: 0.004183729644864798\n", "Epoch: [2/2], Step: [8500/17789], Loss: 0.01284581795334816\n", "Epoch: [2/2], Step: [9000/17789], Loss: 0.014269811101257801\n", "Epoch: [2/2], Step: [9500/17789], Loss: 0.009626287035644054\n", "Epoch: [2/2], Step: [10000/17789], Loss: 0.008639814332127571\n", "Epoch: [2/2], Step: [10500/17789], Loss: 0.011639382690191269\n", "Epoch: [2/2], Step: [11000/17789], Loss: 0.005331861320883036\n", "Epoch: [2/2], Step: [11500/17789], Loss: 0.011540957726538181\n", "Epoch: [2/2], Step: [12000/17789], Loss: 0.010148015804588795\n", "Epoch: [2/2], Step: [12500/17789], Loss: 0.011556670069694519\n", "Epoch: [2/2], Step: [13000/17789], Loss: 0.0069694253616034985\n", "Epoch: [2/2], Step: [13500/17789], Loss: 0.008971192874014378\n", "Epoch: [2/2], Step: [14000/17789], Loss: 0.02061212807893753\n", "Epoch: [2/2], Step: [14500/17789], Loss: 0.013362124562263489\n", "Epoch: [2/2], Step: [15000/17789], Loss: 0.00966110359877348\n", "Epoch: [2/2], Step: [15500/17789], Loss: 0.017838571220636368\n", "Epoch: [2/2], Step: [16000/17789], Loss: 0.007174369413405657\n", "Epoch: [2/2], Step: [16500/17789], Loss: 0.0074622794054448605\n", "Epoch: [2/2], Step: [17000/17789], Loss: 0.015448285266757011\n", "Epoch: [2/2], Step: [17500/17789], Loss: 0.011626753024756908\n" ] } ], "source": [ "# model is by default in train mode. Training can be resumed after .eval() but needs to be set to .train() again\n", "model.train()\n", "for ep in range(N_EPOCHS):\n", " for i, (fps, labels) in enumerate(train_loader):\n", " # move it to GPU if available\n", " fps, labels = fps.to(device), labels.to(device)\n", "\n", " optimizer.zero_grad()\n", " outputs = model(fps)\n", " \n", " # calc the loss\n", " loss = torch.tensor(0.0).to(device)\n", " for j, crit in enumerate(criterion):\n", " # mask keeping labeled molecules for each task\n", " mask = labels[:, j] >= 0.0\n", " if len(labels[:, j][mask]) != 0:\n", " # the loss is the sum of each task/target loss.\n", " # there are labeled samples for this task, so we add it's loss\n", " loss += crit(outputs[j][mask], labels[:, j][mask].view(-1, 1))\n", "\n", " loss.backward()\n", " optimizer.step()\n", "\n", " if (i+1) % 500 == 0:\n", " print(f\"Epoch: [{ep+1}/{N_EPOCHS}], Step: [{i+1}/{len(train_indices)//BATCH_SIZE}], Loss: {loss.item()}\")\n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Test the model" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "accuracy: 0.8371918235997756, auc: 0.8942389411754185, sens: 0.7053822792666977, spec: 0.8987519347341067, prec: 0.7649158653846154, mcc: 0.6179805824644773, f1: 0.733943790291889\n", "Not bad for only 2 epochs!\n" ] } ], "source": [ "y_trues = []\n", "y_preds = []\n", "y_preds_proba = []\n", "\n", "# do not track history\n", "with torch.no_grad():\n", " for fps, labels in test_loader:\n", " # move it to GPU if available\n", " fps, labels = fps.to(device), labels.to(device)\n", " # set model to eval, so will not use the dropout layer\n", " model.eval()\n", " outputs = model(fps)\n", " for j, out in enumerate(outputs):\n", " mask = labels[:, j] >= 0.0\n", " y_pred = torch.where(out[mask] > 0.5, torch.ones(1), torch.zeros(1)).view(1, -1)\n", "\n", " if y_pred.shape[1] > 0:\n", " for l in labels[:, j][mask].long().tolist():\n", " y_trues.append(l)\n", " for p in y_pred.view(-1, 1).tolist():\n", " y_preds.append(int(p[0]))\n", " for p in out[mask].view(-1, 1).tolist():\n", " y_preds_proba.append(float(p[0]))\n", "\n", "tn, fp, fn, tp = confusion_matrix(y_trues, y_preds).ravel()\n", "sens = tp / (tp + fn)\n", "spec = tn / (tn + fp)\n", "prec = tp / (tp + fp)\n", "f1 = f1_score(y_trues, y_preds)\n", "acc = accuracy_score(y_trues, y_preds)\n", "mcc = matthews_corrcoef(y_trues, y_preds)\n", "auc = roc_auc_score(y_trues, y_preds_proba)\n", "\n", "print(f\"accuracy: {acc}, auc: {auc}, sens: {sens}, spec: {spec}, prec: {prec}, mcc: {mcc}, f1: {f1}\")\n", "print(f\"Not bad for only {N_EPOCHS} epochs!\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Save the model to a file" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "torch.save(model.state_dict(), f\"./{MODEL_FILE}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Load the model from the file" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "ChEMBLMultiTask(\n", " (fc1): Linear(in_features=1024, out_features=2000, bias=True)\n", " (fc2): Linear(in_features=2000, out_features=100, bias=True)\n", " (dropout): Dropout(p=0.25)\n", " (y0o): Linear(in_features=100, out_features=1, bias=True)\n", " (y1o): Linear(in_features=100, out_features=1, bias=True)\n", " (y2o): Linear(in_features=100, out_features=1, bias=True)\n", " (y3o): Linear(in_features=100, out_features=1, bias=True)\n", " (y4o): Linear(in_features=100, out_features=1, bias=True)\n", " (y5o): Linear(in_features=100, out_features=1, bias=True)\n", " (y6o): Linear(in_features=100, out_features=1, bias=True)\n", " (y7o): Linear(in_features=100, out_features=1, bias=True)\n", " (y8o): Linear(in_features=100, out_features=1, bias=True)\n", " (y9o): Linear(in_features=100, out_features=1, bias=True)\n", " (y10o): Linear(in_features=100, out_features=1, bias=True)\n", " (y11o): Linear(in_features=100, out_features=1, bias=True)\n", " (y12o): Linear(in_features=100, out_features=1, bias=True)\n", " (y13o): Linear(in_features=100, out_features=1, bias=True)\n", " (y14o): Linear(in_features=100, out_features=1, bias=True)\n", " (y15o): Linear(in_features=100, out_features=1, bias=True)\n", " (y16o): Linear(in_features=100, out_features=1, bias=True)\n", " (y17o): Linear(in_features=100, out_features=1, bias=True)\n", " (y18o): Linear(in_features=100, out_features=1, bias=True)\n", " (y19o): Linear(in_features=100, out_features=1, bias=True)\n", " (y20o): Linear(in_features=100, out_features=1, bias=True)\n", " (y21o): Linear(in_features=100, out_features=1, bias=True)\n", " (y22o): Linear(in_features=100, out_features=1, bias=True)\n", " (y23o): Linear(in_features=100, out_features=1, bias=True)\n", " (y24o): Linear(in_features=100, out_features=1, bias=True)\n", " (y25o): Linear(in_features=100, out_features=1, bias=True)\n", " (y26o): Linear(in_features=100, out_features=1, bias=True)\n", " (y27o): Linear(in_features=100, out_features=1, bias=True)\n", " (y28o): Linear(in_features=100, out_features=1, bias=True)\n", " (y29o): Linear(in_features=100, out_features=1, bias=True)\n", " (y30o): Linear(in_features=100, out_features=1, bias=True)\n", " (y31o): Linear(in_features=100, out_features=1, bias=True)\n", " (y32o): Linear(in_features=100, out_features=1, bias=True)\n", " (y33o): Linear(in_features=100, out_features=1, bias=True)\n", " (y34o): Linear(in_features=100, out_features=1, bias=True)\n", " (y35o): Linear(in_features=100, out_features=1, bias=True)\n", " (y36o): Linear(in_features=100, out_features=1, bias=True)\n", " (y37o): Linear(in_features=100, out_features=1, bias=True)\n", " (y38o): Linear(in_features=100, out_features=1, bias=True)\n", " (y39o): Linear(in_features=100, out_features=1, bias=True)\n", " (y40o): Linear(in_features=100, out_features=1, bias=True)\n", " (y41o): Linear(in_features=100, out_features=1, bias=True)\n", " (y42o): Linear(in_features=100, out_features=1, bias=True)\n", " (y43o): Linear(in_features=100, out_features=1, bias=True)\n", " (y44o): Linear(in_features=100, out_features=1, bias=True)\n", " (y45o): Linear(in_features=100, out_features=1, bias=True)\n", " (y46o): Linear(in_features=100, out_features=1, bias=True)\n", " (y47o): Linear(in_features=100, out_features=1, bias=True)\n", " (y48o): Linear(in_features=100, out_features=1, bias=True)\n", " (y49o): Linear(in_features=100, out_features=1, bias=True)\n", " (y50o): Linear(in_features=100, out_features=1, bias=True)\n", " (y51o): Linear(in_features=100, out_features=1, bias=True)\n", " (y52o): Linear(in_features=100, out_features=1, bias=True)\n", " (y53o): Linear(in_features=100, out_features=1, bias=True)\n", " (y54o): Linear(in_features=100, out_features=1, bias=True)\n", " (y55o): Linear(in_features=100, out_features=1, bias=True)\n", " (y56o): Linear(in_features=100, out_features=1, bias=True)\n", " (y57o): Linear(in_features=100, out_features=1, bias=True)\n", " (y58o): Linear(in_features=100, out_features=1, bias=True)\n", " (y59o): Linear(in_features=100, out_features=1, bias=True)\n", " (y60o): Linear(in_features=100, out_features=1, bias=True)\n", " (y61o): Linear(in_features=100, out_features=1, bias=True)\n", " (y62o): Linear(in_features=100, out_features=1, bias=True)\n", " (y63o): Linear(in_features=100, out_features=1, bias=True)\n", " (y64o): Linear(in_features=100, out_features=1, bias=True)\n", " (y65o): Linear(in_features=100, out_features=1, bias=True)\n", " (y66o): Linear(in_features=100, out_features=1, bias=True)\n", " (y67o): Linear(in_features=100, out_features=1, bias=True)\n", " (y68o): Linear(in_features=100, out_features=1, bias=True)\n", " (y69o): Linear(in_features=100, out_features=1, bias=True)\n", " (y70o): Linear(in_features=100, out_features=1, bias=True)\n", " (y71o): Linear(in_features=100, out_features=1, bias=True)\n", " (y72o): Linear(in_features=100, out_features=1, bias=True)\n", " (y73o): Linear(in_features=100, out_features=1, bias=True)\n", " (y74o): Linear(in_features=100, out_features=1, bias=True)\n", " (y75o): Linear(in_features=100, out_features=1, bias=True)\n", " (y76o): Linear(in_features=100, out_features=1, bias=True)\n", " (y77o): Linear(in_features=100, out_features=1, bias=True)\n", " (y78o): Linear(in_features=100, out_features=1, bias=True)\n", " (y79o): Linear(in_features=100, out_features=1, bias=True)\n", " (y80o): Linear(in_features=100, out_features=1, bias=True)\n", " (y81o): Linear(in_features=100, out_features=1, bias=True)\n", " (y82o): Linear(in_features=100, out_features=1, bias=True)\n", " (y83o): Linear(in_features=100, out_features=1, bias=True)\n", " (y84o): Linear(in_features=100, out_features=1, bias=True)\n", " (y85o): Linear(in_features=100, out_features=1, bias=True)\n", " (y86o): Linear(in_features=100, out_features=1, bias=True)\n", " (y87o): Linear(in_features=100, out_features=1, bias=True)\n", " (y88o): Linear(in_features=100, out_features=1, bias=True)\n", " (y89o): Linear(in_features=100, out_features=1, bias=True)\n", " (y90o): Linear(in_features=100, out_features=1, bias=True)\n", " (y91o): Linear(in_features=100, out_features=1, bias=True)\n", " (y92o): Linear(in_features=100, out_features=1, bias=True)\n", " (y93o): Linear(in_features=100, out_features=1, bias=True)\n", " (y94o): Linear(in_features=100, out_features=1, bias=True)\n", " (y95o): Linear(in_features=100, out_features=1, bias=True)\n", " (y96o): Linear(in_features=100, out_features=1, bias=True)\n", " (y97o): Linear(in_features=100, out_features=1, bias=True)\n", " (y98o): Linear(in_features=100, out_features=1, bias=True)\n", " (y99o): Linear(in_features=100, out_features=1, bias=True)\n", " (y100o): Linear(in_features=100, out_features=1, bias=True)\n", " (y101o): Linear(in_features=100, out_features=1, bias=True)\n", " (y102o): Linear(in_features=100, out_features=1, bias=True)\n", " (y103o): Linear(in_features=100, out_features=1, bias=True)\n", " (y104o): Linear(in_features=100, out_features=1, bias=True)\n", " (y105o): Linear(in_features=100, out_features=1, bias=True)\n", " (y106o): Linear(in_features=100, out_features=1, bias=True)\n", " (y107o): Linear(in_features=100, out_features=1, bias=True)\n", " (y108o): Linear(in_features=100, out_features=1, bias=True)\n", " (y109o): Linear(in_features=100, out_features=1, bias=True)\n", " (y110o): Linear(in_features=100, out_features=1, bias=True)\n", " (y111o): Linear(in_features=100, out_features=1, bias=True)\n", " (y112o): Linear(in_features=100, out_features=1, bias=True)\n", " (y113o): Linear(in_features=100, out_features=1, bias=True)\n", " (y114o): Linear(in_features=100, out_features=1, bias=True)\n", " (y115o): Linear(in_features=100, out_features=1, bias=True)\n", " (y116o): Linear(in_features=100, out_features=1, bias=True)\n", " (y117o): Linear(in_features=100, out_features=1, bias=True)\n", " (y118o): Linear(in_features=100, out_features=1, bias=True)\n", " (y119o): Linear(in_features=100, out_features=1, bias=True)\n", " (y120o): Linear(in_features=100, out_features=1, bias=True)\n", " (y121o): Linear(in_features=100, out_features=1, bias=True)\n", " (y122o): Linear(in_features=100, out_features=1, bias=True)\n", " (y123o): Linear(in_features=100, out_features=1, bias=True)\n", " (y124o): Linear(in_features=100, out_features=1, bias=True)\n", " (y125o): Linear(in_features=100, out_features=1, bias=True)\n", " (y126o): Linear(in_features=100, out_features=1, bias=True)\n", " (y127o): Linear(in_features=100, out_features=1, bias=True)\n", " (y128o): Linear(in_features=100, out_features=1, bias=True)\n", " (y129o): Linear(in_features=100, out_features=1, bias=True)\n", " (y130o): Linear(in_features=100, out_features=1, bias=True)\n", " (y131o): Linear(in_features=100, out_features=1, bias=True)\n", " (y132o): Linear(in_features=100, out_features=1, bias=True)\n", " (y133o): Linear(in_features=100, out_features=1, bias=True)\n", " (y134o): Linear(in_features=100, out_features=1, bias=True)\n", " (y135o): Linear(in_features=100, out_features=1, bias=True)\n", " (y136o): Linear(in_features=100, out_features=1, bias=True)\n", " (y137o): Linear(in_features=100, out_features=1, bias=True)\n", " (y138o): Linear(in_features=100, out_features=1, bias=True)\n", " (y139o): Linear(in_features=100, out_features=1, bias=True)\n", " (y140o): Linear(in_features=100, out_features=1, bias=True)\n", " (y141o): Linear(in_features=100, out_features=1, bias=True)\n", " (y142o): Linear(in_features=100, out_features=1, bias=True)\n", " (y143o): Linear(in_features=100, out_features=1, bias=True)\n", " (y144o): Linear(in_features=100, out_features=1, bias=True)\n", " (y145o): Linear(in_features=100, out_features=1, bias=True)\n", " (y146o): Linear(in_features=100, out_features=1, bias=True)\n", " (y147o): Linear(in_features=100, out_features=1, bias=True)\n", " (y148o): Linear(in_features=100, out_features=1, bias=True)\n", " (y149o): Linear(in_features=100, out_features=1, bias=True)\n", " (y150o): Linear(in_features=100, out_features=1, bias=True)\n", " (y151o): Linear(in_features=100, out_features=1, bias=True)\n", " (y152o): Linear(in_features=100, out_features=1, bias=True)\n", " (y153o): Linear(in_features=100, out_features=1, bias=True)\n", " (y154o): Linear(in_features=100, out_features=1, bias=True)\n", " (y155o): Linear(in_features=100, out_features=1, bias=True)\n", " (y156o): Linear(in_features=100, out_features=1, bias=True)\n", " (y157o): Linear(in_features=100, out_features=1, bias=True)\n", " (y158o): Linear(in_features=100, out_features=1, bias=True)\n", " (y159o): Linear(in_features=100, out_features=1, bias=True)\n", " (y160o): Linear(in_features=100, out_features=1, bias=True)\n", " (y161o): Linear(in_features=100, out_features=1, bias=True)\n", " (y162o): Linear(in_features=100, out_features=1, bias=True)\n", " (y163o): Linear(in_features=100, out_features=1, bias=True)\n", " (y164o): Linear(in_features=100, out_features=1, bias=True)\n", " (y165o): Linear(in_features=100, out_features=1, bias=True)\n", " (y166o): Linear(in_features=100, out_features=1, bias=True)\n", " (y167o): Linear(in_features=100, out_features=1, bias=True)\n", " (y168o): Linear(in_features=100, out_features=1, bias=True)\n", " (y169o): Linear(in_features=100, out_features=1, bias=True)\n", " (y170o): Linear(in_features=100, out_features=1, bias=True)\n", " (y171o): Linear(in_features=100, out_features=1, bias=True)\n", " (y172o): Linear(in_features=100, out_features=1, bias=True)\n", " (y173o): Linear(in_features=100, out_features=1, bias=True)\n", " (y174o): Linear(in_features=100, out_features=1, bias=True)\n", " (y175o): Linear(in_features=100, out_features=1, bias=True)\n", " (y176o): Linear(in_features=100, out_features=1, bias=True)\n", " (y177o): Linear(in_features=100, out_features=1, bias=True)\n", " (y178o): Linear(in_features=100, out_features=1, bias=True)\n", " (y179o): Linear(in_features=100, out_features=1, bias=True)\n", " (y180o): Linear(in_features=100, out_features=1, bias=True)\n", " (y181o): Linear(in_features=100, out_features=1, bias=True)\n", " (y182o): Linear(in_features=100, out_features=1, bias=True)\n", " (y183o): Linear(in_features=100, out_features=1, bias=True)\n", " (y184o): Linear(in_features=100, out_features=1, bias=True)\n", " (y185o): Linear(in_features=100, out_features=1, bias=True)\n", " (y186o): Linear(in_features=100, out_features=1, bias=True)\n", " (y187o): Linear(in_features=100, out_features=1, bias=True)\n", " (y188o): Linear(in_features=100, out_features=1, bias=True)\n", " (y189o): Linear(in_features=100, out_features=1, bias=True)\n", " (y190o): Linear(in_features=100, out_features=1, bias=True)\n", " (y191o): Linear(in_features=100, out_features=1, bias=True)\n", " (y192o): Linear(in_features=100, out_features=1, bias=True)\n", " (y193o): Linear(in_features=100, out_features=1, bias=True)\n", " (y194o): Linear(in_features=100, out_features=1, bias=True)\n", " (y195o): Linear(in_features=100, out_features=1, bias=True)\n", " (y196o): Linear(in_features=100, out_features=1, bias=True)\n", " (y197o): Linear(in_features=100, out_features=1, bias=True)\n", " (y198o): Linear(in_features=100, out_features=1, bias=True)\n", " (y199o): Linear(in_features=100, out_features=1, bias=True)\n", " (y200o): Linear(in_features=100, out_features=1, bias=True)\n", " (y201o): Linear(in_features=100, out_features=1, bias=True)\n", " (y202o): Linear(in_features=100, out_features=1, bias=True)\n", " (y203o): Linear(in_features=100, out_features=1, bias=True)\n", " (y204o): Linear(in_features=100, out_features=1, bias=True)\n", " (y205o): Linear(in_features=100, out_features=1, bias=True)\n", " (y206o): Linear(in_features=100, out_features=1, bias=True)\n", " (y207o): Linear(in_features=100, out_features=1, bias=True)\n", " (y208o): Linear(in_features=100, out_features=1, bias=True)\n", " (y209o): Linear(in_features=100, out_features=1, bias=True)\n", " (y210o): Linear(in_features=100, out_features=1, bias=True)\n", " (y211o): Linear(in_features=100, out_features=1, bias=True)\n", " (y212o): Linear(in_features=100, out_features=1, bias=True)\n", " (y213o): Linear(in_features=100, out_features=1, bias=True)\n", " (y214o): Linear(in_features=100, out_features=1, bias=True)\n", " (y215o): Linear(in_features=100, out_features=1, bias=True)\n", " (y216o): Linear(in_features=100, out_features=1, bias=True)\n", " (y217o): Linear(in_features=100, out_features=1, bias=True)\n", " (y218o): Linear(in_features=100, out_features=1, bias=True)\n", " (y219o): Linear(in_features=100, out_features=1, bias=True)\n", " (y220o): Linear(in_features=100, out_features=1, bias=True)\n", " (y221o): Linear(in_features=100, out_features=1, bias=True)\n", " (y222o): Linear(in_features=100, out_features=1, bias=True)\n", " (y223o): Linear(in_features=100, out_features=1, bias=True)\n", " (y224o): Linear(in_features=100, out_features=1, bias=True)\n", " (y225o): Linear(in_features=100, out_features=1, bias=True)\n", " (y226o): Linear(in_features=100, out_features=1, bias=True)\n", " (y227o): Linear(in_features=100, out_features=1, bias=True)\n", " (y228o): Linear(in_features=100, out_features=1, bias=True)\n", " (y229o): Linear(in_features=100, out_features=1, bias=True)\n", " (y230o): Linear(in_features=100, out_features=1, bias=True)\n", " (y231o): Linear(in_features=100, out_features=1, bias=True)\n", " (y232o): Linear(in_features=100, out_features=1, bias=True)\n", " (y233o): Linear(in_features=100, out_features=1, bias=True)\n", " (y234o): Linear(in_features=100, out_features=1, bias=True)\n", " (y235o): Linear(in_features=100, out_features=1, bias=True)\n", " (y236o): Linear(in_features=100, out_features=1, bias=True)\n", " (y237o): Linear(in_features=100, out_features=1, bias=True)\n", " (y238o): Linear(in_features=100, out_features=1, bias=True)\n", " (y239o): Linear(in_features=100, out_features=1, bias=True)\n", " (y240o): Linear(in_features=100, out_features=1, bias=True)\n", " (y241o): Linear(in_features=100, out_features=1, bias=True)\n", " (y242o): Linear(in_features=100, out_features=1, bias=True)\n", " (y243o): Linear(in_features=100, out_features=1, bias=True)\n", " (y244o): Linear(in_features=100, out_features=1, bias=True)\n", " (y245o): Linear(in_features=100, out_features=1, bias=True)\n", " (y246o): Linear(in_features=100, out_features=1, bias=True)\n", " (y247o): Linear(in_features=100, out_features=1, bias=True)\n", " (y248o): Linear(in_features=100, out_features=1, bias=True)\n", " (y249o): Linear(in_features=100, out_features=1, bias=True)\n", " (y250o): Linear(in_features=100, out_features=1, bias=True)\n", " (y251o): Linear(in_features=100, out_features=1, bias=True)\n", " (y252o): Linear(in_features=100, out_features=1, bias=True)\n", " (y253o): Linear(in_features=100, out_features=1, bias=True)\n", " (y254o): Linear(in_features=100, out_features=1, bias=True)\n", " (y255o): Linear(in_features=100, out_features=1, bias=True)\n", " (y256o): Linear(in_features=100, out_features=1, bias=True)\n", " (y257o): Linear(in_features=100, out_features=1, bias=True)\n", " (y258o): Linear(in_features=100, out_features=1, bias=True)\n", " (y259o): Linear(in_features=100, out_features=1, bias=True)\n", " (y260o): Linear(in_features=100, out_features=1, bias=True)\n", " (y261o): Linear(in_features=100, out_features=1, bias=True)\n", " (y262o): Linear(in_features=100, out_features=1, bias=True)\n", " (y263o): Linear(in_features=100, out_features=1, bias=True)\n", " (y264o): Linear(in_features=100, out_features=1, bias=True)\n", " (y265o): Linear(in_features=100, out_features=1, bias=True)\n", " (y266o): Linear(in_features=100, out_features=1, bias=True)\n", " (y267o): Linear(in_features=100, out_features=1, bias=True)\n", " (y268o): Linear(in_features=100, out_features=1, bias=True)\n", " (y269o): Linear(in_features=100, out_features=1, bias=True)\n", " (y270o): Linear(in_features=100, out_features=1, bias=True)\n", " (y271o): Linear(in_features=100, out_features=1, bias=True)\n", " (y272o): Linear(in_features=100, out_features=1, bias=True)\n", " (y273o): Linear(in_features=100, out_features=1, bias=True)\n", " (y274o): Linear(in_features=100, out_features=1, bias=True)\n", " (y275o): Linear(in_features=100, out_features=1, bias=True)\n", " (y276o): Linear(in_features=100, out_features=1, bias=True)\n", " (y277o): Linear(in_features=100, out_features=1, bias=True)\n", " (y278o): Linear(in_features=100, out_features=1, bias=True)\n", " (y279o): Linear(in_features=100, out_features=1, bias=True)\n", " (y280o): Linear(in_features=100, out_features=1, bias=True)\n", " (y281o): Linear(in_features=100, out_features=1, bias=True)\n", " (y282o): Linear(in_features=100, out_features=1, bias=True)\n", " (y283o): Linear(in_features=100, out_features=1, bias=True)\n", " (y284o): Linear(in_features=100, out_features=1, bias=True)\n", " (y285o): Linear(in_features=100, out_features=1, bias=True)\n", " (y286o): Linear(in_features=100, out_features=1, bias=True)\n", " (y287o): Linear(in_features=100, out_features=1, bias=True)\n", " (y288o): Linear(in_features=100, out_features=1, bias=True)\n", " (y289o): Linear(in_features=100, out_features=1, bias=True)\n", " (y290o): Linear(in_features=100, out_features=1, bias=True)\n", " (y291o): Linear(in_features=100, out_features=1, bias=True)\n", " (y292o): Linear(in_features=100, out_features=1, bias=True)\n", " (y293o): Linear(in_features=100, out_features=1, bias=True)\n", " (y294o): Linear(in_features=100, out_features=1, bias=True)\n", " (y295o): Linear(in_features=100, out_features=1, bias=True)\n", " (y296o): Linear(in_features=100, out_features=1, bias=True)\n", " (y297o): Linear(in_features=100, out_features=1, bias=True)\n", " (y298o): Linear(in_features=100, out_features=1, bias=True)\n", " (y299o): Linear(in_features=100, out_features=1, bias=True)\n", " (y300o): Linear(in_features=100, out_features=1, bias=True)\n", " (y301o): Linear(in_features=100, out_features=1, bias=True)\n", " (y302o): Linear(in_features=100, out_features=1, bias=True)\n", " (y303o): Linear(in_features=100, out_features=1, bias=True)\n", " (y304o): Linear(in_features=100, out_features=1, bias=True)\n", " (y305o): Linear(in_features=100, out_features=1, bias=True)\n", " (y306o): Linear(in_features=100, out_features=1, bias=True)\n", " (y307o): Linear(in_features=100, out_features=1, bias=True)\n", " (y308o): Linear(in_features=100, out_features=1, bias=True)\n", " (y309o): Linear(in_features=100, out_features=1, bias=True)\n", " (y310o): Linear(in_features=100, out_features=1, bias=True)\n", " (y311o): Linear(in_features=100, out_features=1, bias=True)\n", " (y312o): Linear(in_features=100, out_features=1, bias=True)\n", " (y313o): Linear(in_features=100, out_features=1, bias=True)\n", " (y314o): Linear(in_features=100, out_features=1, bias=True)\n", " (y315o): Linear(in_features=100, out_features=1, bias=True)\n", " (y316o): Linear(in_features=100, out_features=1, bias=True)\n", " (y317o): Linear(in_features=100, out_features=1, bias=True)\n", " (y318o): Linear(in_features=100, out_features=1, bias=True)\n", " (y319o): Linear(in_features=100, out_features=1, bias=True)\n", " (y320o): Linear(in_features=100, out_features=1, bias=True)\n", " (y321o): Linear(in_features=100, out_features=1, bias=True)\n", " (y322o): Linear(in_features=100, out_features=1, bias=True)\n", " (y323o): Linear(in_features=100, out_features=1, bias=True)\n", " (y324o): Linear(in_features=100, out_features=1, bias=True)\n", " (y325o): Linear(in_features=100, out_features=1, bias=True)\n", " (y326o): Linear(in_features=100, out_features=1, bias=True)\n", " (y327o): Linear(in_features=100, out_features=1, bias=True)\n", " (y328o): Linear(in_features=100, out_features=1, bias=True)\n", " (y329o): Linear(in_features=100, out_features=1, bias=True)\n", " (y330o): Linear(in_features=100, out_features=1, bias=True)\n", " (y331o): Linear(in_features=100, out_features=1, bias=True)\n", " (y332o): Linear(in_features=100, out_features=1, bias=True)\n", " (y333o): Linear(in_features=100, out_features=1, bias=True)\n", " (y334o): Linear(in_features=100, out_features=1, bias=True)\n", " (y335o): Linear(in_features=100, out_features=1, bias=True)\n", " (y336o): Linear(in_features=100, out_features=1, bias=True)\n", " (y337o): Linear(in_features=100, out_features=1, bias=True)\n", " (y338o): Linear(in_features=100, out_features=1, bias=True)\n", " (y339o): Linear(in_features=100, out_features=1, bias=True)\n", " (y340o): Linear(in_features=100, out_features=1, bias=True)\n", " (y341o): Linear(in_features=100, out_features=1, bias=True)\n", " (y342o): Linear(in_features=100, out_features=1, bias=True)\n", " (y343o): Linear(in_features=100, out_features=1, bias=True)\n", " (y344o): Linear(in_features=100, out_features=1, bias=True)\n", " (y345o): Linear(in_features=100, out_features=1, bias=True)\n", " (y346o): Linear(in_features=100, out_features=1, bias=True)\n", " (y347o): Linear(in_features=100, out_features=1, bias=True)\n", " (y348o): Linear(in_features=100, out_features=1, bias=True)\n", " (y349o): Linear(in_features=100, out_features=1, bias=True)\n", " (y350o): Linear(in_features=100, out_features=1, bias=True)\n", " (y351o): Linear(in_features=100, out_features=1, bias=True)\n", " (y352o): Linear(in_features=100, out_features=1, bias=True)\n", " (y353o): Linear(in_features=100, out_features=1, bias=True)\n", " (y354o): Linear(in_features=100, out_features=1, bias=True)\n", " (y355o): Linear(in_features=100, out_features=1, bias=True)\n", " (y356o): Linear(in_features=100, out_features=1, bias=True)\n", " (y357o): Linear(in_features=100, out_features=1, bias=True)\n", " (y358o): Linear(in_features=100, out_features=1, bias=True)\n", " (y359o): Linear(in_features=100, out_features=1, bias=True)\n", " (y360o): Linear(in_features=100, out_features=1, bias=True)\n", " (y361o): Linear(in_features=100, out_features=1, bias=True)\n", " (y362o): Linear(in_features=100, out_features=1, bias=True)\n", " (y363o): Linear(in_features=100, out_features=1, bias=True)\n", " (y364o): Linear(in_features=100, out_features=1, bias=True)\n", " (y365o): Linear(in_features=100, out_features=1, bias=True)\n", " (y366o): Linear(in_features=100, out_features=1, bias=True)\n", " (y367o): Linear(in_features=100, out_features=1, bias=True)\n", " (y368o): Linear(in_features=100, out_features=1, bias=True)\n", " (y369o): Linear(in_features=100, out_features=1, bias=True)\n", " (y370o): Linear(in_features=100, out_features=1, bias=True)\n", " (y371o): Linear(in_features=100, out_features=1, bias=True)\n", " (y372o): Linear(in_features=100, out_features=1, bias=True)\n", " (y373o): Linear(in_features=100, out_features=1, bias=True)\n", " (y374o): Linear(in_features=100, out_features=1, bias=True)\n", " (y375o): Linear(in_features=100, out_features=1, bias=True)\n", " (y376o): Linear(in_features=100, out_features=1, bias=True)\n", " (y377o): Linear(in_features=100, out_features=1, bias=True)\n", " (y378o): Linear(in_features=100, out_features=1, bias=True)\n", " (y379o): Linear(in_features=100, out_features=1, bias=True)\n", " (y380o): Linear(in_features=100, out_features=1, bias=True)\n", " (y381o): Linear(in_features=100, out_features=1, bias=True)\n", " (y382o): Linear(in_features=100, out_features=1, bias=True)\n", " (y383o): Linear(in_features=100, out_features=1, bias=True)\n", " (y384o): Linear(in_features=100, out_features=1, bias=True)\n", " (y385o): Linear(in_features=100, out_features=1, bias=True)\n", " (y386o): Linear(in_features=100, out_features=1, bias=True)\n", " (y387o): Linear(in_features=100, out_features=1, bias=True)\n", " (y388o): Linear(in_features=100, out_features=1, bias=True)\n", " (y389o): Linear(in_features=100, out_features=1, bias=True)\n", " (y390o): Linear(in_features=100, out_features=1, bias=True)\n", " (y391o): Linear(in_features=100, out_features=1, bias=True)\n", " (y392o): Linear(in_features=100, out_features=1, bias=True)\n", " (y393o): Linear(in_features=100, out_features=1, bias=True)\n", " (y394o): Linear(in_features=100, out_features=1, bias=True)\n", " (y395o): Linear(in_features=100, out_features=1, bias=True)\n", " (y396o): Linear(in_features=100, out_features=1, bias=True)\n", " (y397o): Linear(in_features=100, out_features=1, bias=True)\n", " (y398o): Linear(in_features=100, out_features=1, bias=True)\n", " (y399o): Linear(in_features=100, out_features=1, bias=True)\n", " (y400o): Linear(in_features=100, out_features=1, bias=True)\n", " (y401o): Linear(in_features=100, out_features=1, bias=True)\n", " (y402o): Linear(in_features=100, out_features=1, bias=True)\n", " (y403o): Linear(in_features=100, out_features=1, bias=True)\n", " (y404o): Linear(in_features=100, out_features=1, bias=True)\n", " (y405o): Linear(in_features=100, out_features=1, bias=True)\n", " (y406o): Linear(in_features=100, out_features=1, bias=True)\n", " (y407o): Linear(in_features=100, out_features=1, bias=True)\n", " (y408o): Linear(in_features=100, out_features=1, bias=True)\n", " (y409o): Linear(in_features=100, out_features=1, bias=True)\n", " (y410o): Linear(in_features=100, out_features=1, bias=True)\n", " (y411o): Linear(in_features=100, out_features=1, bias=True)\n", " (y412o): Linear(in_features=100, out_features=1, bias=True)\n", " (y413o): Linear(in_features=100, out_features=1, bias=True)\n", " (y414o): Linear(in_features=100, out_features=1, bias=True)\n", " (y415o): Linear(in_features=100, out_features=1, bias=True)\n", " (y416o): Linear(in_features=100, out_features=1, bias=True)\n", " (y417o): Linear(in_features=100, out_features=1, bias=True)\n", " (y418o): Linear(in_features=100, out_features=1, bias=True)\n", " (y419o): Linear(in_features=100, out_features=1, bias=True)\n", " (y420o): Linear(in_features=100, out_features=1, bias=True)\n", " (y421o): Linear(in_features=100, out_features=1, bias=True)\n", " (y422o): Linear(in_features=100, out_features=1, bias=True)\n", " (y423o): Linear(in_features=100, out_features=1, bias=True)\n", " (y424o): Linear(in_features=100, out_features=1, bias=True)\n", " (y425o): Linear(in_features=100, out_features=1, bias=True)\n", " (y426o): Linear(in_features=100, out_features=1, bias=True)\n", " (y427o): Linear(in_features=100, out_features=1, bias=True)\n", " (y428o): Linear(in_features=100, out_features=1, bias=True)\n", " (y429o): Linear(in_features=100, out_features=1, bias=True)\n", " (y430o): Linear(in_features=100, out_features=1, bias=True)\n", " (y431o): Linear(in_features=100, out_features=1, bias=True)\n", " (y432o): Linear(in_features=100, out_features=1, bias=True)\n", " (y433o): Linear(in_features=100, out_features=1, bias=True)\n", " (y434o): Linear(in_features=100, out_features=1, bias=True)\n", " (y435o): Linear(in_features=100, out_features=1, bias=True)\n", " (y436o): Linear(in_features=100, out_features=1, bias=True)\n", " (y437o): Linear(in_features=100, out_features=1, bias=True)\n", " (y438o): Linear(in_features=100, out_features=1, bias=True)\n", " (y439o): Linear(in_features=100, out_features=1, bias=True)\n", " (y440o): Linear(in_features=100, out_features=1, bias=True)\n", " (y441o): Linear(in_features=100, out_features=1, bias=True)\n", " (y442o): Linear(in_features=100, out_features=1, bias=True)\n", " (y443o): Linear(in_features=100, out_features=1, bias=True)\n", " (y444o): Linear(in_features=100, out_features=1, bias=True)\n", " (y445o): Linear(in_features=100, out_features=1, bias=True)\n", " (y446o): Linear(in_features=100, out_features=1, bias=True)\n", " (y447o): Linear(in_features=100, out_features=1, bias=True)\n", " (y448o): Linear(in_features=100, out_features=1, bias=True)\n", " (y449o): Linear(in_features=100, out_features=1, bias=True)\n", " (y450o): Linear(in_features=100, out_features=1, bias=True)\n", " (y451o): Linear(in_features=100, out_features=1, bias=True)\n", " (y452o): Linear(in_features=100, out_features=1, bias=True)\n", " (y453o): Linear(in_features=100, out_features=1, bias=True)\n", " (y454o): Linear(in_features=100, out_features=1, bias=True)\n", " (y455o): Linear(in_features=100, out_features=1, bias=True)\n", " (y456o): Linear(in_features=100, out_features=1, bias=True)\n", " (y457o): Linear(in_features=100, out_features=1, bias=True)\n", " (y458o): Linear(in_features=100, out_features=1, bias=True)\n", " (y459o): Linear(in_features=100, out_features=1, bias=True)\n", " (y460o): Linear(in_features=100, out_features=1, bias=True)\n", " (y461o): Linear(in_features=100, out_features=1, bias=True)\n", " (y462o): Linear(in_features=100, out_features=1, bias=True)\n", " (y463o): Linear(in_features=100, out_features=1, bias=True)\n", " (y464o): Linear(in_features=100, out_features=1, bias=True)\n", " (y465o): Linear(in_features=100, out_features=1, bias=True)\n", " (y466o): Linear(in_features=100, out_features=1, bias=True)\n", " (y467o): Linear(in_features=100, out_features=1, bias=True)\n", " (y468o): Linear(in_features=100, out_features=1, bias=True)\n", " (y469o): Linear(in_features=100, out_features=1, bias=True)\n", " (y470o): Linear(in_features=100, out_features=1, bias=True)\n", " (y471o): Linear(in_features=100, out_features=1, bias=True)\n", " (y472o): Linear(in_features=100, out_features=1, bias=True)\n", " (y473o): Linear(in_features=100, out_features=1, bias=True)\n", " (y474o): Linear(in_features=100, out_features=1, bias=True)\n", " (y475o): Linear(in_features=100, out_features=1, bias=True)\n", " (y476o): Linear(in_features=100, out_features=1, bias=True)\n", " (y477o): Linear(in_features=100, out_features=1, bias=True)\n", " (y478o): Linear(in_features=100, out_features=1, bias=True)\n", " (y479o): Linear(in_features=100, out_features=1, bias=True)\n", " (y480o): Linear(in_features=100, out_features=1, bias=True)\n", " (y481o): Linear(in_features=100, out_features=1, bias=True)\n", " (y482o): Linear(in_features=100, out_features=1, bias=True)\n", " (y483o): Linear(in_features=100, out_features=1, bias=True)\n", " (y484o): Linear(in_features=100, out_features=1, bias=True)\n", " (y485o): Linear(in_features=100, out_features=1, bias=True)\n", " (y486o): Linear(in_features=100, out_features=1, bias=True)\n", " (y487o): Linear(in_features=100, out_features=1, bias=True)\n", " (y488o): Linear(in_features=100, out_features=1, bias=True)\n", " (y489o): Linear(in_features=100, out_features=1, bias=True)\n", " (y490o): Linear(in_features=100, out_features=1, bias=True)\n", " (y491o): Linear(in_features=100, out_features=1, bias=True)\n", " (y492o): Linear(in_features=100, out_features=1, bias=True)\n", " (y493o): Linear(in_features=100, out_features=1, bias=True)\n", " (y494o): Linear(in_features=100, out_features=1, bias=True)\n", " (y495o): Linear(in_features=100, out_features=1, bias=True)\n", " (y496o): Linear(in_features=100, out_features=1, bias=True)\n", " (y497o): Linear(in_features=100, out_features=1, bias=True)\n", " (y498o): Linear(in_features=100, out_features=1, bias=True)\n", " (y499o): Linear(in_features=100, out_features=1, bias=True)\n", " (y500o): Linear(in_features=100, out_features=1, bias=True)\n", " (y501o): Linear(in_features=100, out_features=1, bias=True)\n", " (y502o): Linear(in_features=100, out_features=1, bias=True)\n", " (y503o): Linear(in_features=100, out_features=1, bias=True)\n", " (y504o): Linear(in_features=100, out_features=1, bias=True)\n", " (y505o): Linear(in_features=100, out_features=1, bias=True)\n", " (y506o): Linear(in_features=100, out_features=1, bias=True)\n", " (y507o): Linear(in_features=100, out_features=1, bias=True)\n", " (y508o): Linear(in_features=100, out_features=1, bias=True)\n", " (y509o): Linear(in_features=100, out_features=1, bias=True)\n", " (y510o): Linear(in_features=100, out_features=1, bias=True)\n", " (y511o): Linear(in_features=100, out_features=1, bias=True)\n", " (y512o): Linear(in_features=100, out_features=1, bias=True)\n", " (y513o): Linear(in_features=100, out_features=1, bias=True)\n", " (y514o): Linear(in_features=100, out_features=1, bias=True)\n", " (y515o): Linear(in_features=100, out_features=1, bias=True)\n", " (y516o): Linear(in_features=100, out_features=1, bias=True)\n", " (y517o): Linear(in_features=100, out_features=1, bias=True)\n", " (y518o): Linear(in_features=100, out_features=1, bias=True)\n", " (y519o): Linear(in_features=100, out_features=1, bias=True)\n", " (y520o): Linear(in_features=100, out_features=1, bias=True)\n", " (y521o): Linear(in_features=100, out_features=1, bias=True)\n", " (y522o): Linear(in_features=100, out_features=1, bias=True)\n", " (y523o): Linear(in_features=100, out_features=1, bias=True)\n", " (y524o): Linear(in_features=100, out_features=1, bias=True)\n", " (y525o): Linear(in_features=100, out_features=1, bias=True)\n", " (y526o): Linear(in_features=100, out_features=1, bias=True)\n", " (y527o): Linear(in_features=100, out_features=1, bias=True)\n", " (y528o): Linear(in_features=100, out_features=1, bias=True)\n", " (y529o): Linear(in_features=100, out_features=1, bias=True)\n", " (y530o): Linear(in_features=100, out_features=1, bias=True)\n", " (y531o): Linear(in_features=100, out_features=1, bias=True)\n", " (y532o): Linear(in_features=100, out_features=1, bias=True)\n", " (y533o): Linear(in_features=100, out_features=1, bias=True)\n", " (y534o): Linear(in_features=100, out_features=1, bias=True)\n", " (y535o): Linear(in_features=100, out_features=1, bias=True)\n", " (y536o): Linear(in_features=100, out_features=1, bias=True)\n", " (y537o): Linear(in_features=100, out_features=1, bias=True)\n", " (y538o): Linear(in_features=100, out_features=1, bias=True)\n", " (y539o): Linear(in_features=100, out_features=1, bias=True)\n", " (y540o): Linear(in_features=100, out_features=1, bias=True)\n", " (y541o): Linear(in_features=100, out_features=1, bias=True)\n", " (y542o): Linear(in_features=100, out_features=1, bias=True)\n", " (y543o): Linear(in_features=100, out_features=1, bias=True)\n", " (y544o): Linear(in_features=100, out_features=1, bias=True)\n", " (y545o): Linear(in_features=100, out_features=1, bias=True)\n", " (y546o): Linear(in_features=100, out_features=1, bias=True)\n", " (y547o): Linear(in_features=100, out_features=1, bias=True)\n", " (y548o): Linear(in_features=100, out_features=1, bias=True)\n", " (y549o): Linear(in_features=100, out_features=1, bias=True)\n", " (y550o): Linear(in_features=100, out_features=1, bias=True)\n", " (y551o): Linear(in_features=100, out_features=1, bias=True)\n", " (y552o): Linear(in_features=100, out_features=1, bias=True)\n", " (y553o): Linear(in_features=100, out_features=1, bias=True)\n", " (y554o): Linear(in_features=100, out_features=1, bias=True)\n", " (y555o): Linear(in_features=100, out_features=1, bias=True)\n", " (y556o): Linear(in_features=100, out_features=1, bias=True)\n", " (y557o): Linear(in_features=100, out_features=1, bias=True)\n", " (y558o): Linear(in_features=100, out_features=1, bias=True)\n", " (y559o): Linear(in_features=100, out_features=1, bias=True)\n", ")" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = ChEMBLMultiTask(560) # number of tasks\n", "model.load_state_dict(torch.load(f\"./{MODEL_FILE}\"))\n", "model.eval()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }