Quickstart

Installation

pip install torch-mentor

Or in editable mode from source:

git clone <repo>
cd mentor
pip install -e .

Subclassing Mentee

Every model is a subclass of Mentee. There are two ways to define training behaviour.

Option A — Built-in trainer (least code)

Assign a Classifier or Regressor trainer to self.trainer and only implement forward. The trainer supplies training_step and validation_step automatically.

import torch.nn as nn
from mentor import Mentee, Classifier

class MyClassifier(Mentee):
    def __init__(self, num_classes: int = 10):
        super().__init__(num_classes=num_classes)
        self.fc = nn.Linear(128, num_classes)
        self.trainer = Classifier()   # cross-entropy loss + accuracy out of the box

    def forward(self, x):
        return self.fc(x)
import torch.nn as nn
from mentor import Mentee, Regressor

class MyRegressor(Mentee):
    def __init__(self, in_features: int = 8):
        super().__init__(in_features=in_features)
        self.fc = nn.Linear(in_features, 1)
        self.trainer = Regressor()    # MSE loss + RMSE metric out of the box

    def forward(self, x):
        return self.fc(x).squeeze(-1)

Option B — Custom training step

Override training_step() and optionally validation_step() for full control.

import torch
import torch.nn as nn
import torch.nn.functional as F
from mentor import Mentee

class MyClassifier(Mentee):
    def __init__(self, num_classes: int = 10):
        super().__init__(num_classes=num_classes)   # kwargs stored in checkpoint
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        return self.fc(x)

    def training_step(self, batch):
        x, y = batch
        x, y = x.to(self.device), y.to(self.device)
        loss = F.cross_entropy(self(x), y)
        return loss, {"loss": loss.item()}

    def validation_step(self, batch):
        x, y = batch
        x, y = x.to(self.device), y.to(self.device)
        acc = (self(x).argmax(1) == y).float().mean().item()
        return {"acc": acc}

Training loop

The loop is the same regardless of whether you use a built-in trainer or custom step methods.

model = MyClassifier(num_classes=10)
_to = model.create_train_objects(lr=1e-3)

opt, sched = _to["optimizer"], _to["lr_scheduler"]

for epoch in range(20):
    train_metrics = model.train_epoch(train_loader, opt, sched, verbose=True)
    val_metrics   = model.validate_epoch(val_loader)
    model.save("checkpoint.pt", optimizer=opt, lr_scheduler=sched)
    print(f"epoch {model.current_epoch}  "
          f"loss={train_metrics['loss']:.4f}  "
          f"acc={val_metrics['acc']:.4f}")

Alternatively, pass the cached model.optimizer / model.lr_scheduler properties directly instead of unpacking _to:

model.create_train_objects(lr=1e-3)

for epoch in range(20):
    model.train_epoch(train_loader, model.optimizer, model.lr_scheduler)
    model.validate_epoch(val_loader)
    model.save("checkpoint.pt")   # optimizer and scheduler auto-saved from cache

Resuming training

from mentor import Mentee

model, opt, sched = Mentee.resume_training(
    "checkpoint.pt",
    model_class=MyClassifier,
    device="cuda",
    lr=1e-4,
)
model.train_epoch(train_loader, opt, sched)

Inference-only loading

model = Mentee.resume("checkpoint.pt", model_class=MyClassifier)
model.eval()
with torch.no_grad():
    logits = model(x)

Batteries-included inference state

Use register_inference_state() to attach any data computed from the training set (label names, vocabulary, normalisation statistics) so the checkpoint is fully self-contained.

# during training
model.register_inference_state("classes", ["cat", "dog", "bird"])
model.save("checkpoint.pt")

# at inference time — no external config needed
model = Mentee.resume("checkpoint.pt", model_class=MyClassifier)
classes = model.get_inference_state("classes")

Gradient accumulation

Pass pseudo_batch_size to train_epoch() to accumulate gradients over multiple samples before each optimiser step:

# effective batch size = 64, memory footprint of 8
model.train_epoch(train_loader, opt, pseudo_batch_size=8)