mentor.mentee

The Mentee class is the single public entry point for building and training models with mentor.

class mentor.Mentee(**constructor_params)[source]

Bases: Module

A torch.nn.Module subclass that bundles training, validation, checkpointing, provenance tracking, and inference state in a single .pt file.

Subclass Mentee and implement at minimum forward(), training_step(), and validation_step(). All other methods have working defaults or raise NotImplementedError with informative messages.

Parameters:

**constructor_params (Any) – Keyword arguments stored verbatim in the checkpoint so the model can be re-instantiated without external scaffolding.

Examples

>>> class MyNet(Mentee):
...     def __init__(self, num_classes=10):
...         super().__init__(num_classes=num_classes)
...         self.fc = torch.nn.Linear(128, num_classes)
...     def forward(self, x):
...         return self.fc(x)
...     def training_step(self, sample):
...         x, y = sample
...         loss = torch.nn.functional.cross_entropy(self(x), y)
...         return loss, {"loss": loss.item()}
...     def validation_step(self, sample):
...         x, y = sample
...         acc = (self(x).argmax(1) == y).float().mean().item()
...         return {"acc": acc}
__init__(**constructor_params)[source]

Initialise internal history buffers and record constructor parameters.

Constructor parameters are stored verbatim in every checkpoint so that resume() can reconstruct the model without any external scaffolding. There are two ways to supply them:

Explicit (classic subclassing)

Pass every argument you want recorded as a keyword argument to super().__init__:

class MyNet(Mentee):
    def __init__(self, num_classes=10, dropout=0.5):
        super().__init__(num_classes=num_classes, dropout=dropout)
        self.fc = nn.Linear(128, num_classes)

Implicit (zero-boilerplate)

Call super().__init__() with no arguments — or let an intermediate base pass its own kwargs upward. This method always walks the entire call stack collecting __init__ frames that operate on the same object, and reads the locals of the topmost such frame. The topmost frame always belongs to the most-derived (concrete) class (type(self)), so all user-defined parameters are captured regardless of inheritance depth or whether an intermediate base forwarded explicit kwargs:

class Base(Mentee):
    def __init__(self, a=1):
        super().__init__()        # Mentee always walks to Child

class Child(Base):
    def __init__(self, a=1, b=2):
        super().__init__()        # constructor_params = {'a': 1, 'b': 2}

The same result holds even when an intermediate base uses explicit passing:

class Base(Mentee):
    def __init__(self, a=1):
        super().__init__(a=a)     # explicit — but walk still runs

class Child(Base):
    def __init__(self, a=1, b=2):
        super().__init__()        # constructor_params still = {'a', 'b'}

The walk stops as soon as either condition below is violated:

  1. The frame’s code object is named __init__ (rules out factory functions, class methods, and calls at module level).

  2. The self local in that frame is the exact same object being constructed here (frame.f_locals['self'] is self), ruling out construction happening inside another object’s __init__.

A third guard prevents capturing locals when Mentee itself is instantiated directly (type(self) is not Mentee).

When no __init__ frame is found (factory function, module-level call), the explicitly provided **constructor_params are kept as-is.

class MyNet(Mentee):
    def __init__(self, num_classes=10, dropout=0.5):
        super().__init__()   # num_classes and dropout captured automatically
        self.fc = nn.Linear(128, num_classes)

The implicit path also captures any **kwargs the child accepted:

class MyNet(Mentee):
    def __init__(self, num_classes=10, **extra):
        super().__init__()   # num_classes + contents of extra are all recorded

When implicit capture is skipped

If the three conditions above are not met (e.g. Mentee() is instantiated directly, or Mentee.__init__ is called from outside an __init__ context), constructor_params is left as whatever was explicitly passed — which may be an empty dict. No error is raised; the checkpoint will simply store {}.

Parameters:

**constructor_params (Any) – Keyword arguments to store. When non-empty, they are used as-is and frame introspection is skipped entirely.

Return type:

None

Notes

Frame introspection relies on inspect.currentframe(), which is guaranteed on CPython (the runtime used by PyTorch in practice) but not mandated by the Python language specification. On alternative implementations such as PyPy the implicit path may silently fall back to an empty dict; use explicit passing if portability matters.

Examples

>>> class MyNet(Mentee):
...     def __init__(self, num_classes=10):
...         super().__init__()          # implicit: num_classes=10 captured
...         self.fc = nn.Linear(128, num_classes)
>>> model = MyNet(num_classes=5)
>>> model._constructor_params
{'num_classes': 5}
>>> class MyNet(Mentee):
...     def __init__(self, num_classes=10):
...         super().__init__(num_classes=num_classes)   # explicit
...         self.fc = nn.Linear(128, num_classes)
>>> model = MyNet(num_classes=5)
>>> model._constructor_params
{'num_classes': 5}
property current_epoch: int

Number of completed training epochs.

Returns:

Equal to len(self._train_history). Zero on a fresh model.

Return type:

int

property total_train_iterations: int

Cumulative number of batches processed across all train_epoch() calls.

Incremented at the end of every epoch before the LR scheduler step. Persisted in every checkpoint and restored on resume.

property layer_names: List[str]

Full dotted paths of every parameter-bearing module, in module order.

These are the names accepted by freeze() and unfreeze(), and are also the node labels shown by mtr_checkpoint -verbose.

Returns:

E.g. ['backbone', 'backbone.layer4', 'backbone.layer4.1.bn2', 'head']. Modules with no parameters (ReLU, Dropout, …) are omitted.

Return type:

list[str]

property device: device

Device on which the model parameters currently reside.

Returns:

Inferred from the first parameter tensor.

Return type:

torch.device

Raises:

StopIteration – If the model has no parameters (bare Mentee with no submodules).

property optimizer: Any | None

The optimizer produced by the last create_train_objects() call.

When a trainer is set, returns trainer.optimizer. Otherwise returns the locally cached _optimizer. None until create_train_objects() has been called.

property lr_scheduler: Any | None

The LR scheduler produced by the last create_train_objects() call.

When a trainer is set, returns trainer.lr_scheduler. Otherwise returns the locally cached _lr_scheduler. None until create_train_objects() has been called.

property loss_fn: Any | None

The default loss function registered by create_train_objects().

When a trainer is set, returns trainer.loss_fn. Otherwise returns _default_loss_fn. None until a loss has been registered.

register_inference_state(key, value)[source]

Store an arbitrary picklable object needed at inference time.

Unlike constructor_params, inference state is typically computed from data (e.g. a fitted label encoder, vocabulary, or normalisation statistics) and may be large. It is serialised transparently inside the checkpoint alongside the model weights.

Parameters:
  • key (str) – Identifier used to retrieve the value with get_inference_state().

  • value (Any) – Any picklable Python object (dict, list, tensor, sklearn transformer, …).

Return type:

None

Examples

>>> model.register_inference_state("classes", ["cat", "dog", "bird"])
>>> model.register_inference_state("mean", torch.tensor([0.485, 0.456, 0.406]))
get_inference_state(key, default=None)[source]

Retrieve a value previously stored with register_inference_state().

Parameters:
Returns:

The stored object, or default if the key is absent.

Return type:

Any

Examples

>>> classes = model.get_inference_state("classes", default=[])
__repr__()[source]

Return repr(self).

Return type:

str

__str__()[source]

Return str(self).

Return type:

str

forward(*args, **kwargs)[source]

Forward pass — must be overridden by subclasses.

Parameters:
  • *args (Any) – Positional inputs (typically a batch tensor).

  • **kwargs (Any) – Keyword inputs.

Returns:

Model output (logits, embeddings, sequences, …).

Return type:

Any

Raises:

NotImplementedError – Always raised by the base implementation.

training_step(sample, loss_fn=None)[source]

Compute the loss for a single training sample or mini-batch.

Called inside train_epoch(). The returned tensor must be differentiable with respect to the model parameters.

Parameters:

sample (Any) – One element yielded by the training DataLoader.

Returns:

  • loss (torch.Tensor) – Scalar loss tensor (requires_grad=True).

  • metrics (dict[str, float]) – Scalar metrics to accumulate and log. The first key is treated as the principal metric by validate_epoch() for best-model tracking.

Raises:

NotImplementedError – Always raised by the base implementation.

Return type:

Tuple[Tensor, Dict[str, float]]

Examples

>>> def training_step(self, sample):
...     x, y = sample
...     loss = F.cross_entropy(self(x.to(self.device)), y.to(self.device))
...     return loss, {"loss": loss.item()}
validation_step(sample, loss_fn=None)[source]

Evaluate the model on a single validation sample or mini-batch.

Defaults to calling training_step() with the same arguments, so subclasses that only implement training_step() get validation for free. Override when the validation forward pass differs from training (e.g. different augmentation, TTA, beam search).

Called inside validate_epoch() under torch.no_grad(). The first key of the returned dict is used as the principal metric when comparing epochs for best-model selection.

Parameters:
  • sample (Any) – One element yielded by the validation DataLoader.

  • loss_fn (callable, optional) – Loss function forwarded to training_step().

Returns:

Scalar evaluation metrics (may include "loss").

Return type:

dict[str, float]

Examples

>>> # default: no override needed if training_step covers both
>>> def validation_step(self, sample, loss_fn=None):  # custom override
...     x, y = sample
...     logits = self(x.to(self.device))
...     acc = (logits.argmax(1) == y.to(self.device)).float().mean().item()
...     return {"acc": acc}
preprocess(raw_input)[source]

Transform a raw input into a model-ready tensor.

Override to make the checkpoint self-contained for inference. Use get_inference_state() to access tokenizers, normalisation statistics, or other data-derived artefacts.

Parameters:

raw_input (Any) – Raw data (PIL image, string, numpy array, …).

Returns:

Model-ready tensor or batch.

Return type:

Any

Raises:

NotImplementedError – Raised by the base implementation.

Examples

>>> def preprocess(self, raw_input):
...     mean = self.get_inference_state("mean")
...     return (torch.tensor(raw_input) - mean) / std
decode(model_output)[source]

Transform raw model output into a human-readable result.

Override to make the checkpoint self-contained for inference. Use get_inference_state() to access label maps, alphabets, or beam-search decoders.

Parameters:

model_output (Any) – Raw output from forward().

Returns:

Human-readable prediction (class name, decoded string, bounding box, …).

Return type:

Any

Raises:

NotImplementedError – Raised by the base implementation.

Examples

>>> def decode(self, model_output):
...     idx = model_output.argmax(1).item()
...     return self.get_inference_state("classes")[idx]
get_output_schema()[source]

Describe the output space as a serialisable dict.

The returned dict is embedded in the checkpoint and displayed by mtr_checkpoint. Override to self-document what the model produces.

Returns:

Arbitrary JSON-serialisable description. Common keys: type, num_classes, classes, alphabet. Returns {} by default.

Return type:

dict[str, Any]

Examples

>>> def get_output_schema(self):
...     return {"type": "classification",
...             "classes": self.get_inference_state("classes")}
get_preprocessing_info()[source]

Describe preprocessing requirements as a serialisable dict.

The returned dict is embedded in the checkpoint and displayed by mtr_checkpoint. Override to self-document expected inputs.

Returns:

Arbitrary JSON-serialisable description. Common keys: input_size, mean, std, resize. Returns {} by default.

Return type:

dict[str, Any]

Examples

>>> def get_preprocessing_info(self):
...     return {"input_size": [1, 28, 28],
...             "mean": [0.1307], "std": [0.3081]}
create_train_objects(lr=0.001, step_size=10, gamma=0.1, loss_fn=None, overwrite_default_loss=False)[source]

Create training objects and (optionally) set the default loss function.

Returns a dict with "optimizer", "lr_scheduler", and "loss_fn" keys. Calling this method more than once is safe — by default it will not replace a previously set default loss (overwrite_default_loss=False), so a parametric loss that has already been partially trained is preserved across optimizer resets.

Override to substitute a different optimiser or scheduler; the dict structure must be preserved.

Parameters:
  • lr (float, optional) – Initial learning rate for Adam. Default is 1e-3.

  • step_size (int, optional) – Period (in epochs) for the StepLR decay. Default is 10.

  • gamma (float, optional) – Multiplicative decay factor for StepLR. Default is 0.1.

  • loss_fn (callable, optional) – Loss function to register as the default. If None and no default is currently set, _default_loss_fn remains None (which means training_step() must either provide one or raise its own error).

  • overwrite_default_loss (bool, optional) – If True, always replace the existing default loss with the newly supplied loss_fn. If False (default) and a default is already set, the existing default is preserved even when loss_fn is provided. Set to True when intentionally switching loss functions mid-training.

Returns:

{"optimizer": Adam, "lr_scheduler": StepLR, "loss_fn": <fn or None>}

Return type:

dict

Examples

>>> train_objs = model.create_train_objects(lr=1e-4, step_size=5,
...                                         loss_fn=nn.CrossEntropyLoss())
>>> train_objs["optimizer"], train_objs["lr_scheduler"]
(Adam ..., StepLR ...)
>>> # second call with overwrite_default_loss=False keeps the first loss
>>> train_objs2 = model.create_train_objects(lr=1e-5)
>>> train_objs2["loss_fn"] is train_objs["loss_fn"]
True
train_epoch(dataset, optimizer, lr_scheduler=None, pseudo_batch_size=1, memfail='raise', tensorboard_writer=None, verbose=False, refresh_freq=20, batch_size=None, collate_fn=None, num_workers=0, shuffle=True, amp=False)[source]

Train the model for one full epoch.

Iterates over dataset, calls training_step() for each batch, and accumulates gradients for pseudo_batch_size batches before calling optimizer.step(). Appends the epoch metrics to _train_history, incrementing current_epoch.

dataset may be a DataLoader (used directly) or a Dataset / any sized iterable (wrapped automatically using batch_size, collate_fn, num_workers, and shuffle). When a DataLoader is passed the four loader kwargs are ignored.

Parameters:
  • dataset (DataLoader or Dataset) – Batched DataLoader or an unbatched Dataset to be wrapped.

  • optimizer (torch.optim.Optimizer) – Optimiser to use for parameter updates.

  • lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional) – Scheduler stepped once at the end of the epoch.

  • pseudo_batch_size (int, optional) – Number of batches over which gradients are accumulated before each optimizer.step(). Default is 1.

  • memfail ({'raise', 'skip'}, optional) – Policy when training_step() raises MemoryError. 'raise' propagates immediately; 'skip' counts the failure and continues. Default is 'raise'.

  • tensorboard_writer (torch.utils.tensorboard.SummaryWriter, optional) – If provided, each metric is logged under train/<metric>.

  • verbose (bool, optional) – Show a tqdm progress bar. Default is False.

  • refresh_freq (int, optional) – Progress-bar postfix update interval (in batches). Default is 20.

  • batch_size (int, optional) – Batch size used when dataset is not a DataLoader. Defaults to 1.

  • collate_fn (callable, optional) – Custom collate function forwarded to the DataLoader when dataset is not already a DataLoader.

  • num_workers (int, optional) – Number of DataLoader worker processes. Default is 0 (main-process loading).

  • shuffle (bool, optional) – Whether to shuffle samples when building a DataLoader from a Dataset. Default is True. Ignored when dataset is already a DataLoader.

  • amp (bool, optional) – Enable automatic mixed precision via torch.autocast and torch.amp.GradScaler. The scaler is cached on the model as _grad_scaler so its loss-scale adapts correctly across epochs. Default is False.

Returns:

Per-metric averages over the epoch, plus memfails (count of skipped batches).

Return type:

dict[str, float]

Raises:

MemoryError – When memfail is 'raise' and a batch triggers OOM.

Examples

>>> _to = model.create_train_objects(lr=1e-3)
>>> # from a DataLoader (existing usage)
>>> metrics = model.train_epoch(train_loader, _to["optimizer"], pseudo_batch_size=4)
>>> # from a Dataset (new usage)
>>> metrics = model.train_epoch(train_dataset, _to["optimizer"], batch_size=32, shuffle=True)
>>> print(f"epoch {model.current_epoch}  loss={metrics['loss']:.4f}")
validate_epoch(dataset, recalculate=False, memfail='raise', tensorboard_writer=None, verbose=False, refresh_freq=20, batch_size=None, collate_fn=None, num_workers=0)[source]

Validate the model at the current epoch.

Results are cached in _validate_history keyed by epoch. Calling this method twice for the same epoch returns the cached dict without re-running inference, unless recalculate is True.

If the principal metric (first key of the returned dict) exceeds all previous epochs, the current weights are saved to _best_weights_so_far.

dataset may be a DataLoader (used directly) or a Dataset / any sized iterable (wrapped automatically with batch_size and collate_fn). Shuffle is always False for validation.

Parameters:
  • dataset (DataLoader or Dataset) – Batched DataLoader or an unbatched Dataset to be wrapped.

  • recalculate (bool, optional) – Force re-evaluation even if this epoch was already validated. Default is False.

  • memfail ({'raise', 'skip'}, optional) – Policy for MemoryError inside validation_step(). Default is 'raise'.

  • tensorboard_writer (torch.utils.tensorboard.SummaryWriter, optional) – If provided, metrics are logged under val/<metric>.

  • verbose (bool, optional) – Show a tqdm progress bar. Default is False.

  • refresh_freq (int, optional) – Progress-bar postfix update interval. Default is 20.

  • batch_size (int, optional) – Batch size used when dataset is not a DataLoader. Defaults to 1.

  • collate_fn (callable, optional) – Custom collate function forwarded to the DataLoader when dataset is not already a DataLoader.

  • num_workers (int, optional) – Number of DataLoader worker processes. Default is 0.

Returns:

Per-metric averages, plus memfails.

Return type:

dict[str, float]

Raises:

MemoryError – When memfail is 'raise' and a batch triggers OOM.

Examples

>>> # from a DataLoader (existing usage)
>>> val_metrics = model.validate_epoch(val_loader)
>>> # from a Dataset (new usage)
>>> val_metrics = model.validate_epoch(val_dataset, batch_size=64)
>>> print(f"acc={val_metrics['acc']:.4f}  best_epoch={model._best_epoch_so_far}")
fit(train_data, val_data=None, epochs=1, lr=0.001, batch_size=None, collate_fn=None, num_workers=0, pseudo_batch_size=1, checkpoint_path=None, tensorboard_dir=None, verbose=False, memfail='raise', device=None, patience=None, amp=False, save_freq=1, validate_freq=1, report_wandb=False, report_gradio=False, training_label=None)[source]

Train and optionally validate for a fixed number of epochs.

A convenience wrapper around train_epoch(), validate_epoch(), and save() that drives the full training loop in one call. It is equivalent to writing the loop manually and is provided for cases where you do not need to insert custom logic between epochs.

If optimizer is None when fit is called, create_train_objects() is called automatically with the supplied lr. If training objects already exist (e.g. a previous call to create_train_objects() or resume_training()), they are reused unchanged.

Parameters:
  • train_data (DataLoader or Dataset) – Training data — passed directly to train_epoch().

  • val_data (DataLoader or Dataset, optional) – Validation data — passed to validate_epoch() after each epoch. Skipped when None.

  • epochs (int, optional) – Number of epochs to train. Default is 1.

  • lr (float, optional) – Learning rate passed to create_train_objects() when no optimizer exists yet. Ignored if training objects are already set up. Default is 1e-3.

  • batch_size (int, optional) – Batch size used when train_data or val_data is not already a DataLoader.

  • collate_fn (callable, optional) – Custom collate function forwarded to the DataLoader.

  • num_workers (int, optional) – DataLoader worker processes. Default is 0.

  • pseudo_batch_size (int, optional) – Gradient accumulation steps. Default is 1.

  • checkpoint_path (str or Path, optional) – If provided, save() is called after every save_freq epochs.

  • save_freq (int, optional) – Save frequency in epochs. 1 (default) saves after every epoch. <=0 disables saving entirely.

  • validate_freq (int, optional) – Validation frequency in epochs. 1 (default) validates after every epoch. <=0 disables validation entirely (including the epoch-0 baseline).

  • tensorboard_dir (str, optional) – Directory for a SummaryWriter. A writer is created at the start and closed when training ends. Skipped when None.

  • verbose (bool, optional) – Show tqdm progress bars and per-epoch summary lines. Default is False.

  • memfail ({'raise', 'ignore'}, optional) – OOM policy forwarded to train_epoch() and validate_epoch(). Default is 'raise'.

  • device (str, optional) – If provided, the model is moved to this device before training starts (e.g. 'cuda', 'cpu').

  • patience (int, optional) – Early-stopping patience. If the principal validation metric has not improved for patience consecutive epochs, training stops before reaching epochs. Requires val_data to be set; ignored when None (default).

  • amp (bool, optional) – Enable automatic mixed precision. Forwarded to train_epoch(). Default is False.

  • report_wandb (bool, optional) –

    Log metrics to Weights & Biases. Requires wandb to be installed (pip install wandb). When True:

    • A run is initialised with wandb.init if one is not already active; the project name defaults to the model class name.

    • Train metrics are logged under train/<key> and validation metrics under val/<key> once per epoch.

    • The run URL is printed at startup, with an ASCII QR code when the qrcode package is available (pip install qrcode).

    • When verbose is True, the URL is also reprinted at the end of every epoch line.

    Default is False. A RuntimeWarning is emitted when True but wandb is not installed.

  • report_gradio (bool, optional) – Start a local Gradio dashboard and expose it via a public reverse tunnel (gradio’s share=True). Requires gradio (pip install gradio). The tunnel URL and an ASCII QR code are printed at startup; the dashboard auto-refreshes every 30 s. Default is False.

  • training_label (str, optional) – Human-readable identifier for this training run. Used as the wandb run name and id so the same label always resumes the same wandb run. When None (default) an automatic label is derived as f"{ClassName}_{os.getpid()}".

Returns:

self, so calls can be chained.

Return type:

Mentee

Examples

>>> model = MyNet()
>>> model.fit(train_loader, val_loader, epochs=10, lr=1e-3,
...           checkpoint_path="run.pt", tensorboard_dir="tb/",
...           verbose=True)
>>> print(f"best epoch: {model._best_epoch_so_far}")
find_lr(train_data, start_lr=1e-07, end_lr=10.0, num_iter=100, smooth=0.98, diverge_threshold=4.0, batch_size=None, collate_fn=None, num_workers=0, amp=False)[source]

Run the learning-rate range test (Smith 2017).

Sweeps the learning rate geometrically from start_lr to end_lr over num_iter batches, records the smoothed loss at each step, and then restores the model weights so the run has no side-effects.

A fresh optimizer is created for the sweep via a new instance of type(self.trainer) (or a plain Adam when no trainer is set), so neither the cached optimizer nor the trainer state are affected.

Parameters:
  • train_data (DataLoader or Dataset) – Data to iterate over — only num_iter batches are consumed.

  • start_lr (float, optional) – Lower bound of the LR sweep. Default is 1e-7.

  • end_lr (float, optional) – Upper bound of the LR sweep. Default is 10.0.

  • num_iter (int, optional) – Number of batches to sweep over. Default is 100.

  • smooth (float, optional) – Exponential moving-average factor for loss smoothing. Higher values produce a smoother curve. Default is 0.98.

  • diverge_threshold (float, optional) – Stop early when the smoothed loss exceeds diverge_threshold × best_loss. Default is 4.0.

  • batch_size (int, optional) – Batch size when train_data is not already a DataLoader.

  • collate_fn (callable, optional) – Custom collate function forwarded to the DataLoader.

  • num_workers (int, optional) – DataLoader worker processes. Default is 0.

  • amp (bool, optional) – Run the sweep with automatic mixed precision. Default is False.

Returns:

{"lrs": [float, ...], "losses": [float, ...]} — one entry per completed step, suitable for plotting.

Return type:

dict

Examples

>>> result = model.find_lr(train_loader, start_lr=1e-6, end_lr=1.0)
>>> import matplotlib.pyplot as plt
>>> plt.semilogx(result["lrs"], result["losses"]); plt.show()
set_lr_coefficient(coefficient, patterns, optimizer=None)[source]

Set a per-layer learning-rate coefficient for layers matching patterns.

The effective LR for each layer is global_lr * coefficient. _lr_coefficients is the source of truth; it is persisted in every checkpoint and applied automatically by create_train_objects().

If an optimizer is available (via optimizer or the cached trainer/model optimizer), _apply_lr_coefficients() is called immediately to rebuild the param groups with the new coefficient applied. Optimizer momentum / variance state (keyed by parameter id) is preserved.

Note

_apply_lr_coefficients() derives the base learning rate from optimizer.defaults["lr"] – the value passed at optimizer creation, not the current scheduler-decayed value. Calling this method mid-training therefore resets affected layers to initial_lr * coefficient, discarding any decay. Call set_lr_coefficient() before training begins or at a deliberate phase boundary (e.g. when unfreezing layers with a freshly built optimizer) to avoid this effect.

Parameters:
  • coefficient (float) – Multiplier relative to the global LR. 1.0 restores the default (the layer is removed from _lr_coefficients to keep it sparse). 0.0 zeroes the layer’s LR without setting requires_grad=False.

  • patterns (str or list[str]) – Exact names or re.fullmatch patterns matched against layer_names.

  • optimizer (torch.optim.Optimizer, optional) – Optimizer to update. Defaults to the trainer’s or the model’s cached optimizer.

Returns:

self, for chaining.

Return type:

Mentee

Examples

>>> model.set_lr_coefficient(0.1, "backbone")        # 10x slower backbone
>>> model.set_lr_coefficient(0.0, ["backbone"])       # zero backbone LR
>>> model.set_lr_coefficient(1.0, r"backbone\..*")   # restore all sub-layers
select_layers(layer_names)[source]

Return layer paths that match any entry in layer_names, deduplicated and sorted in module traversal order (the same order as layer_names).

Each entry in layer_names is matched with re.fullmatch against the full dotted path of every module in layer_names (e.g. backbone.layer4.0.conv2). Plain strings act as exact-match selectors; regex patterns select groups of layers. The dot separator in layer paths is a literal character — escape it as \. in patterns to avoid matching unintended paths. Duplicate matches (a name matched by several patterns) are collapsed to a single entry. The order of the returned list always follows layer_names, never the order of the input patterns.

Parameters:

layer_names (list[str]) – Exact path names or re.fullmatch patterns applied to the full dotted path (e.g. r"backbone\.layer[34]\..*").

Returns:

Matched layer paths in module order, without duplicates.

Return type:

list[str]

Examples

For a model whose layer_names is ['backbone', 'backbone.layer4', 'backbone.layer4.0.conv2', 'head']:

# exact match
model.select_layers(['backbone.layer4'])
# -> ['backbone.layer4']

# regex: all sub-layers of backbone (dot must be escaped)
model.select_layers([r'backbone\..*'])
# -> ['backbone.layer4', 'backbone.layer4.0.conv2']

# input order does not affect output order
model.select_layers(['head', 'backbone'])
# -> ['backbone', 'head']

# duplicate matches collapsed to one entry
model.select_layers([r'backbone\..*', 'backbone.layer4'])
# -> ['backbone.layer4', 'backbone.layer4.0.conv2']
freeze(patterns, optimizer=None, reset_optimizer_if_needed=False)[source]

Freeze layers selected by re.fullmatch patterns.

Updates _frozen_modules (source of truth) and sets requires_grad=False on the affected parameters. If an optimizer is resolved, the corresponding param groups are left in place but their parameters will produce no gradients so Adam skips them automatically — no restructuring is required.

Parameters:
  • patterns (str or list[str]) – Exact names or re.fullmatch patterns matched against layer_names.

  • optimizer (torch.optim.Optimizer, optional) – Optimizer to update. Defaults to the trainer’s or the model’s cached optimizer.

  • reset_optimizer_if_needed (bool, optional) – Accepted for API symmetry with unfreeze() and set_lr_coefficient(); currently unused because freezing never requires restructuring the optimizer.

Returns:

self, for chaining.

Return type:

Mentee

unfreeze(patterns, optimizer=None, reset_optimizer_if_needed=False)[source]

Unfreeze layers selected by re.fullmatch patterns.

Updates _frozen_modules (source of truth) and sets requires_grad=True on the affected parameters.

If an optimizer is resolved and the unfrozen layer already has a param group (because it was frozen after the optimizer was built), the group’s parameters are live again and Adam will initialise their state on the first gradient step — no rebuild needed. If the layer has no group (it was frozen before the optimizer was built), a rebuild is required.

Parameters:
  • patterns (str or list[str]) – Exact names or re.fullmatch patterns matched against layer_names.

  • optimizer (torch.optim.Optimizer, optional) – Optimizer to inspect and possibly rebuild. Defaults to the trainer’s or the model’s cached optimizer.

  • reset_optimizer_if_needed (bool, optional) – When True and the unfrozen layer has no param group, create_train_objects() is called to rebuild the optimizer (Adam state is reset). When False (default) a RuntimeError is raised instead.

Returns:

self, for chaining.

Return type:

Mentee

save(path, optimizer=None, lr_scheduler=None)[source]

Serialise the full training state to a .pt checkpoint.

All tensors are moved to CPU before saving so the checkpoint is device-independent. The file contains model weights, training and validation history, provenance metadata, inference state, and (optionally) optimiser and scheduler state.

Parameters:
  • path (str or pathlib.Path) – Destination file path or any file-like object accepted by torch.save() (e.g. io.BytesIO).

  • optimizer (torch.optim.Optimizer, optional) – If provided, its state_dict is stored so training can be resumed with exactly the same optimiser state.

  • lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional) – If provided, its state_dict is stored alongside the optimiser.

Return type:

None

Examples

>>> model.save("checkpoint.pt", optimizer=opt, lr_scheduler=sched)
>>> # or in-memory:
>>> import io; buf = io.BytesIO()
>>> model.save(buf); buf.seek(0)
classmethod resume(path, model_class=None, tolerate_irresumable_model=True, trainer=None, **kwargs)[source]

Load a checkpoint saved by save() and return the model.

If model_class is None, the class is resolved from the class_module / class_name fields stored in the checkpoint using importlib.import_module().

When the checkpoint was created by wrap_as_mentee(), the stored class name points to the original (unwrapped) class. resume detects this and re-applies the Mentee mixin automatically, so the returned object is always a Mentee instance.

Parameters:
  • path (str or pathlib.Path) – Path to the .pt file, or a file-like object.

  • model_class (type, optional) – Explicit subclass to instantiate. Required when the checkpoint’s module is not importable in the current environment, and required when tolerate_irresumable_model is True.

  • tolerate_irresumable_model (bool, optional) – When True (default), any failure to load the model — including a missing file, an unimportable class, a state-dict mismatch, or any other exception — falls back to a freshly instantiated model_class using the constructor params stored in the checkpoint (or an empty dict when the file is missing or unreadable). When False, any such failure raises immediately.

  • trainer (MentorTrainer subclass (uninstantiated), optional) – If supplied, an instance of this trainer is assigned to model.trainer after loading. Use this when the checkpoint was created with wrap_as_mentee() and a custom trainer that is not serialised inside the checkpoint.

  • kwargs (Any)

Returns:

Fully restored model (from checkpoint) or a fresh instance.

Return type:

Mentee

Raises:
  • Exception – Any failure when tolerate_irresumable_model is False.

  • ValueError – If tolerate_irresumable_model is True but model_class is None when the fallback is triggered.

Examples

>>> model = Mentee.resume("checkpoint.pt", model_class=MyNet)
>>> model.eval()

Start from scratch when no checkpoint exists yet:

>>> model = MyNet.resume("run/checkpoint.pt", model_class=MyNet)
classmethod resume_training(path, model_class=None, device=None, tolerate_irresumable_model=True, tolerate_irresumable_trainstate=False, **kwargs)[source]

Load a checkpoint and reconstruct everything needed to continue training.

Restores model weights and history, moves the model to device, calls create_train_objects(), and restores optimiser and scheduler state if present in the checkpoint.

Parameters:
  • path (str or pathlib.Path) – Path to the .pt file, or a file-like object.

  • model_class (type, optional) – Explicit subclass to instantiate (see resume()).

  • device (str or torch.device, optional) – Target device, e.g. "cuda" or "cpu". If None the model stays on CPU as loaded.

  • tolerate_irresumable_model (bool, optional) – When True (default), any failure to load the model — including a missing file, an unimportable class, a state-dict mismatch, or any other exception — falls back to a freshly instantiated model_class. model_class must be provided when this fallback is triggered. When False, any such failure raises immediately.

  • tolerate_irresumable_trainstate (bool, optional) – When False (default) and the checkpoint contains no optimizer state, or the optimizer / scheduler / scaler state cannot be restored, an exception is raised. Set to True to silently continue with a freshly constructed optimizer instead.

  • **kwargs (Any) – Passed to create_train_objects() (e.g. lr=1e-4).

Returns:

(model, optimizer, lr_scheduler) — the same objects returned by create_train_objects(), prepended with the loaded model.

Return type:

tuple

Raises:
  • FileNotFoundError – If path does not exist and instantiate_on_fail is False.

  • RuntimeError – If the model cannot be loaded and tolerate_irresumable_model is False, or if the training state cannot be restored and tolerate_irresumable_trainstate is False.

  • ValueError – If tolerate_irresumable_model is True but model_class is None when the fallback is triggered.

Examples

>>> model, opt, sched = Mentee.resume_training(
...     "checkpoint.pt", model_class=MyNet, device="cuda", lr=1e-4
... )
>>> model.train_epoch(train_loader, opt, sched)

Start from scratch when no checkpoint exists yet:

>>> model, opt, sched = MyNet.resume_training(

… “run/checkpoint.pt”, … model_class=MyNet, … device=”cuda”, … pretrained=True, … lr=1e-4, … )

Helper functions

The following module-level helpers are used internally and exposed for advanced use.

mentor.mentee._to_cpu(obj)[source]

Recursively move tensors in nested dicts/lists to CPU.

Parameters:

obj (Any)

Return type:

Any

mentor.mentee._state_dict_architecture_lines(state_dict)[source]

Derive architecture stats from a state_dict without instantiating the model.

Parameters:

state_dict (Dict[str, Any])

Return type:

List[str]

mentor.mentee._get_software_snapshot()[source]

Collect a reproducibility snapshot of the current software environment.

Fields recorded

python, torch, mentor_version

Version strings for the interpreter and key libraries.

torchvision, numpy

Version strings when available; "unavailable" otherwise.

cuda

torch.version.cuda (e.g. "12.1") or "cpu" for CPU builds.

platform

OS description from platform (e.g. "Linux-6.1 x86_64").

hostname, user

Machine and user identity.

main_script

Absolute path to sys.argv[0] — the entry-point script.

git_hash

Full SHA-1 of HEAD; "unavailable" when git is absent.

git_branch

Current branch name; helps locate the commit in a crowded history.

git_remote

URL of the origin remote; identifies the repo/fork.

git_dirty

"true" when there are uncommitted changes (hash insufficient for exact reproduction), "false" otherwise.

Return type:

Dict[str, str]