HuggingFace Transfer Learning with mentor

This example fine-tunes a HuggingFace pretrained image classifier on Oxford Flowers-102 (102 flower categories, ~1020 training images) using mentor’s two-stage curriculum: frozen backbone first, then full fine-tune.

It is the primary guide for wrapping HuggingFace models as Mentee objects and covers the non-obvious complications that arise from doing so.

Quick start

# Stage 1 + 2 — fresh run (downloads model and dataset automatically)
python examples/hf/classify.py train

# Resume from where you left off
python examples/hf/classify.py train -resume_fname ./tmp/mobilenetv2_flowers102.mentor.pt

# Inference
python examples/hf/classify.py inference -img flower.jpg

Why Flowers-102?

Property

Value

Training images

1020 (~10 per class)

Validation images

1020

Classes

102 fine-grained flower species

Native resolution

variable, 300–800 px

ImageNet overlap

large (many flower classes present)

With only 10 labelled examples per class, training a classifier from scratch is hopeless. Pretrained ImageNet features (edges, textures, shapes) transfer directly, so a frozen MobileNetV2 backbone + a fresh linear head converges to useful accuracy within a single epoch. This gap is exactly what makes Flowers-102 a good knowledge-transfer benchmark.

Training recipe

Stage 1 (head only)

Stage 2 (full fine-tune)

Epochs

0–10

10–40

Backbone

frozen

unfrozen

LR

1e-3

1e-4

Optimizer

Adam

Adam

The stage boundary is stored in the checkpoint as current_epoch, so the script always resumes at the right stage with no manual bookkeeping.

Wrapping HuggingFace models as Mentee

The core pattern

from transformers import AutoModelForImageClassification
import mentor

model = AutoModelForImageClassification.from_pretrained(
    "google/mobilenet_v2_1.0_224",
    num_labels=102,
    ignore_mismatched_sizes=True,   # replaces the 1000-class head
)
model = mentor.wrap_as_mentee(
    model,
    trainer=HFClassifier,
    constructor_params={"config": model.config},  # required — see below
)

wrap_as_mentee injects the full Mentee API (fit, freeze, save, …) into the live HF model object without touching its weights or submodules.

Why constructor_params={"config": model.config} is required

When mentor saves a checkpoint it records class_name and class_module so it can reconstruct the model at resume time via:

model_class(**constructor_params)

For a plain nn.Module the constructor typically takes simple scalar arguments. HuggingFace models require a PretrainedConfig object as their first argument — passing {} results in a TypeError on resume.

The config object is picklable and self-contained: it captures num_labels, hidden sizes, and all architectural hyperparameters, so MobileNetV2ForImageClassification(config=checkpoint["constructor_params"]["config"]) faithfully reconstructs the architecture before the saved weights are loaded on top.

The classifier head must stay architecture-compatible

mentor reconstructs the model by calling model_class(config=...), then loading the saved state_dict. If the classifier head was replaced with a different nn.Module after construction (e.g. wrapping nn.Linear in nn.Sequential), the state_dict keys change and load_state_dict fails silently under the default tolerate_irresumable_model=True, returning a fresh random model instead.

# WRONG — Sequential changes state_dict keys (classifier.1.weight vs classifier.weight)
model.classifier = nn.Sequential(nn.Dropout(0.4), nn.Linear(in_features, 102))

# RIGHT — keep the architecture that from_pretrained produces
# num_labels=102 already gives nn.Linear(hidden_size, 102) as the head

If you need dropout in the head, use weight decay in the optimizer:

model.fit(..., weight_decay=1e-4)

Or subclass the model so the custom head is part of __init__ and therefore matches what model_class(config=...) produces.

HF models return structured outputs, not bare tensors

AutoModelForImageClassification returns an ImageClassifierOutput object, not a raw tensor. The default mentor.Classifier trainer calls loss_fn(model(x), y) which fails because loss_fn expects a tensor.

The fix is a one-method subclass that unwraps .logits:

class HFClassifier(mentor.Classifier):
    @classmethod
    def default_training_step(cls, model, batch, loss_fn=None):
        x, y = batch
        x, y = x.to(model.device), y.to(model.device)
        logits = model(pixel_values=x).logits   # unwrap structured output
        eff_fn = loss_fn if loss_fn is not None else F.cross_entropy
        loss = eff_fn(logits, y)
        acc = (logits.argmax(1) == y).float().mean().item()
        return loss, {"acc": acc, "loss": loss.item()}

Note pixel_values=x — HF image models use keyword arguments, not positional, for their primary input tensor.

Leeching preprocessing from the HF processor

HF ships an AutoImageProcessor alongside each model that encodes the correct resize size, centre-crop, mean, and std. Rather than hardcoding these values, extract them at startup and build a native torchvision transform pipeline:

def make_transform(processor):
    size = processor.size.get("shortest_edge", processor.size.get("height", 224))
    return tv.transforms.Compose([
        tv.transforms.Resize(size),
        tv.transforms.CenterCrop(size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
    ])

Calling the processor object directly inside the DataLoader workers (one call per sample) is 3–4× slower because of Python-level overhead and PIL → numpy → tensor round-trips. The torchvision-native pipeline above runs in C++ inside the worker processes and is the correct approach for training throughput.

Curriculum training: two-stage freeze / fine-tune

# Stage 1 — freeze backbone, train head only
model.freeze("mobilenet_v2")
model.fit(train_data, val_data=val_data, epochs=10, lr=1e-3, ...)

# Stage 2 — unfreeze backbone, fine-tune everything
model.unfreeze("mobilenet_v2")
model.fit(train_data, val_data=val_data, epochs=40, lr=1e-4, ...)

mentor.Mentee.freeze and unfreeze accept regex patterns matched against the full dotted layer path. Use mtr_checkpoint view to explore the layer tree:

mtr_checkpoint view -path ./tmp/mobilenetv2_flowers102.mentor.pt

The epochs argument to fit is a ceiling, not a count. After resuming at epoch 7, fit(epochs=10) runs 3 more epochs to reach epoch 10 — not 10 additional ones. This means the stage boundary logic works unchanged across resume:

if model.current_epoch < args.freeze_epochs:
    model.freeze("mobilenet_v2")
    model.fit(..., epochs=args.freeze_epochs, ...)

model.unfreeze("mobilenet_v2")
model.fit(..., epochs=args.total_epochs, ...)

Resuming a HuggingFace Mentee checkpoint

model = mentor.Mentee.resume(
    "./tmp/mobilenetv2_flowers102.mentor.pt",
    trainer=HFClassifier,   # trainer is not serialised — must be re-supplied
)

Two things are re-applied automatically at resume time:

  1. Mentee mixinwrap_as_mentee stores the original HF class name in the checkpoint. resume detects this and re-inserts Mentee into the MRO so the returned object is always a proper Mentee instance.

  2. Instance state — attributes like _train_history, _frozen_modules, and _lr_coefficients are injected before the state_dict is loaded, because HF model __init__ does not call Mentee.__init__ via super().

The trainer argument must be re-supplied because trainer instances are not serialised inside the checkpoint (they are stateless strategy objects and may live in user scripts that are not importable at resume time).

What cannot easily be done

Limitation

Workaround

Custom head architecture diverges from from_pretrained output

Subclass the HF model so the custom head is built in __init__

Optimizer state not restored by Mentee.resume

Use Mentee.resume_training instead

Multi-modal HF models (e.g. CLIP, LLaVA)

Override default_training_step and default_validate_step for each input modality

HF models with multiple outputs (detection, segmentation)

Same — the .logits unwrap pattern extends naturally

Source

  1"""Fine-tune a HuggingFace pretrained image classifier on Oxford Flowers-102.
  2
  3Demonstrates two-stage transfer learning with mentor:
  4
  5  Stage 1 (epochs 0–freeze_epochs)
  6      Backbone frozen, only the classification head is updated.
  7      Fast convergence — ImageNet features transfer directly to flower species.
  8
  9  Stage 2 (epochs freeze_epochs–total_epochs)
 10      Full fine-tune with the backbone unfrozen at a lower learning rate.
 11
 12Usage::
 13
 14    # fresh run (downloads model and dataset automatically)
 15    python examples/hf/classify.py train
 16
 17    # resume from checkpoint
 18    python examples/hf/classify.py train -resume_fname ./tmp/mobilenetv2_flowers102.mentor.pt
 19
 20    # inference on one or more images
 21    python examples/hf/classify.py inference -img flower.jpg
 22
 23See examples/hf/README.md for a full explanation of the HuggingFace wrapping
 24pattern, preprocessing choices, and resume behaviour.
 25"""
 26import os
 27import time
 28from typing import Any, Dict, List, Optional, Tuple
 29
 30import torch
 31import torch.nn.functional as F
 32import torchvision as tv
 33from PIL import Image
 34from transformers import AutoImageProcessor, AutoModelForImageClassification
 35
 36import fargv
 37import mentor
 38
 39
 40# ---------------------------------------------------------------------------
 41# Preprocessing
 42# ---------------------------------------------------------------------------
 43
 44def make_transform(processor: Any) -> tv.transforms.Compose:
 45    """Build a torchvision transform pipeline from a HuggingFace processor.
 46
 47    Extracts resize size, mean, and std from the processor config so the
 48    DataLoader preprocessing always matches the model's expected input —
 49    no hardcoded values.  Using native torchvision ops (not the processor
 50    directly) keeps DataLoader throughput 3-4x higher.
 51    """
 52    size: int = processor.size.get("shortest_edge", processor.size.get("height", 224))
 53    return tv.transforms.Compose([
 54        tv.transforms.Resize(size),
 55        tv.transforms.CenterCrop(size),
 56        tv.transforms.ToTensor(),
 57        tv.transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
 58    ])
 59
 60
 61def load_train_dataloaders(
 62    processor: Any,
 63    batch_size: int = 32,
 64    num_workers: int = 4,
 65) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
 66    """Return (train_loader, val_loader) for Oxford Flowers-102."""
 67    transform = make_transform(processor)
 68    train_ds = tv.datasets.Flowers102(
 69        root="./tmp/flowers102", split="train", download=True, transform=transform
 70    )
 71    val_ds = tv.datasets.Flowers102(
 72        root="./tmp/flowers102", split="val", download=True, transform=transform
 73    )
 74    train_loader = torch.utils.data.DataLoader(
 75        train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers
 76    )
 77    val_loader = torch.utils.data.DataLoader(
 78        val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers
 79    )
 80    return train_loader, val_loader
 81
 82
 83# ---------------------------------------------------------------------------
 84# Trainer
 85# ---------------------------------------------------------------------------
 86
 87class HFClassifier(mentor.Classifier):
 88    """Classifier trainer for HuggingFace image classification models.
 89
 90    The only difference from the built-in :class:`mentor.Classifier` is that
 91    HF models return a structured ``ImageClassifierOutput`` rather than a bare
 92    tensor, so ``.logits`` must be unwrapped before computing the loss.
 93    """
 94
 95    @classmethod
 96    def default_training_step(
 97        cls,
 98        model: Any,
 99        batch: Any,
100        loss_fn: Optional[Any] = None,
101    ) -> Tuple[torch.Tensor, Dict[str, float]]:
102        x, y = batch
103        x, y = x.to(model.device), y.to(model.device)
104        logits = model(pixel_values=x).logits   # unwrap HF structured output
105        eff_fn = loss_fn if loss_fn is not None else F.cross_entropy
106        loss = eff_fn(logits, y)
107        acc = (logits.argmax(1) == y).float().mean().item()
108        return loss, {"acc": acc, "loss": loss.item()}
109
110
111# ---------------------------------------------------------------------------
112# Model / processor helpers
113# ---------------------------------------------------------------------------
114
115def _load_or_cache_model(
116    model_id: str,
117    hf_cache: str,
118    num_labels: int = 102,
119) -> Tuple[Any, Any]:
120    """Return ``(model, processor)``, downloading from HF Hub on first call.
121
122    On subsequent calls the model and processor are loaded from *hf_cache*,
123    avoiding repeated Hub downloads and allowing offline use.
124    """
125    if not os.path.exists(hf_cache):
126        model = AutoModelForImageClassification.from_pretrained(
127            model_id, num_labels=num_labels, ignore_mismatched_sizes=True
128        )
129        processor = AutoImageProcessor.from_pretrained(model_id)
130        os.makedirs(hf_cache, exist_ok=True)
131        model.save_pretrained(hf_cache)
132        processor.save_pretrained(hf_cache)
133    else:
134        model = AutoModelForImageClassification.from_pretrained(hf_cache)
135        processor = AutoImageProcessor.from_pretrained(hf_cache)
136    return model, processor
137
138
139# ---------------------------------------------------------------------------
140# Entry point
141# ---------------------------------------------------------------------------
142
143def main_train() -> None:
144    args, _ = fargv.parse({
145        "hf_cache":     "./tmp/mobilenetv2.hf",
146        "resume_fname": "./tmp/mobilenetv2_flowers102.mentor.pt",
147        "device":       "cuda" if torch.cuda.is_available() else "cpu",
148        "cmd": {
149            "train": {
150                "model_id":        "google/mobilenet_v2_1.0_224",
151                "freeze_epochs":   10,
152                "total_epochs":    40,
153                "lr":              1e-3,
154                "num_workers":     4,
155                "batch_size":      8,
156                "pseudo_batch_size": 4,
157            },
158            "inference": {
159                "img": [],
160            },
161        },
162    })
163
164    if args.cmd == "train":
165        if os.path.exists(args.resume_fname):
166            # Resume: load Mentee checkpoint, re-supply the trainer (not serialised).
167            model = mentor.Mentee.resume(args.resume_fname, trainer=HFClassifier)
168            _, processor = _load_or_cache_model(args.model_id, args.hf_cache)
169        else:
170            # Fresh start: download/cache HF model, wrap as Mentee.
171            # constructor_params={"config": model.config} is required so that
172            # Mentee.resume can reconstruct the HF architecture at resume time.
173            model, processor = _load_or_cache_model(args.model_id, args.hf_cache)
174            model = mentor.wrap_as_mentee(
175                model,
176                trainer=HFClassifier,
177                constructor_params={"config": model.config},
178            )
179
180        model = model.to(args.device)
181        train_loader, val_loader = load_train_dataloaders(
182            processor, batch_size=args.batch_size, num_workers=args.num_workers
183        )
184
185        # Stage 1: train head only (backbone frozen)
186        # fit(epochs=N) is a ceiling — resumes seamlessly from any epoch < N.
187        if model.current_epoch < args.freeze_epochs:
188            model.freeze("mobilenet_v2")
189            model.fit(
190                train_data=train_loader,
191                val_data=val_loader,
192                epochs=args.freeze_epochs,
193                pseudo_batch_size=args.pseudo_batch_size,
194                lr=args.lr,
195                checkpoint_path=args.resume_fname,
196                verbose=args.verbosity > 0,
197                num_workers=args.num_workers,
198            )
199
200        # Stage 2: full fine-tune (backbone unfrozen, lower LR)
201        model.unfreeze("mobilenet_v2")
202        model.fit(
203            train_data=train_loader,
204            val_data=val_loader,
205            epochs=args.total_epochs,
206            pseudo_batch_size=args.pseudo_batch_size,
207            lr=args.lr / 10,
208            checkpoint_path=args.resume_fname,
209            verbose=args.verbosity > 0,
210            num_workers=args.num_workers,
211        )
212
213    elif args.cmd == "inference":
214        model = mentor.Mentee.resume(args.resume_fname, trainer=HFClassifier)
215        _, processor = _load_or_cache_model("", args.hf_cache)
216        model = model.to(args.device)
217        model.eval()
218        t = time.time()
219        with torch.no_grad():
220            for img_path in args.img:
221                img = Image.open(img_path).convert("RGB")
222                inputs = {
223                    k: v.to(args.device)
224                    for k, v in processor(images=img, return_tensors="pt").items()
225                }
226                logits = model(**inputs).logits
227                pred = logits.argmax(1).item()
228                if args.verbosity > 1:
229                    print(f"{time.time() - t:.2f}s: ", end="")
230                print(f"{img_path}: class {pred}")
231        if args.verbosity > 0:
232            print(f"Inferred {len(args.img)} images in {time.time() - t:.2f}s")
233
234    else:
235        raise ValueError(f"Unknown command {args.cmd!r}")
236
237
238if __name__ == "__main__":
239    main_train()