{ "cells": [ { "cell_type": "markdown", "id": "heated-violin", "metadata": {}, "source": [ "# \"pytorch-widedeep, deep learning for tabular data IV: Deep Learning vs LightGBM\"\n", "> A thorough comparison between DL algorithms and LightGBM for tabular data for classification and regression problems\n", "\n", "- author: Javier Rodriguez\n", "- toc: true \n", "- badges: true\n", "- comments: true\n" ] }, { "cell_type": "markdown", "id": "subsequent-group", "metadata": {}, "source": [ "Here we go with yet another post in the series. I started planning this posts a few months ago, as soon as I released what it was the last beta version (`0.4.8`) of the library [pytorch-widedeep](https://github.com/jrzaurin/pytorch-widedeep). However, since then, a few things took priority, which meant that to run the hundreds of experiments that I run (probably over 1500), took me considerably more time than I expected. Nevertheless, here we are. \n", "\n", "Let me start by saying thanks to the guys at the [AWS community builders](https://aws.amazon.com/developer/community/community-builders/) and specially to [Cameron](https://www.linkedin.com/in/cameronperon/), for making my life a lot easier around AWS.\n", "\n", "All the Deep Learning models for this project were run on a `p2.xlarge` instance and all the `LightGBM` experiments were run on my Mac `Mid 2015`. \n", "\n", "Once the proper acknowledgments have been made, let me tell you a bit about the context of all those experiments and eventually this post." ] }, { "cell_type": "markdown", "id": "prerequisite-nelson", "metadata": {}, "source": [ "## 1. Introduction: why all this?\n", "\n", "Through the last couple of years, and in particular during the last year, I have been putting a lot of effort in improving [pytorch-widedeep](https://github.com/jrzaurin/pytorch-widedeep). This has been **really** entertaining, and I have learned a lot. However, as I was adding models to the library, especially for the tabular component (see [here](https://pytorch-widedeep.readthedocs.io/en/latest/model_components.html)), I wondered if there was a purpose to it, other than learning those models themselves. You see, I am a scientist in education and I spent over a decade in academia. There we used to do *a lot* of not-very-useful things, cool (sometimes), but not very useful. One of the aspects that drove me to the private sector, a few years back now, was the search for a sense of \"usefulness\", where I could build things that have a scientific aspect and at the same time are useful. With that in mind, I wanted the library to be, forgive the redundancy, useful. Here the adjective \"useful\" can mean a number of things. It could mean directly using the library, or fork the repo and use the code, or just copy and paste some portion of the code for a given project. However eventually, a question that I wanted to answer was: *do these models compare well or even improve the performance of other more \"standard\" models like GBMs?*. Note that I write \"*a question*\" and not \"*the question*\". More on this later in the post.\n", "\n", "Of course, I am not the first to compare Deep Learning (hereafter DL) approaches with GBMs for tabular data, and I won't be the last. In fact, by the time I am writing these lines, a new paper: [Tabular Data: Deep Learning is Not All You Need](https://arxiv.org/pdf/2106.03253.pdf) [1] was published. This post and that paper are certainly very similar, and the conclusion entirely consistent. However, there are some differences. The compare DL algorithms against `XGBoost` [2] and `CatBoost` [3], while I use `LightGBM` [4] (see Section 2.3 for an explanation on the use of this algorithm). Also, I would say that three of the four datasets that I use here are a bit more challenging that the datasets in their paper, but that might be just my perception. Finally, with the exception of `TabNet`, the DL models I use here and those in that paper are different. Nonetheless, in the Conclusion section I will write some thoughts on ways to tackle this benchmark/testing exercises. \n", "\n", "Aside from that paper, in *all* papers where they release new models there are often comprehensive comparisons between DL architectures and GBMs. My main caveats with some of these publications are the following: I often do not manage to reproduce the results in the paper (which of course might be my fault) and I sometimes find that the effort placed in optimizing the DL models is a bit more \"*intense*\" than that for the GBMs. Last but not least, the lack of consistency in the results tables in some papers is, sometimes, confusing. For example, Paper A will use DL Model A to find that performs better than all GBMs, normally `XGBoost`, `Catboost` and `LightGBM`. Then Paper B will come with a new DL Model B that will also perform better than all GBMs, but in their paper it turns out that Model A does not beat GBMs anymore.\n", "\n", "Considering all that, I decided to use [pytorch-widedeep](https://github.com/jrzaurin/pytorch-widedeep) and run a sizeable set of experiments comprising different DL models for tabular data and [`LightGBM`](https://lightgbm.readthedocs.io/en/latest/#). \n", "\n", "Before I move on let me comment on the code \"quality\" in that repo. One has to bear in mind that the goal here is to test algorithms in a rigorous manner, and not to write production code. If you wanted to see better code you can go to the [pytorch-widedeep](https://github.com/jrzaurin/pytorch-widedeep) itself or maybe some other of my repos. Just saying in case some \"purist\" is tempted to waste universe's time." ] }, { "cell_type": "markdown", "id": "medical-browser", "metadata": {}, "source": [ "## 2. Datasets and Models\n", "\n", "For the experiments here I have used four datasets and four DL models. \n", "\n", "### 2.1 Datasets\n", "\n", "\n", "1. [Adult Census](https://archive.ics.uci.edu/ml/datasets/adult) (binary classification) \n", "2. [Bank Marketing](https://archive.ics.uci.edu/ml/datasets/Bank+Marketing) (binary classification)\n", "3. [NYC taxi ride duration](https://www.kaggle.com/neomatrix369/nyc-taxi-trip-duration-extended) (regression)\n", "4. [Facebook Comment Volume](https://archive.ics.uci.edu/ml/datasets/Facebook+Comment+Volume+Dataset) (regression)\n", "\n", "The bash script `get_data.sh` in the [repo](https://github.com/jrzaurin/tabulardl-benchmark) has all the info you need to get those datasets in case you wanted to explore them yourself. Of course, all the code used to run the experiments and reproduce the results is also available in that repo. \n", "\n", "Here are some basic information about the datasets:" ] }, { "cell_type": "code", "execution_count": 1, "id": "continent-muscle", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Datasetn_rowsn_colsobjectiveneg_pos_ratio
0adult4522215binary_classification0.3295
1bank_marketing4118820binary_classification0.1270
2nyc_taxi145864426regressionNaN
3facebook_comments_vol19902954regressionNaN
\n", "
" ], "text/plain": [ " Dataset n_rows n_cols objective \\\n", "0 adult 45222 15 binary_classification \n", "1 bank_marketing 41188 20 binary_classification \n", "2 nyc_taxi 1458644 26 regression \n", "3 facebook_comments_vol 199029 54 regression \n", "\n", " neg_pos_ratio \n", "0 0.3295 \n", "1 0.1270 \n", "2 NaN \n", "3 NaN " ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#collapse-hide\n", "import pandas as pd\n", "\n", "basic_info = pd.read_csv(\"../../tabulardl-benchmark/raw_data/basic_stats_df.csv\")\n", "basic_info[basic_info.Dataset != \"airbnb\"].reset_index(drop=True)" ] }, { "cell_type": "markdown", "id": "historic-combat", "metadata": {}, "source": [ "**Table 1**. Basic information for the datasets used in this post" ] }, { "cell_type": "markdown", "id": "exact-murder", "metadata": {}, "source": [ "There are reasons why I choose those datasets.\n", "\n", "In general, I looked for a binary, multi-class and regression datasets that had a good number of, if not dominated by categorical features. This is because in my experience, DL models for tabular data become more useful and competitive in sizeable datasets where categorical features are present (although [5] suggests that better results are obtained encoding numerical features as well) and moreover if these categorical features have a lot of categories. This is because the embeddings acquire a more significant value, i.e. we learn representations of those categorical features that encode relationships with all other features and also the target for a specific dataset. Note that this does not happen when using GBMs. Even if one used [target encoding](https://maxhalford.github.io/blog/target-encoding/), in reality there is not much of a learning element there (still useful of course). \n", "\n", "Of course, one could take datasets that are dominated by numerical features and bin them somehow to turn them into categorical. However, this seemed a bit too \"forced\" for me. With the idea of keeping the content of this post as close as possible to real use cases, it is hard for me to think of many \"real world\" scenarios where we are provided with datasets dominated by numerical features that are then turned/binned into categorical before being fed to an algorithm. In other words, I did not want to consider datasets where I had to bin the numerical features into categorical just to compare GBMs and DL models. \n", "\n", "On the other hand, I also looked for datasets that were already familiar to me or did not required too much feature engineering to get to a stage where the data could be passed to a model. This way I could perhaps save some time on that aspect and focus a bit more on the experimentation, since I intended to run a large number of experiments. Finally I looked for datasets that, to some extent, resemble as much as possible to datasets that one would find in the \"real world\", but had a tractable size so I could experiment within a reasonable time frame. \n", "\n", "While I did manage to find suitable datasets for binary classification and regression, and I did not find datasets that I particularly liked in the case of multi-class classification (if anyone has any suggestion, please comment below and I am happy to give it a go). Perhaps I will include the [CoverType](https://archive.ics.uci.edu/ml/datasets/covertype) dataset in the future, but the one at the UCI ML repository, not the Kaggle's balanced version. For now, I will move on with those four enumerated above. Let me briefly comment on each dataset.\n", "\n", "I would refer to the *Adult Census dataset* as the \"*easiest dataset*\", in the sense that simple models (i.e. a Naive Bayes classifier) will already lead to accuracies of $\\sim$ 84$\\%$ without any feature engineering. Personally, I normally don't find these nice datasets in the real-world. However, it is one of the most popular and well known datasets for ML tutorials, posts etc, and I eventually decided to include it.\n", "\n", "The *Bank Marketing* dataset is also well known. This data is related with direct marketing campaigns based on phone calls, trying to predict whether or not a client will subscribe to a product. In this case it is important to mention a couple of relevant aspects. In the first place I used the [original dataset](https://archive.ics.uci.edu/ml/datasets/Bank+Marketing), which is a bit imbalanced (positive to negative class ratio is 0.127). Secondly, you might look around and find that some people obtained better results that those I will show later in the post. All such cases that I found use either a balanced dataset from Kaggle, a feature called `duration`, or both. The `duration` feature, which refers to the duration of the call, is something you know **after** the call and highly affects the target. Therefore, I have not used it in my experiments. This dataset resembles more a real use case than the adult dataset in the sense that the data is imbalanced and the prediction is not an easy task at all. Still, the data size is small and is not that imbalanced.\n", "\n", "The *NYC taxi ride duration* dataset is also well known and is the largest of all datasets I used. Here our goal is to predict the total ride duration of taxi trips in New York City. Instead of getting the dataset from the [Kaggle site](https://www.kaggle.com/c/nyc-taxi-trip-duration) I manually downloaded an extended version from [here](https://www.kaggle.com/neomatrix369/nyc-taxi-trip-duration-extended), where all the feature engineering had already been done. \n", "\n", "Finally the *Facebook Comment Volume* dataset was another ideal candidate, since it has a good size and all the feature engineering was done for me. Our goal here is to predict the comment volume that posts will receive. In fact this dataset was originally used to compare decision trees versus neural networks. A very detailed description of the dataset and the pre-processing can be found in the [original publication](https://uksim.info/uksim2015/data/8713a015.pdf) [6]. In particular, I used their training Variant - 5 dataset for the experiments in this post, which has 199029 rows and 54 columns.\n", "\n", "All the code for the data preparation steps, before the data is fed to the algorithms can be found [here](https://github.com/jrzaurin/tabulardl-benchmark/tree/master/prepare_datasets)" ] }, { "cell_type": "markdown", "id": "dressed-drunk", "metadata": {}, "source": [ "### 2.2. The DL Models\n", "\n", "As I mentioned earlier in the post, all DL models were run via [pytorch-widedeep](https://github.com/jrzaurin/pytorch-widedeep). This library offers four wide and deep model [components](https://pytorch-widedeep.readthedocs.io/en/latest/model_components.html): `wide`, `deeptabular`, `deeptext`, `deepimage`. Let me briefly comment on each one of them. For more details, please see the [companion posts](https://jrzaurin.github.io/infinitoml/), the [documentation](https://pytorch-widedeep.readthedocs.io/en/latest/model_components.html) or the [source code](https://github.com/jrzaurin/pytorch-widedeep/tree/tabnet/pytorch_widedeep/models) itself.\n", "\n", "1. `wide`: this is just a linear model implemented via an `Embedding` layer\n", "\n", "\n", "2. `deeptabular`: this component will take care of the \"standard\" tabular data (i.e. categorical and numerical columns) and has 4 alternatives:\n", " \n", " 2.1 `TabMlp`: a simple standard MLP. Very similar to, for example, the [tabular api](https://docs.fast.ai/tabular.learner.html) implementation in the fastai library.\n", " \n", " 2.2 `TabResnet`: similar to the MLP but instead of dense layers I use Resnet blocks.\n", " \n", " 2.3 `Tabnet`[7]: this is a very interesting implementation. It is hard to explain it in a few sentences, therefore I strongly suggest reading the [paper](https://arxiv.org/abs/1908.07442). `Tabnet` is meant to be competitive with GBMs and offers model interpretability via feature importance. `pytorch-widedeep`'s implementation of `Tabnet` is fully based on the fantastic [implementation](https://github.com/dreamquark-ai/Tabnet) by the guys at dreamquark-ai, therefore, **ALL** credit to them. Simply, I have adapted it to work within a Wide and Deep frame and added a couple of extra features, such as internal dropout in the GLU blocks and the possibility of not using ghost batch normalization [8]. \n", " \n", " Note that the original implementation allows training in two stages. First self-supervised training via a standard encoder-decoder approach and then supervised training or fine-tuning using only the encoder. Only the supervised training (i.e. the encoder) is implemented in `pytorch-widedeep`. The authors showed that unsupervised pre-training improves the performance mostly in low data sizes regime or when the unlabeled dataset is much larger than the labeled dataset. Therefore, if you are in one of those scenarios (or simply as a general statement), you better use dreamquark-ai's implementation. \n", " \n", " 2.4.`TabTransformer`[9]: this is similar to `TabResnet`, but instead of Resnet blocks the authors used Transformer [10] blocks. Similar to the case of `Tabnet`, the `TabTransformer` allows for a two stages training process, unsupervised pre-training followed by supervised training or fine-tuning. `pytorch-widedeep`'s implementation of the `TabTransformer` is designed to be used in a \"standard\" way, i.e. supervised training. Note that consistent with the results of Sercan Ö. Arık, Tomas Pfister for `Tabnet`, the authors found that unsupervised pre-training improves the performance mostly in low data volume regime or when the unlabeled dataset is much larger than the labeled dataset. The `TabTransformer` implementation available in `pytorch-widedeep` is partially based on that at the [autogluon](https://github.com/awslabs/autogluon/tree/058398b61d1b2011f56a9dce149b0989adbbb04a/tabular/src/autogluon/tabular/models/tab_transformer) library and that from Phil Wang [here](https://github.com/lucidrains/tab-transformer-pytorch).\n", " \n", " \n", "3. `deeptext`: standard text classifier/regressor comprised by a stack of RNNs (LSTMs or GRUs). In addition, there is the option to add a set of dense layers on top of the stack of RNNs and some other extra features. \n", "\n", "\n", "4. `deepimage`: standard image classifier/regressor using a pretrained network (in particular ResNets) or a sequence of 4 convolution layers. In addition, there is the option to add a set of dense layers on top of the stack of CNNs and some other extra features. " ] }, { "cell_type": "markdown", "id": "obvious-foundation", "metadata": {}, "source": [ "### 2.3. Why `LightGBM`?\n", "\n", "If you have worked with me, or even have a chat with me about some ML project, you will know that one of my favorite algorithms is [`LightGBM`](https://lightgbm.readthedocs.io/en/latest/). I have used is extensively. In fact, the last 3 ML systems that I have productionised all relied on `LightGBM`. It performs similarly, when not better, than its brothers and sisters (e.g. [`XGBoost`](https://xgboost.readthedocs.io/en/latest/) or [`CatBoost`](https://catboost.ai/)), is significantly faster and offers support for categorical features (see [here](https://www.tandfonline.com/doi/abs/10.1080/01621459.1958.10501479). Although when it comes to support for categorical features `CatBoost` is probably the superior solution). In additions, offers the usual flexibility and performance of GBMs. " ] }, { "cell_type": "markdown", "id": "voluntary-organ", "metadata": {}, "source": [ "### 2.4. Experiments setup and other considerations\n", "\n", "As I mentioned earlier in the post, I run many experiments (not all were recorded and/or made it to the post) for the four datasets focusing on the different models available for the `deeptabular` component. All the experiments run can be found [here](https://github.com/jrzaurin/tabulardl-benchmark/tree/master/run_experiments) in the repo. \n", "\n", "The experiments not only considered different parameters for the models (i.e. number of units, layers, etc..) but also different optimizers, learning rate schedulers, and training processes. For example, all experiments where run with early stopping, with `patience` of 30 epochs in most cases. I used three different optimizers (`Adam`[11], `AdamW`[12] and `RAdam`[13]) and three different learning rate schedulers (`ReduceLROnPlateau`, `OneCycleLR`[14], `CyclicLR`[15]). The following command corresponds to one of the experiments run:\n", "\n", "```bash\n", "python adult/adult_tabmlp.py --mlp_hidden_dims [100,50] --mlp_dropout 0.2 --optimizer Adam --early_stop_patience 30 --lr_scheduler CyclicLR --base_lr 5e-4 --max_lr 0.01 --n_cycles 10 --n_epochs 100 --save_results\n", "```\n", "\n", "That command above will run a `TabMlp` model for the adult dataset. Most `args` are straightforward to understand. Perhaps the only interesting aspect to comment is that this particular experiment was run with a `CyclicLR` scheduler, where the learning rate oscillates between 0.0005 to 0.01, 10 times over 100 epochs (i.e. a cycle every 10 epochs).\n", "\n", "It is worth mentioning that when running the experiments, I assumed that there is an inherent hierarchy in the DL model parameters and training set ups. Therefore, rather than optimizing all parameters at once, I chose those that I considered more relevant and run experiments that reproduced that hierarchy. For example, when running a simple `MLP`, I assume that the number of neurons in the layers is a more important parameter than whether or not I use `BatchNorm` in the last layer. It might be, or surely it is, that the best thing to do is to optimize all parameters at once, but following this \"hierarchical\" approach also gave me a sense of how changing some individual parameters affected the performance of the model. Nonetheless, around 100 experiments were run per model and per dataset on average, so the exploration was relatively exhaustive (just relatively). \n", "\n", "On the other hand `LightGBM` was optimized using [`Optuna`](https://optuna.org/)[16], [`Hyperopt`](https://github.com/hyperopt/hyperopt)[17], or both and choosing the parameters that lead to the best metrics. All the code can be found [here](https://github.com/jrzaurin/tabulardl-benchmark). Note that the experiments, and the code in the repo, represent a very detailed and thorough tutorial on how to use `pytorch-widedeep` (if you wanted to use the library). \n", "\n", "It is also worth mentioning that when running the experiment, the early stop criterion for both the DL models and `LightGBM` was based on the validation loss. Alternatively, one can monitor a metric, such as accuracy of the f1 score. Note that accuracy (or f1) and loss are not necessarily exactly inversely correlated. There might be edge cases where the algorithm is really unsure about some predictions (i.e. predictions are close to the metric threshold leading to high loss values) yet ends up making the right prediction (higher accuracy). Of course, ideally we want the algorithm to be sure and make the right predictions, but you know, the real world is messy and noisy. Nonetheless, out of curiosity, I tried to monitor metrics in some experiments. Overall, I did find that the results where consistent with those monitoring loss values, although slightly better metrics could be achieved in some cases.\n", "\n", "Another relevant piece of information is related to the number of embeddings used to represent the categorical features. As one can imagine the amount of possibilities here is endless, and I had to find a way to consistently automate the process across all experiment. To that end I decided to use fastai's [empirical rule of thumb](https://github.com/fastai/fastai/blob/90e009b90b9843dde8c02b0268ab9021ebef342f/fastai/tabular/model.py#L10). For a given categorical feature, the number of embeddings will be:\n", "\n", "$$\n", "n_{embed} = min\\big(600, int(1.6 \\times n_{cat}^{0.56})\\big)\n", "$$\n", "\n", "The exception is the `TabTransformer`. The `TabTransformer` treats the categorical features as if they were part of a sequence (i.e. contextual) where the sequence order is irrelevant, i.e. no positional encoding needed. Therefore, rather than stack them \"one besides another\", they are stacked \"one on top of each other\". This means that all categorical features must have the same dimensions. Note that this is bit of an inconvenient when we have a wide range of categories for the categorical features in the dataset. \n", "\n", "For example, let's say we have a dataset with just 2 categorical features having 50 and 3 different categories respectively. While using embeddings of 16 dimensions, for example, seems appropriate for the former, it certainly seems like an \"over-representation\" in the latter case. One could still use fastai's rule of thumb and pad the embeddings with lower dimension, but that would imply that some of the attention heads will be attending to zeros/nothing throughout the entire training process, which seems like a waste to me. Despite of this potential \"waste\", I am considering bringing this as an option for `pytorch-widedeep`'s `TabTransformer` implementation. In the meantime, \"*all*\" `TabTransformer` experiments were run with an additional set up where categorical features with a small number of categories were passed through the `wide` component. \n", "\n", "Finally, for all experiments I used 80% of the data for training and 10% for validation/parameter tuning. Then these 2 datasets were combined in one last training run and the algorithm was tested on the remaining 10% of the data. The datasets were split at random unless there is a temporal component. In those cases I used chronological train/test split (note that in the case of the *Facebook Comment Volume* dataset I did not use the test set used in the paper. All train, validation and test datasets are splits of the Variant - 5 dataset described in the paper).\n", "\n", "And that's all, without further ado, let's move to the results." ] }, { "cell_type": "markdown", "id": "varied-white", "metadata": {}, "source": [ "## 3. Results\n", "\n", "The previous sections provide context to this \"project\" and details on the experiments that I did run. In this section I will simply show the top 5 results for all data and model combinations along with some comments when I consider necessary. The complete tables with the results for *all* experiments can be found [here](https://github.com/jrzaurin/tabulardl-benchmark/tree/master/analyze_experiments/leaderboards)." ] }, { "cell_type": "code", "execution_count": 2, "id": "altered-marine", "metadata": {}, "outputs": [], "source": [ "#hide\n", "from pathlib import Path\n", "\n", "import pandas as pd\n", "\n", "TABLES_DIR = Path(\"/Users/javier/Projects/tabulardl-benchmark/analyze_experiments/leaderboards\")" ] }, { "cell_type": "markdown", "id": "metallic-defense", "metadata": {}, "source": [ "### 3.1 Adult Census Dataset\n", "\n", "#### 3.1.1 `TabMlp`" ] }, { "cell_type": "code", "execution_count": 3, "id": "partial-veteran", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
mlp_hidden_dimsmlp_activationmlp_dropoutmlp_batchnormmlp_batchnorm_lastmlp_linear_firstembed_dropoutlrbatch_sizeweight_decayoptimizerlr_schedulerbase_lrmax_lrdiv_factorfinal_div_factorn_cyclesval_loss_or_metric
0[400,200]relu0.5FalseFalseFalse0.10.00101280.0AdamWReduceLROnPlateau0.00100.012510000.05.00.2857
1[400,200]relu0.5FalseFalseFalse0.00.00051280.0AdamCyclicLR0.00050.012510000.010.00.2860
2[100,50]relu0.2FalseFalseFalse0.00.00041280.0AdamOneCycleLR0.00100.01251000.05.00.2860
3[400,200]relu0.5FalseFalseFalse0.10.00101280.0AdamReduceLROnPlateau0.00100.012510000.05.00.2861
4[400,200]relu0.5FalseFalseFalse0.00.00051280.0RAdamCyclicLR0.00050.012510000.010.00.2862
\n", "
" ], "text/plain": [ " mlp_hidden_dims mlp_activation mlp_dropout mlp_batchnorm \\\n", "0 [400,200] relu 0.5 False \n", "1 [400,200] relu 0.5 False \n", "2 [100,50] relu 0.2 False \n", "3 [400,200] relu 0.5 False \n", "4 [400,200] relu 0.5 False \n", "\n", " mlp_batchnorm_last mlp_linear_first embed_dropout lr batch_size \\\n", "0 False False 0.1 0.0010 128 \n", "1 False False 0.0 0.0005 128 \n", "2 False False 0.0 0.0004 128 \n", "3 False False 0.1 0.0010 128 \n", "4 False False 0.0 0.0005 128 \n", "\n", " weight_decay optimizer lr_scheduler base_lr max_lr div_factor \\\n", "0 0.0 AdamW ReduceLROnPlateau 0.0010 0.01 25 \n", "1 0.0 Adam CyclicLR 0.0005 0.01 25 \n", "2 0.0 Adam OneCycleLR 0.0010 0.01 25 \n", "3 0.0 Adam ReduceLROnPlateau 0.0010 0.01 25 \n", "4 0.0 RAdam CyclicLR 0.0005 0.01 25 \n", "\n", " final_div_factor n_cycles val_loss_or_metric \n", "0 10000.0 5.0 0.2857 \n", "1 10000.0 10.0 0.2860 \n", "2 1000.0 5.0 0.2860 \n", "3 10000.0 5.0 0.2861 \n", "4 10000.0 10.0 0.2862 " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#collapse-hide\n", "adult_tabmlp = pd.read_csv(TABLES_DIR / \"adult_tabmlp.csv\").iloc[:5]\n", "adult_tabmlp.round(4)" ] }, { "cell_type": "markdown", "id": "annoying-tulsa", "metadata": {}, "source": [ "**Table 2**. Results obtained for the Adult Census dataset using `TabMlp`.\n", "\n", "Perhaps the first comment to make relates to the columns/parameters. It is straightforward to understand that not all parameters/columns apply to each experiment/row. For example, parameters/columns like `base_lr`, `max_lr`, `div_factor` or `final_div_factor` apply only when the learning rate scheduler is either `CyclicLR` or `OneCycleLR`. \n", "\n", "On the other hand, the dense layers of the MLP are built using a very similar approach to that in the `fastai` library. This approach offers flexibility in terms of the operations that occur within each dense layer in the MLP (see [here](https://pytorch-widedeep.readthedocs.io/en/latest/model_components.html#pytorch_widedeep.models.tab_mlp.TabMlp) for details). in that context thee columns `mlp_batchnorm_last` and `mlp_linear_first` set the order in which these operations occur. For example, if for a given dense layer we set `mlp_linear_first = True`, the implemented dense layer will look like this: `[LIN -> ACT -> DP]`. On the other hand, If `mlp_linear_first = False` then the dense layer will perform the operations in the following order: `[DP -> LIN -> ACT]`.\n", "\n", "In the case of the Adult census dataset cyclic learning rates schedulers produce very good results. In fact, a one cycle learning rate with the adequate parameters would already lead to an acceptable validation loss in just one epoch (provided that the batch size is small enough), which perhaps illustrates that this dataset is not particularly difficult. Nonetheless the best result (by a negligible amount) was obtained with a `ReduceLROnPlateau` learning rate scheduler. This is actually common across all experiments for the different dataset and is also consistent with my experience running DL models in many different scenarios, for tabular data or text. The `ReduceLROnPlateau` learning rate scheduler was run with `patience` of 10 epochs. This along with the `EarlyStopping` patience of 30 epochs means that, when `ReduceLROnPlateau` is used, the learning rate will be reduced 3 times before the experiment is forced to stop.\n", "\n", "For full details on the experiments setup, the model implementation and the meaning behind each parameter/column please have a look to the two `pytorch-widedeep`'s [documentation](https://pytorch-widedeep.readthedocs.io/en/latest/index.html) and the experiments [repo](https://github.com/jrzaurin/tabulardl-benchmark)." ] }, { "cell_type": "markdown", "id": "saving-franchise", "metadata": {}, "source": [ "#### 3.1.2 `TabResnet`" ] }, { "cell_type": "code", "execution_count": 4, "id": "abroad-conversion", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
blocks_dimsblocks_dropoutmlp_hidden_dimsmlp_activationmlp_dropoutmlp_batchnormmlp_batchnorm_lastmlp_linear_firstembed_dropoutlrbatch_sizeweight_decayoptimizerlr_schedulerbase_lrmax_lrdiv_factorfinal_div_factorn_cyclesval_loss_or_metric
0same0.5Nonerelu0.1FalseFalseFalse0.10.0004320.0AdamOneCycleLR0.0010.01251000.05.00.2850
1same0.5Nonerelu0.1FalseFalseFalse0.00.0004320.0AdamOneCycleLR0.0010.01251000.05.00.2853
2same0.5Nonerelu0.1FalseFalseFalse0.10.00041280.0AdamWOneCycleLR0.0010.01251000.05.00.2854
3same0.5Nonerelu0.1FalseFalseFalse0.10.0004640.0AdamWOneCycleLR0.0010.01251000.05.00.2855
4same0.5Nonerelu0.1FalseFalseFalse0.10.0004320.0AdamWOneCycleLR0.0010.01251000.05.00.2856
\n", "
" ], "text/plain": [ " blocks_dims blocks_dropout mlp_hidden_dims mlp_activation mlp_dropout \\\n", "0 same 0.5 None relu 0.1 \n", "1 same 0.5 None relu 0.1 \n", "2 same 0.5 None relu 0.1 \n", "3 same 0.5 None relu 0.1 \n", "4 same 0.5 None relu 0.1 \n", "\n", " mlp_batchnorm mlp_batchnorm_last mlp_linear_first embed_dropout lr \\\n", "0 False False False 0.1 0.0004 \n", "1 False False False 0.0 0.0004 \n", "2 False False False 0.1 0.0004 \n", "3 False False False 0.1 0.0004 \n", "4 False False False 0.1 0.0004 \n", "\n", " batch_size weight_decay optimizer lr_scheduler base_lr max_lr \\\n", "0 32 0.0 Adam OneCycleLR 0.001 0.01 \n", "1 32 0.0 Adam OneCycleLR 0.001 0.01 \n", "2 128 0.0 AdamW OneCycleLR 0.001 0.01 \n", "3 64 0.0 AdamW OneCycleLR 0.001 0.01 \n", "4 32 0.0 AdamW OneCycleLR 0.001 0.01 \n", "\n", " div_factor final_div_factor n_cycles val_loss_or_metric \n", "0 25 1000.0 5.0 0.2850 \n", "1 25 1000.0 5.0 0.2853 \n", "2 25 1000.0 5.0 0.2854 \n", "3 25 1000.0 5.0 0.2855 \n", "4 25 1000.0 5.0 0.2856 " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#collapse-hide\n", "adult_tabresnet = pd.read_csv(TABLES_DIR / \"adult_tabresnet.csv\").iloc[:5]\n", "adult_tabresnet.round(4)" ] }, { "cell_type": "markdown", "id": "pressing-uncle", "metadata": {}, "source": [ "**Table 3**. Results obtained for the Adult dataset using `TabResnet`.\n", "\n", "`block_dim = same` in Table 3 indicate that the Resnet blocks, which are comprised by dense layers, have the same dimensions than the incoming embeddings (see [here](https://github.com/jrzaurin/pytorch-widedeep/blob/tabnet/pytorch_widedeep/models/tab_resnet.py) for details on the implementation). \n", "\n", "On the other hand, the `TabResnet` model offers the possibility of using an MLP \"on top\" of the Resnet blocks. When `mlp_hidden_dims = None` indicates that no MLP was used and the output of the last Resnet block was \"plugged\" directly into the output neuron. Therefore, as shown in Table 3, the top 5 results obtained using `TabResnet` correspond to architectures that have no MLP. In consequence, all MLP related parameters/columns are redundant for those experiments. \n", "\n", "I find interesting that whether `Adam` or `AdamW`, the best results are obtained using `OneCycleLR`. When using this scheduler, I normally set the number of epochs to be in between 1 and 10. Normally I obtain the best results for a small number of epochs ($\\leq 5$) and a small batch size, which implies that the increase/decrease of the learning rate will be more gradual (i.e. spread over a higher number of steps) as opposed as using large batch sizes. Finally note that the parameter/column `n_cycles` only apply to the `CyclicLR` scheduler. Since it is not used in any of the top 5 experiments it can be ignored in Table 3." ] }, { "cell_type": "markdown", "id": "palestinian-expansion", "metadata": {}, "source": [ "#### 3.1.3 `Tabnet`" ] }, { "cell_type": "code", "execution_count": 5, "id": "burning-curtis", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
n_stepsstep_dimattn_dimghost_bnvirtual_batch_sizemomentumgammadropoutembed_dropoutlrbatch_sizeweight_decaylambda_sparseoptimizerlr_schedulerbase_lrmax_lrdiv_factorfinal_div_factorn_cyclesval_loss_or_metric
053232False1280.981.50.10.00.031280.00.0001AdamWReduceLROnPlateau0.0010.012510000.05.00.2916
156464False1280.981.50.20.00.031280.00.0001AdamReduceLROnPlateau0.0010.012510000.05.00.2938
253232False1280.981.50.10.00.031280.00.0001AdamReduceLROnPlateau0.0010.012510000.05.00.2939
356464False1280.981.50.20.00.031280.00.0001AdamWReduceLROnPlateau0.0010.012510000.05.00.2945
456464False1280.981.50.20.00.051280.00.0001RAdamReduceLROnPlateau0.0010.012510000.05.00.2962
\n", "
" ], "text/plain": [ " n_steps step_dim attn_dim ghost_bn virtual_batch_size momentum gamma \\\n", "0 5 32 32 False 128 0.98 1.5 \n", "1 5 64 64 False 128 0.98 1.5 \n", "2 5 32 32 False 128 0.98 1.5 \n", "3 5 64 64 False 128 0.98 1.5 \n", "4 5 64 64 False 128 0.98 1.5 \n", "\n", " dropout embed_dropout lr batch_size weight_decay lambda_sparse \\\n", "0 0.1 0.0 0.03 128 0.0 0.0001 \n", "1 0.2 0.0 0.03 128 0.0 0.0001 \n", "2 0.1 0.0 0.03 128 0.0 0.0001 \n", "3 0.2 0.0 0.03 128 0.0 0.0001 \n", "4 0.2 0.0 0.05 128 0.0 0.0001 \n", "\n", " optimizer lr_scheduler base_lr max_lr div_factor final_div_factor \\\n", "0 AdamW ReduceLROnPlateau 0.001 0.01 25 10000.0 \n", "1 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 \n", "2 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 \n", "3 AdamW ReduceLROnPlateau 0.001 0.01 25 10000.0 \n", "4 RAdam ReduceLROnPlateau 0.001 0.01 25 10000.0 \n", "\n", " n_cycles val_loss_or_metric \n", "0 5.0 0.2916 \n", "1 5.0 0.2938 \n", "2 5.0 0.2939 \n", "3 5.0 0.2945 \n", "4 5.0 0.2962 " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#collapse-hide\n", "adult_tabnet = pd.read_csv(TABLES_DIR / \"adult_tabnet.csv\").iloc[:5]\n", "adult_tabnet.round(4)" ] }, { "cell_type": "markdown", "id": "global-germany", "metadata": {}, "source": [ "**Table 4**. Results obtained for the Adult dataset using `Tabnet`.\n", "\n", "`Tabnet` has received some attention lately for being competitive with GBMs, and even over-performing them. In addition, it is a very elegant implementation that offers model interpretability via feature importance obtained using attention mechanisms. \n", "\n", "The reality is that for the Adult Census dataset I obtain the worst loss values on the validation set (but as we will see later, not the worst metric). Maybe I simply missed \"that precise\" set of parameters that lead to better results. However, it is worth emphasizing that I have explored `Tabnet` with the same level of detail that any of the other 3 model alternatives. \n", "\n", "On the other hand, it is interesting that, within all the experiments run, the best results are consistently obtained without Ghost batch normalization. Therefore, the parameter/column `virtual_batch_size` can be ignored in Table 4. Similarly, since the best results are all obtained using `ReduceLROnPlateau`, all the parameters related to cyclic learning rate schedulers can be ignored in Table 4. \n", "\n", "Finally, consistent with some experiments I run in the past, the best results obtained using `RAdam` normally involve relatively high learning rates. " ] }, { "cell_type": "markdown", "id": "hungry-orbit", "metadata": {}, "source": [ "#### 3.1.4 `TabTransformer`" ] }, { "cell_type": "code", "execution_count": 6, "id": "respective-working", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
embed_dropoutfull_embed_dropoutshared_embedadd_shared_embedfrac_shared_embedinput_dimn_headsn_blocksdropoutff_hidden_dimtransformer_activationmlp_hidden_dimsmlp_activationmlp_batchnormmlp_batchnorm_lastmlp_linear_firstwith_widelrbatch_sizeweight_decayoptimizerlr_schedulerbase_lrmax_lrdiv_factorfinal_div_factorn_cyclesval_loss_or_metric
00.0FalseFalseFalse816440.1NaNreluNonereluFalseFalseFalseFalse0.0101280.0RAdamReduceLROnPlateau0.0010.012510000.05.00.2879
10.0FalseFalseFalse816440.1NaNrelusamereluFalseFalseFalseFalse0.0101280.0RAdamReduceLROnPlateau0.0010.012510000.05.00.2885
20.0FalseFalseFalse816440.1NaNreluNonereluFalseFalseFalseTrue0.0101280.0RAdamReduceLROnPlateau0.0010.012510000.05.00.2888
30.0FalseFalseFalse816480.2NaNreluNonereluFalseFalseFalseTrue0.0011280.0AdamWReduceLROnPlateau0.0010.012510000.05.00.2892
40.0FalseFalseFalse816240.1NaNreluNonereluFalseFalseFalseFalse0.0101280.0RAdamReduceLROnPlateau0.0010.012510000.05.00.2894
\n", "
" ], "text/plain": [ " embed_dropout full_embed_dropout shared_embed add_shared_embed \\\n", "0 0.0 False False False \n", "1 0.0 False False False \n", "2 0.0 False False False \n", "3 0.0 False False False \n", "4 0.0 False False False \n", "\n", " frac_shared_embed input_dim n_heads n_blocks dropout ff_hidden_dim \\\n", "0 8 16 4 4 0.1 NaN \n", "1 8 16 4 4 0.1 NaN \n", "2 8 16 4 4 0.1 NaN \n", "3 8 16 4 8 0.2 NaN \n", "4 8 16 2 4 0.1 NaN \n", "\n", " transformer_activation mlp_hidden_dims mlp_activation mlp_batchnorm \\\n", "0 relu None relu False \n", "1 relu same relu False \n", "2 relu None relu False \n", "3 relu None relu False \n", "4 relu None relu False \n", "\n", " mlp_batchnorm_last mlp_linear_first with_wide lr batch_size \\\n", "0 False False False 0.010 128 \n", "1 False False False 0.010 128 \n", "2 False False True 0.010 128 \n", "3 False False True 0.001 128 \n", "4 False False False 0.010 128 \n", "\n", " weight_decay optimizer lr_scheduler base_lr max_lr div_factor \\\n", "0 0.0 RAdam ReduceLROnPlateau 0.001 0.01 25 \n", "1 0.0 RAdam ReduceLROnPlateau 0.001 0.01 25 \n", "2 0.0 RAdam ReduceLROnPlateau 0.001 0.01 25 \n", "3 0.0 AdamW ReduceLROnPlateau 0.001 0.01 25 \n", "4 0.0 RAdam ReduceLROnPlateau 0.001 0.01 25 \n", "\n", " final_div_factor n_cycles val_loss_or_metric \n", "0 10000.0 5.0 0.2879 \n", "1 10000.0 5.0 0.2885 \n", "2 10000.0 5.0 0.2888 \n", "3 10000.0 5.0 0.2892 \n", "4 10000.0 5.0 0.2894 " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#collapse-hide\n", "adult_tabtransformer = pd.read_csv(TABLES_DIR / \"adult_tabtransformer.csv\").iloc[:5]\n", "adult_tabtransformer.round(4)" ] }, { "cell_type": "markdown", "id": "periodic-composition", "metadata": {}, "source": [ "**Table 5**. Results obtained for the Adult Census dataset using the `TabTransformer`.\n", "\n", "\n", "As with all the previous models, if you wanted details on the meaning of each parameter/column, please have a look to the [documentation] of the [source code] itself. \n", "\n", "It is perhaps worth mentioning that when feed forward hidden dim (`ff_hidden_dim`) is set to `NaN` the model will default to a `ff_hidden_dim` value equal to 4 times the input embedding dimensions (16 in all the experiments/rows shown in the Table). This will result in a feed forward layer with dimensions `[ff_input_dim -> 4 * ff_input_dim -> ff_input_dim]`. Similarly, when `mlp_hidden_dims = None` the model will default to 4 times the input dimensions, resulting in an MLP of dimensions `[mlp_input_dim -> 4 * mlp_input_dim -> 2* mlp_input_dim -> output_dim]`.\n", "\n", "On In addition, and as mentioned before, the `TabTransformer` was also run with a set up that includes a `wide` component. This is specified by the `with_wide` parameter. \n", "\n", "Is is worth noticing that the best loss values, which are similar to those of the rest of the DL models, are normally obtained using a `RAdam` optimizer. " ] }, { "cell_type": "markdown", "id": "accessible-supervisor", "metadata": {}, "source": [ "#### 3.1.5 DL vs `LightGBM`\n", "\n", "After having gone through the results obtained for each of the DL models, this is the moment of truth, let's see how the DL results compare with those obtained with `LightGBM`." ] }, { "cell_type": "code", "execution_count": 7, "id": "french-italian", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
modelaccruntimebest_epoch_or_ntrees
0lightgbm0.87820.9086408.0
1tabmlp0.8722205.357662.0
2tabtransformer0.8718288.640632.0
3tabnet0.8704422.296726.0
4tabresnet0.8698388.932525.0
\n", "
" ], "text/plain": [ " model acc runtime best_epoch_or_ntrees\n", "0 lightgbm 0.8782 0.9086 408.0\n", "1 tabmlp 0.8722 205.3576 62.0\n", "2 tabtransformer 0.8718 288.6406 32.0\n", "3 tabnet 0.8704 422.2967 26.0\n", "4 tabresnet 0.8698 388.9325 25.0" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#collapse-hide\n", "lightgbm_vs_dl_adult = pd.read_csv(TABLES_DIR / \"lightgbm_vs_dl_adult.csv\")\n", "lightgbm_vs_dl_adult.round(4)" ] }, { "cell_type": "markdown", "id": "consecutive-ideal", "metadata": {}, "source": [ "**Table 6**. Results obtained for the Adult Census dataset using four DL models and `LightGBM`. `runtime` units are seconds\n", "\n", "\n", "Let me emphasize again that the metrics shown in Table 6 are *all* obtained, of course, for the test dataset. The `runtime` column shows the training time, in seconds, for the final train dataset (i.e. a dataset comprising 90% of the data) using the best parameters obtained during validation. The DL models where run on a `p2.xlarge` instance on AWS and all the `LightGBM` experiments were run on my Mac Mid 2015. \n", "\n", "They are a few aspects worth commenting. In the first place, all DL models obtain results that are competitive with, but not better than, those of `LightGBM`. Secondly, the best performing DL model (by a rather marginal amount) is the simplest model, the `TabMlp`. And finally, the training time when using `LightGBM` is simply \"*gigantically*\" better than with any of the DL models. " ] }, { "cell_type": "markdown", "id": "electrical-finish", "metadata": {}, "source": [ "### 3.2 Bank Marketing Dataset\n", "\n", "Most of the comments in the previous section apply to the tables shown in this section. \n", "\n", "Note that as I mentioned earlier in the post, the Bank Marketing dataset is slightly imbalanced. Therefore I also run some experiments using the [focal loss](https://arxiv.org/abs/1708.02002?source=post_page---------------------------) [18] (which is accessible in `pytorch_widedeep` via a parameter or as a loss function input. See [here](https://pytorch-widedeep.readthedocs.io/en/latest/trainer.html)). Overall, the results obtained where similar to, but not better than those without the focal loss. This is consistent with my experience with other datasets where I find that the focal loss leads to notably better results when the dataset is highly imbalanced (for example, around 2% positive to negative class ratio). \n", "\n", "#### 3.2.1 `TabMlp`" ] }, { "cell_type": "code", "execution_count": 8, "id": "dense-saskatchewan", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
mlp_hidden_dimsmlp_activationmlp_dropoutmlp_batchnormmlp_batchnorm_lastmlp_linear_firstembed_dropoutlrbatch_sizeweight_decayoptimizerlr_schedulerbase_lrmax_lrdiv_factorfinal_div_factorn_cyclesval_loss_or_metric
0[100,50]relu0.1TrueTrueFalse0.10.0015120.0AdamWReduceLROnPlateau0.0010.012510000.05.00.2638
1[100,50]relu0.1TrueFalseTrue0.10.0015120.0AdamWReduceLROnPlateau0.0010.012510000.05.00.2639
2[100,50]relu0.1TrueTrueFalse0.10.0015120.0AdamReduceLROnPlateau0.0010.012510000.05.00.2643
3[100,50]relu0.1FalseFalseFalse0.10.0015120.0AdamWReduceLROnPlateau0.0010.012510000.05.00.2643
4[100,50]relu0.1TrueFalseFalse0.10.0015120.0AdamReduceLROnPlateau0.0010.012510000.05.00.2646
\n", "
" ], "text/plain": [ " mlp_hidden_dims mlp_activation mlp_dropout mlp_batchnorm \\\n", "0 [100,50] relu 0.1 True \n", "1 [100,50] relu 0.1 True \n", "2 [100,50] relu 0.1 True \n", "3 [100,50] relu 0.1 False \n", "4 [100,50] relu 0.1 True \n", "\n", " mlp_batchnorm_last mlp_linear_first embed_dropout lr batch_size \\\n", "0 True False 0.1 0.001 512 \n", "1 False True 0.1 0.001 512 \n", "2 True False 0.1 0.001 512 \n", "3 False False 0.1 0.001 512 \n", "4 False False 0.1 0.001 512 \n", "\n", " weight_decay optimizer lr_scheduler base_lr max_lr div_factor \\\n", "0 0.0 AdamW ReduceLROnPlateau 0.001 0.01 25 \n", "1 0.0 AdamW ReduceLROnPlateau 0.001 0.01 25 \n", "2 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 \n", "3 0.0 AdamW ReduceLROnPlateau 0.001 0.01 25 \n", "4 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 \n", "\n", " final_div_factor n_cycles val_loss_or_metric \n", "0 10000.0 5.0 0.2638 \n", "1 10000.0 5.0 0.2639 \n", "2 10000.0 5.0 0.2643 \n", "3 10000.0 5.0 0.2643 \n", "4 10000.0 5.0 0.2646 " ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#collapse-hide\n", "bank_marketing_tabmlp = pd.read_csv(TABLES_DIR / \"bank_marketing_tabmlp.csv\")\n", "# focal loss values are on a different scale\n", "bank_marketing_tabmlp = bank_marketing_tabmlp[bank_marketing_tabmlp.val_loss_or_metric > 0.2]\n", "(bank_marketing_tabmlp\n", " .sort_values(\"val_loss_or_metric\", ascending=True)\n", " .reset_index(drop=True)\n", " .head(5)).round(4)" ] }, { "cell_type": "markdown", "id": "given-committee", "metadata": {}, "source": [ "**Table 7**. Results obtained for the Bank Marketing dataset using `TabMlp`." ] }, { "cell_type": "markdown", "id": "solved-exercise", "metadata": {}, "source": [ "#### 3.2.2 `TabResnet`" ] }, { "cell_type": "code", "execution_count": 9, "id": "structured-prospect", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
blocks_dimsblocks_dropoutmlp_hidden_dimsmlp_activationmlp_dropoutmlp_batchnormmlp_batchnorm_lastmlp_linear_firstembed_dropoutlrbatch_sizeweight_decayoptimizerlr_schedulerbase_lrmax_lrdiv_factorfinal_div_factorn_cyclesval_loss_or_metric
0same0.5Nonerelu0.1FalseFalseFalse0.00.0004640.0AdamOneCycleLR0.0010.01251000.05.00.2660
1[50,50,50,50]0.2Nonerelu0.1FalseFalseFalse0.00.00105120.0AdamReduceLROnPlateau0.0010.012510000.05.00.2661
2same0.5Nonerelu0.1FalseFalseFalse0.00.0004640.0RAdamOneCycleLR0.0010.01251000.05.00.2663
3same0.5Nonerelu0.1FalseFalseFalse0.00.00041280.0RAdamOneCycleLR0.0010.01251000.05.00.2664
4same0.5Nonerelu0.1FalseFalseFalse0.00.00041280.0AdamOneCycleLR0.0010.01251000.05.00.2667
\n", "
" ], "text/plain": [ " blocks_dims blocks_dropout mlp_hidden_dims mlp_activation mlp_dropout \\\n", "0 same 0.5 None relu 0.1 \n", "1 [50,50,50,50] 0.2 None relu 0.1 \n", "2 same 0.5 None relu 0.1 \n", "3 same 0.5 None relu 0.1 \n", "4 same 0.5 None relu 0.1 \n", "\n", " mlp_batchnorm mlp_batchnorm_last mlp_linear_first embed_dropout lr \\\n", "0 False False False 0.0 0.0004 \n", "1 False False False 0.0 0.0010 \n", "2 False False False 0.0 0.0004 \n", "3 False False False 0.0 0.0004 \n", "4 False False False 0.0 0.0004 \n", "\n", " batch_size weight_decay optimizer lr_scheduler base_lr max_lr \\\n", "0 64 0.0 Adam OneCycleLR 0.001 0.01 \n", "1 512 0.0 Adam ReduceLROnPlateau 0.001 0.01 \n", "2 64 0.0 RAdam OneCycleLR 0.001 0.01 \n", "3 128 0.0 RAdam OneCycleLR 0.001 0.01 \n", "4 128 0.0 Adam OneCycleLR 0.001 0.01 \n", "\n", " div_factor final_div_factor n_cycles val_loss_or_metric \n", "0 25 1000.0 5.0 0.2660 \n", "1 25 10000.0 5.0 0.2661 \n", "2 25 1000.0 5.0 0.2663 \n", "3 25 1000.0 5.0 0.2664 \n", "4 25 1000.0 5.0 0.2667 " ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#collapse-hide\n", "bank_marketing_tabresnet = pd.read_csv(TABLES_DIR / \"bank_marketing_tabresnet.csv\").head(5)\n", "bank_marketing_tabresnet.round(4)" ] }, { "cell_type": "markdown", "id": "lyric-wonder", "metadata": {}, "source": [ "**Table 8**. Results obtained for the Bank Marketing dataset using `TabResnet`.\n", "\n", "Again, and very interestingly, `RAdam` optimizer and `OneCycleLR` leading to some of the best results for this DL model. " ] }, { "cell_type": "markdown", "id": "collective-evaluation", "metadata": {}, "source": [ "#### 3.2.3 `Tabnet`" ] }, { "cell_type": "code", "execution_count": 10, "id": "middle-flooring", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
n_stepsstep_dimattn_dimghost_bnvirtual_batch_sizemomentumgammadropoutembed_dropoutlrbatch_sizeweight_decaylambda_sparseoptimizerlr_schedulerbase_lrmax_lrdiv_factorfinal_div_factorn_cyclesval_loss_or_metric
051616True1280.751.50.00.00.035120.00.0001AdamWReduceLROnPlateau0.0010.012510000.05.00.2714
151616True640.251.50.00.00.035120.00.0001AdamWReduceLROnPlateau0.0010.012510000.05.00.2722
256464False1280.981.50.20.00.031280.00.0001AdamReduceLROnPlateau0.0010.012510000.05.00.2726
356464False1280.981.50.20.00.031280.00.0001AdamWReduceLROnPlateau0.0010.012510000.05.00.2738
451616True1280.982.00.00.00.035120.00.0001AdamWReduceLROnPlateau0.0010.012510000.05.00.2739
\n", "
" ], "text/plain": [ " n_steps step_dim attn_dim ghost_bn virtual_batch_size momentum gamma \\\n", "0 5 16 16 True 128 0.75 1.5 \n", "1 5 16 16 True 64 0.25 1.5 \n", "2 5 64 64 False 128 0.98 1.5 \n", "3 5 64 64 False 128 0.98 1.5 \n", "4 5 16 16 True 128 0.98 2.0 \n", "\n", " dropout embed_dropout lr batch_size weight_decay lambda_sparse \\\n", "0 0.0 0.0 0.03 512 0.0 0.0001 \n", "1 0.0 0.0 0.03 512 0.0 0.0001 \n", "2 0.2 0.0 0.03 128 0.0 0.0001 \n", "3 0.2 0.0 0.03 128 0.0 0.0001 \n", "4 0.0 0.0 0.03 512 0.0 0.0001 \n", "\n", " optimizer lr_scheduler base_lr max_lr div_factor final_div_factor \\\n", "0 AdamW ReduceLROnPlateau 0.001 0.01 25 10000.0 \n", "1 AdamW ReduceLROnPlateau 0.001 0.01 25 10000.0 \n", "2 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 \n", "3 AdamW ReduceLROnPlateau 0.001 0.01 25 10000.0 \n", "4 AdamW ReduceLROnPlateau 0.001 0.01 25 10000.0 \n", "\n", " n_cycles val_loss_or_metric \n", "0 5.0 0.2714 \n", "1 5.0 0.2722 \n", "2 5.0 0.2726 \n", "3 5.0 0.2738 \n", "4 5.0 0.2739 " ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#collapse-hide\n", "bank_marketing_tabnet = pd.read_csv(TABLES_DIR / \"bank_marketing_tabnet.csv\").head(5)\n", "bank_marketing_tabnet.round(4) " ] }, { "cell_type": "markdown", "id": "round-school", "metadata": {}, "source": [ "**Table 9**. Results obtained for the Bank Marketing dataset using `Tabnet`.\n", "\n", "Note the top 5 results obtained with `Tabnet` in this case all have relatively high learning rate values (`lr = 0.03`). Also, and similar to the case of the Adult Census dataset, `Tabnet` produces the worst validation loss values." ] }, { "cell_type": "markdown", "id": "impressed-inspiration", "metadata": {}, "source": [ "#### 3.2.4 `TabTransformer`" ] }, { "cell_type": "code", "execution_count": 11, "id": "rational-fiction", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
embed_dropoutfull_embed_dropoutshared_embedadd_shared_embedfrac_shared_embedinput_dimn_headsn_blocksdropoutff_hidden_dimtransformer_activationmlp_hidden_dimsmlp_activationmlp_batchnormmlp_batchnorm_lastmlp_linear_firstwith_widelrbatch_sizeweight_decayoptimizerlr_schedulerbase_lrmax_lrdiv_factorfinal_div_factorn_cyclesval_loss_or_metric
00.0FalseFalseFalse832860.1NaNreluNonereluFalseFalseFalseFalse0.0015120.0AdamReduceLROnPlateau0.0010.012510000.05.00.2646
10.0FalseFalseFalse832860.1NaNreluNonereluFalseFalseFalseFalse0.0015120.0AdamWReduceLROnPlateau0.0010.012510000.05.00.2647
20.0FalseTrueFalse416460.1NaNreluNonereluFalseFalseFalseFalse0.0101280.0RAdamReduceLROnPlateau0.0010.012510000.05.00.2668
30.0FalseFalseFalse832860.1NaNreluNonereluFalseFalseFalseFalse0.01010240.0RAdamReduceLROnPlateau0.0010.012510000.05.00.2672
40.0FalseFalseFalse832860.1NaNreluNonereluFalseFalseFalseFalse0.00110240.0AdamReduceLROnPlateau0.0010.012510000.05.00.2672
\n", "
" ], "text/plain": [ " embed_dropout full_embed_dropout shared_embed add_shared_embed \\\n", "0 0.0 False False False \n", "1 0.0 False False False \n", "2 0.0 False True False \n", "3 0.0 False False False \n", "4 0.0 False False False \n", "\n", " frac_shared_embed input_dim n_heads n_blocks dropout ff_hidden_dim \\\n", "0 8 32 8 6 0.1 NaN \n", "1 8 32 8 6 0.1 NaN \n", "2 4 16 4 6 0.1 NaN \n", "3 8 32 8 6 0.1 NaN \n", "4 8 32 8 6 0.1 NaN \n", "\n", " transformer_activation mlp_hidden_dims mlp_activation mlp_batchnorm \\\n", "0 relu None relu False \n", "1 relu None relu False \n", "2 relu None relu False \n", "3 relu None relu False \n", "4 relu None relu False \n", "\n", " mlp_batchnorm_last mlp_linear_first with_wide lr batch_size \\\n", "0 False False False 0.001 512 \n", "1 False False False 0.001 512 \n", "2 False False False 0.010 128 \n", "3 False False False 0.010 1024 \n", "4 False False False 0.001 1024 \n", "\n", " weight_decay optimizer lr_scheduler base_lr max_lr div_factor \\\n", "0 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 \n", "1 0.0 AdamW ReduceLROnPlateau 0.001 0.01 25 \n", "2 0.0 RAdam ReduceLROnPlateau 0.001 0.01 25 \n", "3 0.0 RAdam ReduceLROnPlateau 0.001 0.01 25 \n", "4 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 \n", "\n", " final_div_factor n_cycles val_loss_or_metric \n", "0 10000.0 5.0 0.2646 \n", "1 10000.0 5.0 0.2647 \n", "2 10000.0 5.0 0.2668 \n", "3 10000.0 5.0 0.2672 \n", "4 10000.0 5.0 0.2672 " ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#collapse-hide\n", "bank_marketing_tabtransformer = pd.read_csv(TABLES_DIR / \"bank_marketing_tabtransformer.csv\").head(5)\n", "bank_marketing_tabtransformer.round(4)" ] }, { "cell_type": "markdown", "id": "greatest-consideration", "metadata": {}, "source": [ "**Table 10**. Results obtained for the Bank Marketing dataset using the `TabTransformer`.\n", "\n", "It is perhaps worth noticing that consistent with some of the previous results, the best results obtained here using `RAdam` involve relatively high learning rates (a factor of 10 compared to those obtained using `Adam` or `AdamW`.) " ] }, { "cell_type": "markdown", "id": "grateful-powder", "metadata": {}, "source": [ "#### 3.2.5 DL vs `LightGBM`" ] }, { "cell_type": "code", "execution_count": 12, "id": "composed-startup", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
modelf1aucruntimebest_epoch_or_ntrees
0tabresnet0.42980.650192.517511.0
1tabtransformer0.42000.644031.69384.0
2tabmlp0.38550.62819.57217.0
3lightgbm0.38520.62650.461457.0
4tabnet0.30870.594377.878113.0
\n", "
" ], "text/plain": [ " model f1 auc runtime best_epoch_or_ntrees\n", "0 tabresnet 0.4298 0.6501 92.5175 11.0\n", "1 tabtransformer 0.4200 0.6440 31.6938 4.0\n", "2 tabmlp 0.3855 0.6281 9.5721 7.0\n", "3 lightgbm 0.3852 0.6265 0.4614 57.0\n", "4 tabnet 0.3087 0.5943 77.8781 13.0" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#collapse-hide\n", "lightgbm_vs_dl_bank_marketing = pd.read_csv(TABLES_DIR / \"lightgbm_vs_dl_bank_marketing.csv\")\n", "lightgbm_vs_dl_bank_marketing.round(4)" ] }, { "cell_type": "markdown", "id": "complimentary-wealth", "metadata": {}, "source": [ "**Table 11**. Results obtained for the Bank Marketing dataset using four DL models and LightGBM.\n", "\n", "I must admit that the results shown in Table 11 were surprising to me at first (to say the least). I went and run a few DL models again and `LightGBM` multiple times to double check, and finally concluded (spoiler alert) that this is going to be the only case among all experiments I run in this post where DL models perform better than `LightGBM`. In fact, if we joined the experiments here with my experience at work, this is the second time ever that I find that DL models perform better than `LightGBM` (more on this later). Furthermore, the improvement obtained using `TabResnet` or the `TabTransformer` is quite significant to the point that if this was a \"real world\" example, one might consider using a DL model and accept the trade between running time and success metric. \n", "\n", "Of course one could go and dive a bit deeper into `LightGBM`, setting sample weights, or even using a custom loss, but the same can be said about the DL models. Therefore, and overall, I consider the comparison fair. However, I am so surprised that I consider the possibility that I might have a bug in the code that I have not been able to find. Therefore, if anyone goes through the code at some point and finds indeed a bug please let me know 🙂. \n", "\n", "Finally, someone might feel disappointed by `Tabnet`'s performance, as I was. There is a possibility that I have not implemented it correctly, although the code is fully based on that from dreamquark-ai's implementation (**ALL** credit to them) and when tested with easier datasets, I obtain similar results to those with GBMs. I find `Tabnet` to be a very elegant implementation and somehow I believe it should perform better. I will come back to this point in the Conclusions section. " ] }, { "cell_type": "markdown", "id": "useful-circus", "metadata": {}, "source": [ "### 3.3 NYC Taxi trip duration\n", "\n", "As I mentioned earlier this is the largest dataset, and in consequence, I experimented with larger batch sizes. While this might slightly change some of the individual results, I believe it will not change the overall conclusion in this section. \n", "\n", "#### 3.3.1 `TabMlp`" ] }, { "cell_type": "code", "execution_count": 13, "id": "large-driver", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
mlp_hidden_dimsmlp_activationmlp_dropoutmlp_batchnormmlp_batchnorm_lastmlp_linear_firstembed_dropoutlrbatch_sizeweight_decayoptimizerlr_schedulerbase_lrmax_lrdiv_factorfinal_div_factorn_cyclesval_loss_or_metric
0autorelu0.1FalseFalseTrue0.00.0110240.0AdamReduceLROnPlateau0.0010.012510000.05.079252.7786
1autorelu0.1FalseFalseTrue0.00.0110240.0AdamWReduceLROnPlateau0.0010.012510000.05.079440.6025
2autorelu0.1FalseFalseFalse0.10.0110240.0AdamReduceLROnPlateau0.0010.012510000.05.079477.5653
3autorelu0.1FalseFalseFalse0.10.0110240.0AdamWReduceLROnPlateau0.0010.012510000.05.079710.8551
4autorelu0.1FalseFalseFalse0.00.0110240.0AdamWReduceLROnPlateau0.0010.012510000.05.080214.7197
\n", "
" ], "text/plain": [ " mlp_hidden_dims mlp_activation mlp_dropout mlp_batchnorm \\\n", "0 auto relu 0.1 False \n", "1 auto relu 0.1 False \n", "2 auto relu 0.1 False \n", "3 auto relu 0.1 False \n", "4 auto relu 0.1 False \n", "\n", " mlp_batchnorm_last mlp_linear_first embed_dropout lr batch_size \\\n", "0 False True 0.0 0.01 1024 \n", "1 False True 0.0 0.01 1024 \n", "2 False False 0.1 0.01 1024 \n", "3 False False 0.1 0.01 1024 \n", "4 False False 0.0 0.01 1024 \n", "\n", " weight_decay optimizer lr_scheduler base_lr max_lr div_factor \\\n", "0 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 \n", "1 0.0 AdamW ReduceLROnPlateau 0.001 0.01 25 \n", "2 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 \n", "3 0.0 AdamW ReduceLROnPlateau 0.001 0.01 25 \n", "4 0.0 AdamW ReduceLROnPlateau 0.001 0.01 25 \n", "\n", " final_div_factor n_cycles val_loss_or_metric \n", "0 10000.0 5.0 79252.7786 \n", "1 10000.0 5.0 79440.6025 \n", "2 10000.0 5.0 79477.5653 \n", "3 10000.0 5.0 79710.8551 \n", "4 10000.0 5.0 80214.7197 " ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#collapse-hide\n", "nyc_taxi_tabmlp = pd.read_csv(TABLES_DIR / \"nyc_taxi_tabmlp.csv\").iloc[:5]\n", "nyc_taxi_tabmlp.round(4)" ] }, { "cell_type": "markdown", "id": "understanding-boutique", "metadata": {}, "source": [ "**Table 12**. Results obtained for the NYC Taxi trip duration dataset using the `TabMlp`.\n", "\n", "The validation loss in this case is the `MSE`. The standard deviation (`std` hereafter) of the target variable in the validation set is $\\sim$599. Given that the `std` is the `RMSE` we would obtain if we always predicted the expected value, we can see that this is not a very powerful model, i.e. the task of predicting taxi trip duration is, indeed, relatively challenging. \n", "\n", "Let's see how the other DL models perform. " ] }, { "cell_type": "markdown", "id": "competitive-rapid", "metadata": {}, "source": [ "#### 3.3.2 `TabResnet`" ] }, { "cell_type": "code", "execution_count": 14, "id": "danish-pearl", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
blocks_dimsblocks_dropoutmlp_hidden_dimsmlp_activationmlp_dropoutmlp_batchnormmlp_batchnorm_lastmlp_linear_firstembed_dropoutlrbatch_sizeweight_decayoptimizerlr_schedulerbase_lrmax_lrdiv_factorfinal_div_factorn_cyclesval_loss_or_metric
0same0.5autorelu0.2FalseFalseFalse0.00.0120480.0AdamReduceLROnPlateau0.0010.012510000.0597015.1182
1same0.2autorelu0.1FalseFalseFalse0.00.0110240.0AdamWReduceLROnPlateau0.0010.012510000.0598266.4310
2same0.5autorelu0.2FalseFalseFalse0.00.0420480.0AdamReduceLROnPlateau0.0010.012510000.05100332.3569
3same0.2autorelu0.1FalseFalseFalse0.00.0110240.0AdamReduceLROnPlateau0.0010.012510000.05103006.5603
4same0.5autorelu0.2FalseFalseFalse0.00.0120480.0AdamWReduceLROnPlateau0.0010.012510000.05105967.2627
\n", "
" ], "text/plain": [ " blocks_dims blocks_dropout mlp_hidden_dims mlp_activation mlp_dropout \\\n", "0 same 0.5 auto relu 0.2 \n", "1 same 0.2 auto relu 0.1 \n", "2 same 0.5 auto relu 0.2 \n", "3 same 0.2 auto relu 0.1 \n", "4 same 0.5 auto relu 0.2 \n", "\n", " mlp_batchnorm mlp_batchnorm_last mlp_linear_first embed_dropout lr \\\n", "0 False False False 0.0 0.01 \n", "1 False False False 0.0 0.01 \n", "2 False False False 0.0 0.04 \n", "3 False False False 0.0 0.01 \n", "4 False False False 0.0 0.01 \n", "\n", " batch_size weight_decay optimizer lr_scheduler base_lr max_lr \\\n", "0 2048 0.0 Adam ReduceLROnPlateau 0.001 0.01 \n", "1 1024 0.0 AdamW ReduceLROnPlateau 0.001 0.01 \n", "2 2048 0.0 Adam ReduceLROnPlateau 0.001 0.01 \n", "3 1024 0.0 Adam ReduceLROnPlateau 0.001 0.01 \n", "4 2048 0.0 AdamW ReduceLROnPlateau 0.001 0.01 \n", "\n", " div_factor final_div_factor n_cycles val_loss_or_metric \n", "0 25 10000.0 5 97015.1182 \n", "1 25 10000.0 5 98266.4310 \n", "2 25 10000.0 5 100332.3569 \n", "3 25 10000.0 5 103006.5603 \n", "4 25 10000.0 5 105967.2627 " ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#hide\n", "nyc_taxi_tabresnet = pd.read_csv(TABLES_DIR / \"nyc_taxi_tabresnet.csv\").iloc[:5]\n", "nyc_taxi_tabresnet.round(4)" ] }, { "cell_type": "markdown", "id": "excessive-vitamin", "metadata": {}, "source": [ "**Table 13**. Results obtained for the NYC Taxi trip duration dataset using the `TabResnet`." ] }, { "cell_type": "markdown", "id": "right-consumption", "metadata": {}, "source": [ "#### 3.3.3 `Tabnet`" ] }, { "cell_type": "code", "execution_count": 15, "id": "prompt-singles", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
n_stepsstep_dimattn_dimghost_bnvirtual_batch_sizemomentumgammadropoutembed_dropoutlrbatch_sizeweight_decaylambda_sparseoptimizerlr_schedulerbase_lrmax_lrdiv_factorfinal_div_factorn_cyclesval_loss_or_metric
0588False1280.751.50.00.00.0110240.00.0001AdamReduceLROnPlateau0.0010.012510000.05144819.1190
1588False1280.981.50.00.00.0110240.00.0001AdamReduceLROnPlateau0.0010.012510000.05146057.8078
2588False1280.501.50.00.00.0110240.00.0001AdamReduceLROnPlateau0.0010.012510000.05146201.3771
351616False1280.981.50.00.00.0110240.00.0001AdamReduceLROnPlateau0.0010.012510000.05146461.7343
4588False1280.251.50.00.00.0110240.00.0001AdamReduceLROnPlateau0.0010.012510000.05148636.8888
\n", "
" ], "text/plain": [ " n_steps step_dim attn_dim ghost_bn virtual_batch_size momentum gamma \\\n", "0 5 8 8 False 128 0.75 1.5 \n", "1 5 8 8 False 128 0.98 1.5 \n", "2 5 8 8 False 128 0.50 1.5 \n", "3 5 16 16 False 128 0.98 1.5 \n", "4 5 8 8 False 128 0.25 1.5 \n", "\n", " dropout embed_dropout lr batch_size weight_decay lambda_sparse \\\n", "0 0.0 0.0 0.01 1024 0.0 0.0001 \n", "1 0.0 0.0 0.01 1024 0.0 0.0001 \n", "2 0.0 0.0 0.01 1024 0.0 0.0001 \n", "3 0.0 0.0 0.01 1024 0.0 0.0001 \n", "4 0.0 0.0 0.01 1024 0.0 0.0001 \n", "\n", " optimizer lr_scheduler base_lr max_lr div_factor final_div_factor \\\n", "0 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 \n", "1 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 \n", "2 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 \n", "3 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 \n", "4 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 \n", "\n", " n_cycles val_loss_or_metric \n", "0 5 144819.1190 \n", "1 5 146057.8078 \n", "2 5 146201.3771 \n", "3 5 146461.7343 \n", "4 5 148636.8888 " ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#collapse-hide\n", "nyc_taxi_tabnet = pd.read_csv(TABLES_DIR / \"nyc_taxi_tabnet.csv\").iloc[:5]\n", "nyc_taxi_tabnet.round(4)" ] }, { "cell_type": "markdown", "id": "assured-creativity", "metadata": {}, "source": [ "**Table 14**. Results obtained for the NYC Taxi trip duration dataset using the `Tabnet`." ] }, { "cell_type": "markdown", "id": "equivalent-preserve", "metadata": {}, "source": [ "#### 3.3.4 `TabTransformer`" ] }, { "cell_type": "code", "execution_count": 16, "id": "challenging-arctic", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
embed_dropoutfull_embed_dropoutshared_embedadd_shared_embedfrac_shared_embedinput_dimn_headsn_blocksdropoutff_hidden_dimtransformer_activationmlp_hidden_dimsmlp_activationmlp_batchnormmlp_batchnorm_lastmlp_linear_firstwith_widelrbatch_sizeweight_decayoptimizerlr_schedulerbase_lrmax_lrdiv_factorfinal_div_factorn_cyclesval_loss_or_metric
00.0FalseFalseFalse816440.1NaNreluNonereluFalseFalseFalseFalse0.0110240.0AdamReduceLROnPlateau0.0010.012510000.05180162.4087
10.0FalseFalseFalse816440.1NaNreluNonereluFalseFalseFalseFalse0.012560.0AdamReduceLROnPlateau0.0010.012510000.05186017.1888
20.0FalseFalseFalse816440.1NaNreluNonereluFalseFalseFalseFalse0.015120.0AdamReduceLROnPlateau0.0010.012510000.05196144.0674
30.0FalseFalseFalse832840.4NaNreluNonereluFalseFalseFalseFalse0.0110240.0AdamReduceLROnPlateau0.0010.012510000.05357869.3703
40.0FalseFalseFalse8641640.4NaNreluNonereluFalseFalseFalseFalse0.015120.0AdamReduceLROnPlateau0.0010.012510000.05357884.9043
\n", "
" ], "text/plain": [ " embed_dropout full_embed_dropout shared_embed add_shared_embed \\\n", "0 0.0 False False False \n", "1 0.0 False False False \n", "2 0.0 False False False \n", "3 0.0 False False False \n", "4 0.0 False False False \n", "\n", " frac_shared_embed input_dim n_heads n_blocks dropout ff_hidden_dim \\\n", "0 8 16 4 4 0.1 NaN \n", "1 8 16 4 4 0.1 NaN \n", "2 8 16 4 4 0.1 NaN \n", "3 8 32 8 4 0.4 NaN \n", "4 8 64 16 4 0.4 NaN \n", "\n", " transformer_activation mlp_hidden_dims mlp_activation mlp_batchnorm \\\n", "0 relu None relu False \n", "1 relu None relu False \n", "2 relu None relu False \n", "3 relu None relu False \n", "4 relu None relu False \n", "\n", " mlp_batchnorm_last mlp_linear_first with_wide lr batch_size \\\n", "0 False False False 0.01 1024 \n", "1 False False False 0.01 256 \n", "2 False False False 0.01 512 \n", "3 False False False 0.01 1024 \n", "4 False False False 0.01 512 \n", "\n", " weight_decay optimizer lr_scheduler base_lr max_lr div_factor \\\n", "0 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 \n", "1 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 \n", "2 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 \n", "3 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 \n", "4 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 \n", "\n", " final_div_factor n_cycles val_loss_or_metric \n", "0 10000.0 5 180162.4087 \n", "1 10000.0 5 186017.1888 \n", "2 10000.0 5 196144.0674 \n", "3 10000.0 5 357869.3703 \n", "4 10000.0 5 357884.9043 " ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#collapse-hide\n", "nyc_taxi_tabtransformer = pd.read_csv(TABLES_DIR / \"nyc_taxi_tabtransformer.csv\").iloc[:5]\n", "nyc_taxi_tabtransformer.round(4)" ] }, { "cell_type": "markdown", "id": "hungarian-blowing", "metadata": {}, "source": [ "**Table 15**. Results obtained for the NYC Taxi trip duration dataset using the `TabTransformer`." ] }, { "cell_type": "markdown", "id": "fatal-karen", "metadata": {}, "source": [ "#### 3.3.5 DL vs `LightGBM`" ] }, { "cell_type": "code", "execution_count": 17, "id": "gross-devices", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
modelrmser2runtimebest_epoch_or_ntrees
0lightgbm262.70990.804442.7211504.0
1tabmlp271.34220.7913568.430924.0
2tabresnet292.89080.7569471.265024.0
3tabtransformer336.58260.67895779.031454.0
4tabnet376.05300.59921844.472315.0
\n", "
" ], "text/plain": [ " model rmse r2 runtime best_epoch_or_ntrees\n", "0 lightgbm 262.7099 0.8044 42.7211 504.0\n", "1 tabmlp 271.3422 0.7913 568.4309 24.0\n", "2 tabresnet 292.8908 0.7569 471.2650 24.0\n", "3 tabtransformer 336.5826 0.6789 5779.0314 54.0\n", "4 tabnet 376.0530 0.5992 1844.4723 15.0" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#collapse-hide\n", "lightgbm_vs_dl_nyc_taxy = pd.read_csv(TABLES_DIR / \"lightgbm_vs_dl_nyc_taxi.csv\")\n", "lightgbm_vs_dl_nyc_taxy.round(4)" ] }, { "cell_type": "markdown", "id": "viral-elizabeth", "metadata": {}, "source": [ "**Table 16**. Results obtained for the NYC Taxi trip duration dataset using four DL models and LightGBM.\n", "\n", "The `TabTransformer` and `Tabnet` are, in this case, the models that have the worst performance. As I mentioned earlier I will reflect on potential reasons later in the Conclusion section." ] }, { "cell_type": "markdown", "id": "possible-seller", "metadata": {}, "source": [ "### 3.4 Facebook comments volume\n", "\n", "This is the last of the four datasets we will be discussing in this post, a second regression problem.\n", "\n", "#### 3.4.1 `TabMlp`" ] }, { "cell_type": "code", "execution_count": 18, "id": "danish-organic", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
mlp_hidden_dimsmlp_activationmlp_dropoutmlp_batchnormmlp_batchnorm_lastmlp_linear_firstembed_dropoutlrbatch_sizeweight_decayoptimizerlr_schedulerbase_lrmax_lrdiv_factorfinal_div_factorn_cyclesval_loss_or_metric
0[100,50]relu0.1FalseFalseTrue0.00.0015120.0RAdamReduceLROnPlateau0.0010.012510000.05.032.5931
1[100,50]relu0.1FalseFalseFalse0.00.0015120.0RAdamReduceLROnPlateau0.0010.012510000.05.033.3515
2[200, 100]relu0.1FalseFalseFalse0.00.0012560.0AdamReduceLROnPlateau0.0010.012510000.05.033.4140
3[200, 100]relu0.1FalseFalseFalse0.10.0012560.0AdamReduceLROnPlateau0.0010.012510000.05.033.5679
4[200, 100]relu0.1FalseFalseFalse0.00.0015120.0RAdamReduceLROnPlateau0.0010.012510000.05.033.6284
\n", "
" ], "text/plain": [ " mlp_hidden_dims mlp_activation mlp_dropout mlp_batchnorm \\\n", "0 [100,50] relu 0.1 False \n", "1 [100,50] relu 0.1 False \n", "2 [200, 100] relu 0.1 False \n", "3 [200, 100] relu 0.1 False \n", "4 [200, 100] relu 0.1 False \n", "\n", " mlp_batchnorm_last mlp_linear_first embed_dropout lr batch_size \\\n", "0 False True 0.0 0.001 512 \n", "1 False False 0.0 0.001 512 \n", "2 False False 0.0 0.001 256 \n", "3 False False 0.1 0.001 256 \n", "4 False False 0.0 0.001 512 \n", "\n", " weight_decay optimizer lr_scheduler base_lr max_lr div_factor \\\n", "0 0.0 RAdam ReduceLROnPlateau 0.001 0.01 25 \n", "1 0.0 RAdam ReduceLROnPlateau 0.001 0.01 25 \n", "2 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 \n", "3 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 \n", "4 0.0 RAdam ReduceLROnPlateau 0.001 0.01 25 \n", "\n", " final_div_factor n_cycles val_loss_or_metric \n", "0 10000.0 5.0 32.5931 \n", "1 10000.0 5.0 33.3515 \n", "2 10000.0 5.0 33.4140 \n", "3 10000.0 5.0 33.5679 \n", "4 10000.0 5.0 33.6284 " ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#collapse-hide\n", "fb_comments_tabmlp = pd.read_csv(TABLES_DIR / \"fb_comments_tabmlp.csv\").iloc[:5]\n", "fb_comments_tabmlp.round(4)" ] }, { "cell_type": "markdown", "id": "persistent-australian", "metadata": {}, "source": [ "**Table 17**. Results obtained for the Facebook comments volume dataset using `TabMlp`.\n", "\n", "As in the case of the NYC Taxi trip duration, the validation loss is the `MSE` loss. The `std` of the target variable is ~13 in the case of the Facebook comments volume dataset. Therefore, following the same reasoning, we can see that the task of predicting the volume of facebook comments using this particular dataset\n", " is challenging. \n", " \n", " Let's see how the other DL models perform." ] }, { "cell_type": "markdown", "id": "featured-reconstruction", "metadata": {}, "source": [ "#### 3.4.2 `TabResnet`" ] }, { "cell_type": "code", "execution_count": 19, "id": "consolidated-school", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
blocks_dimsblocks_dropoutmlp_hidden_dimsmlp_activationmlp_dropoutmlp_batchnormmlp_batchnorm_lastmlp_linear_firstembed_dropoutlrbatch_sizeweight_decayoptimizerlr_schedulerbase_lrmax_lrdiv_factorfinal_div_factorn_cyclesval_loss_or_metric
0[100, 100, 100]0.1Nonerelu0.1FalseFalseFalse0.00.00055120.0AdamCyclicLR0.00050.032510000.010.034.4972
1[100, 100, 100]0.1Nonerelu0.1FalseFalseFalse0.00.00055120.0AdamWCyclicLR0.00050.032510000.010.034.8520
2[100, 100, 100]0.1Nonerelu0.1FalseFalseFalse0.00.00055120.0AdamCyclicLR0.00050.032510000.010.034.9504
3[100, 100, 100]0.1Nonerelu0.1FalseFalseFalse0.00.00055120.0AdamCyclicLR0.00050.012510000.010.035.1668
4[100, 100, 100]0.1Nonerelu0.1FalseFalseFalse0.00.00055120.0AdamWCyclicLR0.00050.012510000.010.035.2503
\n", "
" ], "text/plain": [ " blocks_dims blocks_dropout mlp_hidden_dims mlp_activation \\\n", "0 [100, 100, 100] 0.1 None relu \n", "1 [100, 100, 100] 0.1 None relu \n", "2 [100, 100, 100] 0.1 None relu \n", "3 [100, 100, 100] 0.1 None relu \n", "4 [100, 100, 100] 0.1 None relu \n", "\n", " mlp_dropout mlp_batchnorm mlp_batchnorm_last mlp_linear_first \\\n", "0 0.1 False False False \n", "1 0.1 False False False \n", "2 0.1 False False False \n", "3 0.1 False False False \n", "4 0.1 False False False \n", "\n", " embed_dropout lr batch_size weight_decay optimizer lr_scheduler \\\n", "0 0.0 0.0005 512 0.0 Adam CyclicLR \n", "1 0.0 0.0005 512 0.0 AdamW CyclicLR \n", "2 0.0 0.0005 512 0.0 Adam CyclicLR \n", "3 0.0 0.0005 512 0.0 Adam CyclicLR \n", "4 0.0 0.0005 512 0.0 AdamW CyclicLR \n", "\n", " base_lr max_lr div_factor final_div_factor n_cycles val_loss_or_metric \n", "0 0.0005 0.03 25 10000.0 10.0 34.4972 \n", "1 0.0005 0.03 25 10000.0 10.0 34.8520 \n", "2 0.0005 0.03 25 10000.0 10.0 34.9504 \n", "3 0.0005 0.01 25 10000.0 10.0 35.1668 \n", "4 0.0005 0.01 25 10000.0 10.0 35.2503 " ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#collapse-hide\n", "fb_comments_tabresnet = pd.read_csv(TABLES_DIR / \"fb_comments_tabresnet.csv\").iloc[:5]\n", "fb_comments_tabresnet.round(4)" ] }, { "cell_type": "markdown", "id": "small-edinburgh", "metadata": {}, "source": [ "**Table 18**. Results obtained for the Facebook comments volume dataset using `TabResnet`." ] }, { "cell_type": "markdown", "id": "extraordinary-mauritius", "metadata": {}, "source": [ "#### 3.4.3 `Tabnet`" ] }, { "cell_type": "code", "execution_count": 20, "id": "communist-musical", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
n_stepsstep_dimattn_dimghost_bnvirtual_batch_sizemomentumgammadropoutembed_dropoutlrbatch_sizeweight_decaylambda_sparseoptimizerlr_schedulerbase_lrmax_lrdiv_factorfinal_div_factorn_cyclesval_loss_or_metric
051616False1280.981.50.00.00.035120.00.0001AdamWReduceLROnPlateau0.0010.012510000.0535.8122
131616False1280.981.50.20.00.035120.00.0001AdamReduceLROnPlateau0.0010.012510000.0537.6417
251616False1280.981.50.00.00.035120.00.0001AdamWReduceLROnPlateau0.0010.012510000.0538.9771
351616False1280.981.50.20.00.035120.00.0001AdamReduceLROnPlateau0.0010.012510000.0539.5899
451616False1280.981.50.00.00.032560.00.0001AdamReduceLROnPlateau0.0010.012510000.0540.9462
\n", "
" ], "text/plain": [ " n_steps step_dim attn_dim ghost_bn virtual_batch_size momentum gamma \\\n", "0 5 16 16 False 128 0.98 1.5 \n", "1 3 16 16 False 128 0.98 1.5 \n", "2 5 16 16 False 128 0.98 1.5 \n", "3 5 16 16 False 128 0.98 1.5 \n", "4 5 16 16 False 128 0.98 1.5 \n", "\n", " dropout embed_dropout lr batch_size weight_decay lambda_sparse \\\n", "0 0.0 0.0 0.03 512 0.0 0.0001 \n", "1 0.2 0.0 0.03 512 0.0 0.0001 \n", "2 0.0 0.0 0.03 512 0.0 0.0001 \n", "3 0.2 0.0 0.03 512 0.0 0.0001 \n", "4 0.0 0.0 0.03 256 0.0 0.0001 \n", "\n", " optimizer lr_scheduler base_lr max_lr div_factor final_div_factor \\\n", "0 AdamW ReduceLROnPlateau 0.001 0.01 25 10000.0 \n", "1 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 \n", "2 AdamW ReduceLROnPlateau 0.001 0.01 25 10000.0 \n", "3 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 \n", "4 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 \n", "\n", " n_cycles val_loss_or_metric \n", "0 5 35.8122 \n", "1 5 37.6417 \n", "2 5 38.9771 \n", "3 5 39.5899 \n", "4 5 40.9462 " ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#collapse-hide\n", "fb_comments_tabnet = pd.read_csv(TABLES_DIR / \"fb_comments_tabnet.csv\").iloc[:5]\n", "fb_comments_tabnet.round(4)" ] }, { "cell_type": "markdown", "id": "ambient-vietnam", "metadata": {}, "source": [ "**Table 19**. Results obtained for the Facebook comments volume dataset using `Tabnet`." ] }, { "cell_type": "markdown", "id": "minus-donor", "metadata": {}, "source": [ "#### 3.4.4 `TabTransformer`" ] }, { "cell_type": "code", "execution_count": 21, "id": "pending-albuquerque", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
embed_dropoutfull_embed_dropoutshared_embedadd_shared_embedfrac_shared_embedinput_dimn_headsn_blocksdropoutff_hidden_dimtransformer_activationmlp_hidden_dimsmlp_activationmlp_batchnormmlp_batchnorm_lastmlp_linear_firstwith_widelrbatch_sizeweight_decayoptimizerlr_schedulerbase_lrmax_lrdiv_factorfinal_div_factorn_cyclesval_loss_or_metric
00.0FalseFalseFalse816240.1NaNreluNonereluFalseFalseFalseFalse0.000510240.0AdamCyclicLR0.00050.012510000.010.033.0946
10.0FalseFalseFalse816240.1NaNreluNonereluFalseFalseFalseFalse0.000540960.0AdamWOneCycleLR0.00100.01251000.05.033.1283
20.0FalseFalseFalse816240.1NaNreluNonereluFalseFalseFalseFalse0.001010240.0AdamReduceLROnPlateau0.00100.012510000.05.033.2175
30.0FalseFalseFalse816240.1NaNrelusamereluFalseFalseFalseFalse0.001010240.0AdamReduceLROnPlateau0.00100.012510000.05.033.4698
40.0FalseFalseFalse816440.1NaNreluNonereluFalseFalseFalseFalse0.001010240.0AdamReduceLROnPlateau0.00100.012510000.05.033.7950
\n", "
" ], "text/plain": [ " embed_dropout full_embed_dropout shared_embed add_shared_embed \\\n", "0 0.0 False False False \n", "1 0.0 False False False \n", "2 0.0 False False False \n", "3 0.0 False False False \n", "4 0.0 False False False \n", "\n", " frac_shared_embed input_dim n_heads n_blocks dropout ff_hidden_dim \\\n", "0 8 16 2 4 0.1 NaN \n", "1 8 16 2 4 0.1 NaN \n", "2 8 16 2 4 0.1 NaN \n", "3 8 16 2 4 0.1 NaN \n", "4 8 16 4 4 0.1 NaN \n", "\n", " transformer_activation mlp_hidden_dims mlp_activation mlp_batchnorm \\\n", "0 relu None relu False \n", "1 relu None relu False \n", "2 relu None relu False \n", "3 relu same relu False \n", "4 relu None relu False \n", "\n", " mlp_batchnorm_last mlp_linear_first with_wide lr batch_size \\\n", "0 False False False 0.0005 1024 \n", "1 False False False 0.0005 4096 \n", "2 False False False 0.0010 1024 \n", "3 False False False 0.0010 1024 \n", "4 False False False 0.0010 1024 \n", "\n", " weight_decay optimizer lr_scheduler base_lr max_lr div_factor \\\n", "0 0.0 Adam CyclicLR 0.0005 0.01 25 \n", "1 0.0 AdamW OneCycleLR 0.0010 0.01 25 \n", "2 0.0 Adam ReduceLROnPlateau 0.0010 0.01 25 \n", "3 0.0 Adam ReduceLROnPlateau 0.0010 0.01 25 \n", "4 0.0 Adam ReduceLROnPlateau 0.0010 0.01 25 \n", "\n", " final_div_factor n_cycles val_loss_or_metric \n", "0 10000.0 10.0 33.0946 \n", "1 1000.0 5.0 33.1283 \n", "2 10000.0 5.0 33.2175 \n", "3 10000.0 5.0 33.4698 \n", "4 10000.0 5.0 33.7950 " ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#collapse-hide\n", "fb_comments_tabtransformer = pd.read_csv(TABLES_DIR / \"fb_comments_tabtransformer.csv\").iloc[:5]\n", "fb_comments_tabtransformer.round(4)" ] }, { "cell_type": "markdown", "id": "blond-tiger", "metadata": {}, "source": [ "**Table 20**. Results obtained for the Facebook comments volume dataset using the `TabTransformer`." ] }, { "cell_type": "markdown", "id": "starting-charles", "metadata": {}, "source": [ "#### 3.4.5 DL vs `LightGBM`" ] }, { "cell_type": "code", "execution_count": 22, "id": "threaded-injury", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
modelrmser2runtimebest_epoch_or_ntrees
0lightgbm5.52900.82326.5259687.0
1tabmlp5.90850.7981250.476843.0
2tabtransformer5.92560.7969533.390827.0
3tabresnet6.21380.776770.46619.0
4tabnet6.42850.7610935.020559.0
\n", "
" ], "text/plain": [ " model rmse r2 runtime best_epoch_or_ntrees\n", "0 lightgbm 5.5290 0.8232 6.5259 687.0\n", "1 tabmlp 5.9085 0.7981 250.4768 43.0\n", "2 tabtransformer 5.9256 0.7969 533.3908 27.0\n", "3 tabresnet 6.2138 0.7767 70.4661 9.0\n", "4 tabnet 6.4285 0.7610 935.0205 59.0" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#collapse-hide\n", "lightgbm_vs_dl_fb_comments = pd.read_csv(TABLES_DIR / \"lightgbm_vs_dl_fb_comments.csv\")\n", "lightgbm_vs_dl_fb_comments.round(4)" ] }, { "cell_type": "markdown", "id": "governmental-metropolitan", "metadata": {}, "source": [ "**Table 21**. Results obtained for the Facebook comments volume dataset using four DL models and LightGBM." ] }, { "cell_type": "markdown", "id": "increased-rebound", "metadata": {}, "source": [ "## 4. Summary\n", "\n", "I have used four datasets and run over 1500 experiments (meaning runs with a parameter setup) comparing four DL models with `LightGBM`. This is a summary of some of the results.\n", "\n", "\n", "- `LightGBM` wins, and there was never a fight\n", "\n", " With one exception, `LightGBM` performs better than the DL models, and that one exception is precisely that, exceptional. To the experiments run and discussed here I could add two occasions where I used DL for tabular data in companies that I worked with. In particular, the model that is referred here as `TabMlp` with a `wide` component in one case and on its own in the other. \n", " \n", " The Wide & Deep model was used in the context of a recommendation algorithm, shortly after the popular [Wide and Deep](https://arxiv.org/abs/1606.07792) [19] paper was published in 2016. Back then I was using XGBoost to predict a measure of interest and rank offers based on that measure. The Wide and Deep model, implemented then with [`Keras`](https://keras.io/), obtained slightly better MAP and NDCG than XGBoost (almost identical metrics, although slightly lower, were obtained when using just the deep component). Given the number of additional considerations that one needs to take into account as you go to production, we eventually used XGBoost. \n", " \n", " In the second occasion, a more recent project, `TabMlp` on its own obtained very similar, but still lower RMSE and R2 values to those obtained using `LightGBM`. Even though `TabMLP`'s predictions were not directly used, we found the embeddings useful for a number of additional projects and we built a production system around `TabMlp`.\n", "\n", " Up to this point, I have focused on performance as measured by success metrics. However, when it comes to training (and prediction) time, the difference is so significant that makes some of these algorithms, at this stage, just useful for research purposes and/or kaggle competitions. Don't get me wrong, you only push an industry technologically by challenging current solutions and established concepts. I am simply stating that at this stage, in a production environment, it would be hard to envision a robust system built around some of these algorithms. This is the reason why I wrote \"*there was never a fight*\". When you go live, quite often is not only about success metrics but also speed and resilience. Considering altogether it seems to me like DL models for tabular data are still a bit far from being normally inserted in productions systems (but read below). \n", " \n", " Finally, you might read here and there that with the proper feature engineering, noise removal, balancing and \"who-knows-what-else\" DL models outperform GBMs. The truth is that in my experience is actually the opposite. When one manages to engineer good, powerful features GBMs perform even better than DL models. This is also consistent with the results in some recent competitions. For example, in the [RecSys Challenge 2020](https://recsys-twitter.com/) the guys at NVIDIA won using [clever featuring engineering](https://medium.com/rapids-ai/winning-solution-of-recsys2020-challenge-gpu-accelerated-feature-engineering-and-training-for-cd67c5a87b1f) (e.g. target oriented encoding) \"plugged\" into XGBoost on steroids (or better, GPUs). I am not sure that using those features and a DL model would actually improve their results.\n", " \n", " Overall, if I joined the results found this post, plus that I have found trying DL models on tabular data on real datasets in the industry, I can only conclude that DL models for tabular data \"are not quite there yet\" in terms of overall performance. \n", "\n", "\n", "- `TabNet` and the `TabTransformer`\n", "\n", " One rather surprising results was the poor performance of `Tabnet`, and perhaps to a lesser extent, the `TabTransformer`. \n", " \n", " One possibility is that I have not found the right set of parameters that lead to good metrics. In fact, the amount of overfitting when using `Tabnet` and `TabTransformer` was very significant, higher than in the case of `TabResnet` and furthermore `TabMlp`. This makes me believe that if I find a better set of regularization parameters, or simply using a different number of embeddings per categorical feature, I might be able to improve the results shown in the tables above. However, I should also say that given the good reception that these algorithms are having and the poor results I obtained, I placed a bit more emphasis in trying some additional parameters. Unfortunately, none of my attempts lead to a significant improvement.\n", " \n", " A second possibility is, of course, that the implementation at `pytorch-widedeep` is wrong. I guess I will find this out as I keep releasing versions and using the package. \n", " \n", " Overall, I find that `TabNet` is the worst performer (and the slowest) and I will certainly devote some extra time in the coming weeks to see if this is related to the input parameters. \n", "\n", "\n", " - Simplicity over complexity.\n", "\n", " It is interesting to see that overall, the DL algorithm that achieves similar performance to that of `LightGBM` is a simple MLP. By the time I write this, I wonder if this is somehow related to the emerging trend that is bringing MLPs back (e.g. [20], [21] or [22]), and the advent of more complex models is simply the result of hype instead of a proper exploration of current solutions. \n", " \n", " Of course, for more complex models, there is more room for exploration and hyperparameter optimization. While this is something I intend to keep exploring, there is a moment in space and time that one wonders \"*is this really worth it?*\". \n", " \n", " Let's see if I manage to answer this question in the next section\n", " \n", "## 5 Conclusion\n", "\n", "When I started thinking of this post a part of me already knew that DL models were, overall, not a real challenge for `LightGBM`. If we focused only in performance metrics and running time the only possible conclusion is that DL models for tabular data are still not competition for GBMs in real-world environments. However, at this stage in the industry/market, is that really *the question* to answer? I don't think so.\n", "\n", "This is not a competition, and it should not be, this should be a coalition. The question to answer is: \"how DL models for tabular data can help in the industry and complement the current systems\". Let's reflect a bit on this question. \n", "\n", "In my experience, DL models on tabular data perform best on sizeable dataset that involve many categorical features and these have many categories themselves. In those scenarios, one could just try DL models with an initial aim of using directly the prediction. However, even if the prediction is eventually not used, the embeddings contain a wealth of useful information. Information on how each categorical feature interacts with each other and information on how each categorical features relates to the target variable. These embeddings can be used for a number of additional products.\n", "\n", "For example, let's assume that you have a dataset with metadata for thousands of brands and prices for their corresponding products. Your task is to predict how the price changes over time (i.e. forecasting price). The embeddings for the categorical feature `brand` will give you information about how a particular brand relates to the rest of the columns in the dataset and the target (price). In other words, if given a brand you find the closest brands as defined by embeddings proximity you would be \"naturally\" and directly finding competitors within a given space (assuming that the dataset is representative of the market).\n", "\n", "In additions, GBMs do not allow for transfer learning, but DL models do. Furthermore, and as mentioned in the `TabNet` and the `TabTransformer` papers, self-supervised training leads to better performance in regimes where the data is low or the unlabeled dataset is much larger than the labeled dataset. Therefore, there are scenarios where DL models can be extremely useful. \n", "\n", "For example, let's assume you have a large dataset for a given problem in one country but a much smaller dataset for the exact same problem in another country. Let's also assume that the datasets are, column-wise, rather similar. One could train a DL model using the large dataset and \"transfer the learnings\" to the second, much smaller dataset with the hope of obtaining a much higher performance than just using that small dataset alone. \n", "\n", "There are some other scenarios that I can think of, but I will leave it here. In general, I simply wanted to illustrate that, if you came here to enjoy the fact that GBMs perform better than DL models, I hope you enjoyed the ride (and that you start thinking in a good therapist), but in my opinion, that is not the point. \n", "\n", "**In terms of metrics, GBMs perform better than DL models, that is correct, but the latter bring some functionalities to the table that GBMs don't have and therefore, complement them perfectly.**" ] }, { "cell_type": "markdown", "id": "received-detection", "metadata": {}, "source": [ "## 6. Future Work\n", "\n", "I started thinking in this post months ago. Then some other things took priority in my life (plus a lot of work) and it became a bit of a longer journey. I now hope I can get a bit of help from very clever people in my team and improve the Tabular vs DL code in the repo, perhaps automating some processes so I can easily add more datasets in the future. \n", "\n", "Also this has been a good test for the [`pytorch-widedeep`](https://github.com/jrzaurin/pytorch-widedeep) library (if you like it, or find it useful, give it a star please 😊). All the links in this post point towards the `tabnet` branch in the repo, which is the most updated. During the next few days I will merge and release v1 of the package and then update the links and the post. From there, there are a series of algorithms we would like to bring (such as SAINT) and also add some different forms of training.\n", "\n", "Beyond adding more algorithms to the library or improving the benchmark code, I wanted to close this with one final thought. As I mentioned in the beginning of the post, there is an element of inconsistency between papers. Different papers will find different results for all algorithms considered, GBMs or DL-based. When you read them one gets the feeling that there is some rush, some urgency to publish something that obtains SoTA. For someone like me, coming from a different background than computer science, this reminds me, in a sense, of my days as astronomer. For years then I found that most of the publications in my field where not very good, but since all that you are judged for are publications and citations, one would publish anything, and the faster, the better.\n", "\n", "At this stage, leaving publications and citations aside, I think there is an opportunity for some of us, and some companies as well (so that we can use actual real-world data), to collaborate and properly benchmark DL algorithms for tabular data. I believe the potential of these algorithms in the industry is enormous and with proper benchmarks we could learn not only where they perform better, but how to use them more efficiently.\n", "\n", "And that's it! if you made it to here I hope you enjoyed and/or find this useful." ] }, { "cell_type": "markdown", "id": "alleged-blair", "metadata": {}, "source": [ "## References\n", "\n", "[1] Tabular Data: Deep Learning is Not All You Need: Ravid Shwartz-Ziv, Amitai Armon, 2021, [arxiv:2106.03253](https://[arxiv.org/pdf/2106.03253.pdf)\n", "\n", "[2] XGBoost: A Scalable Tree Boosting System. Tianqi Chen, Carlos Guestrin 2016, [arXiv:1603.02754](https://arxiv.org/abs/1603.02754)\n", "\n", "[3] CatBoost: unbiased boosting with categorical features. Liudmila Prokhorenkova, Gleb Gusev, Aleksandr Vorobev, Anna Veronika Dorogush, Andrey Gulin, [arXiv:1706.09516](https://arxiv.org/abs/1706.09516)\n", "\n", "[4] LightGBM: A Highly Efficient Gradient Boosting Decision Tree. Guolin Ke, Qi Meng, Thomas Finley, Taifeng Wang, 2017, [31st Conference on Neural Information Processing Systems](https://papers.nips.cc/paper/2017/file/6449f44a102fde848669bdd9eb6b76fa-Paper.pdf) \n", "\n", "[5] SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training. Gowthami Somepalli, Micah Goldblum, Avi Schwarzschild, C. Bayan Bruss, Tom Goldstein, 2021, [arXiv:2106.01342](https://arxiv.org/abs/2106.01342)\n", "\n", "[6] Comment Volume Prediction using Neural Networks and Decision Trees, Kamaljot Singh, Ranjeet Kaur, 2015 17th UKSIM-AMSS International Conference on Modelling and Simulation.\n", "\n", "[7] TabNet: Attentive Interpretable Tabular Learning, Sercan O. Arik, Tomas Pfister, [arXiv:1908.07442v5](https://arxiv.org/abs/1908.07442)\n", "\n", "[8] Train longer, generalize better: closing the generalization gap in large batch training of neural networks.\n", "Elad Hoffer, Itay Hubara and Daniel Soudry, 2017, [arXiv:1705.08741](https://arxiv.org/abs/1705.08741)\n", "\n", "[9] TabTransformer: Tabular Data Modeling Using Contextual Embeddings. Xin Huang, Ashish Khetan, Milan Cvitkovic, Zohar Karnin, 2020. [arXiv:2012.06678v1](https://arxiv.org/abs/2012.06678)\n", " \n", "[10] Attention Is All You Need, Ashish Vaswani, Noam Shazeer, Niki Parmar, et al., 2017. [arXiv:1706.03762v5](https://arxiv.org/abs/1706.03762)\n", "\n", "[11] Adam: A Method for Stochastic Optimization, Diederik P. Kingma, Jimmy Ba, 2014, [arXiv:1412.6980](https://arxiv.org/abs/1412.6980)\n", "\n", "[12] Decoupled Weight Decay Regularization, Ilya Loshchilov, Frank Hutter, 2017.[arXiv:1711.05101](https://arxiv.org/abs/1711.05101)\n", "\n", "[13] On the Variance of the Adaptive Learning Rate and Beyond, Liyuan Liu, Haoming Jiang, Pengcheng He, Weizhu Chen, Xiaodong Liu, Jianfeng Gao, Jiawei Han, 2019, [arxiv.org:1908.03265](https://arxiv.org/abs/1908.03265)\n", "\n", "[14] Cyclical Learning Rates for Training Neural Networks, Leslie N. Smith, 2017, [arxiv.org:1506.01186](https://arxiv.org/abs/1506.01186)\n", "\n", "[15] Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates, Leslie N. Smith, Nicholay Topin, 2017, [arxiv.org:1708.0712](https://arxiv.org/abs/1708.07120)\n", "\n", "[16] Optuna: A Next-generation Hyperparameter Optimization Framework. Takuya Akiba, Shotaro Sano, Toshihiko Yanase, Takeru Ohta, Masanori Koyama, 2019, [arXiv:1907.10902](https://arxiv.org/abs/1907.10902)\n", "\n", "[17] Algorithms for Hyper-Parameter Optimization, James Bergstra, Rémi Bardenet, Yoshua Bengio, Balázs Kégl, 2011, \n", "[25th Conference on Neural Information Processing Systems](https://papers.nips.cc/paper/2011/file/86e8f7ab32cfd12577bc2619bc635690-Paper.pdf)\n", "\n", "[18] Focal Loss for Dense Object Detection, Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, Piotr Dollár, 2017, [arxiv.org:1708.02002](https://arxiv.org/abs/1708.02002?source=post_page---------------------------)\n", "\n", "[19] Wide & Deep Learning for Recommender Systems, Heng-Tze Cheng, Levent Koc, Jeremiah Harmsen, et al, 2016, [arxiv.org:1606.07792](https://arxiv.org/abs/1606.07792)\n", "\n", "[20] FNet: Mixing Tokens with Fourier Transforms, James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, Santiago Ontanon, 2021, [arxiv.org:2105.03824](https://arxiv.org/abs/2105.03824)\n", "\n", "[21] Pay Attention to MLPs, Hanxiao Liu, Zihang Dai, David R. So, Quoc V. Le, 2021, [arxiv.org:2105.08050](https://arxiv.org/abs/2105.08050)\n", "\n", "[22] ResMLP: Feedforward networks for image classification with data-efficient training,\n", "Hugo Touvron, Piotr Bojanowski, Mathilde Caron, et al, 2021, [arxiv.org:2105.03404](https://arxiv.org/abs/2105.03404)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 5 }