TabNet Pretrainer#

pytorch_tabnet package initialization.

class pytorch_tabnet.TabNetPretrainer(n_d: int = 8, n_a: int = 8, n_steps: int = 3, gamma: float = 1.3, cat_idxs: List[int] = <factory>, cat_dims: List[int] = <factory>, cat_emb_dim: int | List[int] = 1, n_independent: int = 2, n_shared: int = 2, epsilon: float = 1e-15, momentum: float = 0.02, lambda_sparse: float = 0.001, seed: int = 0, clip_value: int = 1, verbose: int = 1, optimizer_fn: Any = <class 'torch.optim.adam.Adam'>, optimizer_params: Dict = <factory>, scheduler_fn: Any = None, scheduler_params: Dict = <factory>, mask_type: str = 'sparsemax', input_dim: int = None, output_dim: List[int] | int = None, device_name: str = 'auto', n_shared_decoder: int = 1, n_indep_decoder: int = 1, grouped_features: List[List[int]] = <factory>, compile_backend: str = '')[source]

Bases: TabModel

Abstract base class for TabNet pretraining models.

compute_loss(output: Tensor, embedded_x: Tensor, obf_vars: Tensor, w: Tensor | None = None) Tensor[source]

Compute the unsupervised loss for pretraining.

Parameters:
  • output (torch.Tensor) – Network output.

  • embedded_x (torch.Tensor) – Embedded input.

  • obf_vars (torch.Tensor) – Obfuscated variables mask.

  • w (Optional[torch.Tensor]) – Optional sample weights.

Returns:

Loss value.

Return type:

torch.Tensor

fit(X_train: ndarray, eval_set: List[ndarray | List[ndarray]] | None = None, eval_name: List[str] | None = None, loss_fn: Callable | None = None, pretraining_ratio: float = 0.5, weights: int | ndarray = 0, max_epochs: int = 100, patience: int = 10, batch_size: int = 1024, virtual_batch_size: int = 128, num_workers: int = 0, drop_last: bool = True, callbacks: List[Callable] | None = None, pin_memory: bool = True, warm_start: bool = False, *args: List, **kwargs: Dict) None[source]

Train the TabNet pretrainer model.

Parameters:
  • X_train (np.ndarray) – Train set to reconstruct in self supervision

  • eval_set (list of np.array) – List of evaluation set

  • eval_name (list of str) – List of eval set names.

  • loss_fn (callable or None) – PyTorch loss function

  • pretraining_ratio (float) – Percentage of features to mask for reconstruction

  • weights (int or np.ndarray) – Sampling weights for each example

  • max_epochs (int) – Maximum number of epochs

  • patience (int) – Early stopping patience

  • batch_size (int) – Training batch size

  • virtual_batch_size (int) – Batch size for Ghost Batch Normalization

  • num_workers (int) – Number of workers for DataLoader

  • drop_last (bool) – Whether to drop last batch

  • callbacks (list of callable) – Custom callbacks

  • pin_memory (bool) – Whether to use pinned memory

  • warm_start (bool) – Whether to warm start from previous fit

  • *args (list) – Additional arguments

  • **kwargs (dict) – Additional keyword arguments

predict(X: ndarray) Tuple[ndarray, ndarray][source]

Predict reconstructed values for inputs.

Parameters:

X (np.ndarray) – Input matrix.

Returns:

Reconstructed values and masks.

Return type:

Tuple[np.ndarray, np.ndarray]

set_fit_request(*, X_train: bool | None | str = '$UNCHANGED$', batch_size: bool | None | str = '$UNCHANGED$', callbacks: bool | None | str = '$UNCHANGED$', drop_last: bool | None | str = '$UNCHANGED$', eval_name: bool | None | str = '$UNCHANGED$', eval_set: bool | None | str = '$UNCHANGED$', loss_fn: bool | None | str = '$UNCHANGED$', max_epochs: bool | None | str = '$UNCHANGED$', num_workers: bool | None | str = '$UNCHANGED$', patience: bool | None | str = '$UNCHANGED$', pin_memory: bool | None | str = '$UNCHANGED$', pretraining_ratio: bool | None | str = '$UNCHANGED$', virtual_batch_size: bool | None | str = '$UNCHANGED$', warm_start: bool | None | str = '$UNCHANGED$', weights: bool | None | str = '$UNCHANGED$') TabNetPretrainer

Configure whether metadata should be requested to be passed to the fit method.

Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with enable_metadata_routing=True (see sklearn.set_config()). Please check the User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to fit if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to fit.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Parameters:
  • X_train (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for X_train parameter in fit.

  • batch_size (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for batch_size parameter in fit.

  • callbacks (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for callbacks parameter in fit.

  • drop_last (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for drop_last parameter in fit.

  • eval_name (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for eval_name parameter in fit.

  • eval_set (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for eval_set parameter in fit.

  • loss_fn (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for loss_fn parameter in fit.

  • max_epochs (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for max_epochs parameter in fit.

  • num_workers (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for num_workers parameter in fit.

  • patience (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for patience parameter in fit.

  • pin_memory (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for pin_memory parameter in fit.

  • pretraining_ratio (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for pretraining_ratio parameter in fit.

  • virtual_batch_size (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for virtual_batch_size parameter in fit.

  • warm_start (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for warm_start parameter in fit.

  • weights (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for weights parameter in fit.

Returns:

self – The updated object.

Return type:

object

stack_batches(list_output: List[Tensor], list_embedded_x: List[Tensor], list_obfuscation: List[Tensor]) Tuple[Tensor, Tensor, Tensor][source]

Stack batches of outputs, embeddings, and obfuscations.

Parameters:
  • list_output (List[torch.Tensor]) – List of outputs.

  • list_embedded_x (List[torch.Tensor]) – List of embedded inputs.

  • list_obfuscation (List[torch.Tensor]) – List of obfuscation masks.

Returns:

Stacked outputs, embeddings, and obfuscations.

Return type:

tuple

update_fit_params(weights: ndarray) None[source]

Update fit parameters for pretraining.

Parameters:

weights (np.ndarray) – Sample weights.

Example#

from pytorch_tabnet.pretraining import TabNetPretrainer
import numpy as np
X = np.random.rand(100, 10)
pretrainer = TabNetPretrainer()
pretrainer.fit(X_train=X)