Python Data Science Program
📓 Abrir notebook en GitHub

Clase 144 — Flash Attention v2/v3, RoPE, GQA: el motor de los LLMs modernos

Parte: 2 — Deep Learning · Fuente: Dao et al. (2022, 2023, 2024) FlashAttention + Su et al. (2021) RoPE + Ainslie et al. (2023) GQA. ⏱️ Duración estimada: 90 min.

🎯 Objetivo

Entender en profundidad las 3 piezas técnicas que hacen que un LLM moderno (Llama 3, Mistral, Qwen, Gemma) sea rápido y memory-efficient: Flash Attention v2/v3 (O(N) memoria + 2-3× speedup), Rotary Position Embeddings (RoPE) (mejor extrapolación), Grouped-Query Attention (GQA) (menos KV cache en inference).

📚 Resultados de aprendizaje

Al finalizar, el estudiante podrá:

🗺️ Temas

📖 Definiciones y características

📂 Dataset / recursos

🧪 Ejercicios

  1. SDPA vs naïve: implementar attention naïve y F.scaled_dot_product_attention. Benchmark.
  2. RoPE: implementar rotation function, verificar propiedad attention(R_θ q, R_φ k) = f(θ - φ).
  3. GQA Vs MHA: con Llama config (n_heads=32, kv_heads=8), inspeccionar shapes.
  4. KV cache: medir VRAM en inference con secuencia 8192 — comparar MHA vs GQA.
  5. FlashAttention v3 en H100: si tenés H100, benchmark vs v2.

📝 Homework verificable

Mini-GPT con piezas modernas:

  1. 6-layer Transformer con: RMSNorm, GQA (4 KV heads / 8 Q heads), RoPE, SwiGLU FFN.
  2. Train next-token sobre Tiny Shakespeare.
  3. Comparar contra mini-GPT clásico (LayerNorm + MHA + Sin PE + GELU FFN).

Criterio de aceptación: mini-GPT moderno entrena más estable + menor memoria; quality comparable.

⚠️ Errores comunes

Síntoma / mensaje Causa y cómo arreglar
flash-attn no instala en CUDA viejo Requiere CUDA 11.6+, GPU Ampere+. Fix: usar SDPA built-in de PyTorch 2.0+.
RoPE con base distinta a 10000 Para extrapolación a contextos largos (32k+) ajustar base (NTK scaling).
MQA da peor calidad Esperado. Fix: GQA es el compromiso.
KV cache OOM en context largo Inherente. Fix: GQA + quantization (Q8 KV).
is_causal=True en SDPA solo aplica si tensor square Fix: passing attn_mask cuando shapes asimétricos.

❓ Preguntas frecuentes

❓ FlashAttention v2 o v3?

v3 si tenés H100. v2 estable para todo lo demás. SDPA de PyTorch elige el mejor disponible.

❓ RoPE absoluto o relativo?

RoPE codifica posición absoluta pero produce comportamiento relativo en attention. Lo mejor de ambos.

❓ GQA en training?

Sí — entrenar con GQA desde el principio. Llama 2 70B y todos los modernos lo hacen.

❓ Combina con sliding window attention?

Sí — Mistral 7B usa GQA + sliding window. Para contextos infinitos.

❓ ¿Y para CV (ViT)?

ViT moderno también usa Flash Attention (timm support). RoPE en algunos (DiT). GQA menos común en CV.

🔗 Referencias

📥 Material descargable

➡️ Siguiente clase

Clase 145 — Hugging Face Transformers (uso práctico)