Custom Metrics and Losses

This guide demonstrates how to use custom evaluation metrics and custom loss functions with TabNet. Each example is standalone.

Custom Evaluation Metric Example

import numpy as np
from pytorch_tabnet.tab_model import TabNetClassifier
from pytorch_tabnet.metrics import Metric
import torch
from torcheval.metrics.functional import binary_auroc

class Gini(Metric):
    def __init__(self):
        self._name = "gini"
        self._maximize = True

    def __call__(self, y_true, y_score, weights=None):
        # Ensure tensors are on CPU and correct type
        y_true = y_true.detach().cpu().float()

        y_score = y_score.detach().cpu().float()

        # If y_score is 2D, take the second column (prob for class 1)
        if y_score.ndim == 2 and y_score.shape[1] == 2:
            y_score = y_score[:, 1]
        auc = binary_auroc(y_score, y_true)
        return max(2*auc.item() - 1, 0.)

# Generate dummy data
X_train = np.random.rand(100, 10)
y_train = np.random.randint(0, 2, 100)
X_valid = np.random.rand(20, 10)
y_valid = np.random.randint(0, 2, 20)

clf = TabNetClassifier()
clf.fit(X_train, y_train, eval_set=[(X_valid, y_valid)], eval_metric=[Gini])

Custom Loss Function Example

import numpy as np
import torch
import torch.nn as nn
from pytorch_tabnet.tab_model import TabNetRegressor

# Generate dummy data
X_train = np.random.rand(100, 10).astype(np.float32)
y_train = np.random.rand(100).astype(np.float32).reshape(-1, 1)
X_valid = np.random.rand(20, 10).astype(np.float32)
y_valid = np.random.rand(20).astype(np.float32).reshape(-1, 1)

import torch
def custom_loss(y_true, y_pred):
    loss = nn.functional.mse_loss(y_pred, y_true, reduction="none")
    loss = loss.mean()
    return loss + 0.1 * torch.mean(torch.abs(y_pred))

reg = TabNetRegressor()
reg.fit(X_train, y_train, eval_set=[(X_valid, y_valid)], loss_fn=custom_loss)