Introduction

PyPI version Python versions License Ubuntu Windows MacOS Coverage Last Commit

TabNet is an attentive, interpretable deep learning architecture for tabular data, implemented in PyTorch. This project is a maintained fork of the original DreamQuark TabNet, with improvements for metrics, GPU support, and usability. It is suitable for classification, regression, and multitask learning on tabular datasets.

Installation

Install TabNet using pip:

pip install pytorch-tabnet2

Original Repository

This project is a maintained fork of the original DreamQuark TabNet implementation:

Key Features

  • Supports classification, regression, multitask, and unsupervised pretraining

  • GPU acceleration and efficient data handling

  • Interpretable feature masks and explanations

  • Flexible API for research and production

For more details, see the original paper: https://arxiv.org/pdf/1908.07442.pdf

Project Changes from the Original Implementation

Key Changes from Original

  • Removed the PyTorch DataLoader, which previously accessed each datapoint individually and limited vectorization, resulting in slow performance. Data is now processed in a more efficient, vectorized manner.

  • Replaced sklearn metrics with torcheval, enabling fast, GPU-accelerated metric computation without the need to move data to the CPU or convert to numpy.

  • Shifted data weighting from the sampling/data loading stage to the loss function and metric calculations, providing more flexibility and efficiency.

Key Improvements

  • Added comprehensive unittests, achieving over 90% code coverage for improved reliability and maintainability.

  • Significantly reduced training time on both CPU and GPU, primarily due to the removal of the DataLoader and improved vectorization.

  • Enabled real-time validation metric calculation on the GPU during training, leveraging torcheval for efficient, on-device evaluation.

  • Added type annotations throughout the codebase for improved code clarity and static analysis.

  • Added support for newer versions of Python: 3.10, 3.11, 3.12, and 3.13.