"""TabNet classifier model class and training logic."""
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from ..data_handlers import PredictDataset, TBDataLoader
from ..utils import check_output_dim, infer_output_dim
from .abstract_models import TabSupervisedModel
[docs]
@dataclass
class TabNetClassifier(TabSupervisedModel):
"""TabNet model for classification tasks."""
output_dim: int = None
weight: Any = field(init=False, default=0)
def __post_init__(self) -> None:
"""Initialize the classifier and set default loss and metric."""
super(TabNetClassifier, self).__post_init__()
self._task: str = "classification"
self._default_loss: Any = partial(
torch.nn.functional.cross_entropy,
reduction="none",
)
self._default_metric: str = "accuracy"
[docs]
def weight_updater(self, weights: Union[bool, Dict[Union[str, int], Any], Any]) -> Union[bool, Dict[Union[str, int], Any]]:
"""Update class weights for training.
Parameters
----------
weights : bool, dict, or any
Class weights or indicator.
Returns
-------
bool or dict
Updated weights.
"""
if isinstance(weights, int):
return weights # type: ignore
elif isinstance(weights, dict):
return {self.target_mapper[key]: value for key, value in weights.items()}
else:
return weights
[docs]
def prepare_target(self, y: np.ndarray) -> np.ndarray:
"""Map targets using the target mapper.
Parameters
----------
y : np.ndarray
Target array.
Returns
-------
np.ndarray
Mapped target array.
"""
return np.vectorize(self.target_mapper.get)(y)
[docs]
def compute_loss(
self,
y_pred: torch.Tensor,
y_true: torch.Tensor,
w: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Compute the loss for classification.
Parameters
----------
y_pred : torch.Tensor
Network output.
y_true : torch.Tensor
True labels.
w : Optional[torch.Tensor]
Optional sample weights.
Returns
-------
torch.Tensor
Loss value.
"""
class_count = None
if isinstance(self.weight, int) and self.weight == 1:
_class_num, class_count = y_true.long().unique(return_counts=True)
class_count[class_count == 0] = 1
loss = self.loss_fn(y_pred, y_true.long(), weight=1 / class_count if class_count is not None else None)
if w is not None:
loss = loss * w
return loss.mean()
[docs]
def update_fit_params( # type: ignore[override]
self,
X_train: np.ndarray,
y_train: np.ndarray,
eval_set: List[Tuple[np.ndarray, np.ndarray]],
weights: Union[bool, Dict[str, Any]],
) -> None:
"""Update fit parameters for classification.
Parameters
----------
X_train : np.ndarray
Training data.
y_train : np.ndarray
Training targets.
eval_set : list
List of evaluation sets.
weights : bool or dict
Class weights.
"""
output_dim: int
train_labels: List[Any]
output_dim, train_labels = infer_output_dim(y_train)
for _X, y in eval_set:
check_output_dim(train_labels, y)
self.output_dim: int = output_dim
self._default_metric = "auc" if self.output_dim == 2 else "accuracy"
self.classes_: List[Any] = train_labels
self.target_mapper: Dict[Any, int] = {class_label: index for index, class_label in enumerate(self.classes_)}
self.preds_mapper: Dict[str, Any] = {str(index): class_label for index, class_label in enumerate(self.classes_)}
# self.updated_weights: Union[bool, Dict[Union[str, int], Any]] = self.weight_updater(weights)
[docs]
def stack_batches(
self,
list_y_true: List[torch.Tensor],
list_y_score: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Stack batches of true and predicted values.
Parameters
----------
list_y_true : List[torch.Tensor]
List of true labels for each batch.
list_y_score : List[torch.Tensor]
List of predicted scores for each batch.
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
Stacked true labels and predicted scores.
"""
y_true: torch.Tensor = torch.hstack(list_y_true)
y_score: torch.Tensor = torch.vstack(list_y_score)
y_score = torch.nn.Softmax(dim=1)(y_score)
return y_true, y_score
[docs]
def predict_func(self, outputs: np.ndarray) -> np.ndarray:
"""Convert network outputs to class predictions.
Parameters
----------
outputs : np.ndarray
Network outputs.
Returns
-------
np.ndarray
Predicted classes.
"""
outputs = np.argmax(outputs, axis=1)
return np.vectorize(self.preds_mapper.get)(outputs.astype(str))
[docs]
def predict_proba(self, X: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
"""Predict class probabilities for classification.
Parameters
----------
X : torch.Tensor or np.ndarray
Input data.
Returns
-------
np.ndarray
Probability predictions.
"""
self.network.eval()
dataloader = TBDataLoader(
name="predict",
dataset=PredictDataset(X),
batch_size=self.batch_size,
# shuffle=False,
predict=True,
)
results: List[np.ndarray] = []
with torch.no_grad():
for _batch_nb, (data, _, _) in enumerate(dataloader): # type: ignore
data = data.to(self.device).float() # type: ignore
output: torch.Tensor
_M_loss: torch.Tensor
output, _M_loss = self.network(data)
predictions: np.ndarray = (
torch.nn.Softmax(dim=1)(output).cpu().detach().numpy()
) # todo: replace with pytorch's torch.vstack
results.append(predictions)
res: np.ndarray = np.vstack(results) # todo: replace with pytorch's torch.vstack
return res