未验证 提交 86217bbc 编写于 作者: J Javier 提交者: GitHub

Merge pull request #49 from jrzaurin/jrzaurin/perceiver

Jrzaurin/perceiver
......@@ -9,7 +9,7 @@
[![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://github.com/jrzaurin/pytorch-widedeep/graphs/commit-activity)
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/jrzaurin/pytorch-widedeep/issues)
[![codecov](https://codecov.io/gh/jrzaurin/pytorch-widedeep/branch/master/graph/badge.svg)](https://codecov.io/gh/jrzaurin/pytorch-widedeep)
[![Python 3.6 3.7 3.8 3.9](https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8%20%7C%203.9-blue.svg)](https://www.python.org/)
[![Python 3.7 3.8 3.9](https://img.shields.io/badge/python-3.7%20%7C%203.8%20%7C%203.9-blue.svg)](https://www.python.org/)
# pytorch-widedeep
......@@ -24,6 +24,13 @@ using wide and deep models.
**Slack**: if you want to contribute or just want to chat with us, join [slack](https://join.slack.com/t/pytorch-widedeep/shared_invite/zt-soss7stf-iXpVuLeKZz8lGTnxxtHtTw)
The content of this document is organized as follows:
1. [introduction](#introduction)
2. [The deeptabular component](#the-deeptabular-component)
3. [installation](#installation)
4. [quick start (tl;dr)](#quick-start)
### Introduction
``pytorch-widedeep`` is based on Google's [Wide and Deep Algorithm](https://arxiv.org/abs/1606.07792)
......@@ -82,61 +89,58 @@ into:
<img width="300" src="docs/figures/architecture_2_math.png">
</p>
I recommend using the ``wide`` and ``deeptabular`` models in
``pytorch-widedeep``. However it is very likely that users will want to use
their own models for the ``deeptext`` and ``deepimage`` components. That is
perfectly possible as long as the the custom models have an attribute called
``output_dim`` with the size of the last layer of activations, so that
``WideDeep`` can be constructed. Again, examples on how to use custom
components can be found in the Examples folder. Just in case
``pytorch-widedeep`` includes standard text (stack of LSTMs) and image
(pre-trained ResNets or stack of CNNs) models.
### The ``deeptabular`` component
It is important to emphasize that **each individual component, `wide`,
`deeptabular`, `deeptext` and `deepimage`, can be used independently** and in
isolation. For example, one could use only `wide`, which is in simply a
linear model. In fact, one of the most interesting functionalities
in``pytorch-widedeep`` is the ``deeptabular`` component. Currently,
``pytorch-widedeep`` offers the following different models for that
component:
in``pytorch-widedeep`` would be the use of the ``deeptabular`` component on
its own, i.e. what one might normally refer as Deep Learning for Tabular
Data. Currently, ``pytorch-widedeep`` offers the following different models
for that component:
1. ``TabMlp``: this is almost identical to the [tabular
model](https://docs.fast.ai/tutorial.tabular.html) in the fantastic
[fastai](https://docs.fast.ai/) library, and consists simply in embeddings
representing the categorical features, concatenated with the continuous
features, and passed then through a MLP.
2. ``TabRenset``: This is similar to the previous model but the embeddings are
1. **TabMlp**: a simple MLP that receives embeddings representing the
categorical features, concatenated with the continuous features.
2. **TabResnet**: similar to the previous model but the embeddings are
passed through a series of ResNet blocks built with dense layers.
3. ``Tabnet``: Details on TabNet can be found in:
3. **TabNet**: details on TabNet can be found in
[TabNet: Attentive Interpretable Tabular Learning](https://arxiv.org/abs/1908.07442)
4. ``TabTransformer``: Details on the TabTransformer can be found in:
[TabTransformer: Tabular Data Modeling Using Contextual Embeddings](https://arxiv.org/pdf/2012.06678.pdf).
Note that the TabTransformer implementation available at ``pytorch-widedeep``
is an adaptation of the original implementation.
And the ``Tabformer`` family, i.e. Transformers for Tabular data:
5. ``FT-Transformer``: or Feature Tokenizer transformer. This is a relatively small
variation of the ``TabTransformer``. The variation itself was first
introduced in the ``SAINT`` paper, but the name "``FT-Transformer``" was first
used in
4. **TabTransformer**: details on the TabTransformer can be found in
[TabTransformer: Tabular Data Modeling Using Contextual Embeddings](https://arxiv.org/pdf/2012.06678.pdf).
5. **SAINT**: Details on SAINT can be found in
[SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training](https://arxiv.org/abs/2106.01342).
6. **FT-Transformer**: details on the FT-Transformer can be found in
[Revisiting Deep Learning Models for Tabular Data](https://arxiv.org/abs/2106.11959).
When using the ``FT-Transformer`` each continuous feature is "embedded"
(i.e. going through a 1-layer MLP with or without activation function) and
then passed through the attention blocks along with the categorical features.
This is available in ``pytorch-widedeep``'s ``TabTransformer`` by setting the
parameter ``embed_continuous = True``.
7. **TabFastFormer**: adaptation of the FastFormer for tabular data. Details
on the Fasformer can be found in
[FastFormers: Highly Efficient Transformer Models for Natural Language Understanding](https://arxiv.org/abs/2010.13382)
8. **TabPerceiver**: adaptation of the Perceiver for tabular data. Details on
the Perceiver can be found in
[Perceiver: General Perception with Iterative Attention](https://arxiv.org/abs/2103.03206)
6. ``SAINT``: Details on SAINT can be found in:
[SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training](https://arxiv.org/abs/2106.01342).
Note that while there are scientific publications for the TabTransformer,
SAINT and FT-Transformer, the TabFasfFormer and TabPerceiver are our own
adaptation of those algorithms for tabular data.
For details on these models and their options please see the examples in the
Examples folder and the documentation.
Finally, while I recommend using the ``wide`` and ``deeptabular`` models in
``pytorch-widedeep`` it is very likely that users will want to use their own
models for the ``deeptext`` and ``deepimage`` components. That is perfectly
possible as long as the the custom models have an attribute called
``output_dim`` with the size of the last layer of activations, so that
``WideDeep`` can be constructed. Again, examples on how to use custom
components can be found in the Examples folder. Just in case
``pytorch-widedeep`` includes standard text (stack of LSTMs) and image
(pre-trained ResNets or stack of CNNs) models.
### Installation
### Installation
Install using pip:
......@@ -167,8 +171,8 @@ when running on Mac, present in previous versions, persist on this release
and the data-loaders will not run in parallel. In addition, since `python
3.8`, [the `multiprocessing` library start method changed from `'fork'` to`'spawn'`](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods).
This also affects the data-loaders (for any `torch` version) and they will
not run in parallel. Therefore, for Mac users I recommend using `python
3.6` or `3.7` and `torch <= 1.6` (with the corresponding, consistent
not run in parallel. Therefore, for Mac users I recommend using `python 3.7`
and `torch <= 1.6` (with the corresponding, consistent
version of `torchvision`, e.g. `0.7.0` for `torch 1.6`). I do not want to
force this versioning in the `setup.py` file since I expect that all these
issues are fixed in the future. Therefore, after installing
......
1.0.5
\ No newline at end of file
1.0.9
\ No newline at end of file
......@@ -16,3 +16,4 @@ them to address different problems
* `Save and Load Model and Artifacts <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/08_save_and_load_model_and_artifacts.ipynb>`__
* `Using Custom DataLoaders and Torchmetrics <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/09_Custom_DataLoader_Imbalanced_dataset.ipynb>`__
* `The Transformer Family <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/10_The_Transformer_Family.ipynb>`__
* `Extracting Embeddings <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/11_Extracting_Embeddings.ipynb>`__
......@@ -23,6 +23,7 @@ Documentation
Dataloaders <dataloaders>
Callbacks <callbacks>
The Trainer <trainer>
Tab2Vec <tab2vec>
Examples <examples>
......
......@@ -17,8 +17,8 @@ on their own and can be imported as:
from pytorch_widedeep.losses import FocalLoss
.. note:: Losses in this module expect the predictions and ground truth to have the
same dimensions for regression and binary classification problems (i.e.
:math:`N_{samples}, 1)`. In the case of multiclass classification problems
same dimensions for regression and binary classification problems
:math:`(N_{samples}, 1)`. In the case of multiclass classification problems
the ground truth is expected to be a 1D tensor with the corresponding
classes. See Examples below
......
......@@ -2,10 +2,9 @@ Metrics
=======
.. note:: Metrics in this module expect the predictions and ground truth to have the
same dimensions for regression and binary classification problems (i.e.
:math:`N_{samples}, 1)`. In the case of multiclass classification problems the
ground truth is expected to be a 1D tensor with the corresponding classes.
See Examples below
same dimensions for regression and binary classification problems: :math:`(N_{samples}, 1)`.
In the case of multiclass classification problems the ground truth is expected to be
a 1D tensor with the corresponding classes. See Examples below
We have added the possibility of using the metrics available at the
`torchmetrics <https://torchmetrics.readthedocs.io/en/latest/>`_ library.
......
......@@ -5,9 +5,10 @@ This module contains the four main components that will comprise a Wide and
Deep model, and the ``WideDeep`` "constructor" class. These four components
are: ``wide``, ``deeptabular``, ``deeptext``, ``deepimage``.
.. note:: ``TabMlp``, ``TabResnet``, ``TabNet``, ``TabTransformer`` and ``SAINT`` can
all be used as the ``deeptabular`` component of the model and simply
represent different alternatives
.. note:: ``TabMlp``, ``TabResnet``, ``TabNet``, ``TabTransformer``, ``SAINT``,
``FTTransformer``, ``TabPerceiver`` and ``TabFastFormer`` can all be used
as the ``deeptabular`` component of the model and simply represent different
alternatives
.. autoclass:: pytorch_widedeep.models.wide.Wide
:exclude-members: forward
......@@ -33,6 +34,18 @@ are: ``wide``, ``deeptabular``, ``deeptext``, ``deepimage``.
:exclude-members: forward
:members:
.. autoclass:: pytorch_widedeep.models.transformers.ft_transformer.FTTransformer
:exclude-members: forward
:members:
.. autoclass:: pytorch_widedeep.models.transformers.tab_perceiver.TabPerceiver
:exclude-members: forward
:members:
.. autoclass:: pytorch_widedeep.models.transformers.tab_fastformer.TabFastFormer
:exclude-members: forward
:members:
.. autoclass:: pytorch_widedeep.models.deep_text.DeepText
:exclude-members: forward
:members:
......
Tab2Vec
=======
.. autoclass:: pytorch_widedeep.tab2vec.Tab2Vec
:members:
:undoc-members:
......@@ -511,15 +511,17 @@
" )\n",
" (deeptabular): Sequential(\n",
" (0): TabMlp(\n",
" (embed_layers): ModuleDict(\n",
" (emb_layer_education): Embedding(17, 16, padding_idx=0)\n",
" (emb_layer_native_country): Embedding(43, 16, padding_idx=0)\n",
" (emb_layer_occupation): Embedding(16, 16, padding_idx=0)\n",
" (emb_layer_relationship): Embedding(7, 8, padding_idx=0)\n",
" (emb_layer_workclass): Embedding(10, 16, padding_idx=0)\n",
" (cat_embed_and_cont): CatEmbeddingsAndCont(\n",
" (embed_layers): ModuleDict(\n",
" (emb_layer_education): Embedding(17, 16, padding_idx=0)\n",
" (emb_layer_native_country): Embedding(43, 16, padding_idx=0)\n",
" (emb_layer_occupation): Embedding(16, 16, padding_idx=0)\n",
" (emb_layer_relationship): Embedding(7, 8, padding_idx=0)\n",
" (emb_layer_workclass): Embedding(10, 16, padding_idx=0)\n",
" )\n",
" (embedding_dropout): Dropout(p=0.1, inplace=False)\n",
" (cont_norm): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" (embedding_dropout): Dropout(p=0.1, inplace=False)\n",
" (cont_norm): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (tab_mlp): MLP(\n",
" (mlp): Sequential(\n",
" (dense_layer_0): Sequential(\n",
......@@ -589,16 +591,16 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 611/611 [00:07<00:00, 83.83it/s, loss=0.425, metrics={'acc': 0.801, 'prec': 0.6074}] \n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 129.25it/s, loss=0.362, metrics={'acc': 0.8341, 'prec': 0.6947}]\n",
"epoch 2: 100%|██████████| 611/611 [00:07<00:00, 79.96it/s, loss=0.373, metrics={'acc': 0.8245, 'prec': 0.6621}]\n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 140.03it/s, loss=0.356, metrics={'acc': 0.8353, 'prec': 0.6742}]\n",
"epoch 3: 100%|██████████| 611/611 [00:07<00:00, 79.08it/s, loss=0.364, metrics={'acc': 0.8288, 'prec': 0.6729}]\n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 150.18it/s, loss=0.35, metrics={'acc': 0.838, 'prec': 0.6875}] \n",
"epoch 4: 100%|██████████| 611/611 [00:07<00:00, 82.86it/s, loss=0.358, metrics={'acc': 0.8319, 'prec': 0.6814}]\n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 147.48it/s, loss=0.345, metrics={'acc': 0.8394, 'prec': 0.6949}]\n",
"epoch 5: 100%|██████████| 611/611 [00:07<00:00, 78.20it/s, loss=0.354, metrics={'acc': 0.8337, 'prec': 0.6872}]\n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 150.62it/s, loss=0.344, metrics={'acc': 0.8426, 'prec': 0.7066}]\n"
"epoch 1: 100%|██████████| 611/611 [00:07<00:00, 78.90it/s, loss=0.477, metrics={'acc': 0.7763, 'prec': 0.5377}]\n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 124.13it/s, loss=0.387, metrics={'acc': 0.8148, 'prec': 0.6034}]\n",
"epoch 2: 100%|██████████| 611/611 [00:06<00:00, 88.99it/s, loss=0.383, metrics={'acc': 0.8205, 'prec': 0.6525}]\n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 116.64it/s, loss=0.364, metrics={'acc': 0.832, 'prec': 0.6629}] \n",
"epoch 3: 100%|██████████| 611/611 [00:09<00:00, 67.26it/s, loss=0.372, metrics={'acc': 0.8264, 'prec': 0.6683}]\n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 145.81it/s, loss=0.355, metrics={'acc': 0.8343, 'prec': 0.669}] \n",
"epoch 4: 100%|██████████| 611/611 [00:07<00:00, 86.61it/s, loss=0.36, metrics={'acc': 0.8306, 'prec': 0.6784}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 168.01it/s, loss=0.354, metrics={'acc': 0.8323, 'prec': 0.6549}]\n",
"epoch 5: 100%|██████████| 611/611 [00:06<00:00, 87.79it/s, loss=0.357, metrics={'acc': 0.8321, 'prec': 0.6841}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 171.45it/s, loss=0.352, metrics={'acc': 0.8341, 'prec': 0.6671}]\n"
]
}
],
......@@ -647,16 +649,16 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 611/611 [00:08<00:00, 67.98it/s, loss=0.385, metrics={'acc': 0.8182, 'prec': 0.6465}]\n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 146.93it/s, loss=0.359, metrics={'acc': 0.8361, 'prec': 0.6862}]\n",
"epoch 2: 100%|██████████| 611/611 [00:09<00:00, 67.33it/s, loss=0.363, metrics={'acc': 0.8296, 'prec': 0.6756}]\n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 148.88it/s, loss=0.354, metrics={'acc': 0.8353, 'prec': 0.7058}]\n",
"epoch 3: 100%|██████████| 611/611 [00:09<00:00, 67.31it/s, loss=0.357, metrics={'acc': 0.8312, 'prec': 0.6822}]\n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 142.79it/s, loss=0.351, metrics={'acc': 0.8387, 'prec': 0.6813}]\n",
"epoch 4: 100%|██████████| 611/611 [00:09<00:00, 62.17it/s, loss=0.353, metrics={'acc': 0.8347, 'prec': 0.6897}]\n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 124.27it/s, loss=0.348, metrics={'acc': 0.8404, 'prec': 0.692}] \n",
"epoch 5: 100%|██████████| 611/611 [00:17<00:00, 35.55it/s, loss=0.35, metrics={'acc': 0.8347, 'prec': 0.6893}] \n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 116.47it/s, loss=0.345, metrics={'acc': 0.8427, 'prec': 0.6936}]\n"
"epoch 1: 100%|██████████| 611/611 [00:08<00:00, 70.16it/s, loss=0.439, metrics={'acc': 0.7865, 'prec': 0.5561}]\n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 152.45it/s, loss=0.361, metrics={'acc': 0.8349, 'prec': 0.6803}]\n",
"epoch 2: 100%|██████████| 611/611 [00:08<00:00, 70.18it/s, loss=0.373, metrics={'acc': 0.8236, 'prec': 0.6609}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 157.46it/s, loss=0.354, metrics={'acc': 0.839, 'prec': 0.704}] \n",
"epoch 3: 100%|██████████| 611/611 [00:08<00:00, 70.71it/s, loss=0.363, metrics={'acc': 0.8294, 'prec': 0.6717}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 155.59it/s, loss=0.353, metrics={'acc': 0.8381, 'prec': 0.6954}]\n",
"epoch 4: 100%|██████████| 611/611 [00:08<00:00, 69.53it/s, loss=0.358, metrics={'acc': 0.8323, 'prec': 0.683}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 157.64it/s, loss=0.353, metrics={'acc': 0.8339, 'prec': 0.658}] \n",
"epoch 5: 100%|██████████| 611/611 [00:08<00:00, 69.47it/s, loss=0.354, metrics={'acc': 0.8338, 'prec': 0.6851}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 157.29it/s, loss=0.353, metrics={'acc': 0.8375, 'prec': 0.6764}]\n"
]
}
],
......@@ -668,7 +670,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Using the `FT-Transformer` as the `deeptabular` component"
"Using the `TabTransformer` as the `deeptabular` component"
]
},
{
......@@ -703,7 +705,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
......@@ -717,7 +719,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
......@@ -732,7 +734,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
......@@ -741,17 +743,17 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 77/77 [00:21<00:00, 3.57it/s, loss=0.446, metrics={'acc': 0.7905, 'prec': 0.5849}]\n",
"valid: 100%|██████████| 20/20 [00:01<00:00, 13.74it/s, loss=0.374, metrics={'acc': 0.8227, 'prec': 0.6443}]\n",
"epoch 2: 100%|██████████| 77/77 [00:21<00:00, 3.63it/s, loss=0.377, metrics={'acc': 0.8231, 'prec': 0.6586}]\n",
"valid: 100%|██████████| 20/20 [00:01<00:00, 14.04it/s, loss=0.372, metrics={'acc': 0.8216, 'prec': 0.6112}]\n"
"epoch 1: 100%|██████████| 77/77 [00:20<00:00, 3.79it/s, loss=0.667, metrics={'acc': 0.6787, 'prec': 0.3306}]\n",
"valid: 100%|██████████| 20/20 [00:01<00:00, 14.78it/s, loss=0.409, metrics={'acc': 0.8033, 'prec': 0.583}] \n",
"epoch 2: 100%|██████████| 77/77 [00:21<00:00, 3.52it/s, loss=0.403, metrics={'acc': 0.8136, 'prec': 0.6326}]\n",
"valid: 100%|██████████| 20/20 [00:01<00:00, 14.32it/s, loss=0.374, metrics={'acc': 0.8224, 'prec': 0.6504}]\n"
]
}
],
......@@ -768,7 +770,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
......@@ -777,7 +779,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
......@@ -786,23 +788,23 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 611/611 [00:03<00:00, 159.50it/s, loss=0.46, metrics={'acc': 0.7836, 'prec': 0.573}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 168.08it/s, loss=0.422, metrics={'acc': 0.805, 'prec': 0.6403}] \n",
"epoch 2: 100%|██████████| 611/611 [00:03<00:00, 158.35it/s, loss=0.405, metrics={'acc': 0.8131, 'prec': 0.6643}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 214.83it/s, loss=0.394, metrics={'acc': 0.8168, 'prec': 0.6741}]\n",
"epoch 3: 100%|██████████| 611/611 [00:03<00:00, 162.96it/s, loss=0.385, metrics={'acc': 0.8201, 'prec': 0.6837}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 210.59it/s, loss=0.381, metrics={'acc': 0.8228, 'prec': 0.6799}]\n",
"epoch 4: 100%|██████████| 611/611 [00:03<00:00, 171.06it/s, loss=0.375, metrics={'acc': 0.8256, 'prec': 0.6895}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 210.97it/s, loss=0.374, metrics={'acc': 0.8259, 'prec': 0.6798}]\n",
"epoch 5: 100%|██████████| 611/611 [00:03<00:00, 157.49it/s, loss=0.369, metrics={'acc': 0.828, 'prec': 0.692}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 197.19it/s, loss=0.37, metrics={'acc': 0.8275, 'prec': 0.6856}] \n"
"epoch 1: 100%|██████████| 611/611 [00:03<00:00, 166.24it/s, loss=0.799, metrics={'acc': 0.5446, 'prec': 0.229}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 261.71it/s, loss=0.56, metrics={'acc': 0.7394, 'prec': 0.4044}] \n",
"epoch 2: 100%|██████████| 611/611 [00:03<00:00, 167.73it/s, loss=0.491, metrics={'acc': 0.7699, 'prec': 0.5417}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 259.04it/s, loss=0.446, metrics={'acc': 0.7939, 'prec': 0.6517}]\n",
"epoch 3: 100%|██████████| 611/611 [00:03<00:00, 168.46it/s, loss=0.425, metrics={'acc': 0.809, 'prec': 0.6809}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 257.24it/s, loss=0.406, metrics={'acc': 0.8201, 'prec': 0.7044}]\n",
"epoch 4: 100%|██████████| 611/611 [00:03<00:00, 169.35it/s, loss=0.397, metrics={'acc': 0.8176, 'prec': 0.6909}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 174.28it/s, loss=0.388, metrics={'acc': 0.8248, 'prec': 0.7048}]\n",
"epoch 5: 100%|██████████| 611/611 [00:04<00:00, 142.70it/s, loss=0.382, metrics={'acc': 0.8239, 'prec': 0.6947}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 207.43it/s, loss=0.378, metrics={'acc': 0.8288, 'prec': 0.6999}]\n"
]
}
],
......
......@@ -1069,7 +1069,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
" 4%|▍ | 40/1001 [00:00<00:02, 392.70it/s]"
" 4%|▍ | 41/1001 [00:00<00:02, 402.45it/s]"
]
},
{
......@@ -1083,7 +1083,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1001/1001 [00:02<00:00, 382.63it/s]\n"
"100%|██████████| 1001/1001 [00:02<00:00, 411.97it/s]\n"
]
},
{
......@@ -1173,15 +1173,15 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 25/25 [02:05<00:00, 5.03s/it, loss=107]\n",
"valid: 100%|██████████| 7/7 [00:15<00:00, 2.22s/it, loss=129]\n"
"epoch 1: 100%|██████████| 25/25 [02:13<00:00, 5.35s/it, loss=115]\n",
"valid: 100%|██████████| 7/7 [00:15<00:00, 2.20s/it, loss=108] \n"
]
}
],
......@@ -1201,7 +1201,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
......@@ -1238,7 +1238,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
......@@ -1260,7 +1260,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 15,
"metadata": {
"scrolled": false
},
......@@ -1273,19 +1273,21 @@
" (wide_linear): Embedding(357, 1, padding_idx=0)\n",
" )\n",
" (deeptabular): TabMlp(\n",
" (embed_layers): ModuleDict(\n",
" (emb_layer_accommodates_catg): Embedding(4, 16, padding_idx=0)\n",
" (emb_layer_bathrooms_catg): Embedding(4, 16, padding_idx=0)\n",
" (emb_layer_bedrooms_catg): Embedding(5, 16, padding_idx=0)\n",
" (emb_layer_beds_catg): Embedding(5, 16, padding_idx=0)\n",
" (emb_layer_cancellation_policy): Embedding(6, 16, padding_idx=0)\n",
" (emb_layer_guests_included_catg): Embedding(4, 16, padding_idx=0)\n",
" (emb_layer_host_listings_count_catg): Embedding(5, 16, padding_idx=0)\n",
" (emb_layer_minimum_nights_catg): Embedding(4, 16, padding_idx=0)\n",
" (emb_layer_neighbourhood_cleansed): Embedding(33, 64, padding_idx=0)\n",
" (cat_embed_and_cont): CatEmbeddingsAndCont(\n",
" (embed_layers): ModuleDict(\n",
" (emb_layer_accommodates_catg): Embedding(4, 16, padding_idx=0)\n",
" (emb_layer_bathrooms_catg): Embedding(4, 16, padding_idx=0)\n",
" (emb_layer_bedrooms_catg): Embedding(5, 16, padding_idx=0)\n",
" (emb_layer_beds_catg): Embedding(5, 16, padding_idx=0)\n",
" (emb_layer_cancellation_policy): Embedding(6, 16, padding_idx=0)\n",
" (emb_layer_guests_included_catg): Embedding(4, 16, padding_idx=0)\n",
" (emb_layer_host_listings_count_catg): Embedding(5, 16, padding_idx=0)\n",
" (emb_layer_minimum_nights_catg): Embedding(4, 16, padding_idx=0)\n",
" (emb_layer_neighbourhood_cleansed): Embedding(33, 64, padding_idx=0)\n",
" )\n",
" (embedding_dropout): Dropout(p=0.1, inplace=False)\n",
" (cont_norm): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" (embedding_dropout): Dropout(p=0.1, inplace=False)\n",
" (cont_norm): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (tab_mlp): MLP(\n",
" (mlp): Sequential(\n",
" (dense_layer_0): Sequential(\n",
......@@ -1423,7 +1425,7 @@
")"
]
},
"execution_count": 16,
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
......@@ -1443,7 +1445,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
......@@ -1457,7 +1459,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
......@@ -1470,7 +1472,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
......@@ -1483,7 +1485,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
......@@ -1509,7 +1511,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 20,
"metadata": {},
"outputs": [
{
......@@ -1535,15 +1537,15 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 25/25 [02:09<00:00, 5.18s/it, loss=106]\n",
"valid: 100%|██████████| 7/7 [00:16<00:00, 2.31s/it, loss=95.5]"
"epoch 1: 100%|██████████| 25/25 [02:08<00:00, 5.12s/it, loss=108]\n",
"valid: 100%|██████████| 7/7 [00:15<00:00, 2.19s/it, loss=92.9]"
]
},
{
......@@ -1575,7 +1577,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 22,
"metadata": {},
"outputs": [
{
......@@ -1603,7 +1605,7 @@
" 'lr_deephead_0': [0.001, 0.001]}"
]
},
"execution_count": 23,
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
......
......@@ -306,10 +306,10 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 153/153 [00:04<00:00, 36.06it/s, loss=0.529, metrics={'acc': 0.7448}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 68.26it/s, loss=0.389, metrics={'acc': 0.8176}]\n",
"epoch 2: 100%|██████████| 153/153 [00:03<00:00, 39.18it/s, loss=0.401, metrics={'acc': 0.8122}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 116.68it/s, loss=0.368, metrics={'acc': 0.8272}]\n"
"epoch 1: 100%|██████████| 153/153 [00:03<00:00, 43.48it/s, loss=0.565, metrics={'acc': 0.7249}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 62.94it/s, loss=0.387, metrics={'acc': 0.8207}]\n",
"epoch 2: 100%|██████████| 153/153 [00:04<00:00, 30.88it/s, loss=0.389, metrics={'acc': 0.8195}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 92.51it/s, loss=0.372, metrics={'acc': 0.8261}] \n"
]
}
],
......@@ -387,7 +387,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 3%|▎ | 5/191 [00:00<00:03, 47.72it/s, loss=0.794, metrics={'acc': 0.5348}]"
"epoch 1: 4%|▎ | 7/191 [00:00<00:02, 64.63it/s, loss=1.17, metrics={'acc': 0.418}] "
]
},
{
......@@ -401,9 +401,9 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 191/191 [00:02<00:00, 67.54it/s, loss=0.504, metrics={'acc': 0.7554}]\n",
"epoch 2: 100%|██████████| 191/191 [00:02<00:00, 70.24it/s, loss=0.386, metrics={'acc': 0.79}] \n",
"epoch 1: 4%|▎ | 7/191 [00:00<00:03, 60.96it/s, loss=0.39, metrics={'acc': 0.7909}] "
"epoch 1: 100%|██████████| 191/191 [00:02<00:00, 68.57it/s, loss=0.584, metrics={'acc': 0.7155}]\n",
"epoch 2: 100%|██████████| 191/191 [00:03<00:00, 62.76it/s, loss=0.39, metrics={'acc': 0.7697}] \n",
"epoch 1: 3%|▎ | 6/191 [00:00<00:03, 56.94it/s, loss=0.403, metrics={'acc': 0.7705}]"
]
},
{
......@@ -417,9 +417,9 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 191/191 [00:03<00:00, 62.41it/s, loss=0.369, metrics={'acc': 0.8028}]\n",
"epoch 2: 100%|██████████| 191/191 [00:03<00:00, 59.52it/s, loss=0.352, metrics={'acc': 0.8107}]\n",
"epoch 1: 3%|▎ | 5/191 [00:00<00:04, 43.10it/s, loss=0.363, metrics={'acc': 0.8418}]"
"epoch 1: 100%|██████████| 191/191 [00:03<00:00, 58.09it/s, loss=0.369, metrics={'acc': 0.7887}]\n",
"epoch 2: 100%|██████████| 191/191 [00:03<00:00, 50.37it/s, loss=0.353, metrics={'acc': 0.8003}]\n",
"epoch 1: 2%|▏ | 4/191 [00:00<00:05, 36.39it/s, loss=0.399, metrics={'acc': 0.8298}]"
]
},
{
......@@ -433,8 +433,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 191/191 [00:04<00:00, 39.91it/s, loss=0.352, metrics={'acc': 0.8378}]\n",
"epoch 2: 100%|██████████| 191/191 [00:04<00:00, 43.80it/s, loss=0.344, metrics={'acc': 0.8419}]\n"
"epoch 1: 100%|██████████| 191/191 [00:05<00:00, 33.73it/s, loss=0.355, metrics={'acc': 0.8377}]\n",
"epoch 2: 100%|██████████| 191/191 [00:05<00:00, 36.01it/s, loss=0.347, metrics={'acc': 0.8396}]\n"
]
}
],
......@@ -488,7 +488,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 3%|▎ | 6/172 [00:00<00:02, 58.53it/s, loss=0.988, metrics={'acc': 0.4435}]"
"epoch 1: 3%|▎ | 6/172 [00:00<00:03, 52.52it/s, loss=0.628, metrics={'acc': 0.6977}]"
]
},
{
......@@ -502,9 +502,9 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 172/172 [00:02<00:00, 73.06it/s, loss=0.54, metrics={'acc': 0.7276}] \n",
"epoch 2: 100%|██████████| 172/172 [00:02<00:00, 75.57it/s, loss=0.389, metrics={'acc': 0.7736}]\n",
"epoch 1: 3%|▎ | 6/172 [00:00<00:02, 55.48it/s, loss=0.582, metrics={'acc': 0.7728}]"
"epoch 1: 100%|██████████| 172/172 [00:02<00:00, 68.16it/s, loss=0.475, metrics={'acc': 0.7799}]\n",
"epoch 2: 100%|██████████| 172/172 [00:02<00:00, 68.59it/s, loss=0.387, metrics={'acc': 0.8021}]\n",
"epoch 1: 2%|▏ | 4/172 [00:00<00:04, 34.15it/s, loss=0.62, metrics={'acc': 0.8009}] "
]
},
{
......@@ -518,9 +518,9 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 172/172 [00:02<00:00, 58.52it/s, loss=0.392, metrics={'acc': 0.7881}]\n",
"epoch 2: 100%|██████████| 172/172 [00:02<00:00, 58.26it/s, loss=0.353, metrics={'acc': 0.8}] \n",
"epoch 1: 2%|▏ | 4/172 [00:00<00:04, 38.87it/s, loss=0.337, metrics={'acc': 0.8589}]"
"epoch 1: 100%|██████████| 172/172 [00:04<00:00, 38.78it/s, loss=0.392, metrics={'acc': 0.8075}]\n",
"epoch 2: 100%|██████████| 172/172 [00:03<00:00, 49.88it/s, loss=0.354, metrics={'acc': 0.8145}]\n",
"epoch 1: 2%|▏ | 4/172 [00:00<00:04, 38.71it/s, loss=0.412, metrics={'acc': 0.8326}]"
]
},
{
......@@ -534,10 +534,10 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 172/172 [00:04<00:00, 42.81it/s, loss=0.355, metrics={'acc': 0.8366}]\n",
"valid: 100%|██████████| 20/20 [00:00<00:00, 89.21it/s, loss=0.35, metrics={'acc': 0.8356}] \n",
"epoch 2: 100%|██████████| 172/172 [00:04<00:00, 41.35it/s, loss=0.346, metrics={'acc': 0.8381}]\n",
"valid: 100%|██████████| 20/20 [00:00<00:00, 87.63it/s, loss=0.349, metrics={'acc': 0.8373}]\n"
"epoch 1: 100%|██████████| 172/172 [00:04<00:00, 38.84it/s, loss=0.354, metrics={'acc': 0.8372}]\n",
"valid: 100%|██████████| 20/20 [00:00<00:00, 81.14it/s, loss=0.348, metrics={'acc': 0.8399}]\n",
"epoch 2: 100%|██████████| 172/172 [00:04<00:00, 37.92it/s, loss=0.345, metrics={'acc': 0.8397}]\n",
"valid: 100%|██████████| 20/20 [00:00<00:00, 46.33it/s, loss=0.347, metrics={'acc': 0.8409}]\n"
]
}
],
......@@ -629,15 +629,17 @@
" )\n",
" (deeptabular): Sequential(\n",
" (0): TabResnet(\n",
" (embed_layers): ModuleDict(\n",
" (emb_layer_education): Embedding(17, 16, padding_idx=0)\n",
" (emb_layer_native_country): Embedding(43, 16, padding_idx=0)\n",
" (emb_layer_occupation): Embedding(16, 16, padding_idx=0)\n",
" (emb_layer_relationship): Embedding(7, 8, padding_idx=0)\n",
" (emb_layer_workclass): Embedding(10, 16, padding_idx=0)\n",
" (cat_embed_and_cont): CatEmbeddingsAndCont(\n",
" (embed_layers): ModuleDict(\n",
" (emb_layer_education): Embedding(17, 16, padding_idx=0)\n",
" (emb_layer_native_country): Embedding(43, 16, padding_idx=0)\n",
" (emb_layer_occupation): Embedding(16, 16, padding_idx=0)\n",
" (emb_layer_relationship): Embedding(7, 8, padding_idx=0)\n",
" (emb_layer_workclass): Embedding(10, 16, padding_idx=0)\n",
" )\n",
" (embedding_dropout): Dropout(p=0.1, inplace=False)\n",
" (cont_norm): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" (embedding_dropout): Dropout(p=0.1, inplace=False)\n",
" (cont_norm): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (tab_resnet_blks): DenseResnet(\n",
" (dense_resnet): Sequential(\n",
" (lin1): Linear(in_features=74, out_features=128, bias=True)\n",
......@@ -708,10 +710,10 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 172/172 [00:05<00:00, 29.00it/s, loss=0.453, metrics={'acc': 0.7787}]\n",
"valid: 100%|██████████| 20/20 [00:00<00:00, 90.03it/s, loss=0.363, metrics={'acc': 0.8282}]\n",
"epoch 2: 100%|██████████| 172/172 [00:05<00:00, 32.24it/s, loss=0.371, metrics={'acc': 0.8262}]\n",
"valid: 100%|██████████| 20/20 [00:00<00:00, 88.22it/s, loss=0.351, metrics={'acc': 0.8356}]\n"
"epoch 1: 100%|██████████| 172/172 [00:07<00:00, 23.55it/s, loss=0.411, metrics={'acc': 0.8033}]\n",
"valid: 100%|██████████| 20/20 [00:00<00:00, 71.22it/s, loss=0.364, metrics={'acc': 0.8287}]\n",
"epoch 2: 100%|██████████| 172/172 [00:06<00:00, 25.12it/s, loss=0.369, metrics={'acc': 0.827}] \n",
"valid: 100%|██████████| 20/20 [00:00<00:00, 78.16it/s, loss=0.355, metrics={'acc': 0.8342}]\n"
]
}
],
......@@ -788,7 +790,7 @@
"outputs": [],
"source": [
"tab_deep_layers = list(\n",
" list(list(list(model_3.deeptabular.children())[0].children())[3].children())[\n",
" list(list(list(model_3.deeptabular.children())[0].children())[1].children())[\n",
" 0\n",
" ].children()\n",
")[::-1][:2]"
......@@ -865,14 +867,14 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 5%|▍ | 8/172 [00:00<00:02, 68.51it/s, loss=0.767, metrics={'acc': 0.5605}]"
"epoch 1: 5%|▍ | 8/172 [00:00<00:02, 71.14it/s, loss=0.719, metrics={'acc': 0.6278}]"
]
},
{
......@@ -886,9 +888,9 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 172/172 [00:02<00:00, 75.72it/s, loss=0.489, metrics={'acc': 0.7523}]\n",
"epoch 2: 100%|██████████| 172/172 [00:02<00:00, 64.95it/s, loss=0.383, metrics={'acc': 0.7876}]\n",
"epoch 1: 2%|▏ | 3/172 [00:00<00:07, 22.26it/s, loss=0.402, metrics={'acc': 0.788}] "
"epoch 1: 100%|██████████| 172/172 [00:03<00:00, 56.88it/s, loss=0.496, metrics={'acc': 0.7596}]\n",
"epoch 2: 100%|██████████| 172/172 [00:02<00:00, 68.06it/s, loss=0.386, metrics={'acc': 0.7917}]\n",
"epoch 1: 2%|▏ | 4/172 [00:00<00:04, 38.40it/s, loss=0.435, metrics={'acc': 0.7915}]"
]
},
{
......@@ -902,8 +904,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 172/172 [00:08<00:00, 20.71it/s, loss=0.385, metrics={'acc': 0.7986}]\n",
"epoch 1: 0%| | 0/172 [00:00<?, ?it/s]"
"epoch 1: 100%|██████████| 172/172 [00:04<00:00, 35.74it/s, loss=0.388, metrics={'acc': 0.7992}]\n",
"epoch 1: 2%|▏ | 4/172 [00:00<00:04, 36.11it/s, loss=0.383, metrics={'acc': 0.7994}]"
]
},
{
......@@ -917,8 +919,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 172/172 [00:13<00:00, 13.08it/s, loss=0.369, metrics={'acc': 0.8058}]\n",
"epoch 1: 2%|▏ | 3/172 [00:00<00:07, 21.56it/s, loss=0.355, metrics={'acc': 0.806}]"
"epoch 1: 100%|██████████| 172/172 [00:04<00:00, 35.56it/s, loss=0.37, metrics={'acc': 0.8054}] \n",
"epoch 1: 2%|▏ | 4/172 [00:00<00:04, 34.93it/s, loss=0.389, metrics={'acc': 0.8055}]"
]
},
{
......@@ -932,8 +934,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 172/172 [00:07<00:00, 22.34it/s, loss=0.361, metrics={'acc': 0.8108}]\n",
"epoch 1: 1%| | 2/172 [00:00<00:09, 17.34it/s, loss=0.334, metrics={'acc': 0.8581}]"
"epoch 1: 100%|██████████| 172/172 [00:05<00:00, 33.97it/s, loss=0.36, metrics={'acc': 0.8104}] \n",
"epoch 1: 2%|▏ | 3/172 [00:00<00:06, 27.49it/s, loss=0.385, metrics={'acc': 0.8359}]"
]
},
{
......@@ -947,9 +949,10 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 172/172 [00:16<00:00, 10.31it/s, loss=0.353, metrics={'acc': 0.8366}]\n",
"valid: 100%|██████████| 20/20 [00:00<00:00, 40.53it/s, loss=0.345, metrics={'acc': 0.8405}]\n",
"epoch 2: 89%|████████▉ | 153/172 [00:06<00:00, 28.91it/s, loss=0.342, metrics={'acc': 0.8399}]"
"epoch 1: 100%|██████████| 172/172 [00:06<00:00, 27.49it/s, loss=0.351, metrics={'acc': 0.838}] \n",
"valid: 100%|██████████| 20/20 [00:00<00:00, 85.72it/s, loss=0.347, metrics={'acc': 0.8395}]\n",
"epoch 2: 100%|██████████| 172/172 [00:06<00:00, 27.62it/s, loss=0.342, metrics={'acc': 0.8394}]\n",
"valid: 100%|██████████| 20/20 [00:00<00:00, 79.38it/s, loss=0.349, metrics={'acc': 0.8358}]\n"
]
}
],
......@@ -978,7 +981,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
......@@ -991,7 +994,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
......@@ -1000,16 +1003,25 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 31,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 172/172 [00:03<00:00, 51.98it/s, loss=0.406, metrics={'acc': 0.8067}]\n",
"valid: 100%|██████████| 20/20 [00:00<00:00, 89.34it/s, loss=0.356, metrics={'acc': 0.8323}]\n"
]
}
],
"source": [
"trainer_5.fit(X_wide=X_wide, X_tab=X_tab, target=target, val_split=0.1, n_epochs=1, batch_size=256)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
......@@ -1018,7 +1030,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
......@@ -1031,9 +1043,20 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 34,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_5.load_state_dict(torch.load(\"models_dir/model_5.pt\"))"
]
......@@ -1047,7 +1070,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
......@@ -1056,9 +1079,46 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 36,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 3%|▎ | 6/172 [00:00<00:03, 51.73it/s, loss=0.371, metrics={'acc': 0.8247}]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training deeptabular for 2 epochs\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 172/172 [00:03<00:00, 48.47it/s, loss=0.367, metrics={'acc': 0.8287}]\n",
"epoch 2: 100%|██████████| 172/172 [00:03<00:00, 51.73it/s, loss=0.352, metrics={'acc': 0.833}] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fine-tuning finished\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"trainer_6.fit(\n",
" X_wide=X_wide, \n",
......@@ -1076,7 +1136,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
......
......@@ -1011,7 +1011,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
" 4%|▍ | 42/1001 [00:00<00:02, 419.42it/s]"
" 3%|▎ | 29/1001 [00:00<00:03, 288.35it/s]"
]
},
{
......@@ -1025,7 +1025,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1001/1001 [00:02<00:00, 408.24it/s]\n"
"100%|██████████| 1001/1001 [00:03<00:00, 307.23it/s]\n"
]
},
{
......@@ -1194,8 +1194,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 25/25 [02:03<00:00, 4.94s/it, loss=111]\n",
"valid: 100%|██████████| 7/7 [00:15<00:00, 2.17s/it, loss=94.8]\n"
"epoch 1: 100%|██████████| 25/25 [02:21<00:00, 5.65s/it, loss=120]\n",
"valid: 100%|██████████| 7/7 [00:15<00:00, 2.21s/it, loss=217]\n"
]
}
],
......@@ -1735,16 +1735,16 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 611/611 [00:06<00:00, 90.27it/s, loss=0.455, metrics={'acc': 0.7866}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 155.16it/s, loss=0.367, metrics={'acc': 0.8318}]\n",
"epoch 2: 100%|██████████| 611/611 [00:06<00:00, 89.76it/s, loss=0.379, metrics={'acc': 0.8233}]\n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 151.92it/s, loss=0.354, metrics={'acc': 0.8362}]\n",
"epoch 3: 100%|██████████| 611/611 [00:06<00:00, 89.63it/s, loss=0.365, metrics={'acc': 0.8296}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 154.56it/s, loss=0.35, metrics={'acc': 0.8383}] \n",
"epoch 4: 100%|██████████| 611/611 [00:06<00:00, 87.99it/s, loss=0.357, metrics={'acc': 0.8333}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 155.60it/s, loss=0.348, metrics={'acc': 0.8406}]\n",
"epoch 5: 100%|██████████| 611/611 [00:07<00:00, 87.18it/s, loss=0.354, metrics={'acc': 0.835}] \n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 149.32it/s, loss=0.346, metrics={'acc': 0.8402}]\n"
"epoch 1: 100%|██████████| 611/611 [00:07<00:00, 83.88it/s, loss=0.402, metrics={'acc': 0.8093}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 153.65it/s, loss=0.359, metrics={'acc': 0.8365}]\n",
"epoch 2: 100%|██████████| 611/611 [00:06<00:00, 89.07it/s, loss=0.365, metrics={'acc': 0.8276}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 155.09it/s, loss=0.354, metrics={'acc': 0.8367}]\n",
"epoch 3: 100%|██████████| 611/611 [00:06<00:00, 89.26it/s, loss=0.357, metrics={'acc': 0.8338}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 154.81it/s, loss=0.351, metrics={'acc': 0.8371}]\n",
"epoch 4: 100%|██████████| 611/611 [00:06<00:00, 89.63it/s, loss=0.354, metrics={'acc': 0.8335}]\n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 151.97it/s, loss=0.348, metrics={'acc': 0.8401}]\n",
"epoch 5: 100%|██████████| 611/611 [00:09<00:00, 66.64it/s, loss=0.351, metrics={'acc': 0.8349}]\n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 127.27it/s, loss=0.347, metrics={'acc': 0.8406}]\n"
]
}
],
......
......@@ -488,31 +488,33 @@
"text/plain": [
"WideDeep(\n",
" (wide): Wide(\n",
" (wide_linear): Embedding(773, 1, padding_idx=0)\n",
" (wide_linear): Embedding(779, 1, padding_idx=0)\n",
" )\n",
" (deeptabular): Sequential(\n",
" (0): TabMlp(\n",
" (embed_layers): ModuleDict(\n",
" (emb_layer_age): Embedding(75, 18, padding_idx=0)\n",
" (emb_layer_capital_gain): Embedding(121, 23, padding_idx=0)\n",
" (emb_layer_capital_loss): Embedding(98, 21, padding_idx=0)\n",
" (emb_layer_education): Embedding(17, 8, padding_idx=0)\n",
" (emb_layer_educational_num): Embedding(17, 8, padding_idx=0)\n",
" (emb_layer_gender): Embedding(3, 2, padding_idx=0)\n",
" (emb_layer_hours_per_week): Embedding(96, 20, padding_idx=0)\n",
" (emb_layer_marital_status): Embedding(8, 5, padding_idx=0)\n",
" (emb_layer_native_country): Embedding(43, 13, padding_idx=0)\n",
" (emb_layer_occupation): Embedding(16, 7, padding_idx=0)\n",
" (emb_layer_race): Embedding(6, 4, padding_idx=0)\n",
" (emb_layer_relationship): Embedding(7, 4, padding_idx=0)\n",
" (emb_layer_workclass): Embedding(10, 5, padding_idx=0)\n",
" (cat_embed_and_cont): CatEmbeddingsAndCont(\n",
" (embed_layers): ModuleDict(\n",
" (emb_layer_age): Embedding(75, 18, padding_idx=0)\n",
" (emb_layer_capital_gain): Embedding(119, 23, padding_idx=0)\n",
" (emb_layer_capital_loss): Embedding(98, 21, padding_idx=0)\n",
" (emb_layer_education): Embedding(17, 8, padding_idx=0)\n",
" (emb_layer_educational_num): Embedding(17, 8, padding_idx=0)\n",
" (emb_layer_gender): Embedding(3, 2, padding_idx=0)\n",
" (emb_layer_hours_per_week): Embedding(97, 21, padding_idx=0)\n",
" (emb_layer_marital_status): Embedding(8, 5, padding_idx=0)\n",
" (emb_layer_native_country): Embedding(43, 13, padding_idx=0)\n",
" (emb_layer_occupation): Embedding(16, 7, padding_idx=0)\n",
" (emb_layer_race): Embedding(6, 4, padding_idx=0)\n",
" (emb_layer_relationship): Embedding(7, 4, padding_idx=0)\n",
" (emb_layer_workclass): Embedding(10, 5, padding_idx=0)\n",
" )\n",
" (embedding_dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (embedding_dropout): Dropout(p=0.1, inplace=False)\n",
" (tab_mlp): MLP(\n",
" (mlp): Sequential(\n",
" (dense_layer_0): Sequential(\n",
" (0): Dropout(p=0.1, inplace=False)\n",
" (1): Linear(in_features=138, out_features=200, bias=True)\n",
" (1): Linear(in_features=139, out_features=200, bias=True)\n",
" (2): ReLU(inplace=True)\n",
" )\n",
" (dense_layer_1): Sequential(\n",
......@@ -546,9 +548,9 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 153/153 [00:04<00:00, 31.06it/s, loss=0.479, metrics={'acc': 0.7839}]\n",
"valid: 100%|██████████| 20/20 [00:00<00:00, 48.70it/s, loss=0.348, metrics={'acc': 0.8444}]\n",
"epoch 2: 3%|▎ | 4/153 [00:00<00:04, 32.06it/s, loss=0.365, metrics={'acc': 0.8385}]"
"epoch 1: 100%|██████████| 153/153 [00:04<00:00, 31.17it/s, loss=0.444, metrics={'acc': 0.7952}]\n",
"valid: 100%|██████████| 20/20 [00:00<00:00, 46.39it/s, loss=0.357, metrics={'acc': 0.838}] \n",
"epoch 2: 2%|▏ | 3/153 [00:00<00:05, 27.83it/s, loss=0.38, metrics={'acc': 0.8305}] "
]
},
{
......@@ -556,15 +558,15 @@
"output_type": "stream",
"text": [
"\n",
"Epoch 00001: val_loss improved from inf to 0.34800, saving model to tmp_dir/adult_tabmlp_model_1.p\n"
"Epoch 00001: val_loss improved from inf to 0.35746, saving model to tmp_dir/adult_tabmlp_model_1.p\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 2: 100%|██████████| 153/153 [00:04<00:00, 32.33it/s, loss=0.354, metrics={'acc': 0.8379}]\n",
"valid: 100%|██████████| 20/20 [00:00<00:00, 92.91it/s, loss=0.322, metrics={'acc': 0.8511}] "
"epoch 2: 100%|██████████| 153/153 [00:04<00:00, 31.57it/s, loss=0.352, metrics={'acc': 0.8395}]\n",
"valid: 100%|██████████| 20/20 [00:00<00:00, 88.95it/s, loss=0.325, metrics={'acc': 0.8557}]"
]
},
{
......@@ -572,7 +574,7 @@
"output_type": "stream",
"text": [
"\n",
"Epoch 00002: val_loss improved from 0.34800 to 0.32204, saving model to tmp_dir/adult_tabmlp_model_2.p\n",
"Epoch 00002: val_loss improved from 0.35746 to 0.32545, saving model to tmp_dir/adult_tabmlp_model_2.p\n",
"Model weights restored to best epoch: 2\n"
]
},
......@@ -844,120 +846,120 @@
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>32823</th>\n",
" <td>33</td>\n",
" <th>3428</th>\n",
" <td>40</td>\n",
" <td>Private</td>\n",
" <td>201988</td>\n",
" <td>Masters</td>\n",
" <td>14</td>\n",
" <td>144778</td>\n",
" <td>HS-grad</td>\n",
" <td>9</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Prof-specialty</td>\n",
" <td>Exec-managerial</td>\n",
" <td>Husband</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>45</td>\n",
" <td>50</td>\n",
" <td>United-States</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>40713</th>\n",
" <td>31</td>\n",
" <td>Private</td>\n",
" <td>231826</td>\n",
" <th>8234</th>\n",
" <td>38</td>\n",
" <td>?</td>\n",
" <td>54953</td>\n",
" <td>HS-grad</td>\n",
" <td>9</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Other-service</td>\n",
" <td>Husband</td>\n",
" <td>Divorced</td>\n",
" <td>?</td>\n",
" <td>Not-in-family</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>52</td>\n",
" <td>Mexico</td>\n",
" <td>30</td>\n",
" <td>United-States</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16020</th>\n",
" <td>38</td>\n",
" <td>Private</td>\n",
" <td>24126</td>\n",
" <td>Some-college</td>\n",
" <td>10</td>\n",
" <td>Divorced</td>\n",
" <td>Exec-managerial</td>\n",
" <td>Not-in-family</td>\n",
" <th>1129</th>\n",
" <td>28</td>\n",
" <td>Local-gov</td>\n",
" <td>134771</td>\n",
" <td>Bachelors</td>\n",
" <td>13</td>\n",
" <td>Never-married</td>\n",
" <td>Prof-specialty</td>\n",
" <td>Own-child</td>\n",
" <td>White</td>\n",
" <td>Female</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>55</td>\n",
" <td>United-States</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>32766</th>\n",
" <td>38</td>\n",
" <td>State-gov</td>\n",
" <td>312528</td>\n",
" <td>Bachelors</td>\n",
" <td>13</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Exec-managerial</td>\n",
" <td>Husband</td>\n",
" <th>11866</th>\n",
" <td>47</td>\n",
" <td>Private</td>\n",
" <td>189143</td>\n",
" <td>HS-grad</td>\n",
" <td>9</td>\n",
" <td>Divorced</td>\n",
" <td>Farming-fishing</td>\n",
" <td>Not-in-family</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>37</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9713</th>\n",
" <td>40</td>\n",
" <td>Self-emp-not-inc</td>\n",
" <td>121012</td>\n",
" <td>Prof-school</td>\n",
" <td>15</td>\n",
" <th>39544</th>\n",
" <td>27</td>\n",
" <td>Private</td>\n",
" <td>224849</td>\n",
" <td>HS-grad</td>\n",
" <td>9</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Prof-specialty</td>\n",
" <td>Craft-repair</td>\n",
" <td>Husband</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>1977</td>\n",
" <td>50</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" age workclass fnlwgt education educational_num \\\n",
"32823 33 Private 201988 Masters 14 \n",
"40713 31 Private 231826 HS-grad 9 \n",
"16020 38 Private 24126 Some-college 10 \n",
"32766 38 State-gov 312528 Bachelors 13 \n",
"9713 40 Self-emp-not-inc 121012 Prof-school 15 \n",
" age workclass fnlwgt education educational_num marital_status \\\n",
"3428 40 Private 144778 HS-grad 9 Married-civ-spouse \n",
"8234 38 ? 54953 HS-grad 9 Divorced \n",
"1129 28 Local-gov 134771 Bachelors 13 Never-married \n",
"11866 47 Private 189143 HS-grad 9 Divorced \n",
"39544 27 Private 224849 HS-grad 9 Married-civ-spouse \n",
"\n",
" marital_status occupation relationship race gender \\\n",
"32823 Married-civ-spouse Prof-specialty Husband White Male \n",
"40713 Married-civ-spouse Other-service Husband White Male \n",
"16020 Divorced Exec-managerial Not-in-family White Female \n",
"32766 Married-civ-spouse Exec-managerial Husband White Male \n",
"9713 Married-civ-spouse Prof-specialty Husband White Male \n",
" occupation relationship race gender capital_gain \\\n",
"3428 Exec-managerial Husband White Male 0 \n",
"8234 ? Not-in-family White Male 0 \n",
"1129 Prof-specialty Own-child White Female 0 \n",
"11866 Farming-fishing Not-in-family White Male 0 \n",
"39544 Craft-repair Husband White Male 0 \n",
"\n",
" capital_gain capital_loss hours_per_week native_country target \n",
"32823 0 0 45 United-States 0 \n",
"40713 0 0 52 Mexico 0 \n",
"16020 0 0 40 United-States 0 \n",
"32766 0 0 37 United-States 0 \n",
"9713 0 1977 50 United-States 1 "
" capital_loss hours_per_week native_country target \n",
"3428 0 50 United-States 1 \n",
"8234 0 30 United-States 0 \n",
"1129 0 55 United-States 0 \n",
"11866 0 40 United-States 0 \n",
"39544 0 40 United-States 0 "
]
},
"execution_count": 23,
......@@ -1055,7 +1057,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"predict: 100%|██████████| 20/20 [00:00<00:00, 86.04it/s]\n"
"predict: 100%|██████████| 20/20 [00:00<00:00, 91.44it/s]\n"
]
}
],
......@@ -1080,7 +1082,7 @@
{
"data": {
"text/plain": [
"0.8554759467758444"
"0.8511770726714432"
]
},
"execution_count": 32,
......
......@@ -752,7 +752,9 @@
"WideDeep(\n",
" (deeptabular): Sequential(\n",
" (0): TabMlp(\n",
" (cont_norm): BatchNorm1d(74, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (cat_embed_and_cont): CatEmbeddingsAndCont(\n",
" (cont_norm): BatchNorm1d(74, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" (tab_mlp): MLP(\n",
" (mlp): Sequential(\n",
" (dense_layer_0): Sequential(\n",
......@@ -861,12 +863,12 @@
"Consider using one of the following signatures instead:\n",
"\tnonzero(Tensor input, *, bool as_tuple) (Triggered internally at ../torch/csrc/utils/python_arg_parser.cpp:766.)\n",
" meaningless_indeces = torch.nonzero((tp | fn | fp) == 0).cpu()\n",
"epoch 1: 100%|██████████| 208/208 [00:02<00:00, 76.68it/s, loss=0.225, metrics={'Accuracy': [0.927, 0.8861], 'Precision': 0.9064, 'Recall': [0.927, 0.8861], 'F1': [0.9075, 0.9052]}] \n",
"valid: 100%|██████████| 292/292 [00:02<00:00, 107.00it/s, loss=0.104, metrics={'Accuracy': [0.9626, 0.875], 'Precision': 0.9618, 'Recall': [0.9626, 0.875], 'F1': [0.9804, 0.2886]}] \n",
"epoch 2: 100%|██████████| 208/208 [00:02<00:00, 84.52it/s, loss=0.152, metrics={'Accuracy': [0.9471, 0.9298], 'Precision': 0.9384, 'Recall': [0.9471, 0.9298], 'F1': [0.9384, 0.9383]}]\n",
"valid: 100%|██████████| 292/292 [00:02<00:00, 107.06it/s, loss=0.0915, metrics={'Accuracy': [0.968, 0.8906], 'Precision': 0.9673, 'Recall': [0.968, 0.8906], 'F1': [0.9833, 0.3258]}] \n",
"epoch 3: 100%|██████████| 208/208 [00:02<00:00, 81.70it/s, loss=0.134, metrics={'Accuracy': [0.949, 0.9407], 'Precision': 0.9448, 'Recall': [0.949, 0.9407], 'F1': [0.9446, 0.9451]}] \n",
"valid: 100%|██████████| 292/292 [00:02<00:00, 104.52it/s, loss=0.0847, metrics={'Accuracy': [0.9679, 0.8828], 'Precision': 0.9672, 'Recall': [0.9679, 0.8828], 'F1': [0.9832, 0.3229]}]\n"
"epoch 1: 100%|██████████| 208/208 [00:02<00:00, 73.29it/s, loss=0.232, metrics={'Accuracy': [0.9226, 0.89], 'Precision': 0.9061, 'Recall': [0.9226, 0.89], 'F1': [0.9065, 0.9057]}] \n",
"valid: 100%|██████████| 292/292 [00:02<00:00, 107.70it/s, loss=0.0598, metrics={'Accuracy': [0.981, 0.8672], 'Precision': 0.98, 'Recall': [0.981, 0.8672], 'F1': [0.9898, 0.4341]}] \n",
"epoch 2: 100%|██████████| 208/208 [00:02<00:00, 85.53it/s, loss=0.161, metrics={'Accuracy': [0.9417, 0.9262], 'Precision': 0.9339, 'Recall': [0.9417, 0.9262], 'F1': [0.9344, 0.9335]}]\n",
"valid: 100%|██████████| 292/292 [00:02<00:00, 102.48it/s, loss=0.123, metrics={'Accuracy': [0.9549, 0.9219], 'Precision': 0.9546, 'Recall': [0.9549, 0.9219], 'F1': [0.9766, 0.2647]}]\n",
"epoch 3: 100%|██████████| 208/208 [00:02<00:00, 80.85it/s, loss=0.139, metrics={'Accuracy': [0.948, 0.9413], 'Precision': 0.9446, 'Recall': [0.948, 0.9413], 'F1': [0.9451, 0.9442]}] \n",
"valid: 100%|██████████| 292/292 [00:02<00:00, 104.11it/s, loss=0.0722, metrics={'Accuracy': [0.9743, 0.8984], 'Precision': 0.9737, 'Recall': [0.9743, 0.8984], 'F1': [0.9865, 0.3766]}]\n"
]
},
{
......@@ -930,42 +932,42 @@
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.225108</td>\n",
" <td>[0.9270001649856567, 0.8861073851585388]</td>\n",
" <td>0.906365</td>\n",
" <td>[0.9270001649856567, 0.8861073851585388]</td>\n",
" <td>[0.9074797034263611, 0.9052220582962036]</td>\n",
" <td>0.103954</td>\n",
" <td>[0.962550163269043, 0.875]</td>\n",
" <td>0.961784</td>\n",
" <td>[0.962550163269043, 0.875]</td>\n",
" <td>[0.9803645014762878, 0.28863346576690674]</td>\n",
" <td>0.232111</td>\n",
" <td>[0.9225957989692688, 0.8899886012077332]</td>\n",
" <td>0.906075</td>\n",
" <td>[0.9225957989692688, 0.8899886012077332]</td>\n",
" <td>[0.9064720273017883, 0.9056749939918518]</td>\n",
" <td>0.059772</td>\n",
" <td>[0.9809635877609253, 0.8671875]</td>\n",
" <td>0.979966</td>\n",
" <td>[0.9809635877609253, 0.8671875]</td>\n",
" <td>[0.989802360534668, 0.4341084957122803]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.152386</td>\n",
" <td>[0.9471125602722168, 0.9297876358032227]</td>\n",
" <td>0.938380</td>\n",
" <td>[0.9471125602722168, 0.9297876358032227]</td>\n",
" <td>[0.9384452104568481, 0.9383144974708557]</td>\n",
" <td>0.091541</td>\n",
" <td>[0.9680188298225403, 0.890625]</td>\n",
" <td>0.967341</td>\n",
" <td>[0.9680188298225403, 0.890625]</td>\n",
" <td>[0.9832653999328613, 0.3257790505886078]</td>\n",
" <td>0.160898</td>\n",
" <td>[0.9417101144790649, 0.9261900186538696]</td>\n",
" <td>0.933944</td>\n",
" <td>[0.9417101144790649, 0.9261900186538696]</td>\n",
" <td>[0.9344058036804199, 0.9334757328033447]</td>\n",
" <td>0.122636</td>\n",
" <td>[0.954935610294342, 0.921875]</td>\n",
" <td>0.954648</td>\n",
" <td>[0.954935610294342, 0.921875]</td>\n",
" <td>[0.9766026139259338, 0.2647385895252228]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.134425</td>\n",
" <td>[0.9490371346473694, 0.9407152533531189]</td>\n",
" <td>0.944841</td>\n",
" <td>[0.9490371346473694, 0.9407152533531189]</td>\n",
" <td>[0.9446273446083069, 0.9450528621673584]</td>\n",
" <td>0.084717</td>\n",
" <td>[0.967949628829956, 0.8828125]</td>\n",
" <td>0.967204</td>\n",
" <td>[0.967949628829956, 0.8828125]</td>\n",
" <td>[0.9831950664520264, 0.3229461908340454]</td>\n",
" <td>0.138879</td>\n",
" <td>[0.9479646682739258, 0.9413018226623535]</td>\n",
" <td>0.944648</td>\n",
" <td>[0.9479646682739258, 0.9413018226623535]</td>\n",
" <td>[0.945061206817627, 0.9442285299301147]</td>\n",
" <td>0.072151</td>\n",
" <td>[0.9743181467056274, 0.8984375]</td>\n",
" <td>0.973653</td>\n",
" <td>[0.9743181467056274, 0.8984375]</td>\n",
" <td>[0.9865424036979675, 0.3766234219074249]</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
......@@ -973,29 +975,29 @@
],
"text/plain": [
" train_loss train_Accuracy train_Precision \\\n",
"0 0.225108 [0.9270001649856567, 0.8861073851585388] 0.906365 \n",
"1 0.152386 [0.9471125602722168, 0.9297876358032227] 0.938380 \n",
"2 0.134425 [0.9490371346473694, 0.9407152533531189] 0.944841 \n",
"0 0.232111 [0.9225957989692688, 0.8899886012077332] 0.906075 \n",
"1 0.160898 [0.9417101144790649, 0.9261900186538696] 0.933944 \n",
"2 0.138879 [0.9479646682739258, 0.9413018226623535] 0.944648 \n",
"\n",
" train_Recall \\\n",
"0 [0.9270001649856567, 0.8861073851585388] \n",
"1 [0.9471125602722168, 0.9297876358032227] \n",
"2 [0.9490371346473694, 0.9407152533531189] \n",
"0 [0.9225957989692688, 0.8899886012077332] \n",
"1 [0.9417101144790649, 0.9261900186538696] \n",
"2 [0.9479646682739258, 0.9413018226623535] \n",
"\n",
" train_F1 val_loss \\\n",
"0 [0.9074797034263611, 0.9052220582962036] 0.103954 \n",
"1 [0.9384452104568481, 0.9383144974708557] 0.091541 \n",
"2 [0.9446273446083069, 0.9450528621673584] 0.084717 \n",
"0 [0.9064720273017883, 0.9056749939918518] 0.059772 \n",
"1 [0.9344058036804199, 0.9334757328033447] 0.122636 \n",
"2 [0.945061206817627, 0.9442285299301147] 0.072151 \n",
"\n",
" val_Accuracy val_Precision \\\n",
"0 [0.962550163269043, 0.875] 0.961784 \n",
"1 [0.9680188298225403, 0.890625] 0.967341 \n",
"2 [0.967949628829956, 0.8828125] 0.967204 \n",
" val_Accuracy val_Precision \\\n",
"0 [0.9809635877609253, 0.8671875] 0.979966 \n",
"1 [0.954935610294342, 0.921875] 0.954648 \n",
"2 [0.9743181467056274, 0.8984375] 0.973653 \n",
"\n",
" val_Recall val_F1 \n",
"0 [0.962550163269043, 0.875] [0.9803645014762878, 0.28863346576690674] \n",
"1 [0.9680188298225403, 0.890625] [0.9832653999328613, 0.3257790505886078] \n",
"2 [0.967949628829956, 0.8828125] [0.9831950664520264, 0.3229461908340454] "
" val_Recall val_F1 \n",
"0 [0.9809635877609253, 0.8671875] [0.989802360534668, 0.4341084957122803] \n",
"1 [0.954935610294342, 0.921875] [0.9766026139259338, 0.2647385895252228] \n",
"2 [0.9743181467056274, 0.8984375] [0.9865424036979675, 0.3766234219074249] "
]
},
"execution_count": 14,
......@@ -1016,7 +1018,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"predict: 100%|██████████| 292/292 [00:00<00:00, 294.65it/s]\n"
"predict: 100%|██████████| 292/292 [00:01<00:00, 286.78it/s]\n"
]
},
{
......@@ -1025,15 +1027,15 @@
"text": [
" precision recall f1-score support\n",
"\n",
" 0 1.00 0.97 0.99 14446\n",
" 1 0.23 0.93 0.36 130\n",
" 0 1.00 0.98 0.99 14446\n",
" 1 0.27 0.93 0.41 130\n",
"\n",
" accuracy 0.97 14576\n",
" macro avg 0.61 0.95 0.67 14576\n",
"weighted avg 0.99 0.97 0.98 14576\n",
" accuracy 0.98 14576\n",
" macro avg 0.63 0.95 0.70 14576\n",
"weighted avg 0.99 0.98 0.98 14576\n",
"\n",
"Actual predicted values:\n",
"(array([0, 1]), array([14039, 537]))\n"
"(array([0, 1]), array([14122, 454]))\n"
]
}
],
......
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/javier/.pyenv/versions/3.7.7/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n"
]
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"\n",
"from pytorch_widedeep.preprocessing import TabPreprocessor\n",
"from pytorch_widedeep.training import Trainer\n",
"from pytorch_widedeep.models import FTTransformer, WideDeep\n",
"from pytorch_widedeep.metrics import Accuracy\n",
"from pytorch_widedeep import Tab2Vec"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>workclass</th>\n",
" <th>fnlwgt</th>\n",
" <th>education</th>\n",
" <th>educational-num</th>\n",
" <th>marital-status</th>\n",
" <th>occupation</th>\n",
" <th>relationship</th>\n",
" <th>race</th>\n",
" <th>gender</th>\n",
" <th>capital-gain</th>\n",
" <th>capital-loss</th>\n",
" <th>hours-per-week</th>\n",
" <th>native-country</th>\n",
" <th>income</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>25</td>\n",
" <td>Private</td>\n",
" <td>226802</td>\n",
" <td>11th</td>\n",
" <td>7</td>\n",
" <td>Never-married</td>\n",
" <td>Machine-op-inspct</td>\n",
" <td>Own-child</td>\n",
" <td>Black</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>38</td>\n",
" <td>Private</td>\n",
" <td>89814</td>\n",
" <td>HS-grad</td>\n",
" <td>9</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Farming-fishing</td>\n",
" <td>Husband</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>50</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>28</td>\n",
" <td>Local-gov</td>\n",
" <td>336951</td>\n",
" <td>Assoc-acdm</td>\n",
" <td>12</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Protective-serv</td>\n",
" <td>Husband</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>&gt;50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>44</td>\n",
" <td>Private</td>\n",
" <td>160323</td>\n",
" <td>Some-college</td>\n",
" <td>10</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Machine-op-inspct</td>\n",
" <td>Husband</td>\n",
" <td>Black</td>\n",
" <td>Male</td>\n",
" <td>7688</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>&gt;50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>18</td>\n",
" <td>?</td>\n",
" <td>103497</td>\n",
" <td>Some-college</td>\n",
" <td>10</td>\n",
" <td>Never-married</td>\n",
" <td>?</td>\n",
" <td>Own-child</td>\n",
" <td>White</td>\n",
" <td>Female</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>30</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" age workclass fnlwgt education educational-num marital-status \\\n",
"0 25 Private 226802 11th 7 Never-married \n",
"1 38 Private 89814 HS-grad 9 Married-civ-spouse \n",
"2 28 Local-gov 336951 Assoc-acdm 12 Married-civ-spouse \n",
"3 44 Private 160323 Some-college 10 Married-civ-spouse \n",
"4 18 ? 103497 Some-college 10 Never-married \n",
"\n",
" occupation relationship race gender capital-gain capital-loss \\\n",
"0 Machine-op-inspct Own-child Black Male 0 0 \n",
"1 Farming-fishing Husband White Male 0 0 \n",
"2 Protective-serv Husband White Male 0 0 \n",
"3 Machine-op-inspct Husband Black Male 7688 0 \n",
"4 ? Own-child White Female 0 0 \n",
"\n",
" hours-per-week native-country income \n",
"0 40 United-States <=50K \n",
"1 50 United-States <=50K \n",
"2 40 United-States >50K \n",
"3 40 United-States >50K \n",
"4 30 United-States <=50K "
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.read_csv('data/adult/adult.csv.zip')\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>workclass</th>\n",
" <th>fnlwgt</th>\n",
" <th>education</th>\n",
" <th>marital_status</th>\n",
" <th>occupation</th>\n",
" <th>relationship</th>\n",
" <th>race</th>\n",
" <th>gender</th>\n",
" <th>capital_gain</th>\n",
" <th>capital_loss</th>\n",
" <th>hours_per_week</th>\n",
" <th>native_country</th>\n",
" <th>target</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>25</td>\n",
" <td>Private</td>\n",
" <td>226802</td>\n",
" <td>11th</td>\n",
" <td>Never-married</td>\n",
" <td>Machine-op-inspct</td>\n",
" <td>Own-child</td>\n",
" <td>Black</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>38</td>\n",
" <td>Private</td>\n",
" <td>89814</td>\n",
" <td>HS-grad</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Farming-fishing</td>\n",
" <td>Husband</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>50</td>\n",
" <td>United-States</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>28</td>\n",
" <td>Local-gov</td>\n",
" <td>336951</td>\n",
" <td>Assoc-acdm</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Protective-serv</td>\n",
" <td>Husband</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>44</td>\n",
" <td>Private</td>\n",
" <td>160323</td>\n",
" <td>Some-college</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Machine-op-inspct</td>\n",
" <td>Husband</td>\n",
" <td>Black</td>\n",
" <td>Male</td>\n",
" <td>7688</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>18</td>\n",
" <td>?</td>\n",
" <td>103497</td>\n",
" <td>Some-college</td>\n",
" <td>Never-married</td>\n",
" <td>?</td>\n",
" <td>Own-child</td>\n",
" <td>White</td>\n",
" <td>Female</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>30</td>\n",
" <td>United-States</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" age workclass fnlwgt education marital_status \\\n",
"0 25 Private 226802 11th Never-married \n",
"1 38 Private 89814 HS-grad Married-civ-spouse \n",
"2 28 Local-gov 336951 Assoc-acdm Married-civ-spouse \n",
"3 44 Private 160323 Some-college Married-civ-spouse \n",
"4 18 ? 103497 Some-college Never-married \n",
"\n",
" occupation relationship race gender capital_gain capital_loss \\\n",
"0 Machine-op-inspct Own-child Black Male 0 0 \n",
"1 Farming-fishing Husband White Male 0 0 \n",
"2 Protective-serv Husband White Male 0 0 \n",
"3 Machine-op-inspct Husband Black Male 7688 0 \n",
"4 ? Own-child White Female 0 0 \n",
"\n",
" hours_per_week native_country target \n",
"0 40 United-States 0 \n",
"1 50 United-States 0 \n",
"2 40 United-States 1 \n",
"3 40 United-States 1 \n",
"4 30 United-States 0 "
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# For convenience, we'll replace '-' with '_'\n",
"df.columns = [c.replace(\"-\", \"_\") for c in df.columns]\n",
"#binary target\n",
"df['target'] = (df[\"income\"].apply(lambda x: \">50K\" in x)).astype(int)\n",
"df.drop([\"income\", \"educational_num\"], axis=1, inplace=True)\n",
"\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"cat_cols, cont_cols = [], []\n",
"for col in df.columns:\n",
" # 50 is just a random number I choose here for this example\n",
" if df[col].dtype == \"O\" or df[col].nunique() < 50 and col != \"target\":\n",
" cat_cols.append(col)\n",
" elif col != \"target\": \n",
" cont_cols.append(col)\n",
"target_col = \"target\""
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"target = df[target_col].values\n",
"\n",
"tab_preprocessor = TabPreprocessor(embed_cols=cat_cols, \n",
" continuous_cols=cont_cols, \n",
" for_transformer=True\n",
" )\n",
"X_tab = tab_preprocessor.fit_transform(df)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"ft_transformer = FTTransformer(column_idx=tab_preprocessor.column_idx,\n",
" embed_input=tab_preprocessor.embeddings_input,\n",
" continuous_cols=tab_preprocessor.continuous_cols, \n",
" n_blocks=3, n_heads=6, input_dim=36\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 153/153 [00:15<00:00, 10.12it/s, loss=0.355, metrics={'acc': 0.8326}]\n",
"valid: 100%|██████████| 39/39 [00:01<00:00, 26.68it/s, loss=0.308, metrics={'acc': 0.8598}]\n"
]
}
],
"source": [
"model = WideDeep(deeptabular=ft_transformer)\n",
"trainer = Trainer(model, objective='binary', metrics=[Accuracy])\n",
"trainer.fit(X_tab=X_tab, target=target, n_epochs=1, batch_size=256, val_split=0.2)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"t2v = Tab2Vec(model=model, tab_preprocessor=tab_preprocessor)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# assuming is a test set with target col\n",
"X_vec, y = t2v.transform(df.sample(100), target_col=\"target\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(100, 468)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# X vec is the dataframe turned into the embeddings\n",
"X_vec.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`468 = input_dim (36) * n_cols (13)`"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# ...or if we don't have target col\n",
"X_vec = t2v.transform(df.sample(100))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
......@@ -3,8 +3,15 @@ import torch
import pandas as pd
from pytorch_widedeep import Trainer
from pytorch_widedeep.optim import RAdam
from pytorch_widedeep.models import SAINT, Wide, WideDeep, TabTransformer
from pytorch_widedeep.models import (
SAINT,
Wide,
WideDeep,
TabPerceiver,
FTTransformer,
TabFastFormer,
TabTransformer,
)
from pytorch_widedeep.metrics import Accuracy
from pytorch_widedeep.callbacks import (
LRHistory,
......@@ -64,22 +71,60 @@ if __name__ == "__main__":
embed_input=prepare_deep.embeddings_input,
continuous_cols=continuous_cols,
embed_continuous=True,
n_blocks=4,
)
saint = SAINT(
column_idx=prepare_deep.column_idx,
embed_input=prepare_deep.embeddings_input,
continuous_cols=continuous_cols,
embed_continuous=True,
cont_norm_layer="batchnorm",
n_blocks=4,
)
tab_perceiver = TabPerceiver(
column_idx=prepare_deep.column_idx,
embed_input=prepare_deep.embeddings_input,
continuous_cols=continuous_cols,
n_latents=6,
latent_dim=16,
n_latent_blocks=4,
n_perceiver_blocks=2,
share_weights=False,
)
tab_fastformer = TabFastFormer(
column_idx=prepare_deep.column_idx,
embed_input=prepare_deep.embeddings_input,
continuous_cols=continuous_cols,
n_blocks=4,
n_heads=4,
share_qv_weights=False,
share_weights=False,
)
ft_transformer = FTTransformer(
column_idx=prepare_deep.column_idx,
embed_input=prepare_deep.embeddings_input,
continuous_cols=continuous_cols,
input_dim=32,
kv_compression_factor=0.5,
n_blocks=3,
n_heads=4,
)
for tab_model in [tab_transformer, saint]:
for tab_model in [
tab_transformer,
saint,
ft_transformer,
tab_perceiver,
tab_fastformer,
]:
model = WideDeep(wide=wide, deeptabular=tab_model)
wide_opt = torch.optim.Adam(model.wide.parameters(), lr=0.01)
deep_opt = RAdam(model.deeptabular.parameters())
deep_opt = torch.optim.Adam(model.wide.parameters(), lr=0.01)
wide_sch = torch.optim.lr_scheduler.StepLR(wide_opt, step_size=3)
deep_sch = torch.optim.lr_scheduler.StepLR(deep_opt, step_size=5)
......
......@@ -4,7 +4,7 @@
[![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://github.com/jrzaurin/pytorch-widedeep/graphs/commit-activity)
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/jrzaurin/pytorch-widedeep/issues)
[![codecov](https://codecov.io/gh/jrzaurin/pytorch-widedeep/branch/master/graph/badge.svg)](https://codecov.io/gh/jrzaurin/pytorch-widedeep)
[![Python 3.6 3.7 3.8 3.9](https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8%20%7C%203.9-blue.svg)](https://www.python.org/)
[![Python 3.7 3.8 3.9](https://img.shields.io/badge/python-3.7%20%7C%203.8%20%7C%203.9-blue.svg)](https://www.python.org/)
# pytorch-widedeep
......@@ -63,8 +63,8 @@ when running on Mac, present in previous versions, persist on this release
and the data-loaders will not run in parallel. In addition, since `python
3.8`, [the `multiprocessing` library start method changed from `'fork'` to`'spawn'`](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods).
This also affects the data-loaders (for any `torch` version) and they will
not run in parallel. Therefore, for Mac users I recommend using `python
3.6` or `3.7` and `torch <= 1.6` (with the corresponding, consistent
not run in parallel. Therefore, for Mac users I recommend using `python 3.7`
and `torch <= 1.6` (with the corresponding, consistent
version of `torchvision`, e.g. `0.7.0` for `torch 1.6`). I do not want to
force this versioning in the `setup.py` file since I expect that all these
issues are fixed in the future. Therefore, after installing
......
......@@ -12,5 +12,6 @@ from pytorch_widedeep.utils import (
deeptabular_utils,
fastai_transforms,
)
from pytorch_widedeep.tab2vec import Tab2Vec
from pytorch_widedeep.version import __version__
from pytorch_widedeep.training import Trainer
......@@ -341,19 +341,19 @@ class ModelCheckpoint(Callback):
weights_out_2.pt, ...``
monitor: str, default="loss"
quantity to monitor. Typically 'val_loss' or metric name (e.g. 'val_acc')
verbose:int, default=0,
verbose:int, default=0
verbosity mode
save_best_only: bool, default=False,
the latest best model according to the quantity monitored will not be
overwritten.
mode: str, default="auto",
mode: str, default="auto"
If ``save_best_only=True``, the decision to overwrite the current save
file is made based on either the maximization or the minimization of
the monitored quantity. For `'acc'`, this should be `'max'`, for
`'loss'` this should be `'min'`, etc. In `'auto'` mode, the
direction is automatically inferred from the name of the monitored
quantity.
period: int, default=1,
period: int, default=1
Interval (number of epochs) between checkpoints.
max_save: int, default=-1
Maximum number of outputs to save. If -1 will save all outputs
......@@ -425,11 +425,11 @@ class ModelCheckpoint(Callback):
self.monitor_op = np.less
self.best = np.Inf
elif self.mode == "max":
self.monitor_op = np.greater
self.monitor_op = np.greater # type: ignore[assignment]
self.best = -np.Inf
else:
if _is_metric(self.monitor):
self.monitor_op = np.greater
self.monitor_op = np.greater # type: ignore[assignment]
self.best = -np.Inf
else:
self.monitor_op = np.less
......@@ -596,10 +596,10 @@ class EarlyStopping(Callback):
if self.mode == "min":
self.monitor_op = np.less
elif self.mode == "max":
self.monitor_op = np.greater
self.monitor_op = np.greater # type: ignore[assignment]
else:
if _is_metric(self.monitor):
self.monitor_op = np.greater
self.monitor_op = np.greater # type: ignore[assignment]
else:
self.monitor_op = np.less
......
......@@ -6,4 +6,7 @@ from pytorch_widedeep.models.deep_image import DeepImage
from pytorch_widedeep.models.tab_resnet import TabResnet
from pytorch_widedeep.models.tabnet.tab_net import TabNet
from pytorch_widedeep.models.transformers.saint import SAINT
from pytorch_widedeep.models.transformers.tab_perceiver import TabPerceiver
from pytorch_widedeep.models.transformers.ft_transformer import FTTransformer
from pytorch_widedeep.models.transformers.tab_fastformer import TabFastFormer
from pytorch_widedeep.models.transformers.tab_transformer import TabTransformer
......@@ -55,11 +55,12 @@ class DeepImage(nn.Module):
The resnet architecture. One of 18, 34 or 50
freeze_n: int, default = 6
number of layers to freeze. Must be less than or equal to 8. If 8
the entire 'backbone' of the nwtwork will be frozen
the entire 'backbone' of the network will be frozen
head_hidden_dims: List, Optional, default = None
List with the number of neurons per dense layer in the head. e.g: [64,32]
head_activation: str, default = "relu"
Activation function for the dense layers in the head.
Activation function for the dense layers in the head. Currently
``tanh``, ``relu``, ``leaky_relu`` and ``gelu`` are supported
head_dropout: float, default = 0.1
float indicating the dropout between the dense layers.
head_batchnorm: bool, default = False
......
......@@ -20,7 +20,7 @@ class DeepText(nn.Module):
vocab_size: int
number of words in the vocabulary
rnn_type: str, default = 'lstm'
String indicating the type of RNN to use. One of "lstm" or "gru"
String indicating the type of RNN to use. One of ``lstm`` or ``gru``
hidden_dim: int, default = 64
Hidden dim of the RNN
n_layers: int, default = 3
......@@ -30,9 +30,9 @@ class DeepText(nn.Module):
the last layer
bidirectional: bool, default = True
indicates whether the staked RNNs are bidirectional
use_hidden_state: str, default = True,
use_hidden_state: str, default = True
Boolean indicating whether to use the final hidden state or the
rnn output as predicting features
RNN output as predicting features
padding_idx: int, default = 1
index of the padding token in the padded-tokenised sequences. I
use the ``fastai`` tokenizer where the token index 0 is reserved
......@@ -48,7 +48,8 @@ class DeepText(nn.Module):
List with the sizes of the stacked dense layers in the head
e.g: [128, 64]
head_activation: str, default = "relu"
Activation function for the dense layers in the head
Activation function for the dense layers in the head. Currently
``tanh``, ``relu``, ``leaky_relu`` and ``gelu`` are supported
head_dropout: float, Optional, default = None
dropout between the dense layers in the head
head_batchnorm: bool, default = False
......
......@@ -5,7 +5,7 @@ from torch import nn
from pytorch_widedeep.wdtypes import * # noqa: F403
allowed_activations = ["relu", "leaky_relu", "gelu", "geglu"]
allowed_activations = ["relu", "leaky_relu", "tanh", "gelu", "geglu", "reglu"]
class GEGLU(nn.Module):
......@@ -14,15 +14,25 @@ class GEGLU(nn.Module):
return x * F.gelu(gates)
def _get_activation_fn(activation):
class REGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim=-1)
return x * F.gelu(gates)
def get_activation_fn(activation):
if activation == "relu":
return nn.ReLU(inplace=True)
if activation == "leaky_relu":
return nn.LeakyReLU(inplace=True)
elif activation == "gelu":
if activation == "tanh":
return nn.Tanh()
if activation == "gelu":
return nn.GELU()
elif activation == "geglu":
if activation == "geglu":
return GEGLU()
if activation == "reglu":
return REGLU()
def dense_layer(
......@@ -37,9 +47,9 @@ def dense_layer(
if activation == "geglu":
raise ValueError(
"'geglu' activation is only used as 'transformer_activation' "
"in transformer-based models (TabTransformer and SAINT)"
"in transformer-based models"
)
act_fn = _get_activation_fn(activation)
act_fn = get_activation_fn(activation)
layers = [nn.BatchNorm1d(out if linear_first else inp)] if bn else []
if p != 0:
layers.append(nn.Dropout(p)) # type: ignore[arg-type]
......@@ -48,6 +58,69 @@ def dense_layer(
return nn.Sequential(*layers)
class CatEmbeddingsAndCont(nn.Module):
def __init__(
self,
column_idx: Dict[str, int],
embed_input: List[Tuple[str, int, int]],
embed_dropout: float,
continuous_cols: Optional[List[str]],
cont_norm_layer: str,
):
super(CatEmbeddingsAndCont, self).__init__()
self.column_idx = column_idx
self.embed_input = embed_input
self.continuous_cols = continuous_cols
# Embeddings: val + 1 because 0 is reserved for padding/unseen cateogories.
if self.embed_input is not None:
self.embed_layers = nn.ModuleDict(
{
"emb_layer_" + col: nn.Embedding(val + 1, dim, padding_idx=0)
for col, val, dim in self.embed_input
}
)
self.embedding_dropout = nn.Dropout(embed_dropout)
self.emb_out_dim: int = int(
np.sum([embed[2] for embed in self.embed_input])
)
else:
self.emb_out_dim = 0
# Continuous
if self.continuous_cols is not None:
self.cont_idx = [self.column_idx[col] for col in self.continuous_cols]
self.cont_out_dim: int = len(self.continuous_cols)
if cont_norm_layer == "batchnorm":
self.cont_norm: NormLayers = nn.BatchNorm1d(self.cont_out_dim)
elif cont_norm_layer == "layernorm":
self.cont_norm = nn.LayerNorm(self.cont_out_dim)
else:
self.cont_norm = nn.Identity()
else:
self.cont_out_dim = 0
self.output_dim = self.emb_out_dim + self.cont_out_dim
def forward(self, X: Tensor) -> Tuple[Tensor, Any]:
if self.embed_input is not None:
embed = [
self.embed_layers["emb_layer_" + col](X[:, self.column_idx[col]].long())
for col, _, _ in self.embed_input
]
x_emb = torch.cat(embed, 1)
x_emb = self.embedding_dropout(x_emb)
else:
x_emb = None
if self.continuous_cols is not None:
x_cont = self.cont_norm((X[:, self.cont_idx].float()))
else:
x_cont = None
return x_emb, x_cont
class MLP(nn.Module):
def __init__(
self,
......@@ -95,7 +168,7 @@ class TabMlp(nn.Module):
----------
column_idx: Dict
Dict containing the index of the columns that will be passed through
the TabMlp model. Required to slice the tensors. e.g. {'education':
the ``TabMlp`` model. Required to slice the tensors. e.g. {'education':
0, 'relationship': 1, 'workclass': 2, ...}
embed_input: List, Optional, default = None
List of Tuples with the column name, number of unique values and
......@@ -111,7 +184,7 @@ class TabMlp(nn.Module):
List with the number of neurons per dense layer in the mlp.
mlp_activation: str, default = "relu"
Activation function for the dense layers of the MLP. Currently
'relu', 'leaky_relu' and 'gelu' are supported
``tanh``, ``relu``, ``leaky_relu`` and ``gelu`` are supported
mlp_dropout: float or List, default = 0.1
float or List of floats with the dropout between the dense layers.
e.g: [0.5,0.5]
......@@ -128,13 +201,11 @@ class TabMlp(nn.Module):
Attributes
----------
cont_norm: ``nn.Module``
continuous normalization layer
cat_embed_and_cont: ``nn.Module``
This is the module that processes the categorical and continuous columns
tab_mlp: ``nn.Sequential``
mlp model that will receive the concatenation of the embeddings and
the continuous columns
embed_layers: ``nn.ModuleDict``
``ModuleDict`` with the embeddings set up
output_dim: int
The output dimension of the model. This is a required attribute
neccesary to build the WideDeep class
......@@ -169,15 +240,15 @@ class TabMlp(nn.Module):
super(TabMlp, self).__init__()
self.column_idx = column_idx
self.embed_input = embed_input
self.mlp_hidden_dims = mlp_hidden_dims
self.embed_dropout = embed_dropout
self.continuous_cols = continuous_cols
self.cont_norm_layer = cont_norm_layer
self.mlp_activation = mlp_activation
self.mlp_dropout = mlp_dropout
self.mlp_batchnorm = mlp_batchnorm
self.mlp_linear_first = mlp_linear_first
self.embed_input = embed_input
self.embed_dropout = embed_dropout
self.continuous_cols = continuous_cols
self.cont_norm_layer = cont_norm_layer
if self.mlp_activation not in allowed_activations:
raise ValueError(
......@@ -187,35 +258,17 @@ class TabMlp(nn.Module):
)
)
# Embeddings: val + 1 because 0 is reserved for padding/unseen cateogories.
if self.embed_input is not None:
self.embed_layers = nn.ModuleDict(
{
"emb_layer_" + col: nn.Embedding(val + 1, dim, padding_idx=0)
for col, val, dim in self.embed_input
}
)
self.embedding_dropout = nn.Dropout(embed_dropout)
emb_inp_dim = np.sum([embed[2] for embed in self.embed_input])
else:
emb_inp_dim = 0 # type: ignore[assignment]
# Continuous
if self.continuous_cols is not None:
self.cont_idx = [self.column_idx[col] for col in self.continuous_cols]
cont_inp_dim = len(self.continuous_cols)
if self.cont_norm_layer == "batchnorm":
self.cont_norm: NormLayers = nn.BatchNorm1d(cont_inp_dim)
elif self.cont_norm_layer == "layernorm":
self.cont_norm = nn.LayerNorm(cont_inp_dim)
else:
self.cont_norm = nn.Identity()
else:
cont_inp_dim = 0
self.cat_embed_and_cont = CatEmbeddingsAndCont(
column_idx,
embed_input,
embed_dropout,
continuous_cols,
cont_norm_layer,
)
# MLP
input_dim = emb_inp_dim + cont_inp_dim
mlp_hidden_dims = [input_dim] + mlp_hidden_dims # type: ignore[assignment, operator]
mlp_input_dim = self.cat_embed_and_cont.output_dim
mlp_hidden_dims = [mlp_input_dim] + mlp_hidden_dims
self.tab_mlp = MLP(
mlp_hidden_dims,
mlp_activation,
......@@ -228,18 +281,13 @@ class TabMlp(nn.Module):
# the output_dim attribute will be used as input_dim when "merging" the models
self.output_dim = mlp_hidden_dims[-1]
def forward(self, X: Tensor) -> Tensor: # type: ignore
def forward(self, X: Tensor) -> Tensor:
r"""Forward pass that concatenates the continuous features with the
embeddings. The result is then passed through a series of dense layers
"""
if self.embed_input is not None:
embed = [
self.embed_layers["emb_layer_" + col](X[:, self.column_idx[col]].long())
for col, _, _ in self.embed_input
]
x = torch.cat(embed, 1)
x = self.embedding_dropout(x)
if self.continuous_cols is not None:
x_cont = self.cont_norm((X[:, self.cont_idx].float()))
x = torch.cat([x, x_cont], 1) if self.embed_input is not None else x_cont
x_emb, x_cont = self.cat_embed_and_cont(X)
if x_emb is not None:
x = x_emb
if x_cont is not None:
x = torch.cat([x, x_cont], 1) if x_emb is not None else x_cont
return self.tab_mlp(x)
from collections import OrderedDict
import numpy as np
import torch
from torch import nn
from torch.nn import Module
from pytorch_widedeep.wdtypes import * # noqa: F403
from pytorch_widedeep.models.tab_mlp import MLP
from pytorch_widedeep.models.tab_mlp import MLP, CatEmbeddingsAndCont
class BasicBlock(nn.Module):
# inspired by https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L37
def __init__(self, inp: int, out: int, dropout: float = 0.0, resize: Module = None):
super(BasicBlock, self).__init__()
......@@ -86,20 +86,17 @@ class TabResnet(nn.Module):
r"""Defines a so-called ``TabResnet`` model that can be used as the
``deeptabular`` component of a Wide & Deep model.
This class combines embedding representations of the categorical
features with numerical (aka continuous) features. These are then
passed through a series of Resnet blocks. See
``pytorch_widedeep.models.tab_resnet.BasicBlock`` for details
on the structure of each block.
.. note:: Unlike ``TabMlp``, ``TabResnet`` assumes that there are always
categorical columns
This class combines embedding representations of the categorical features
with numerical (aka continuous) features. These are then passed through a
series of Resnet blocks. See
:obj:`pytorch_widedeep.models.tab_resnet.BasicBlock` for details on the
structure of each block.
Parameters
----------
column_idx: Dict
Dict containing the index of the columns that will be passed through
the TabMlp model. Required to slice the tensors. e.g. {'education':
the ``Resnet`` model. Required to slice the tensors. e.g. {'education':
0, 'relationship': 1, 'workclass': 2, ...}
embed_input: List
List of Tuples with the column name, number of unique values and
......@@ -112,11 +109,11 @@ class TabResnet(nn.Module):
Type of normalization layer applied to the continuous features. Options
are: 'layernorm', 'batchnorm' or None.
concat_cont_first: bool, default = True
Boolean indicating whether the continuum columns will be
concatenated with the Embeddings and then passed through the
Resnet blocks (``True``) or, alternatively, will be concatenated
with the result of passing the embeddings through the Resnet
Blocks (``False``)
If ``True`` the continuum columns will be concatenated with the
Categorical Embeddings and then passed through the Resnet blocks. If
``False``, the Categorical Embeddings will be passed through the
Resnet blocks and then the output of the Resnet blocks will be
concatenated with the continuous features.
blocks_dims: List, default = [200, 100, 100]
List of integers that define the input and output units of each block.
For example: [200, 100, 100] will generate 2 blocks. The first will
......@@ -128,12 +125,13 @@ class TabResnet(nn.Module):
Block's `"internal"` dropout. This dropout is applied to the first
of the two dense layers that comprise each ``BasicBlock``.
mlp_hidden_dims: List, Optional, default = None
List with the number of neurons per dense layer in the mlp. e.g:
List with the number of neurons per dense layer in the MLP. e.g:
[64, 32]. If ``None`` the output of the Resnet Blocks will be
connected directly to the output neuron(s), i.e. using a MLP is
optional.
mlp_activation: str, default = "relu"
MLP activation function. 'relu', 'leaky_relu' and 'gelu' are supported
Activation function for the dense layers of the MLP. Currently
``tanh``, ``relu``, ``leaky_relu`` and ``gelu`` are supported
mlp_dropout: float, default = 0.1
float with the dropout between the dense layers of the MLP.
mlp_batchnorm: bool, default = False
......@@ -149,13 +147,11 @@ class TabResnet(nn.Module):
Attributes
----------
embed_layers: ``nn.ModuleDict``
``ModuleDict`` with the embeddings setup
cat_embed_and_cont: ``nn.Module``
This is the module that processes the categorical and continuous columns
dense_resnet: ``nn.Sequential``
deep dense Resnet model that will receive the concatenation of the
embeddings and the continuous columns
cont_norm: ``nn.Module``
continuous normalization layer
tab_resnet_mlp: ``nn.Sequential``
if ``mlp_hidden_dims`` is ``True``, this attribute will be an mlp
model that will receive:
......@@ -188,7 +184,7 @@ class TabResnet(nn.Module):
def __init__(
self,
column_idx: Dict[str, int],
embed_input: List[Tuple[str, int, int]],
embed_input: Optional[List[Tuple[str, int, int]]] = None,
embed_dropout: float = 0.1,
continuous_cols: Optional[List[str]] = None,
cont_norm_layer: str = "batchnorm",
......@@ -204,6 +200,16 @@ class TabResnet(nn.Module):
):
super(TabResnet, self).__init__()
if len(blocks_dims) < 2:
raise ValueError(
"'blocks' must contain at least two elements, e.g. [256, 128]"
)
if not concat_cont_first and embed_input is None:
raise ValueError(
"If 'concat_cont_first = False' 'embed_input' must be not 'None'"
)
self.column_idx = column_idx
self.embed_input = embed_input
self.embed_dropout = embed_dropout
......@@ -218,43 +224,26 @@ class TabResnet(nn.Module):
self.mlp_batchnorm_last = mlp_batchnorm_last
self.mlp_linear_first = mlp_linear_first
if len(self.blocks_dims) < 2:
raise ValueError(
"'blocks' must contain at least two elements, e.g. [256, 128]"
)
# Embeddings: val + 1 because 0 is reserved for padding/unseen cateogories.
self.embed_layers = nn.ModuleDict(
{
"emb_layer_" + col: nn.Embedding(val + 1, dim, padding_idx=0)
for col, val, dim in self.embed_input
}
self.cat_embed_and_cont = CatEmbeddingsAndCont(
column_idx,
embed_input,
embed_dropout,
continuous_cols,
cont_norm_layer,
)
self.embedding_dropout = nn.Dropout(embed_dropout)
emb_inp_dim = np.sum([embed[2] for embed in self.embed_input])
# Continuous
if self.continuous_cols is not None:
self.cont_idx = [self.column_idx[col] for col in self.continuous_cols]
cont_inp_dim = len(self.continuous_cols)
if self.cont_norm_layer == "batchnorm":
self.cont_norm: NormLayers = nn.BatchNorm1d(cont_inp_dim)
elif self.cont_norm_layer == "layernorm":
self.cont_norm = nn.LayerNorm(cont_inp_dim)
else:
self.cont_norm = nn.Identity()
else:
cont_inp_dim = 0
emb_out_dim = self.cat_embed_and_cont.emb_out_dim
cont_out_dim = self.cat_embed_and_cont.cont_out_dim
# DenseResnet
if self.concat_cont_first:
dense_resnet_input_dim = emb_inp_dim + cont_inp_dim
dense_resnet_input_dim = emb_out_dim + cont_out_dim
self.output_dim = blocks_dims[-1]
else:
dense_resnet_input_dim = emb_inp_dim
self.output_dim = cont_inp_dim + blocks_dims[-1]
dense_resnet_input_dim = emb_out_dim
self.output_dim = cont_out_dim + blocks_dims[-1]
self.tab_resnet_blks = DenseResnet(
dense_resnet_input_dim, blocks_dims, blocks_dropout # type: ignore[arg-type]
dense_resnet_input_dim, blocks_dims, blocks_dropout
)
# MLP
......@@ -262,7 +251,7 @@ class TabResnet(nn.Module):
if self.concat_cont_first:
mlp_input_dim = blocks_dims[-1]
else:
mlp_input_dim = cont_inp_dim + blocks_dims[-1]
mlp_input_dim = cont_out_dim + blocks_dims[-1]
mlp_hidden_dims = [mlp_input_dim] + mlp_hidden_dims
self.tab_resnet_mlp = MLP(
mlp_hidden_dims,
......@@ -274,26 +263,21 @@ class TabResnet(nn.Module):
)
self.output_dim = mlp_hidden_dims[-1]
def forward(self, X: Tensor) -> Tensor: # type: ignore
def forward(self, X: Tensor) -> Tensor:
r"""Forward pass that concatenates the continuous features with the
embeddings. The result is then passed through a series of dense Resnet
blocks"""
embed = [
self.embed_layers["emb_layer_" + col](X[:, self.column_idx[col]].long())
for col, _, _ in self.embed_input
]
x = torch.cat(embed, 1)
x = self.embedding_dropout(x)
if self.continuous_cols is not None:
x_cont = self.cont_norm((X[:, self.cont_idx].float()))
x_emb, x_cont = self.cat_embed_and_cont(X)
if x_cont is not None:
if self.concat_cont_first:
x = torch.cat([x, x_cont], 1)
x = torch.cat([x_emb, x_cont], 1) if x_emb is not None else x_cont
out = self.tab_resnet_blks(x)
else:
out = torch.cat([self.tab_resnet_blks(x), x_cont], 1)
out = torch.cat([self.tab_resnet_blks(x_emb), x_cont], 1)
else:
out = self.tab_resnet_blks(x)
out = self.tab_resnet_blks(x_emb)
if self.mlp_hidden_dims is not None:
out = self.tab_resnet_mlp(out)
......
此差异已折叠。
......@@ -17,7 +17,7 @@ def create_explain_matrix(model: WideDeep) -> csc_matrix:
Examples
--------
>>> from pytorch_widedeep.models import TabNet, WideDeep
>>> from pytorch_widedeep.models.tabnet.tab_net_utils import create_explain_matrix
>>> from pytorch_widedeep.models.tabnet._utils import create_explain_matrix
>>> embed_input = [("a", 4, 2), ("b", 4, 2), ("c", 4, 2)]
>>> cont_cols = ["d", "e"]
>>> column_idx = {k: v for v, k in enumerate(["a", "b", "c", "d", "e"])}
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
"""
During the development of the package I realised that there is a typing
inconsistency. The input components of a Wide and Deep model are of type
nn.Module. These change type internally to nn.Sequential. While nn.Sequential
is an instance of nn.Module the oppossite is, of course, not true. This does
not affect any funcionality of the package, but it is something that needs
fixing. However, while fixing is simple (simply define new attributes that
are the nn.Sequential objects), its implications are quite wide within the
package (involves changing a number of tests and tutorials). Therefore, I
will introduce that fix when I do a major release. For now, we live with it.
"""
import warnings
import torch
......@@ -30,51 +42,16 @@ class WideDeep(nn.Module):
Parameters
----------
wide: ``nn.Module``, Optional, default = None
``Wide`` model. I recommend using the :obj:`Wide` class in this
``Wide`` model. I recommend using the ``Wide`` class in this
package. However, it is possible to use a custom model as long as
is consistent with the required architecture, see
:class:`pytorch_widedeep.models.wide.Wide`
deeptabular: ``nn.Module``, Optional, default = None
currently ``pytorch-widedeep`` implements four possible
architectures for the `deeptabular` component. These are:
TabMlp, TabResnet, TabNet, TabTransformer and SAINT.
1. TabMlp is simply an embedding layer encoding the categorical
features that are then concatenated and passed through a series of
dense (hidden) layers (i.e. and MLP).
See: :obj:`pytorch_widedeep.models.tab_mlp.TabMlp`
2. TabResnet is an embedding layer encoding the categorical
features that are then concatenated and passed through a series of
ResNet blocks formed by dense layers.
See :obj:`pytorch_widedeep.models.tab_resnet.TabResnet`
3. TabNet is detailed in `TabNet: Attentive Interpretable Tabular
Learning <https://arxiv.org/abs/1908.07442>`_. The TabNet
implementation in ``pytorch_widedeep`` is an adaptation of the
`dreamquark-ai <https://github.com/dreamquark-ai/tabnet>`_
implementation. See
:obj:`pytorch_widedeep.models.tabnet.tab_net.TabNet`
3. TabTransformer is detailed in `TabTransformer: Tabular Data
Modeling Using Contextual Embeddings
<https://arxiv.org/abs/2012.06678>`_. The TabTransformer
implementation in ``pytorch-widedeep`` is an adaptation of the
original implementation. See
:obj:`pytorch_widedeep.models.transformers.tab_transformer.TabTransformer`.
3. SAINT is detailed in `SAINT: Improved Neural Networks for Tabular
Data via Row Attention and Contrastive Pre-Training
<https://arxiv.org/abs/2106.01342>`_. The SAINT implementation in
``pytorch-widedeep`` is an adaptation of the original implementation.
See
:obj:`pytorch_widedeep.models.transformers.saint.SAINT`.
I recommend using on of these as ``deeptabular``. However, it is
possible to use a custom model as long as is consistent with the
required architecture.
currently ``pytorch-widedeep`` implements a number of possible
architectures for the ``deeptabular`` component. See the documenation
of the package. I recommend using the ``deeptabular`` components in
this package. However, it is possible to use a custom model as long
as is consistent with the required architecture.
deeptext: ``nn.Module``, Optional, default = None
Model for the text input. Must be an object of class ``DeepText``
or a custom model as long as is consistent with the required
......@@ -97,8 +74,8 @@ class WideDeep(nn.Module):
If ``head_hidden_dims`` is not None, dropout between the layers in
``head_hidden_dims``
head_activation: str, default = "relu"
If ``head_hidden_dims`` is not None, activation function of the
head layers. One of "relu", gelu" or "leaky_relu"
If ``head_hidden_dims`` is not None, activation function of the head
layers. One of ``tanh``, ``relu``, ``gelu`` or ``leaky_relu``
head_batchnorm: bool, default = False
If ``head_hidden_dims`` is not None, specifies if batch
normalizatin should be included in the head layers
......
......@@ -30,12 +30,12 @@ class TabPreprocessor(BasePreprocessor):
continuous_cols: List, default = None
List with the name of the so called continuous cols
scale: bool, default = True
Bool indicating whether or not to scale/standarise continuous
cols. The user should bear in mind that all the ``deeptabular``
components available within ``pytorch-widedeep`` they also include
the possibility of normalising the input continuous features via a
Bool indicating whether or not to scale/standarise continuous cols.
The user should bear in mind that all the ``deeptabular`` components
available within ``pytorch-widedeep`` they also include the
possibility of normalising the input continuous features via a
``BatchNorm`` or a ``LayerNorm`` layer. See
:class:`pytorch_widedeep.models`
:obj:`pytorch_widedeep.models.transformers._embedding_layers`
auto_embed_dim: bool, default = True
Boolean indicating whether the embedding dimensions will be
automatically defined via fastai's rule of thumb':
......@@ -53,26 +53,27 @@ class TabPreprocessor(BasePreprocessor):
tabular library) and not standarize them any further
for_transformer: bool, default = False
Boolean indicating whether the preprocessed data will be passed to a
transformer-based model (i.e. ``TabTransformer`` or ``SAINT``). If
``True``, the param ``embed_cols`` must just be a list containing the
categorical columns: e.g.:['education', 'relationship', ...] This is
because they will all be encoded using embeddings of the same dim
(32 by default).
transformer-based model
(See :obj:`pytorch_widedeep.models.transformers`). If ``True``, the
param ``embed_cols`` must just be a list containing the categorical
columns: e.g.:['education', 'relationship', ...] This is because they
will all be encoded using embeddings of the same dim.
with_cls_token: bool, default = False
Boolean indicating if a `'[CLS]'` token will be added to the dataset
when using transformer-based models (i.e. ``TabTransformer`` or
``SAINT``). The final hidden state corresponding to this token is
used as the aggregate row representation for classification and
regression tasks. If not, the categorical (and continuous embeddings
if present) will be concatenated before being passed to the final
MLP.
when using transformer-based models. The final hidden state
corresponding to this token is used as the aggregated representation
for classification and regression tasks. If not, the categorical
(and continuous embeddings if present) will be concatenated before
being passed to the final MLP.
shared_embed: bool, default = False
This parameter will only be used by the ``TabPreprocessor`` when the
data is being prepapred for a transformer-based model. If that is the
case and the embeddings are 'shared'
(see:
``pytorch_widedeep.models.transformers.layers.SharedEmbeddings``)
then each column will be embed indepedently.
Boolean indicating if the embeddings will be "shared" when using
transformer-based models. The idea behind ``shared_embed`` is
described in the Appendix A in the `TabTransformer paper
<https://arxiv.org/abs/2012.06678>`_: `'The goal of having column
embedding is to enable the model to distinguish the classes in one
column from those in the other columns'`. In other words, the idea is
to let the model learn which column is embedded at the time. See:
:obj:`pytorch_widedeep.models.transformers._layers.SharedEmbeddings`.
verbose: int, default = 1
Attributes
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册