mentor.trainers
Built-in training strategies for Mentee.
A MentorTrainer is a pure-Python strategy object
(not an nn.Module) that is composed into a Mentee via self.trainer.
It separates:
State — the optimizer, LR scheduler, and default loss function cached by
create_train_objects()(exposed as read-only properties).Logic — the forward/loss/metrics computation in
default_training_step()anddefault_validate_step()(classmethods, callable without a trainer instance).
When self.trainer is set on a Mentee, its training_step and
validation_step automatically delegate to the trainer’s classmethods,
injecting the cached loss_fn.
MentorTrainer
- class mentor.trainers.MentorTrainer[source]
Bases:
ABCAbstract base class for Mentee training strategies.
A trainer separates state (the optimizer, LR scheduler, and loss function produced by
create_train_objects()) from logic (the forward/loss/metrics computation indefault_training_step()anddefault_validate_step()).State — per-instance,
Noneuntilcreate_train_objects()is called:Logic — class-level, callable without an instance:
default_training_step()(cls, model, batch, loss_fn=None)— abstract, must be overridden.default_validate_step()(cls, model, batch, loss_fn=None)— default unpacksdefault_training_step(); override when the validation pass differs from training.
When a trainer is assigned to a
Mentee, the model’straining_stepandvalidation_stepautomatically route to these classmethods with the cachedloss_fnpre-injected.- property optimizer: Any | None
Optimizer cached by the last
create_train_objects()call.
- property lr_scheduler: Any | None
LR scheduler cached by the last
create_train_objects()call.
- property loss_fn: Any | None
Default loss callable registered by
create_train_objects().
- abstractmethod classmethod default_training_step(model, batch, loss_fn=None)[source]
Compute the loss and metrics for one training batch.
This is a classmethod so it can be inspected and called without a trainer instance. The
Menteeinjects the trainer’s cachedloss_fnbefore calling this method.- Parameters:
model (Mentee) – The model being trained.
batch (Any) – One element from the training DataLoader.
loss_fn (callable, optional) – Effective loss function — either an explicit override or the cached default forwarded by the Mentee.
- Returns:
loss (torch.Tensor) – Scalar differentiable loss.
metrics (dict[str, float]) – Scalar metrics. The first key is the principal metric used by
validate_epoch()to select the best checkpoint. It must be a metric where higher is better (e.g. accuracy, F1, R²). Never put a loss or error value first — those are lower-is-better and would cause the untrained model to be permanently recorded as “best”.Correct:
return loss, {"acc": acc, "loss": loss.item()} # acc first
Wrong:
return loss, {"loss": loss.item(), "acc": acc} # loss first!
- Return type:
- classmethod default_validate_step(model, batch, loss_fn=None)[source]
Evaluate one validation batch.
Default implementation calls
default_training_step()and strips the loss tensor, returning only the metrics dict. Override when the validation forward pass differs from training.
- abstractmethod create_train_objects(model, lr=0.001, step_size=10, gamma=0.1, loss_fn=None, overwrite_default_loss=False)[source]
Create and cache the optimizer, LR scheduler, and default loss.
- Parameters:
model (Mentee) – Model whose
parameters()are passed to the optimizer.lr (float) – Standard Adam + StepLR hyperparameters.
step_size (int) – Standard Adam + StepLR hyperparameters.
gamma (float) – Standard Adam + StepLR hyperparameters.
loss_fn (callable, optional) – Loss to register as default.
overwrite_default_loss (bool, optional) – Replace an existing cached loss when
True.
- Returns:
{"optimizer": ..., "lr_scheduler": ..., "loss_fn": ...}- Return type:
Classifier
- class mentor.trainers.Classifier[source]
Bases:
MentorTrainerTraining strategy for multi-class classification with cross-entropy loss.
default_training_step()computes cross-entropy loss and top-1 accuracy.create_train_objects()registersnn.CrossEntropyLoss()automatically.The effective loss passed to
default_training_step()is resolved by the owningMenteein this order:Explicit
loss_fnargument tomodel.training_step.loss_fncached bycreate_train_objects().torch.nn.functional.cross_entropy()(stateless fallback).
Examples
>>> class MyNet(Mentee): ... def __init__(self, num_classes=10): ... super().__init__(num_classes=num_classes) ... self.fc = nn.Linear(128, num_classes) ... self.trainer = Classifier() ... def forward(self, x): ... return self.fc(x.flatten(1)) >>> model = MyNet(num_classes=5) >>> model.create_train_objects(lr=1e-3) >>> isinstance(model.loss_fn, nn.CrossEntropyLoss) True
- classmethod default_training_step(model, batch, loss_fn=None)[source]
Cross-entropy loss and top-1 accuracy.
Regressor
- class mentor.trainers.Regressor[source]
Bases:
MentorTrainerTraining strategy for regression with mean-squared-error loss.
default_training_step()computes MSE loss and RMSE metric.create_train_objects()registersnn.MSELoss()automatically.Targets are cast to
floatautomatically.Examples
>>> class MyNet(Mentee): ... def __init__(self, in_features=10): ... super().__init__(in_features=in_features) ... self.fc = nn.Linear(in_features, 1) ... self.trainer = Regressor() ... def forward(self, x): ... return self.fc(x).squeeze(-1) >>> model = MyNet(in_features=8) >>> model.create_train_objects(lr=1e-3) >>> isinstance(model.loss_fn, nn.MSELoss) True
- classmethod default_validate_step(model, batch, loss_fn=None)[source]
MSE loss and RMSE metric, with
neg_rmseas principal metric.Returns
neg_rmse(=-rmse) as the first key so thatvalidate_epoch()can select the best epoch by maximising it (higher neg_rmse <-> lower RMSE).