Introduction#
TabNet is an attentive, interpretable deep learning architecture for tabular data, implemented in PyTorch. This project is a maintained fork of the original DreamQuark TabNet, with improvements for metrics, GPU support, and usability. It is suitable for classification, regression, and multitask learning on tabular datasets.
Installation#
Install TabNet using pip:
pip install pytorch-tabnet2
Original Repository#
This project is a maintained fork of the original DreamQuark TabNet implementation:
Key Features#
Supports classification, regression, multitask, and unsupervised pretraining
GPU acceleration and efficient data handling
Interpretable feature masks and explanations
Flexible API for research and production
For more details, see the original paper: https://arxiv.org/pdf/1908.07442.pdf
Project Changes from the Original Implementation#
Key Changes from Original#
Removed the PyTorch DataLoader, which previously accessed each datapoint individually and limited vectorization, resulting in slow performance. Data is now processed in a more efficient, vectorized manner.
Replaced sklearn metrics with torcheval, enabling fast, GPU-accelerated metric computation without the need to move data to the CPU or convert to numpy.
Shifted data weighting from the sampling/data loading stage to the loss function and metric calculations, providing more flexibility and efficiency.
Key Improvements#
Added comprehensive unittests, achieving over 90% code coverage for improved reliability and maintainability.
Significantly reduced training time on both CPU and GPU, primarily due to the removal of the DataLoader and improved vectorization.
Enabled real-time validation metric calculation on the GPU during training, leveraging torcheval for efficient, on-device evaluation.
Added type annotations throughout the codebase for improved code clarity and static analysis.
Added support for newer versions of Python: 3.10, 3.11, 3.12, and 3.13.