Source code for pytorch_tabnet.tab_models.tab_class

"""TabNet classifier model class and training logic."""

from dataclasses import dataclass, field
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch

from ..data_handlers import TBDataLoader, UnifiedDataset
from ..error_handlers import check_output_dim
from ..utils import infer_output_dim
from .abstract_models import TabSupervisedModel


[docs] @dataclass class TabNetClassifier(TabSupervisedModel): """TabNet model for classification tasks.""" output_dim: int = None weight: Any = field(init=False, default=0) def __post_init__(self) -> None: """Initialize the classifier and set default loss and metric.""" super(TabNetClassifier, self).__post_init__() self._task: str = "classification" self._default_loss: Any = partial( torch.nn.functional.cross_entropy, reduction="none", ) self._default_metric: str = "accuracy"
[docs] def weight_updater(self, weights: Union[bool, Dict[Union[str, int], Any], Any]) -> Union[bool, Dict[Union[str, int], Any]]: """Update class weights for training. Parameters ---------- weights : bool, dict, or any Class weights or indicator. Returns ------- bool or dict Updated weights. """ if isinstance(weights, int): return weights # type: ignore elif isinstance(weights, dict): return {self.target_mapper[key]: value for key, value in weights.items()} else: return weights
[docs] def prepare_target(self, y: np.ndarray) -> np.ndarray: """Map targets using the target mapper. Parameters ---------- y : np.ndarray Target array. Returns ------- np.ndarray Mapped target array. """ return np.vectorize(self.target_mapper.get)(y)
[docs] def compute_loss( self, y_pred: torch.Tensor, y_true: torch.Tensor, w: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Compute the loss for classification. Parameters ---------- y_pred : torch.Tensor Network output. y_true : torch.Tensor True labels. w : Optional[torch.Tensor] Optional sample weights. Returns ------- torch.Tensor Loss value. """ class_count = None if isinstance(self.weight, int) and self.weight == 1: _class_num, class_count = y_true.long().unique(return_counts=True) class_count[class_count == 0] = 1 loss = self.loss_fn(y_pred, y_true.long(), weight=1 / class_count if class_count is not None else None) if w is not None: loss = loss * w return loss.mean()
[docs] def update_fit_params( # type: ignore[override] self, X_train: np.ndarray, y_train: np.ndarray, eval_set: List[Tuple[np.ndarray, np.ndarray]], weights: Union[bool, Dict[str, Any]], ) -> None: """Update fit parameters for classification. Parameters ---------- X_train : np.ndarray Training data. y_train : np.ndarray Training targets. eval_set : list List of evaluation sets. weights : bool or dict Class weights. """ output_dim: int train_labels: List[Any] output_dim, train_labels = infer_output_dim(y_train) for _X, y in eval_set: check_output_dim(train_labels, y) self.output_dim: int = output_dim self._default_metric = "auc" if self.output_dim == 2 else "accuracy" self.classes_: List[Any] = train_labels self.target_mapper: Dict[Any, int] = {class_label: index for index, class_label in enumerate(self.classes_)} self.preds_mapper: Dict[str, Any] = {str(index): class_label for index, class_label in enumerate(self.classes_)}
# self.updated_weights: Union[bool, Dict[Union[str, int], Any]] = self.weight_updater(weights)
[docs] def stack_batches( self, list_y_true: List[torch.Tensor], list_y_score: List[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: """Stack batches of true and predicted values. Parameters ---------- list_y_true : List[torch.Tensor] List of true labels for each batch. list_y_score : List[torch.Tensor] List of predicted scores for each batch. Returns ------- Tuple[torch.Tensor, torch.Tensor] Stacked true labels and predicted scores. """ y_true: torch.Tensor = torch.hstack(list_y_true) y_score: torch.Tensor = torch.vstack(list_y_score) y_score = torch.nn.Softmax(dim=1)(y_score) return y_true, y_score
[docs] def predict_func(self, outputs: np.ndarray) -> np.ndarray: """Convert network outputs to class predictions. Parameters ---------- outputs : np.ndarray Network outputs. Returns ------- np.ndarray Predicted classes. """ outputs = np.argmax(outputs, axis=1) return np.vectorize(self.preds_mapper.get)(outputs.astype(str))
[docs] def predict_proba(self, X: Union[torch.Tensor, np.ndarray]) -> np.ndarray: """Predict class probabilities for classification. Parameters ---------- X : torch.Tensor or np.ndarray Input data. Returns ------- np.ndarray Probability predictions. """ self.network.eval() dataloader = TBDataLoader( name="predict", dataset=UnifiedDataset(X), batch_size=self.batch_size, # shuffle=False, predict=True, ) results: List[np.ndarray] = [] with torch.no_grad(): for _batch_nb, (data, _, _) in enumerate(dataloader): # type: ignore data = data.to(self.device).float() # type: ignore output: torch.Tensor _M_loss: torch.Tensor output, _M_loss = self.network(data) predictions: np.ndarray = ( torch.nn.Softmax(dim=1)(output).cpu().detach().numpy() ) # todo: replace with pytorch's torch.vstack results.append(predictions) res: np.ndarray = np.vstack(results) # todo: replace with pytorch's torch.vstack return res