Source code for pytorch_tabnet.tab_models.multitask

"""Multitask learning utilities for TabNet."""

from dataclasses import dataclass
from functools import partial
from typing import List, Optional, Tuple, Union

import numpy as np

# import scipy
import torch

from ..data_handlers import TBDataLoader, UnifiedDataset
from ..error_handlers import check_output_dim, filter_weights
from ..utils import infer_multitask_output

# from torch.utils.data import DataLoader
from .abstract_models import TabSupervisedModel


[docs] @dataclass class TabNetMultiTaskClassifier(TabSupervisedModel): """TabNet model for multitask classification tasks.""" output_dim: List[int] = None def __post_init__(self) -> None: """Initialize the multitask classifier and set default loss and metric.""" super(TabNetMultiTaskClassifier, self).__post_init__() self._task = "classification" self._default_loss = partial(torch.nn.functional.cross_entropy, reduction="none") self._default_metric = "logloss"
[docs] def prepare_target(self, y: np.ndarray) -> np.ndarray: """Map targets for each task using the target mappers. Parameters ---------- y : np.ndarray Target array with shape (n_samples, n_tasks). Returns ------- np.ndarray Mapped target array. """ y_mapped = y.copy() for task_idx in range(y.shape[1]): task_mapper = self.target_mapper[task_idx] y_mapped[:, task_idx] = np.vectorize(task_mapper.get)(y[:, task_idx]) return y_mapped
[docs] def compute_loss( self, y_pred: List[torch.Tensor], y_true: torch.Tensor, w: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Compute the loss according to network output and targets. Parameters ---------- y_pred : list of torch.Tensor Output of network for each task. y_true : torch.Tensor Targets label encoded. w : Optional[torch.Tensor] Optional sample weights. Returns ------- torch.Tensor Output of loss function(s). """ loss = 0.0 y_true = y_true.long() if isinstance(self.loss_fn, list): # if you specify a different loss for each task for task_loss, task_output, task_id in zip(self.loss_fn, y_pred, range(len(self.loss_fn)), strict=False): t_loss = task_loss(task_output, y_true[:, task_id]) if w is not None: t_loss *= w loss += t_loss.mean() else: # same loss function is applied to all tasks for task_id, task_output in enumerate(y_pred): t_loss = self.loss_fn(task_output, y_true[:, task_id]) if w is not None: t_loss *= w loss += t_loss.mean() loss /= len(y_pred) return loss
[docs] def stack_batches( self, list_y_true: List[torch.Tensor], list_y_score: List[List[torch.Tensor]], ) -> Tuple[torch.Tensor, List[torch.Tensor]]: """Stack batches of true and predicted values for all tasks. Parameters ---------- list_y_true : List[torch.Tensor] List of true labels for each batch. list_y_score : List[List[torch.Tensor]] List of predicted scores for each batch and task. Returns ------- Tuple[torch.Tensor, List[torch.Tensor]] Stacked true labels and list of stacked predicted scores per task. """ y_true = torch.vstack(list_y_true) y_score = [] for i in range(len(self.output_dim)): score = torch.vstack([x[i] for x in list_y_score]) score = torch.nn.Softmax(dim=1)(score) y_score.append(score) return y_true, y_score
[docs] def update_fit_params( self, X_train: np.ndarray, y_train: np.ndarray, eval_set: List[Tuple[np.ndarray, np.ndarray]], weights: Union[np.ndarray, None], ) -> None: """Update fit parameters for multitask classification. Parameters ---------- X_train : np.ndarray or scipy.sparse.csr_matrix Training data. y_train : np.ndarray Training targets. eval_set : list List of evaluation sets. weights : np.ndarray or None Sample weights. """ output_dim, train_labels = infer_multitask_output(y_train) for _, y in eval_set: for task_idx in range(y.shape[1]): check_output_dim(train_labels[task_idx], y[:, task_idx]) self.output_dim = output_dim self.classes_ = train_labels self.target_mapper = [{class_label: index for index, class_label in enumerate(classes)} for classes in self.classes_] self.preds_mapper = [{str(index): str(class_label) for index, class_label in enumerate(classes)} for classes in self.classes_] filter_weights( weights=weights, )
[docs] def predict(self, X: Union[torch.Tensor, np.ndarray]) -> List[np.ndarray]: """Predict the most probable class for each task. Parameters ---------- X : torch.Tensor, np.ndarray, or scipy.sparse.csr_matrix Input data. Returns ------- List[np.ndarray] List of predictions for each task. """ self.network.eval() dataloader = TBDataLoader( name="predict", dataset=UnifiedDataset(X), batch_size=self.batch_size, predict=True, ) results: dict = {} with torch.no_grad(): for data, _, _ in dataloader: # type: ignore data = data.float().to(self.device) output, _ = self.network(data) predictions = [ torch.argmax(torch.nn.Softmax(dim=1)(task_output), dim=1).cpu().detach().numpy().reshape(-1) for task_output in output ] for task_idx in range(len(self.output_dim)): results[task_idx] = results.get(task_idx, []) + [predictions[task_idx]] # stack all task individually results_ = [np.hstack(task_res) for task_res in results.values()] # map all task individually results_ = [np.vectorize(self.preds_mapper[task_idx].get)(task_res.astype(str)) for task_idx, task_res in enumerate(results_)] return results_
[docs] def predict_proba(self, X: Union[torch.Tensor, np.ndarray]) -> List[np.ndarray]: """Predict class probabilities for each task. Parameters ---------- X : torch.Tensor, np.ndarray, or scipy.sparse.csr_matrix Input data. Returns ------- List[np.ndarray] List of probability predictions for each task. """ self.network.eval() dataloader = TBDataLoader( name="predict", dataset=UnifiedDataset(X), batch_size=self.batch_size, predict=True, ) results: dict = {} for data, _, _ in dataloader: # type: ignore data = data.float().to(self.device) output, _ = self.network(data) predictions = [torch.nn.Softmax(dim=1)(task_output).cpu().detach().numpy() for task_output in output] for task_idx in range(len(self.output_dim)): results[task_idx] = results.get(task_idx, []) + [predictions[task_idx]] res = [np.vstack(task_res) for task_res in results.values()] return res