HuggingFace Transfer Learning with mentor
This example fine-tunes a HuggingFace pretrained image classifier on
Oxford Flowers-102 (102 flower categories, ~1020 training images) using
mentor’s two-stage curriculum: frozen backbone first, then full fine-tune.
It is the primary guide for wrapping HuggingFace models as Mentee objects
and covers the non-obvious complications that arise from doing so.
Quick start
# Stage 1 + 2 — fresh run (downloads model and dataset automatically)
python examples/hf/classify.py train
# Resume from where you left off
python examples/hf/classify.py train -resume_fname ./tmp/mobilenetv2_flowers102.mentor.pt
# Inference
python examples/hf/classify.py inference -img flower.jpg
Why Flowers-102?
Property |
Value |
|---|---|
Training images |
1020 (~10 per class) |
Validation images |
1020 |
Classes |
102 fine-grained flower species |
Native resolution |
variable, 300–800 px |
ImageNet overlap |
large (many flower classes present) |
With only 10 labelled examples per class, training a classifier from scratch is hopeless. Pretrained ImageNet features (edges, textures, shapes) transfer directly, so a frozen MobileNetV2 backbone + a fresh linear head converges to useful accuracy within a single epoch. This gap is exactly what makes Flowers-102 a good knowledge-transfer benchmark.
Training recipe
Stage 1 (head only) |
Stage 2 (full fine-tune) |
|
|---|---|---|
Epochs |
0–10 |
10–40 |
Backbone |
frozen |
unfrozen |
LR |
1e-3 |
1e-4 |
Optimizer |
Adam |
Adam |
The stage boundary is stored in the checkpoint as current_epoch, so the
script always resumes at the right stage with no manual bookkeeping.
Wrapping HuggingFace models as Mentee
The core pattern
from transformers import AutoModelForImageClassification
import mentor
model = AutoModelForImageClassification.from_pretrained(
"google/mobilenet_v2_1.0_224",
num_labels=102,
ignore_mismatched_sizes=True, # replaces the 1000-class head
)
model = mentor.wrap_as_mentee(
model,
trainer=HFClassifier,
constructor_params={"config": model.config}, # required — see below
)
wrap_as_mentee injects the full Mentee API (fit, freeze, save, …)
into the live HF model object without touching its weights or submodules.
Why constructor_params={"config": model.config} is required
When mentor saves a checkpoint it records class_name and class_module
so it can reconstruct the model at resume time via:
model_class(**constructor_params)
For a plain nn.Module the constructor typically takes simple scalar
arguments. HuggingFace models require a PretrainedConfig object as their
first argument — passing {} results in a TypeError on resume.
The config object is picklable and self-contained: it captures
num_labels, hidden sizes, and all architectural hyperparameters, so
MobileNetV2ForImageClassification(config=checkpoint["constructor_params"]["config"])
faithfully reconstructs the architecture before the saved weights are loaded
on top.
The classifier head must stay architecture-compatible
mentor reconstructs the model by calling model_class(config=...), then
loading the saved state_dict. If the classifier head was replaced with a
different nn.Module after construction (e.g. wrapping nn.Linear in
nn.Sequential), the state_dict keys change and load_state_dict fails
silently under the default tolerate_irresumable_model=True, returning a
fresh random model instead.
# WRONG — Sequential changes state_dict keys (classifier.1.weight vs classifier.weight)
model.classifier = nn.Sequential(nn.Dropout(0.4), nn.Linear(in_features, 102))
# RIGHT — keep the architecture that from_pretrained produces
# num_labels=102 already gives nn.Linear(hidden_size, 102) as the head
If you need dropout in the head, use weight decay in the optimizer:
model.fit(..., weight_decay=1e-4)
Or subclass the model so the custom head is part of __init__ and therefore
matches what model_class(config=...) produces.
HF models return structured outputs, not bare tensors
AutoModelForImageClassification returns an ImageClassifierOutput object,
not a raw tensor. The default mentor.Classifier trainer calls
loss_fn(model(x), y) which fails because loss_fn expects a tensor.
The fix is a one-method subclass that unwraps .logits:
class HFClassifier(mentor.Classifier):
@classmethod
def default_training_step(cls, model, batch, loss_fn=None):
x, y = batch
x, y = x.to(model.device), y.to(model.device)
logits = model(pixel_values=x).logits # unwrap structured output
eff_fn = loss_fn if loss_fn is not None else F.cross_entropy
loss = eff_fn(logits, y)
acc = (logits.argmax(1) == y).float().mean().item()
return loss, {"acc": acc, "loss": loss.item()}
Note pixel_values=x — HF image models use keyword arguments, not
positional, for their primary input tensor.
Leeching preprocessing from the HF processor
HF ships an AutoImageProcessor alongside each model that encodes the
correct resize size, centre-crop, mean, and std. Rather than hardcoding
these values, extract them at startup and build a native torchvision
transform pipeline:
def make_transform(processor):
size = processor.size.get("shortest_edge", processor.size.get("height", 224))
return tv.transforms.Compose([
tv.transforms.Resize(size),
tv.transforms.CenterCrop(size),
tv.transforms.ToTensor(),
tv.transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
])
Calling the processor object directly inside the DataLoader workers (one call per sample) is 3–4× slower because of Python-level overhead and PIL → numpy → tensor round-trips. The torchvision-native pipeline above runs in C++ inside the worker processes and is the correct approach for training throughput.
Curriculum training: two-stage freeze / fine-tune
# Stage 1 — freeze backbone, train head only
model.freeze("mobilenet_v2")
model.fit(train_data, val_data=val_data, epochs=10, lr=1e-3, ...)
# Stage 2 — unfreeze backbone, fine-tune everything
model.unfreeze("mobilenet_v2")
model.fit(train_data, val_data=val_data, epochs=40, lr=1e-4, ...)
mentor.Mentee.freeze and unfreeze accept regex patterns matched against
the full dotted layer path. Use mtr_checkpoint view to explore the layer
tree:
mtr_checkpoint view -path ./tmp/mobilenetv2_flowers102.mentor.pt
The epochs argument to fit is a ceiling, not a count. After resuming
at epoch 7, fit(epochs=10) runs 3 more epochs to reach epoch 10 — not 10
additional ones. This means the stage boundary logic works unchanged across
resume:
if model.current_epoch < args.freeze_epochs:
model.freeze("mobilenet_v2")
model.fit(..., epochs=args.freeze_epochs, ...)
model.unfreeze("mobilenet_v2")
model.fit(..., epochs=args.total_epochs, ...)
Resuming a HuggingFace Mentee checkpoint
model = mentor.Mentee.resume(
"./tmp/mobilenetv2_flowers102.mentor.pt",
trainer=HFClassifier, # trainer is not serialised — must be re-supplied
)
Two things are re-applied automatically at resume time:
Mentee mixin —
wrap_as_menteestores the original HF class name in the checkpoint.resumedetects this and re-insertsMenteeinto the MRO so the returned object is always a properMenteeinstance.Instance state — attributes like
_train_history,_frozen_modules, and_lr_coefficientsare injected before thestate_dictis loaded, because HF model__init__does not callMentee.__init__viasuper().
The trainer argument must be re-supplied because trainer instances are not
serialised inside the checkpoint (they are stateless strategy objects and may
live in user scripts that are not importable at resume time).
What cannot easily be done
Limitation |
Workaround |
|---|---|
Custom head architecture diverges from |
Subclass the HF model so the custom head is built in |
Optimizer state not restored by |
Use |
Multi-modal HF models (e.g. CLIP, LLaVA) |
Override |
HF models with multiple outputs (detection, segmentation) |
Same — the |
Source
1"""Fine-tune a HuggingFace pretrained image classifier on Oxford Flowers-102.
2
3Demonstrates two-stage transfer learning with mentor:
4
5 Stage 1 (epochs 0–freeze_epochs)
6 Backbone frozen, only the classification head is updated.
7 Fast convergence — ImageNet features transfer directly to flower species.
8
9 Stage 2 (epochs freeze_epochs–total_epochs)
10 Full fine-tune with the backbone unfrozen at a lower learning rate.
11
12Usage::
13
14 # fresh run (downloads model and dataset automatically)
15 python examples/hf/classify.py train
16
17 # resume from checkpoint
18 python examples/hf/classify.py train -resume_fname ./tmp/mobilenetv2_flowers102.mentor.pt
19
20 # inference on one or more images
21 python examples/hf/classify.py inference -img flower.jpg
22
23See examples/hf/README.md for a full explanation of the HuggingFace wrapping
24pattern, preprocessing choices, and resume behaviour.
25"""
26import os
27import time
28from typing import Any, Dict, List, Optional, Tuple
29
30import torch
31import torch.nn.functional as F
32import torchvision as tv
33from PIL import Image
34from transformers import AutoImageProcessor, AutoModelForImageClassification
35
36import fargv
37import mentor
38
39
40# ---------------------------------------------------------------------------
41# Preprocessing
42# ---------------------------------------------------------------------------
43
44def make_transform(processor: Any) -> tv.transforms.Compose:
45 """Build a torchvision transform pipeline from a HuggingFace processor.
46
47 Extracts resize size, mean, and std from the processor config so the
48 DataLoader preprocessing always matches the model's expected input —
49 no hardcoded values. Using native torchvision ops (not the processor
50 directly) keeps DataLoader throughput 3-4x higher.
51 """
52 size: int = processor.size.get("shortest_edge", processor.size.get("height", 224))
53 return tv.transforms.Compose([
54 tv.transforms.Resize(size),
55 tv.transforms.CenterCrop(size),
56 tv.transforms.ToTensor(),
57 tv.transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
58 ])
59
60
61def load_train_dataloaders(
62 processor: Any,
63 batch_size: int = 32,
64 num_workers: int = 4,
65) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
66 """Return (train_loader, val_loader) for Oxford Flowers-102."""
67 transform = make_transform(processor)
68 train_ds = tv.datasets.Flowers102(
69 root="./tmp/flowers102", split="train", download=True, transform=transform
70 )
71 val_ds = tv.datasets.Flowers102(
72 root="./tmp/flowers102", split="val", download=True, transform=transform
73 )
74 train_loader = torch.utils.data.DataLoader(
75 train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers
76 )
77 val_loader = torch.utils.data.DataLoader(
78 val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers
79 )
80 return train_loader, val_loader
81
82
83# ---------------------------------------------------------------------------
84# Trainer
85# ---------------------------------------------------------------------------
86
87class HFClassifier(mentor.Classifier):
88 """Classifier trainer for HuggingFace image classification models.
89
90 The only difference from the built-in :class:`mentor.Classifier` is that
91 HF models return a structured ``ImageClassifierOutput`` rather than a bare
92 tensor, so ``.logits`` must be unwrapped before computing the loss.
93 """
94
95 @classmethod
96 def default_training_step(
97 cls,
98 model: Any,
99 batch: Any,
100 loss_fn: Optional[Any] = None,
101 ) -> Tuple[torch.Tensor, Dict[str, float]]:
102 x, y = batch
103 x, y = x.to(model.device), y.to(model.device)
104 logits = model(pixel_values=x).logits # unwrap HF structured output
105 eff_fn = loss_fn if loss_fn is not None else F.cross_entropy
106 loss = eff_fn(logits, y)
107 acc = (logits.argmax(1) == y).float().mean().item()
108 return loss, {"acc": acc, "loss": loss.item()}
109
110
111# ---------------------------------------------------------------------------
112# Model / processor helpers
113# ---------------------------------------------------------------------------
114
115def _load_or_cache_model(
116 model_id: str,
117 hf_cache: str,
118 num_labels: int = 102,
119) -> Tuple[Any, Any]:
120 """Return ``(model, processor)``, downloading from HF Hub on first call.
121
122 On subsequent calls the model and processor are loaded from *hf_cache*,
123 avoiding repeated Hub downloads and allowing offline use.
124 """
125 if not os.path.exists(hf_cache):
126 model = AutoModelForImageClassification.from_pretrained(
127 model_id, num_labels=num_labels, ignore_mismatched_sizes=True
128 )
129 processor = AutoImageProcessor.from_pretrained(model_id)
130 os.makedirs(hf_cache, exist_ok=True)
131 model.save_pretrained(hf_cache)
132 processor.save_pretrained(hf_cache)
133 else:
134 model = AutoModelForImageClassification.from_pretrained(hf_cache)
135 processor = AutoImageProcessor.from_pretrained(hf_cache)
136 return model, processor
137
138
139# ---------------------------------------------------------------------------
140# Entry point
141# ---------------------------------------------------------------------------
142
143def main_train() -> None:
144 args, _ = fargv.parse({
145 "hf_cache": "./tmp/mobilenetv2.hf",
146 "resume_fname": "./tmp/mobilenetv2_flowers102.mentor.pt",
147 "device": "cuda" if torch.cuda.is_available() else "cpu",
148 "cmd": {
149 "train": {
150 "model_id": "google/mobilenet_v2_1.0_224",
151 "freeze_epochs": 10,
152 "total_epochs": 40,
153 "lr": 1e-3,
154 "num_workers": 4,
155 "batch_size": 8,
156 "pseudo_batch_size": 4,
157 },
158 "inference": {
159 "img": [],
160 },
161 },
162 })
163
164 if args.cmd == "train":
165 if os.path.exists(args.resume_fname):
166 # Resume: load Mentee checkpoint, re-supply the trainer (not serialised).
167 model = mentor.Mentee.resume(args.resume_fname, trainer=HFClassifier)
168 _, processor = _load_or_cache_model(args.model_id, args.hf_cache)
169 else:
170 # Fresh start: download/cache HF model, wrap as Mentee.
171 # constructor_params={"config": model.config} is required so that
172 # Mentee.resume can reconstruct the HF architecture at resume time.
173 model, processor = _load_or_cache_model(args.model_id, args.hf_cache)
174 model = mentor.wrap_as_mentee(
175 model,
176 trainer=HFClassifier,
177 constructor_params={"config": model.config},
178 )
179
180 model = model.to(args.device)
181 train_loader, val_loader = load_train_dataloaders(
182 processor, batch_size=args.batch_size, num_workers=args.num_workers
183 )
184
185 # Stage 1: train head only (backbone frozen)
186 # fit(epochs=N) is a ceiling — resumes seamlessly from any epoch < N.
187 if model.current_epoch < args.freeze_epochs:
188 model.freeze("mobilenet_v2")
189 model.fit(
190 train_data=train_loader,
191 val_data=val_loader,
192 epochs=args.freeze_epochs,
193 pseudo_batch_size=args.pseudo_batch_size,
194 lr=args.lr,
195 checkpoint_path=args.resume_fname,
196 verbose=args.verbosity > 0,
197 num_workers=args.num_workers,
198 )
199
200 # Stage 2: full fine-tune (backbone unfrozen, lower LR)
201 model.unfreeze("mobilenet_v2")
202 model.fit(
203 train_data=train_loader,
204 val_data=val_loader,
205 epochs=args.total_epochs,
206 pseudo_batch_size=args.pseudo_batch_size,
207 lr=args.lr / 10,
208 checkpoint_path=args.resume_fname,
209 verbose=args.verbosity > 0,
210 num_workers=args.num_workers,
211 )
212
213 elif args.cmd == "inference":
214 model = mentor.Mentee.resume(args.resume_fname, trainer=HFClassifier)
215 _, processor = _load_or_cache_model("", args.hf_cache)
216 model = model.to(args.device)
217 model.eval()
218 t = time.time()
219 with torch.no_grad():
220 for img_path in args.img:
221 img = Image.open(img_path).convert("RGB")
222 inputs = {
223 k: v.to(args.device)
224 for k, v in processor(images=img, return_tensors="pt").items()
225 }
226 logits = model(**inputs).logits
227 pred = logits.argmax(1).item()
228 if args.verbosity > 1:
229 print(f"{time.time() - t:.2f}s: ", end="")
230 print(f"{img_path}: class {pred}")
231 if args.verbosity > 0:
232 print(f"Inferred {len(args.img)} images in {time.time() - t:.2f}s")
233
234 else:
235 raise ValueError(f"Unknown command {args.cmd!r}")
236
237
238if __name__ == "__main__":
239 main_train()