Python Data Science Program
📓 Abrir notebook en GitHub

Clase 231 — Capstone 3: visión por computadora con transfer learning

Parte: 8 — Capstones · Fuente: Géron, Hands-On ML 3ª ed. cap. 14-15 + timm + PyTorch Lightning docs. ⏱️ Duración estimada: 180 min.

🎯 Objetivo

Construir un clasificador de imágenes de calidad producción usando transfer learning sobre un backbone moderno (ConvNeXt-tiny / EfficientNetV2-S / ViT-Base/16). Entrenar en dos fases (feature extraction → fine-tuning progresivo), aplicar augmentation moderna (RandAugment, MixUp, CutMix), evaluar con métricas per-clase + slice analysis, y servir vía ONNX + FastAPI con endpoint /predict que recibe imagen en base64. Cerrar el capstone con un fairness check si aplica.

📚 Resultados de aprendizaje

Al finalizar, el estudiante podrá:

🗺️ Fases del capstone

# Fase Por qué importa
1 Dataset + EDA Class imbalance, resolución variable, leakage train/test son los killers silenciosos.
2 Augmentation RandAugment + MixUp suben 2-4 puntos de accuracy gratis; sin esto, fine-tune overfittea en 3 epochs.
3 Backbone preentrenado ConvNeXt-tiny / EffNetV2-S / ViT-B/16 vía timm — pesos ImageNet-21k aceleran 10× la convergencia.
4 Fine-tuning progresivo Feature extraction (head only) → unfreeze último bloque → unfreeze full con LR diferencial (1e-5 backbone, 1e-3 head).
5 Evaluación Accuracy global engaña con imbalance; per-clase F1, confusion matrix y slice analysis revelan los puntos ciegos.
6 Serving Export a ONNX / TorchScript + FastAPI /predict con imagen base64 → JSON {class, prob}.

📖 Definiciones y características

📂 Dataset / recursos

🧪 Ejercicios

  1. Baseline tonto: entrenar una CNN de 2 capas conv from scratch (sin transfer). Reportar accuracy de validation. Esperá <60% en multiclase — establece el piso.
  2. Feature extraction: cargar timm.create_model('convnext_tiny', pretrained=True, num_classes=N). Congelar backbone (for p in model.parameters(): p.requires_grad = False), entrenar solo head 5 epochs. Esperá +20-30 puntos sobre el baseline.
  3. Fine-tuning progresivo: unfreeze último bloque → 5 epochs con LR 1e-4 → unfreeze full → 10 epochs con LR diferencial (1e-5 backbone, 1e-3 head). Loggear con W&B/MLflow.
  4. Augmentation ablation: comparar (a) sin aug, (b) flip + crop, (c) RandAugment, (d) RandAugment + MixUp + CutMix. Reportar curva val_acc.
  5. Serving end-to-end: exportar a ONNX (torch.onnx.export), validar con onnxruntime que la inferencia da la misma probabilidad ±1e-4, levantar FastAPI con endpoint POST /predict que reciba {"image_b64": "..."} y devuelva {"class": "...", "prob": 0.94}. Probar con curl.

📝 Homework verificable

Notebook + repo con:

  1. Dataset elegido + EDA (distribución de clases, resolución, ejemplos por clase).
  2. Pipeline Lightning con 3 fases de entrenamiento (feature extraction → unfreeze parcial → unfreeze full).
  3. Augmentation con Albumentations + verificación de label invariance.
  4. Reporte de métricas: accuracy global, per-class F1, confusion matrix, slice analysis sobre al menos 1 dimensión.
  5. Si el dataset tiene atributos sensibles (rostros, demografía): fairness check con disparate impact + equal opportunity (Clase 224).
  6. Export a ONNX + script serve.py con FastAPI /predict funcionando localmente.
  7. README del repo con resultados, decisiones de diseño y limitaciones.

Criterio de aceptación: accuracy de test > 90% en binario o > 75% en multiclase (10+ clases), el endpoint /predict responde en <500 ms en CPU, y el reporte identifica al menos 1 slice donde el modelo subperforma.

⚠️ Errores comunes

Síntoma / mensaje Causa y cómo arreglar
Accuracy de train sube a 99%, val se queda en 65% Overfitting clásico. Fix: subir augmentation (RandAugment + MixUp), bajar LR del head, EarlyStopping con patience=5, weight decay 0.05.
CUDA out of memory al hacer fine-tune full Batch demasiado grande para AMP off. Fix: activar precision="16-mixed" en Lightning + accumulate_grad_batches=4, o bajar image_size de 384→224.
Val accuracy es altísima pero el modelo en producción falla Leakage train/test (mismas imágenes duplicadas en ambos splits). Fix: deduplicar por hash perceptual (imagehash.phash) antes de splittear.
RuntimeError: shape mismatch al exportar a ONNX El head custom espera shape que el dummy_input no replica. Fix: usar dummy = torch.randn(1, 3, 224, 224) con batch_size=1 y dynamic_axes={'input': {0: 'batch'}}.
Predicciones ONNX difieren de PyTorch en 5to decimal Esperable (precisión float). Fix: validar con np.allclose(out_torch, out_onnx, atol=1e-4), no assert ==.
Una clase tiene F1=0.0 y nadie lo notó Class imbalance (esa clase es 2% del dataset). Fix: WeightedRandomSampler o class_weight en CrossEntropyLoss; mirar per-class F1 siempre, nunca solo accuracy global.

❓ Preguntas frecuentes

❓ ¿ConvNeXt, EfficientNetV2 o ViT?

Para datasets chicos (<10K imágenes), ConvNeXt-tiny o EfficientNetV2-S ganan: las CNN tienen inductive bias (locality, translation equivariance) que ViT no tiene y por eso ViT necesita 10× más datos para igualar. Para datasets grandes (>100K) o si tenés pretraining en ImageNet-21k, ViT-Base/16 suele ser top. Probá los 3 vía timm y comparalos — son 3 líneas de código cada uno.

❓ ¿Feature extraction o fine-tuning?

Feature extraction primero (más rápido, menos riesgo de catastrophic forgetting). Si la accuracy se estanca, hacer fine-tuning progresivo: descongelar el último bloque, después dos, después full, siempre con LR diferencial. Nunca descongelés todo en epoch 1 — el gradiente random del head corrompe los pesos preentrenados.

❓ ¿Por qué Lightning y no PyTorch puro?

Lightning te da gratis: AMP, multi-GPU, gradient accumulation, checkpointing, EarlyStopping, logging a W&B/MLflow, y seed_everything. En un capstone querés invertir el tiempo en el modelo, no en el training loop.

❓ ¿ONNX o TorchScript para serving?

ONNX si querés portabilidad (servir desde C++, Java, Node, navegador con onnxruntime-web). TorchScript si te quedás en Python/C++ con torch instalado y querés el path más simple. Para FastAPI en producción, ONNX + onnxruntime suele ser 2-3× más rápido que torch eager.

❓ ¿Fairness check siempre?

Solo si el dataset tiene atributos sensibles (rostros con edad/género/etnia, datos médicos con demografía). Si clasificás defectos industriales o platos de comida, no aplica. Cuando aplica, es obligatorio — releé Clase 224 antes de publicar el modelo.

🔗 Referencias

📥 Material descargable

➡️ Siguiente clase

Clase 232 — Portafolio público en GitHub Pages y presentación