mentor.trainers

Built-in training strategies for Mentee.

A MentorTrainer is a pure-Python strategy object (not an nn.Module) that is composed into a Mentee via self.trainer. It separates:

When self.trainer is set on a Mentee, its training_step and validation_step automatically delegate to the trainer’s classmethods, injecting the cached loss_fn.

MentorTrainer

class mentor.trainers.MentorTrainer[source]

Bases: ABC

Abstract base class for Mentee training strategies.

A trainer separates state (the optimizer, LR scheduler, and loss function produced by create_train_objects()) from logic (the forward/loss/metrics computation in default_training_step() and default_validate_step()).

State — per-instance, None until create_train_objects() is called:

Logic — class-level, callable without an instance:

When a trainer is assigned to a Mentee, the model’s training_step and validation_step automatically route to these classmethods with the cached loss_fn pre-injected.

__init__()[source]
Return type:

None

property optimizer: Any | None

Optimizer cached by the last create_train_objects() call.

property lr_scheduler: Any | None

LR scheduler cached by the last create_train_objects() call.

property loss_fn: Any | None

Default loss callable registered by create_train_objects().

abstractmethod classmethod default_training_step(model, batch, loss_fn=None)[source]

Compute the loss and metrics for one training batch.

This is a classmethod so it can be inspected and called without a trainer instance. The Mentee injects the trainer’s cached loss_fn before calling this method.

Parameters:
  • model (Mentee) – The model being trained.

  • batch (Any) – One element from the training DataLoader.

  • loss_fn (callable, optional) – Effective loss function — either an explicit override or the cached default forwarded by the Mentee.

Returns:

  • loss (torch.Tensor) – Scalar differentiable loss.

  • metrics (dict[str, float]) – Scalar metrics. The first key is the principal metric used by validate_epoch() to select the best checkpoint. It must be a metric where higher is better (e.g. accuracy, F1, R²). Never put a loss or error value first — those are lower-is-better and would cause the untrained model to be permanently recorded as “best”.

    Correct:

    return loss, {"acc": acc, "loss": loss.item()}   # acc first
    

    Wrong:

    return loss, {"loss": loss.item(), "acc": acc}   # loss first!
    

Return type:

Tuple[Tensor, Dict[str, float]]

classmethod default_validate_step(model, batch, loss_fn=None)[source]

Evaluate one validation batch.

Default implementation calls default_training_step() and strips the loss tensor, returning only the metrics dict. Override when the validation forward pass differs from training.

Parameters:
  • model (Mentee) – The model being evaluated.

  • batch (Any) – One element from the validation DataLoader.

  • loss_fn (callable, optional) – Effective loss function forwarded by the Mentee.

Returns:

Scalar evaluation metrics.

Return type:

dict[str, float]

abstractmethod create_train_objects(model, lr=0.001, step_size=10, gamma=0.1, loss_fn=None, overwrite_default_loss=False)[source]

Create and cache the optimizer, LR scheduler, and default loss.

Parameters:
  • model (Mentee) – Model whose parameters() are passed to the optimizer.

  • lr (float) – Standard Adam + StepLR hyperparameters.

  • step_size (int) – Standard Adam + StepLR hyperparameters.

  • gamma (float) – Standard Adam + StepLR hyperparameters.

  • loss_fn (callable, optional) – Loss to register as default.

  • overwrite_default_loss (bool, optional) – Replace an existing cached loss when True.

Returns:

{"optimizer": ..., "lr_scheduler": ..., "loss_fn": ...}

Return type:

dict

Classifier

class mentor.trainers.Classifier[source]

Bases: MentorTrainer

Training strategy for multi-class classification with cross-entropy loss.

default_training_step() computes cross-entropy loss and top-1 accuracy. create_train_objects() registers nn.CrossEntropyLoss() automatically.

The effective loss passed to default_training_step() is resolved by the owning Mentee in this order:

  1. Explicit loss_fn argument to model.training_step.

  2. loss_fn cached by create_train_objects().

  3. torch.nn.functional.cross_entropy() (stateless fallback).

Examples

>>> class MyNet(Mentee):
...     def __init__(self, num_classes=10):
...         super().__init__(num_classes=num_classes)
...         self.fc = nn.Linear(128, num_classes)
...         self.trainer = Classifier()
...     def forward(self, x):
...         return self.fc(x.flatten(1))
>>> model = MyNet(num_classes=5)
>>> model.create_train_objects(lr=1e-3)
>>> isinstance(model.loss_fn, nn.CrossEntropyLoss)
True
classmethod default_training_step(model, batch, loss_fn=None)[source]

Cross-entropy loss and top-1 accuracy.

Parameters:
  • model (Mentee) – Classification model; model(x) returns logits.

  • batch (tuple[Tensor, Tensor]) – (inputs, targets) — targets are class indices (long).

  • loss_fn (callable, optional) – Effective loss; falls back to F.cross_entropy when None.

Return type:

Tuple[Tensor, Dict[str, float]]

create_train_objects(model, lr=0.001, step_size=10, gamma=0.1, loss_fn=None, overwrite_default_loss=False)[source]

Adam + StepLR with nn.CrossEntropyLoss as default loss.

Registers nn.CrossEntropyLoss() automatically when no loss is cached yet and loss_fn is None.

Parameters:
Return type:

Dict[str, Any]

Regressor

class mentor.trainers.Regressor[source]

Bases: MentorTrainer

Training strategy for regression with mean-squared-error loss.

default_training_step() computes MSE loss and RMSE metric. create_train_objects() registers nn.MSELoss() automatically.

Targets are cast to float automatically.

Examples

>>> class MyNet(Mentee):
...     def __init__(self, in_features=10):
...         super().__init__(in_features=in_features)
...         self.fc = nn.Linear(in_features, 1)
...         self.trainer = Regressor()
...     def forward(self, x):
...         return self.fc(x).squeeze(-1)
>>> model = MyNet(in_features=8)
>>> model.create_train_objects(lr=1e-3)
>>> isinstance(model.loss_fn, nn.MSELoss)
True
classmethod default_training_step(model, batch, loss_fn=None)[source]

MSE loss and RMSE metric.

Parameters:
  • model (Mentee) – Regression model; output and target must have compatible shapes.

  • batch (tuple[Tensor, Tensor]) – (inputs, targets) — targets cast to float.

  • loss_fn (callable, optional) – Effective loss; falls back to F.mse_loss when None.

Return type:

Tuple[Tensor, Dict[str, float]]

classmethod default_validate_step(model, batch, loss_fn=None)[source]

MSE loss and RMSE metric, with neg_rmse as principal metric.

Returns neg_rmse (= -rmse) as the first key so that validate_epoch() can select the best epoch by maximising it (higher neg_rmse <-> lower RMSE).

Parameters:
  • model (Any)

  • batch (Any)

  • loss_fn (Any | None)

Return type:

Dict[str, float]

create_train_objects(model, lr=0.001, step_size=10, gamma=0.1, loss_fn=None, overwrite_default_loss=False)[source]

Adam + StepLR with nn.MSELoss as default loss.

Registers nn.MSELoss() automatically when no loss is cached yet and loss_fn is None.

Parameters:
Return type:

Dict[str, Any]