"""TabNet model class and training logic."""
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
# from torch.utils.data import DataLoader
from pytorch_tabnet.abstract_model_sub import TabSupervisedModel
from pytorch_tabnet.data_handlers import PredictDataset, TBDataLoader
from pytorch_tabnet.utils import check_output_dim, filter_weights, infer_output_dim
[docs]
@dataclass
class TabNetClassifier(TabSupervisedModel):
"""TabNet model for classification tasks."""
output_dim: int = None
weight: Any = field(init=False, default=0)
def __post_init__(self) -> None:
"""Initialize the classifier and set default loss and metric."""
super(TabNetClassifier, self).__post_init__()
self._task: str = "classification"
self._default_loss: Any = partial(
torch.nn.functional.cross_entropy,
reduction="none",
)
self._default_metric: str = "accuracy"
[docs]
def weight_updater(self, weights: Union[bool, Dict[Union[str, int], Any], Any]) -> Union[bool, Dict[Union[str, int], Any]]:
"""Update class weights for training.
Parameters
----------
weights : bool, dict, or any
Class weights or indicator.
Returns
-------
bool or dict
Updated weights.
"""
if isinstance(weights, int):
return weights # type: ignore
elif isinstance(weights, dict):
return {self.target_mapper[key]: value for key, value in weights.items()}
else:
return weights
[docs]
def prepare_target(self, y: np.ndarray) -> np.ndarray:
"""Map targets using the target mapper.
Parameters
----------
y : np.ndarray
Target array.
Returns
-------
np.ndarray
Mapped target array.
"""
return np.vectorize(self.target_mapper.get)(y)
[docs]
def compute_loss(
self,
y_pred: torch.Tensor,
y_true: torch.Tensor,
w: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Compute the loss for classification.
Parameters
----------
y_pred : torch.Tensor
Network output.
y_true : torch.Tensor
True labels.
w : Optional[torch.Tensor]
Optional sample weights.
Returns
-------
torch.Tensor
Loss value.
"""
class_count = None
if isinstance(self.weight, int) and self.weight == 1:
_class_num, class_count = y_true.long().unique(return_counts=True)
class_count[class_count == 0] = 1
loss = self.loss_fn(y_pred, y_true.long(), weight=1 / class_count if class_count is not None else None)
if w is not None:
loss = loss * w
return loss.mean()
[docs]
def update_fit_params( # type: ignore[override]
self,
X_train: np.ndarray,
y_train: np.ndarray,
eval_set: List[Tuple[np.ndarray, np.ndarray]],
weights: Union[bool, Dict[str, Any]],
) -> None:
"""Update fit parameters for classification.
Parameters
----------
X_train : np.ndarray
Training data.
y_train : np.ndarray
Training targets.
eval_set : list
List of evaluation sets.
weights : bool or dict
Class weights.
"""
output_dim: int
train_labels: List[Any]
output_dim, train_labels = infer_output_dim(y_train)
for _X, y in eval_set:
check_output_dim(train_labels, y)
self.output_dim: int = output_dim
self._default_metric = "auc" if self.output_dim == 2 else "accuracy"
self.classes_: List[Any] = train_labels
self.target_mapper: Dict[Any, int] = {class_label: index for index, class_label in enumerate(self.classes_)}
self.preds_mapper: Dict[str, Any] = {str(index): class_label for index, class_label in enumerate(self.classes_)}
# self.updated_weights: Union[bool, Dict[Union[str, int], Any]] = self.weight_updater(weights)
[docs]
def stack_batches(
self,
list_y_true: List[torch.Tensor],
list_y_score: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Stack batches of true and predicted values.
Parameters
----------
list_y_true : List[torch.Tensor]
List of true labels for each batch.
list_y_score : List[torch.Tensor]
List of predicted scores for each batch.
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
Stacked true labels and predicted scores.
"""
y_true: torch.Tensor = torch.hstack(list_y_true)
y_score: torch.Tensor = torch.vstack(list_y_score)
y_score = torch.nn.Softmax(dim=1)(y_score)
return y_true, y_score
[docs]
def predict_func(self, outputs: np.ndarray) -> np.ndarray:
"""Convert network outputs to class predictions.
Parameters
----------
outputs : np.ndarray
Network outputs.
Returns
-------
np.ndarray
Predicted classes.
"""
outputs = np.argmax(outputs, axis=1)
return np.vectorize(self.preds_mapper.get)(outputs.astype(str))
[docs]
def predict_proba(self, X: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
"""Predict class probabilities for classification.
Parameters
----------
X : torch.Tensor or np.ndarray
Input data.
Returns
-------
np.ndarray
Probability predictions.
"""
self.network.eval()
dataloader = TBDataLoader(
name="predict",
dataset=PredictDataset(X),
batch_size=self.batch_size,
# shuffle=False,
predict=True,
)
results: List[np.ndarray] = []
with torch.no_grad():
for _batch_nb, (data, _, _) in enumerate(dataloader): # type: ignore
data = data.to(self.device).float() # type: ignore
output: torch.Tensor
_M_loss: torch.Tensor
output, _M_loss = self.network(data)
predictions: np.ndarray = (
torch.nn.Softmax(dim=1)(output).cpu().detach().numpy()
) # todo: replace with pytorch's torch.vstack
results.append(predictions)
res: np.ndarray = np.vstack(results) # todo: replace with pytorch's torch.vstack
return res
[docs]
@dataclass
class TabNetRegressor(TabSupervisedModel):
"""TabNet model for regression tasks."""
output_dim: int = None
def __post_init__(self) -> None:
"""Initialize the regressor and set default loss and metric."""
super(TabNetRegressor, self).__post_init__()
self._task: str = "regression"
# self._default_loss: Any = torch.nn.functional.mse_loss
self._default_loss: Any = partial(
torch.nn.functional.mse_loss,
reduction="none",
)
self._default_metric: str = "mse"
[docs]
def prepare_target(self, y: np.ndarray) -> np.ndarray:
"""Return the input as target for regression.
Parameters
----------
y : np.ndarray
Target array.
Returns
-------
np.ndarray
Same as input.
"""
return y
[docs]
def compute_loss(
self,
y_pred: torch.Tensor,
y_true: torch.Tensor,
w: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Compute the loss for regression.
Parameters
----------
y_pred : torch.Tensor
Network output.
y_true : torch.Tensor
True values.
w : Optional[torch.Tensor]
Optional sample weights.
Returns
-------
torch.Tensor
Loss value.
"""
loss = self.loss_fn(
y_pred,
y_true,
)
if len(loss.shape) > 1:
loss = torch.mean(loss, dim=1)
if w is not None:
loss = loss * w
return loss.mean()
[docs]
def update_fit_params(
self,
X_train: np.ndarray,
y_train: np.ndarray,
eval_set: List[Tuple[np.ndarray, np.ndarray]],
weights: Union[bool, np.ndarray],
) -> None:
"""Update fit parameters for regression.
Parameters
----------
X_train : np.ndarray
Training data.
y_train : np.ndarray
Training targets.
eval_set : list
List of evaluation sets.
weights : bool or np.ndarray
Sample weights.
Raises
------
ValueError
If y_train does not have 2 dimensions.
"""
if len(y_train.shape) != 2:
msg: str = (
"Targets should be 2D : (n_samples, n_regression) "
+ f"but y_train.shape={y_train.shape} given.\n"
+ "Use reshape(-1, 1) for single regression."
)
raise ValueError(msg)
self.output_dim: int = y_train.shape[1]
self.preds_mapper: None = None
self.updated_weights: Union[bool, np.ndarray] = weights
filter_weights(self.updated_weights)
[docs]
def predict_func(self, outputs: np.ndarray) -> np.ndarray:
"""Return regression outputs as predictions.
Parameters
----------
outputs : np.ndarray
Network outputs.
Returns
-------
np.ndarray
Regression predictions.
"""
return outputs
[docs]
def stack_batches(
self,
list_y_true: List[torch.Tensor],
list_y_score: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Stack batches of true and predicted values for regression.
Parameters
----------
list_y_true : List[torch.Tensor]
List of true values for each batch.
list_y_score : List[torch.Tensor]
List of predicted values for each batch.
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
Stacked true values and predicted values.
"""
y_true: torch.Tensor = torch.vstack(list_y_true)
y_score: torch.Tensor = torch.vstack(list_y_score)
return y_true, y_score
MultiTabNetRegressor = TabNetRegressor