mentor.mentee
The Mentee class is the single public entry point for
building and training models with mentor.
- class mentor.Mentee(**constructor_params)[source]
Bases:
ModuleA
torch.nn.Modulesubclass that bundles training, validation, checkpointing, provenance tracking, and inference state in a single.ptfile.Subclass
Menteeand implement at minimumforward(),training_step(), andvalidation_step(). All other methods have working defaults or raiseNotImplementedErrorwith informative messages.- Parameters:
**constructor_params (Any) – Keyword arguments stored verbatim in the checkpoint so the model can be re-instantiated without external scaffolding.
Examples
>>> class MyNet(Mentee): ... def __init__(self, num_classes=10): ... super().__init__(num_classes=num_classes) ... self.fc = torch.nn.Linear(128, num_classes) ... def forward(self, x): ... return self.fc(x) ... def training_step(self, sample): ... x, y = sample ... loss = torch.nn.functional.cross_entropy(self(x), y) ... return loss, {"loss": loss.item()} ... def validation_step(self, sample): ... x, y = sample ... acc = (self(x).argmax(1) == y).float().mean().item() ... return {"acc": acc}
- __init__(**constructor_params)[source]
Initialise internal history buffers and record constructor parameters.
Constructor parameters are stored verbatim in every checkpoint so that
resume()can reconstruct the model without any external scaffolding. There are two ways to supply them:Explicit (classic subclassing)
Pass every argument you want recorded as a keyword argument to
super().__init__:class MyNet(Mentee): def __init__(self, num_classes=10, dropout=0.5): super().__init__(num_classes=num_classes, dropout=dropout) self.fc = nn.Linear(128, num_classes)
Implicit (zero-boilerplate)
Call
super().__init__()with no arguments — or let an intermediate base pass its own kwargs upward. This method always walks the entire call stack collecting__init__frames that operate on the same object, and reads the locals of the topmost such frame. The topmost frame always belongs to the most-derived (concrete) class (type(self)), so all user-defined parameters are captured regardless of inheritance depth or whether an intermediate base forwarded explicit kwargs:class Base(Mentee): def __init__(self, a=1): super().__init__() # Mentee always walks to Child class Child(Base): def __init__(self, a=1, b=2): super().__init__() # constructor_params = {'a': 1, 'b': 2}
The same result holds even when an intermediate base uses explicit passing:
class Base(Mentee): def __init__(self, a=1): super().__init__(a=a) # explicit — but walk still runs class Child(Base): def __init__(self, a=1, b=2): super().__init__() # constructor_params still = {'a', 'b'}
The walk stops as soon as either condition below is violated:
The frame’s code object is named
__init__(rules out factory functions, class methods, and calls at module level).The
selflocal in that frame is the exact same object being constructed here (frame.f_locals['self'] is self), ruling out construction happening inside another object’s__init__.
A third guard prevents capturing locals when
Menteeitself is instantiated directly (type(self) is not Mentee).When no
__init__frame is found (factory function, module-level call), the explicitly provided**constructor_paramsare kept as-is.class MyNet(Mentee): def __init__(self, num_classes=10, dropout=0.5): super().__init__() # num_classes and dropout captured automatically self.fc = nn.Linear(128, num_classes)
The implicit path also captures any
**kwargsthe child accepted:class MyNet(Mentee): def __init__(self, num_classes=10, **extra): super().__init__() # num_classes + contents of extra are all recorded
When implicit capture is skipped
If the three conditions above are not met (e.g.
Mentee()is instantiated directly, orMentee.__init__is called from outside an__init__context),constructor_paramsis left as whatever was explicitly passed — which may be an empty dict. No error is raised; the checkpoint will simply store{}.- Parameters:
**constructor_params (Any) – Keyword arguments to store. When non-empty, they are used as-is and frame introspection is skipped entirely.
- Return type:
None
Notes
Frame introspection relies on
inspect.currentframe(), which is guaranteed on CPython (the runtime used by PyTorch in practice) but not mandated by the Python language specification. On alternative implementations such as PyPy the implicit path may silently fall back to an empty dict; use explicit passing if portability matters.Examples
>>> class MyNet(Mentee): ... def __init__(self, num_classes=10): ... super().__init__() # implicit: num_classes=10 captured ... self.fc = nn.Linear(128, num_classes) >>> model = MyNet(num_classes=5) >>> model._constructor_params {'num_classes': 5}
>>> class MyNet(Mentee): ... def __init__(self, num_classes=10): ... super().__init__(num_classes=num_classes) # explicit ... self.fc = nn.Linear(128, num_classes) >>> model = MyNet(num_classes=5) >>> model._constructor_params {'num_classes': 5}
- property current_epoch: int
Number of completed training epochs.
- Returns:
Equal to
len(self._train_history). Zero on a fresh model.- Return type:
- property total_train_iterations: int
Cumulative number of batches processed across all
train_epoch()calls.Incremented at the end of every epoch before the LR scheduler step. Persisted in every checkpoint and restored on resume.
- property layer_names: List[str]
Full dotted paths of every parameter-bearing module, in module order.
These are the names accepted by
freeze()andunfreeze(), and are also the node labels shown bymtr_checkpoint -verbose.
- property device: device
Device on which the model parameters currently reside.
- Returns:
Inferred from the first parameter tensor.
- Return type:
- Raises:
StopIteration – If the model has no parameters (bare
Menteewith no submodules).
- property optimizer: Any | None
The optimizer produced by the last
create_train_objects()call.When a
traineris set, returnstrainer.optimizer. Otherwise returns the locally cached_optimizer.Noneuntilcreate_train_objects()has been called.
- property lr_scheduler: Any | None
The LR scheduler produced by the last
create_train_objects()call.When a
traineris set, returnstrainer.lr_scheduler. Otherwise returns the locally cached_lr_scheduler.Noneuntilcreate_train_objects()has been called.
- property loss_fn: Any | None
The default loss function registered by
create_train_objects().When a
traineris set, returnstrainer.loss_fn. Otherwise returns_default_loss_fn.Noneuntil a loss has been registered.
- register_inference_state(key, value)[source]
Store an arbitrary picklable object needed at inference time.
Unlike
constructor_params, inference state is typically computed from data (e.g. a fitted label encoder, vocabulary, or normalisation statistics) and may be large. It is serialised transparently inside the checkpoint alongside the model weights.- Parameters:
key (str) – Identifier used to retrieve the value with
get_inference_state().value (Any) – Any picklable Python object (dict, list, tensor, sklearn transformer, …).
- Return type:
None
Examples
>>> model.register_inference_state("classes", ["cat", "dog", "bird"]) >>> model.register_inference_state("mean", torch.tensor([0.485, 0.456, 0.406]))
- get_inference_state(key, default=None)[source]
Retrieve a value previously stored with
register_inference_state().- Parameters:
key (str) – Identifier passed to
register_inference_state().default (Any, optional) – Returned when key is not present. Default is
None.
- Returns:
The stored object, or default if the key is absent.
- Return type:
Any
Examples
>>> classes = model.get_inference_state("classes", default=[])
- forward(*args, **kwargs)[source]
Forward pass — must be overridden by subclasses.
- Parameters:
*args (Any) – Positional inputs (typically a batch tensor).
**kwargs (Any) – Keyword inputs.
- Returns:
Model output (logits, embeddings, sequences, …).
- Return type:
Any
- Raises:
NotImplementedError – Always raised by the base implementation.
- training_step(sample, loss_fn=None)[source]
Compute the loss for a single training sample or mini-batch.
Called inside
train_epoch(). The returned tensor must be differentiable with respect to the model parameters.- Parameters:
sample (Any) – One element yielded by the training DataLoader.
- Returns:
loss (torch.Tensor) – Scalar loss tensor (
requires_grad=True).metrics (dict[str, float]) – Scalar metrics to accumulate and log. The first key is treated as the principal metric by
validate_epoch()for best-model tracking.
- Raises:
NotImplementedError – Always raised by the base implementation.
- Return type:
Examples
>>> def training_step(self, sample): ... x, y = sample ... loss = F.cross_entropy(self(x.to(self.device)), y.to(self.device)) ... return loss, {"loss": loss.item()}
- validation_step(sample, loss_fn=None)[source]
Evaluate the model on a single validation sample or mini-batch.
Defaults to calling
training_step()with the same arguments, so subclasses that only implementtraining_step()get validation for free. Override when the validation forward pass differs from training (e.g. different augmentation, TTA, beam search).Called inside
validate_epoch()undertorch.no_grad(). The first key of the returned dict is used as the principal metric when comparing epochs for best-model selection.- Parameters:
sample (Any) – One element yielded by the validation DataLoader.
loss_fn (callable, optional) – Loss function forwarded to
training_step().
- Returns:
Scalar evaluation metrics (may include
"loss").- Return type:
Examples
>>> # default: no override needed if training_step covers both >>> def validation_step(self, sample, loss_fn=None): # custom override ... x, y = sample ... logits = self(x.to(self.device)) ... acc = (logits.argmax(1) == y.to(self.device)).float().mean().item() ... return {"acc": acc}
- preprocess(raw_input)[source]
Transform a raw input into a model-ready tensor.
Override to make the checkpoint self-contained for inference. Use
get_inference_state()to access tokenizers, normalisation statistics, or other data-derived artefacts.- Parameters:
raw_input (Any) – Raw data (PIL image, string, numpy array, …).
- Returns:
Model-ready tensor or batch.
- Return type:
Any
- Raises:
NotImplementedError – Raised by the base implementation.
Examples
>>> def preprocess(self, raw_input): ... mean = self.get_inference_state("mean") ... return (torch.tensor(raw_input) - mean) / std
- decode(model_output)[source]
Transform raw model output into a human-readable result.
Override to make the checkpoint self-contained for inference. Use
get_inference_state()to access label maps, alphabets, or beam-search decoders.- Parameters:
model_output (Any) – Raw output from
forward().- Returns:
Human-readable prediction (class name, decoded string, bounding box, …).
- Return type:
Any
- Raises:
NotImplementedError – Raised by the base implementation.
Examples
>>> def decode(self, model_output): ... idx = model_output.argmax(1).item() ... return self.get_inference_state("classes")[idx]
- get_output_schema()[source]
Describe the output space as a serialisable dict.
The returned dict is embedded in the checkpoint and displayed by
mtr_checkpoint. Override to self-document what the model produces.- Returns:
Arbitrary JSON-serialisable description. Common keys:
type,num_classes,classes,alphabet. Returns{}by default.- Return type:
Examples
>>> def get_output_schema(self): ... return {"type": "classification", ... "classes": self.get_inference_state("classes")}
- get_preprocessing_info()[source]
Describe preprocessing requirements as a serialisable dict.
The returned dict is embedded in the checkpoint and displayed by
mtr_checkpoint. Override to self-document expected inputs.- Returns:
Arbitrary JSON-serialisable description. Common keys:
input_size,mean,std,resize. Returns{}by default.- Return type:
Examples
>>> def get_preprocessing_info(self): ... return {"input_size": [1, 28, 28], ... "mean": [0.1307], "std": [0.3081]}
- create_train_objects(lr=0.001, step_size=10, gamma=0.1, loss_fn=None, overwrite_default_loss=False)[source]
Create training objects and (optionally) set the default loss function.
Returns a dict with
"optimizer","lr_scheduler", and"loss_fn"keys. Calling this method more than once is safe — by default it will not replace a previously set default loss (overwrite_default_loss=False), so a parametric loss that has already been partially trained is preserved across optimizer resets.Override to substitute a different optimiser or scheduler; the dict structure must be preserved.
- Parameters:
lr (float, optional) – Initial learning rate for Adam. Default is
1e-3.step_size (int, optional) – Period (in epochs) for the StepLR decay. Default is
10.gamma (float, optional) – Multiplicative decay factor for StepLR. Default is
0.1.loss_fn (callable, optional) – Loss function to register as the default. If
Noneand no default is currently set,_default_loss_fnremainsNone(which meanstraining_step()must either provide one or raise its own error).overwrite_default_loss (bool, optional) – If
True, always replace the existing default loss with the newly supplied loss_fn. IfFalse(default) and a default is already set, the existing default is preserved even when loss_fn is provided. Set toTruewhen intentionally switching loss functions mid-training.
- Returns:
{"optimizer": Adam, "lr_scheduler": StepLR, "loss_fn": <fn or None>}- Return type:
Examples
>>> train_objs = model.create_train_objects(lr=1e-4, step_size=5, ... loss_fn=nn.CrossEntropyLoss()) >>> train_objs["optimizer"], train_objs["lr_scheduler"] (Adam ..., StepLR ...) >>> # second call with overwrite_default_loss=False keeps the first loss >>> train_objs2 = model.create_train_objects(lr=1e-5) >>> train_objs2["loss_fn"] is train_objs["loss_fn"] True
- train_epoch(dataset, optimizer, lr_scheduler=None, pseudo_batch_size=1, memfail='raise', tensorboard_writer=None, verbose=False, refresh_freq=20, batch_size=None, collate_fn=None, num_workers=0, shuffle=True, amp=False)[source]
Train the model for one full epoch.
Iterates over dataset, calls
training_step()for each batch, and accumulates gradients for pseudo_batch_size batches before callingoptimizer.step(). Appends the epoch metrics to_train_history, incrementingcurrent_epoch.dataset may be a
DataLoader(used directly) or aDataset/ any sized iterable (wrapped automatically using batch_size, collate_fn, num_workers, and shuffle). When a DataLoader is passed the four loader kwargs are ignored.- Parameters:
dataset (DataLoader or Dataset) – Batched DataLoader or an unbatched Dataset to be wrapped.
optimizer (torch.optim.Optimizer) – Optimiser to use for parameter updates.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional) – Scheduler stepped once at the end of the epoch.
pseudo_batch_size (int, optional) – Number of batches over which gradients are accumulated before each
optimizer.step(). Default is1.memfail ({'raise', 'skip'}, optional) – Policy when
training_step()raisesMemoryError.'raise'propagates immediately;'skip'counts the failure and continues. Default is'raise'.tensorboard_writer (torch.utils.tensorboard.SummaryWriter, optional) – If provided, each metric is logged under
train/<metric>.verbose (bool, optional) – Show a
tqdmprogress bar. Default isFalse.refresh_freq (int, optional) – Progress-bar postfix update interval (in batches). Default is
20.batch_size (int, optional) – Batch size used when dataset is not a DataLoader. Defaults to
1.collate_fn (callable, optional) – Custom collate function forwarded to the DataLoader when dataset is not already a DataLoader.
num_workers (int, optional) – Number of DataLoader worker processes. Default is
0(main-process loading).shuffle (bool, optional) – Whether to shuffle samples when building a DataLoader from a Dataset. Default is
True. Ignored when dataset is already a DataLoader.amp (bool, optional) – Enable automatic mixed precision via
torch.autocastandtorch.amp.GradScaler. The scaler is cached on the model as_grad_scalerso its loss-scale adapts correctly across epochs. Default isFalse.
- Returns:
Per-metric averages over the epoch, plus
memfails(count of skipped batches).- Return type:
- Raises:
MemoryError – When memfail is
'raise'and a batch triggers OOM.
Examples
>>> _to = model.create_train_objects(lr=1e-3) >>> # from a DataLoader (existing usage) >>> metrics = model.train_epoch(train_loader, _to["optimizer"], pseudo_batch_size=4) >>> # from a Dataset (new usage) >>> metrics = model.train_epoch(train_dataset, _to["optimizer"], batch_size=32, shuffle=True) >>> print(f"epoch {model.current_epoch} loss={metrics['loss']:.4f}")
- validate_epoch(dataset, recalculate=False, memfail='raise', tensorboard_writer=None, verbose=False, refresh_freq=20, batch_size=None, collate_fn=None, num_workers=0)[source]
Validate the model at the current epoch.
Results are cached in
_validate_historykeyed by epoch. Calling this method twice for the same epoch returns the cached dict without re-running inference, unless recalculate isTrue.If the principal metric (first key of the returned dict) exceeds all previous epochs, the current weights are saved to
_best_weights_so_far.dataset may be a
DataLoader(used directly) or aDataset/ any sized iterable (wrapped automatically with batch_size and collate_fn). Shuffle is alwaysFalsefor validation.- Parameters:
dataset (DataLoader or Dataset) – Batched DataLoader or an unbatched Dataset to be wrapped.
recalculate (bool, optional) – Force re-evaluation even if this epoch was already validated. Default is
False.memfail ({'raise', 'skip'}, optional) – Policy for
MemoryErrorinsidevalidation_step(). Default is'raise'.tensorboard_writer (torch.utils.tensorboard.SummaryWriter, optional) – If provided, metrics are logged under
val/<metric>.verbose (bool, optional) – Show a
tqdmprogress bar. Default isFalse.refresh_freq (int, optional) – Progress-bar postfix update interval. Default is
20.batch_size (int, optional) – Batch size used when dataset is not a DataLoader. Defaults to
1.collate_fn (callable, optional) – Custom collate function forwarded to the DataLoader when dataset is not already a DataLoader.
num_workers (int, optional) – Number of DataLoader worker processes. Default is
0.
- Returns:
Per-metric averages, plus
memfails.- Return type:
- Raises:
MemoryError – When memfail is
'raise'and a batch triggers OOM.
Examples
>>> # from a DataLoader (existing usage) >>> val_metrics = model.validate_epoch(val_loader) >>> # from a Dataset (new usage) >>> val_metrics = model.validate_epoch(val_dataset, batch_size=64) >>> print(f"acc={val_metrics['acc']:.4f} best_epoch={model._best_epoch_so_far}")
- fit(train_data, val_data=None, epochs=1, lr=0.001, batch_size=None, collate_fn=None, num_workers=0, pseudo_batch_size=1, checkpoint_path=None, tensorboard_dir=None, verbose=False, memfail='raise', device=None, patience=None, amp=False, save_freq=1, validate_freq=1, report_wandb=False, report_gradio=False, training_label=None)[source]
Train and optionally validate for a fixed number of epochs.
A convenience wrapper around
train_epoch(),validate_epoch(), andsave()that drives the full training loop in one call. It is equivalent to writing the loop manually and is provided for cases where you do not need to insert custom logic between epochs.If
optimizerisNonewhenfitis called,create_train_objects()is called automatically with the supplied lr. If training objects already exist (e.g. a previous call tocreate_train_objects()orresume_training()), they are reused unchanged.- Parameters:
train_data (DataLoader or Dataset) – Training data — passed directly to
train_epoch().val_data (DataLoader or Dataset, optional) – Validation data — passed to
validate_epoch()after each epoch. Skipped whenNone.epochs (int, optional) – Number of epochs to train. Default is
1.lr (float, optional) – Learning rate passed to
create_train_objects()when no optimizer exists yet. Ignored if training objects are already set up. Default is1e-3.batch_size (int, optional) – Batch size used when train_data or val_data is not already a
DataLoader.collate_fn (callable, optional) – Custom collate function forwarded to the DataLoader.
num_workers (int, optional) – DataLoader worker processes. Default is
0.pseudo_batch_size (int, optional) – Gradient accumulation steps. Default is
1.checkpoint_path (str or Path, optional) – If provided,
save()is called after every save_freq epochs.save_freq (int, optional) – Save frequency in epochs.
1(default) saves after every epoch.<=0disables saving entirely.validate_freq (int, optional) – Validation frequency in epochs.
1(default) validates after every epoch.<=0disables validation entirely (including the epoch-0 baseline).tensorboard_dir (str, optional) – Directory for a
SummaryWriter. A writer is created at the start and closed when training ends. Skipped whenNone.verbose (bool, optional) – Show
tqdmprogress bars and per-epoch summary lines. Default isFalse.memfail ({'raise', 'ignore'}, optional) – OOM policy forwarded to
train_epoch()andvalidate_epoch(). Default is'raise'.device (str, optional) – If provided, the model is moved to this device before training starts (e.g.
'cuda','cpu').patience (int, optional) – Early-stopping patience. If the principal validation metric has not improved for patience consecutive epochs, training stops before reaching epochs. Requires val_data to be set; ignored when
None(default).amp (bool, optional) – Enable automatic mixed precision. Forwarded to
train_epoch(). Default isFalse.report_wandb (bool, optional) –
Log metrics to Weights & Biases. Requires
wandbto be installed (pip install wandb). WhenTrue:A run is initialised with
wandb.initif one is not already active; the project name defaults to the model class name.Train metrics are logged under
train/<key>and validation metrics underval/<key>once per epoch.The run URL is printed at startup, with an ASCII QR code when the
qrcodepackage is available (pip install qrcode).When verbose is
True, the URL is also reprinted at the end of every epoch line.
Default is
False. ARuntimeWarningis emitted whenTruebutwandbis not installed.report_gradio (bool, optional) – Start a local Gradio dashboard and expose it via a public reverse tunnel (
gradio’sshare=True). Requiresgradio(pip install gradio). The tunnel URL and an ASCII QR code are printed at startup; the dashboard auto-refreshes every 30 s. Default isFalse.training_label (str, optional) – Human-readable identifier for this training run. Used as the wandb run
nameandidso the same label always resumes the same wandb run. WhenNone(default) an automatic label is derived asf"{ClassName}_{os.getpid()}".
- Returns:
self, so calls can be chained.- Return type:
Examples
>>> model = MyNet() >>> model.fit(train_loader, val_loader, epochs=10, lr=1e-3, ... checkpoint_path="run.pt", tensorboard_dir="tb/", ... verbose=True) >>> print(f"best epoch: {model._best_epoch_so_far}")
- find_lr(train_data, start_lr=1e-07, end_lr=10.0, num_iter=100, smooth=0.98, diverge_threshold=4.0, batch_size=None, collate_fn=None, num_workers=0, amp=False)[source]
Run the learning-rate range test (Smith 2017).
Sweeps the learning rate geometrically from start_lr to end_lr over num_iter batches, records the smoothed loss at each step, and then restores the model weights so the run has no side-effects.
A fresh optimizer is created for the sweep via a new instance of
type(self.trainer)(or a plainAdamwhen no trainer is set), so neither the cached optimizer nor the trainer state are affected.- Parameters:
train_data (DataLoader or Dataset) – Data to iterate over — only num_iter batches are consumed.
start_lr (float, optional) – Lower bound of the LR sweep. Default is
1e-7.end_lr (float, optional) – Upper bound of the LR sweep. Default is
10.0.num_iter (int, optional) – Number of batches to sweep over. Default is
100.smooth (float, optional) – Exponential moving-average factor for loss smoothing. Higher values produce a smoother curve. Default is
0.98.diverge_threshold (float, optional) – Stop early when the smoothed loss exceeds
diverge_threshold × best_loss. Default is4.0.batch_size (int, optional) – Batch size when train_data is not already a DataLoader.
collate_fn (callable, optional) – Custom collate function forwarded to the DataLoader.
num_workers (int, optional) – DataLoader worker processes. Default is
0.amp (bool, optional) – Run the sweep with automatic mixed precision. Default is
False.
- Returns:
{"lrs": [float, ...], "losses": [float, ...]}— one entry per completed step, suitable for plotting.- Return type:
Examples
>>> result = model.find_lr(train_loader, start_lr=1e-6, end_lr=1.0) >>> import matplotlib.pyplot as plt >>> plt.semilogx(result["lrs"], result["losses"]); plt.show()
- set_lr_coefficient(coefficient, patterns, optimizer=None)[source]
Set a per-layer learning-rate coefficient for layers matching patterns.
The effective LR for each layer is
global_lr * coefficient._lr_coefficientsis the source of truth; it is persisted in every checkpoint and applied automatically bycreate_train_objects().If an optimizer is available (via optimizer or the cached trainer/model optimizer),
_apply_lr_coefficients()is called immediately to rebuild the param groups with the new coefficient applied. Optimizer momentum / variance state (keyed by parameter id) is preserved.Note
_apply_lr_coefficients()derives the base learning rate fromoptimizer.defaults["lr"]– the value passed at optimizer creation, not the current scheduler-decayed value. Calling this method mid-training therefore resets affected layers toinitial_lr * coefficient, discarding any decay. Callset_lr_coefficient()before training begins or at a deliberate phase boundary (e.g. when unfreezing layers with a freshly built optimizer) to avoid this effect.- Parameters:
coefficient (float) – Multiplier relative to the global LR.
1.0restores the default (the layer is removed from_lr_coefficientsto keep it sparse).0.0zeroes the layer’s LR without settingrequires_grad=False.patterns (str or list[str]) – Exact names or
re.fullmatchpatterns matched againstlayer_names.optimizer (torch.optim.Optimizer, optional) – Optimizer to update. Defaults to the trainer’s or the model’s cached optimizer.
- Returns:
self, for chaining.
- Return type:
Examples
>>> model.set_lr_coefficient(0.1, "backbone") # 10x slower backbone >>> model.set_lr_coefficient(0.0, ["backbone"]) # zero backbone LR >>> model.set_lr_coefficient(1.0, r"backbone\..*") # restore all sub-layers
- select_layers(layer_names)[source]
Return layer paths that match any entry in layer_names, deduplicated and sorted in module traversal order (the same order as
layer_names).Each entry in layer_names is matched with
re.fullmatchagainst the full dotted path of every module inlayer_names(e.g.backbone.layer4.0.conv2). Plain strings act as exact-match selectors; regex patterns select groups of layers. The dot separator in layer paths is a literal character — escape it as\.in patterns to avoid matching unintended paths. Duplicate matches (a name matched by several patterns) are collapsed to a single entry. The order of the returned list always followslayer_names, never the order of the input patterns.- Parameters:
layer_names (list[str]) – Exact path names or
re.fullmatchpatterns applied to the full dotted path (e.g.r"backbone\.layer[34]\..*").- Returns:
Matched layer paths in module order, without duplicates.
- Return type:
Examples
For a model whose
layer_namesis['backbone', 'backbone.layer4', 'backbone.layer4.0.conv2', 'head']:# exact match model.select_layers(['backbone.layer4']) # -> ['backbone.layer4'] # regex: all sub-layers of backbone (dot must be escaped) model.select_layers([r'backbone\..*']) # -> ['backbone.layer4', 'backbone.layer4.0.conv2'] # input order does not affect output order model.select_layers(['head', 'backbone']) # -> ['backbone', 'head'] # duplicate matches collapsed to one entry model.select_layers([r'backbone\..*', 'backbone.layer4']) # -> ['backbone.layer4', 'backbone.layer4.0.conv2']
- freeze(patterns, optimizer=None, reset_optimizer_if_needed=False)[source]
Freeze layers selected by
re.fullmatchpatterns.Updates
_frozen_modules(source of truth) and setsrequires_grad=Falseon the affected parameters. If an optimizer is resolved, the corresponding param groups are left in place but their parameters will produce no gradients so Adam skips them automatically — no restructuring is required.- Parameters:
patterns (str or list[str]) – Exact names or
re.fullmatchpatterns matched againstlayer_names.optimizer (torch.optim.Optimizer, optional) – Optimizer to update. Defaults to the trainer’s or the model’s cached optimizer.
reset_optimizer_if_needed (bool, optional) – Accepted for API symmetry with
unfreeze()andset_lr_coefficient(); currently unused because freezing never requires restructuring the optimizer.
- Returns:
self, for chaining.
- Return type:
- unfreeze(patterns, optimizer=None, reset_optimizer_if_needed=False)[source]
Unfreeze layers selected by
re.fullmatchpatterns.Updates
_frozen_modules(source of truth) and setsrequires_grad=Trueon the affected parameters.If an optimizer is resolved and the unfrozen layer already has a param group (because it was frozen after the optimizer was built), the group’s parameters are live again and Adam will initialise their state on the first gradient step — no rebuild needed. If the layer has no group (it was frozen before the optimizer was built), a rebuild is required.
- Parameters:
patterns (str or list[str]) – Exact names or
re.fullmatchpatterns matched againstlayer_names.optimizer (torch.optim.Optimizer, optional) – Optimizer to inspect and possibly rebuild. Defaults to the trainer’s or the model’s cached optimizer.
reset_optimizer_if_needed (bool, optional) – When
Trueand the unfrozen layer has no param group,create_train_objects()is called to rebuild the optimizer (Adam state is reset). WhenFalse(default) aRuntimeErroris raised instead.
- Returns:
self, for chaining.
- Return type:
- save(path, optimizer=None, lr_scheduler=None)[source]
Serialise the full training state to a
.ptcheckpoint.All tensors are moved to CPU before saving so the checkpoint is device-independent. The file contains model weights, training and validation history, provenance metadata, inference state, and (optionally) optimiser and scheduler state.
- Parameters:
path (str or pathlib.Path) – Destination file path or any file-like object accepted by
torch.save()(e.g.io.BytesIO).optimizer (torch.optim.Optimizer, optional) – If provided, its
state_dictis stored so training can be resumed with exactly the same optimiser state.lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional) – If provided, its
state_dictis stored alongside the optimiser.
- Return type:
None
Examples
>>> model.save("checkpoint.pt", optimizer=opt, lr_scheduler=sched) >>> # or in-memory: >>> import io; buf = io.BytesIO() >>> model.save(buf); buf.seek(0)
- classmethod resume(path, model_class=None, tolerate_irresumable_model=True, trainer=None, **kwargs)[source]
Load a checkpoint saved by
save()and return the model.If model_class is
None, the class is resolved from theclass_module/class_namefields stored in the checkpoint usingimportlib.import_module().When the checkpoint was created by
wrap_as_mentee(), the stored class name points to the original (unwrapped) class.resumedetects this and re-applies the Mentee mixin automatically, so the returned object is always aMenteeinstance.- Parameters:
path (str or pathlib.Path) – Path to the
.ptfile, or a file-like object.model_class (type, optional) – Explicit subclass to instantiate. Required when the checkpoint’s module is not importable in the current environment, and required when tolerate_irresumable_model is
True.tolerate_irresumable_model (bool, optional) – When
True(default), any failure to load the model — including a missing file, an unimportable class, a state-dict mismatch, or any other exception — falls back to a freshly instantiated model_class using the constructor params stored in the checkpoint (or an empty dict when the file is missing or unreadable). WhenFalse, any such failure raises immediately.trainer (MentorTrainer subclass (uninstantiated), optional) – If supplied, an instance of this trainer is assigned to
model.trainerafter loading. Use this when the checkpoint was created withwrap_as_mentee()and a custom trainer that is not serialised inside the checkpoint.kwargs (Any)
- Returns:
Fully restored model (from checkpoint) or a fresh instance.
- Return type:
- Raises:
Exception – Any failure when tolerate_irresumable_model is
False.ValueError – If tolerate_irresumable_model is
Truebut model_class isNonewhen the fallback is triggered.
Examples
>>> model = Mentee.resume("checkpoint.pt", model_class=MyNet) >>> model.eval()
Start from scratch when no checkpoint exists yet:
>>> model = MyNet.resume("run/checkpoint.pt", model_class=MyNet)
- classmethod resume_training(path, model_class=None, device=None, tolerate_irresumable_model=True, tolerate_irresumable_trainstate=False, **kwargs)[source]
Load a checkpoint and reconstruct everything needed to continue training.
Restores model weights and history, moves the model to device, calls
create_train_objects(), and restores optimiser and scheduler state if present in the checkpoint.- Parameters:
path (str or pathlib.Path) – Path to the
.ptfile, or a file-like object.model_class (type, optional) – Explicit subclass to instantiate (see
resume()).device (str or torch.device, optional) – Target device, e.g.
"cuda"or"cpu". IfNonethe model stays on CPU as loaded.tolerate_irresumable_model (bool, optional) – When
True(default), any failure to load the model — including a missing file, an unimportable class, a state-dict mismatch, or any other exception — falls back to a freshly instantiated model_class. model_class must be provided when this fallback is triggered. WhenFalse, any such failure raises immediately.tolerate_irresumable_trainstate (bool, optional) – When
False(default) and the checkpoint contains no optimizer state, or the optimizer / scheduler / scaler state cannot be restored, an exception is raised. Set toTrueto silently continue with a freshly constructed optimizer instead.**kwargs (Any) – Passed to
create_train_objects()(e.g.lr=1e-4).
- Returns:
(model, optimizer, lr_scheduler)— the same objects returned bycreate_train_objects(), prepended with the loaded model.- Return type:
- Raises:
FileNotFoundError – If path does not exist and instantiate_on_fail is
False.RuntimeError – If the model cannot be loaded and tolerate_irresumable_model is
False, or if the training state cannot be restored and tolerate_irresumable_trainstate isFalse.ValueError – If tolerate_irresumable_model is
Truebut model_class isNonewhen the fallback is triggered.
Examples
>>> model, opt, sched = Mentee.resume_training( ... "checkpoint.pt", model_class=MyNet, device="cuda", lr=1e-4 ... ) >>> model.train_epoch(train_loader, opt, sched)
Start from scratch when no checkpoint exists yet:
>>> model, opt, sched = MyNet.resume_training(
… “run/checkpoint.pt”, … model_class=MyNet, … device=”cuda”, … pretrained=True, … lr=1e-4, … )
Helper functions
The following module-level helpers are used internally and exposed for advanced use.
- mentor.mentee._state_dict_architecture_lines(state_dict)[source]
Derive architecture stats from a state_dict without instantiating the model.
- mentor.mentee._get_software_snapshot()[source]
Collect a reproducibility snapshot of the current software environment.
Fields recorded
- python, torch, mentor_version
Version strings for the interpreter and key libraries.
- torchvision, numpy
Version strings when available;
"unavailable"otherwise.- cuda
torch.version.cuda(e.g."12.1") or"cpu"for CPU builds.- platform
OS description from
platform(e.g."Linux-6.1 x86_64").- hostname, user
Machine and user identity.
- main_script
Absolute path to
sys.argv[0]— the entry-point script.- git_hash
Full SHA-1 of
HEAD;"unavailable"when git is absent.- git_branch
Current branch name; helps locate the commit in a crowded history.
- git_remote
URL of the
originremote; identifies the repo/fork.- git_dirty
"true"when there are uncommitted changes (hash insufficient for exact reproduction),"false"otherwise.