TabNet Pretrainer#
pytorch_tabnet package initialization.
- class pytorch_tabnet.TabNetPretrainer(n_d: int = 8, n_a: int = 8, n_steps: int = 3, gamma: float = 1.3, cat_idxs: List[int] = <factory>, cat_dims: List[int] = <factory>, cat_emb_dim: int | List[int] = 1, n_independent: int = 2, n_shared: int = 2, epsilon: float = 1e-15, momentum: float = 0.02, lambda_sparse: float = 0.001, seed: int = 0, clip_value: int = 1, verbose: int = 1, optimizer_fn: Any = <class 'torch.optim.adam.Adam'>, optimizer_params: Dict = <factory>, scheduler_fn: Any = None, scheduler_params: Dict = <factory>, mask_type: str = 'sparsemax', input_dim: int = None, output_dim: List[int] | int = None, device_name: str = 'auto', n_shared_decoder: int = 1, n_indep_decoder: int = 1, grouped_features: List[List[int]] = <factory>, compile_backend: str = '')[source]
Bases:
TabModelAbstract base class for TabNet pretraining models.
- compute_loss(output: Tensor, embedded_x: Tensor, obf_vars: Tensor, w: Tensor | None = None) Tensor[source]
Compute the unsupervised loss for pretraining.
- Parameters:
output (torch.Tensor) – Network output.
embedded_x (torch.Tensor) – Embedded input.
obf_vars (torch.Tensor) – Obfuscated variables mask.
w (Optional[torch.Tensor]) – Optional sample weights.
- Returns:
Loss value.
- Return type:
torch.Tensor
- fit(X_train: ndarray, eval_set: List[ndarray | List[ndarray]] | None = None, eval_name: List[str] | None = None, loss_fn: Callable | None = None, pretraining_ratio: float = 0.5, weights: int | ndarray = 0, max_epochs: int = 100, patience: int = 10, batch_size: int = 1024, virtual_batch_size: int = 128, num_workers: int = 0, drop_last: bool = True, callbacks: List[Callable] | None = None, pin_memory: bool = True, warm_start: bool = False, *args: List, **kwargs: Dict) None[source]
Train the TabNet pretrainer model.
- Parameters:
X_train (np.ndarray) – Train set to reconstruct in self supervision
eval_set (list of np.array) – List of evaluation set
eval_name (list of str) – List of eval set names.
loss_fn (callable or None) – PyTorch loss function
pretraining_ratio (float) – Percentage of features to mask for reconstruction
weights (int or np.ndarray) – Sampling weights for each example
max_epochs (int) – Maximum number of epochs
patience (int) – Early stopping patience
batch_size (int) – Training batch size
virtual_batch_size (int) – Batch size for Ghost Batch Normalization
num_workers (int) – Number of workers for DataLoader
drop_last (bool) – Whether to drop last batch
callbacks (list of callable) – Custom callbacks
pin_memory (bool) – Whether to use pinned memory
warm_start (bool) – Whether to warm start from previous fit
*args (list) – Additional arguments
**kwargs (dict) – Additional keyword arguments
- predict(X: ndarray) Tuple[ndarray, ndarray][source]
Predict reconstructed values for inputs.
- Parameters:
X (np.ndarray) – Input matrix.
- Returns:
Reconstructed values and masks.
- Return type:
Tuple[np.ndarray, np.ndarray]
- set_fit_request(*, X_train: bool | None | str = '$UNCHANGED$', batch_size: bool | None | str = '$UNCHANGED$', callbacks: bool | None | str = '$UNCHANGED$', drop_last: bool | None | str = '$UNCHANGED$', eval_name: bool | None | str = '$UNCHANGED$', eval_set: bool | None | str = '$UNCHANGED$', loss_fn: bool | None | str = '$UNCHANGED$', max_epochs: bool | None | str = '$UNCHANGED$', num_workers: bool | None | str = '$UNCHANGED$', patience: bool | None | str = '$UNCHANGED$', pin_memory: bool | None | str = '$UNCHANGED$', pretraining_ratio: bool | None | str = '$UNCHANGED$', virtual_batch_size: bool | None | str = '$UNCHANGED$', warm_start: bool | None | str = '$UNCHANGED$', weights: bool | None | str = '$UNCHANGED$') TabNetPretrainer
Configure whether metadata should be requested to be passed to the
fitmethod.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True(seesklearn.set_config()). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed tofitif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it tofit.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- Parameters:
X_train (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
X_trainparameter infit.batch_size (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
batch_sizeparameter infit.callbacks (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
callbacksparameter infit.drop_last (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
drop_lastparameter infit.eval_name (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
eval_nameparameter infit.eval_set (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
eval_setparameter infit.loss_fn (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
loss_fnparameter infit.max_epochs (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
max_epochsparameter infit.num_workers (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
num_workersparameter infit.patience (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
patienceparameter infit.pin_memory (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
pin_memoryparameter infit.pretraining_ratio (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
pretraining_ratioparameter infit.virtual_batch_size (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
virtual_batch_sizeparameter infit.warm_start (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
warm_startparameter infit.weights (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
weightsparameter infit.
- Returns:
self – The updated object.
- Return type:
object
- stack_batches(list_output: List[Tensor], list_embedded_x: List[Tensor], list_obfuscation: List[Tensor]) Tuple[Tensor, Tensor, Tensor][source]
Stack batches of outputs, embeddings, and obfuscations.
- Parameters:
list_output (List[torch.Tensor]) – List of outputs.
list_embedded_x (List[torch.Tensor]) – List of embedded inputs.
list_obfuscation (List[torch.Tensor]) – List of obfuscation masks.
- Returns:
Stacked outputs, embeddings, and obfuscations.
- Return type:
tuple
- update_fit_params(weights: ndarray) None[source]
Update fit parameters for pretraining.
- Parameters:
weights (np.ndarray) – Sample weights.
Example#
from pytorch_tabnet.pretraining import TabNetPretrainer
import numpy as np
X = np.random.rand(100, 10)
pretrainer = TabNetPretrainer()
pretrainer.fit(X_train=X)