Data Augmentation and Callbacks

This guide demonstrates how to use data augmentation and callbacks with TabNet. Each example is standalone.

Warning

Deprecation Notice: The augmentations parameter is deprecated and will be removed in a future version.

Data Augmentation Example

import numpy as np
from pytorch_tabnet.tab_model import TabNetClassifier
from pytorch_tabnet.augmentations import ClassificationSMOTE

# Generate dummy data
X_train = np.random.rand(100, 10)
y_train = np.random.randint(0, 2, 100)
X_valid = np.random.rand(20, 10)
y_valid = np.random.randint(0, 2, 20)

# Note: This approach is deprecated and will be removed in a future version
aug = ClassificationSMOTE()
clf = TabNetClassifier()
clf.fit(X_train, y_train, eval_set=[(X_valid, y_valid)], augmentations=aug)

Custom Callback Example

import numpy as np
from pytorch_tabnet.tab_model import TabNetClassifier
from pytorch_tabnet.callbacks import Callback

# Generate dummy data
X_train = np.random.rand(100, 10)
y_train = np.random.randint(0, 2, 100)
X_valid = np.random.rand(20, 10)
y_valid = np.random.randint(0, 2, 20)

class PrintEpochCallback(Callback):
    def on_epoch_end(self, epoch, logs=None):
        print(f"Epoch {epoch} ended.")

clf = TabNetClassifier()
clf.fit(X_train, y_train, eval_set=[(X_valid, y_valid)], callbacks=[PrintEpochCallback()])