Quickstart
Installation
pip install torch-mentor
Or in editable mode from source:
git clone <repo>
cd mentor
pip install -e .
Subclassing Mentee
Every model is a subclass of Mentee. There are two ways
to define training behaviour.
Option A — Built-in trainer (least code)
Assign a Classifier or Regressor trainer to
self.trainer and only implement forward. The trainer supplies
training_step and validation_step automatically.
import torch.nn as nn
from mentor import Mentee, Classifier
class MyClassifier(Mentee):
def __init__(self, num_classes: int = 10):
super().__init__(num_classes=num_classes)
self.fc = nn.Linear(128, num_classes)
self.trainer = Classifier() # cross-entropy loss + accuracy out of the box
def forward(self, x):
return self.fc(x)
import torch.nn as nn
from mentor import Mentee, Regressor
class MyRegressor(Mentee):
def __init__(self, in_features: int = 8):
super().__init__(in_features=in_features)
self.fc = nn.Linear(in_features, 1)
self.trainer = Regressor() # MSE loss + RMSE metric out of the box
def forward(self, x):
return self.fc(x).squeeze(-1)
Option B — Custom training step
Override training_step() and optionally
validation_step() for full control.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mentor import Mentee
class MyClassifier(Mentee):
def __init__(self, num_classes: int = 10):
super().__init__(num_classes=num_classes) # kwargs stored in checkpoint
self.fc = nn.Linear(128, num_classes)
def forward(self, x):
return self.fc(x)
def training_step(self, batch):
x, y = batch
x, y = x.to(self.device), y.to(self.device)
loss = F.cross_entropy(self(x), y)
return loss, {"loss": loss.item()}
def validation_step(self, batch):
x, y = batch
x, y = x.to(self.device), y.to(self.device)
acc = (self(x).argmax(1) == y).float().mean().item()
return {"acc": acc}
Training loop
The loop is the same regardless of whether you use a built-in trainer or custom step methods.
model = MyClassifier(num_classes=10)
_to = model.create_train_objects(lr=1e-3)
opt, sched = _to["optimizer"], _to["lr_scheduler"]
for epoch in range(20):
train_metrics = model.train_epoch(train_loader, opt, sched, verbose=True)
val_metrics = model.validate_epoch(val_loader)
model.save("checkpoint.pt", optimizer=opt, lr_scheduler=sched)
print(f"epoch {model.current_epoch} "
f"loss={train_metrics['loss']:.4f} "
f"acc={val_metrics['acc']:.4f}")
Alternatively, pass the cached model.optimizer / model.lr_scheduler
properties directly instead of unpacking _to:
model.create_train_objects(lr=1e-3)
for epoch in range(20):
model.train_epoch(train_loader, model.optimizer, model.lr_scheduler)
model.validate_epoch(val_loader)
model.save("checkpoint.pt") # optimizer and scheduler auto-saved from cache
Resuming training
from mentor import Mentee
model, opt, sched = Mentee.resume_training(
"checkpoint.pt",
model_class=MyClassifier,
device="cuda",
lr=1e-4,
)
model.train_epoch(train_loader, opt, sched)
Inference-only loading
model = Mentee.resume("checkpoint.pt", model_class=MyClassifier)
model.eval()
with torch.no_grad():
logits = model(x)
Batteries-included inference state
Use register_inference_state() to attach any data
computed from the training set (label names, vocabulary, normalisation
statistics) so the checkpoint is fully self-contained.
# during training
model.register_inference_state("classes", ["cat", "dog", "bird"])
model.save("checkpoint.pt")
# at inference time — no external config needed
model = Mentee.resume("checkpoint.pt", model_class=MyClassifier)
classes = model.get_inference_state("classes")
Gradient accumulation
Pass pseudo_batch_size to train_epoch() to
accumulate gradients over multiple samples before each optimiser step:
# effective batch size = 64, memory footprint of 8
model.train_epoch(train_loader, opt, pseudo_batch_size=8)