Custom Metrics and Losses
This guide demonstrates how to use custom evaluation metrics and custom loss functions with TabNet. Each example is standalone.
Custom Evaluation Metric Example
import numpy as np
from pytorch_tabnet.tab_model import TabNetClassifier
from pytorch_tabnet.metrics import Metric
import torch
from torcheval.metrics.functional import binary_auroc
class Gini(Metric):
def __init__(self):
self._name = "gini"
self._maximize = True
def __call__(self, y_true, y_score, weights=None):
# Ensure tensors are on CPU and correct type
y_true = y_true.detach().cpu().float()
y_score = y_score.detach().cpu().float()
# If y_score is 2D, take the second column (prob for class 1)
if y_score.ndim == 2 and y_score.shape[1] == 2:
y_score = y_score[:, 1]
auc = binary_auroc(y_score, y_true)
return max(2*auc.item() - 1, 0.)
# Generate dummy data
X_train = np.random.rand(100, 10)
y_train = np.random.randint(0, 2, 100)
X_valid = np.random.rand(20, 10)
y_valid = np.random.randint(0, 2, 20)
clf = TabNetClassifier()
clf.fit(X_train, y_train, eval_set=[(X_valid, y_valid)], eval_metric=[Gini])
Custom Loss Function Example
import numpy as np
import torch
import torch.nn as nn
from pytorch_tabnet.tab_model import TabNetRegressor
# Generate dummy data
X_train = np.random.rand(100, 10).astype(np.float32)
y_train = np.random.rand(100).astype(np.float32).reshape(-1, 1)
X_valid = np.random.rand(20, 10).astype(np.float32)
y_valid = np.random.rand(20).astype(np.float32).reshape(-1, 1)
import torch
def custom_loss(y_true, y_pred):
loss = nn.functional.mse_loss(y_pred, y_true, reduction="none")
loss = loss.mean()
return loss + 0.1 * torch.mean(torch.abs(y_pred))
reg = TabNetRegressor()
reg.fit(X_train, y_train, eval_set=[(X_valid, y_valid)], loss_fn=custom_loss)