Distillation de Modèles
Transférer les connaissances d'un grand modèle vers un modèle plus petit et efficace - techniques, architectures et cas d'usage
Distillation de Modèles
La distillation de connaissances (Knowledge Distillation) permet de créer un modèle compact qui reproduit le comportement d'un modèle plus grand, tout en étant plus rapide et moins coûteux.
Principe
┌─────────────────────────────────────────────────────────────────┐
│ KNOWLEDGE DISTILLATION │
├─────────────────────────────────────────────────────────────────┤
│ │
│ TEACHER MODEL STUDENT MODEL │
│ (Grand, lent) (Petit, rapide) │
│ │
│ ┌───────────────┐ ┌───────────────┐ │
│ │ │ │ │ │
│ │ GPT-4 │ ────────► │ GPT-4-mini │ │
│ │ 175B │ Transfer │ 8B │ │
│ │ │ │ │ │
│ └───────────────┘ └───────────────┘ │
│ │
│ Précision: 95% Précision: 92% │
│ Latence: 500ms Latence: 50ms │
│ Coût: $$$ Coût: $ │
│ │
└─────────────────────────────────────────────────────────────────┘Types de distillation
1. Response-based Distillation
Le student apprend à reproduire les sorties du teacher.
import torch
import torch.nn.functional as F
def distillation_loss(student_logits, teacher_logits, labels, temperature=2.0, alpha=0.5):
"""
Combine soft labels (teacher) et hard labels (ground truth)
"""
# Soft targets: distribution de probabilité du teacher
soft_targets = F.softmax(teacher_logits / temperature, dim=-1)
soft_predictions = F.log_softmax(student_logits / temperature, dim=-1)
# KL Divergence entre student et teacher
soft_loss = F.kl_div(
soft_predictions,
soft_targets,
reduction='batchmean'
) * (temperature ** 2)
# Cross-entropy avec les vrais labels
hard_loss = F.cross_entropy(student_logits, labels)
# Combinaison pondérée
return alpha * soft_loss + (1 - alpha) * hard_loss2. Feature-based Distillation
Le student apprend les représentations intermédiaires du teacher.
def feature_distillation_loss(student_features, teacher_features):
"""
Aligne les représentations internes
"""
# Projection si dimensions différentes
if student_features.shape != teacher_features.shape:
projection = nn.Linear(student_features.shape[-1], teacher_features.shape[-1])
student_features = projection(student_features)
# MSE sur les features
return F.mse_loss(student_features, teacher_features)3. Relation-based Distillation
Le student apprend les relations entre exemples.
┌─────────────────────────────────────────────────────────────────┐
│ RELATION-BASED DISTILLATION │
├─────────────────────────────────────────────────────────────────┤
│ │
│ Teacher voit: Student apprend: │
│ │
│ "chat" ←──similaire──► "chaton" Mêmes relations │
│ │ │ entre embeddings │
│ différent similaire │
│ │ │ │
│ ▼ ▼ │
│ "voiture" "félin" │
│ │
└─────────────────────────────────────────────────────────────────┘Distillation pour LLM
Dataset synthétique
Générer des données d'entraînement avec le teacher.
from transformers import pipeline
# Teacher model (API ou local)
teacher = pipeline("text-generation", model="gpt-4")
def generate_training_data(prompts):
"""Génère des exemples avec le teacher"""
dataset = []
for prompt in prompts:
response = teacher(
prompt,
max_new_tokens=512,
temperature=0.7,
)[0]["generated_text"]
dataset.append({
"prompt": prompt,
"response": response,
})
return dataset
# Créer le dataset
prompts = load_prompts("prompts.txt") # 10K+ prompts diversifiés
synthetic_data = generate_training_data(prompts)Entraînement du student
from transformers import AutoModelForCausalLM, Trainer, TrainingArguments
from datasets import Dataset
# Student model (plus petit)
student = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
# Dataset synthétique
dataset = Dataset.from_list(synthetic_data)
# Fine-tuning sur les réponses du teacher
training_args = TrainingArguments(
output_dir="./distilled-model",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=8,
learning_rate=2e-5,
warmup_ratio=0.1,
fp16=True,
)
trainer = Trainer(
model=student,
args=training_args,
train_dataset=dataset,
)
trainer.train()Exemples de modèles distillés
DistilBERT
┌─────────────────────────────────────────────────────────────────┐
│ DISTILBERT │
├─────────────────────────────────────────────────────────────────┤
│ │
│ BERT-base DistilBERT │
│ ├── 12 layers ──► ├── 6 layers │
│ ├── 110M params ├── 66M params (-40%) │
│ ├── 100% perf ├── 97% perf │
│ └── 417ms └── 218ms (2x faster) │
│ │
└─────────────────────────────────────────────────────────────────┘from transformers import DistilBertModel, DistilBertTokenizer
model = DistilBertModel.from_pretrained("distilbert-base-uncased")
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")Phi (Microsoft)
Petit modèle entraîné sur données synthétiques de GPT-4.
| Modèle | Params | Source données | Benchmark |
|---|---|---|---|
| Phi-1 | 1.3B | Textbooks (GPT-4) | HumanEval: 50% |
| Phi-2 | 2.7B | Synthetic + Web | MMLU: 56% |
| Phi-3 | 3.8B | Filtered + Synthetic | MMLU: 69% |
Orca / Orca 2 (Microsoft)
Distillation des capacités de raisonnement.
┌─────────────────────────────────────────────────────────────────┐
│ ORCA: PROGRESSIVE LEARNING │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 1. Collecter traces de raisonnement de GPT-4 │
│ │
│ User: "Combien font 17 × 24?" │
│ GPT-4: "Je décompose: 17 × 24 = 17 × (20 + 4) │
│ = 17 × 20 + 17 × 4 │
│ = 340 + 68 │
│ = 408" │
│ │
│ 2. Entraîner student sur ces traces │
│ │
│ 3. Student apprend le "comment" pas juste le "quoi" │
│ │
└─────────────────────────────────────────────────────────────────┘Alpaca / Vicuna
Distillation de ChatGPT/GPT-4 vers LLaMA.
# Format Alpaca
{
"instruction": "Explique la photosynthèse",
"input": "",
"output": "La photosynthèse est le processus par lequel..."
}
# 52K exemples générés par text-davinci-003
# Coût: ~$500 pour créer le datasetTechniques avancées
Chain-of-Thought Distillation
Transférer les capacités de raisonnement.
def generate_cot_data(teacher, problems):
"""Génère des exemples avec raisonnement explicite"""
data = []
for problem in problems:
prompt = f"""Résous ce problème étape par étape:
{problem}
Réflexion:"""
response = teacher.generate(prompt)
data.append({
"problem": problem,
"reasoning": response,
})
return dataSelf-Distillation
Le modèle s'améliore en apprenant de ses propres meilleures réponses.
┌─────────────────────────────────────────────────────────────────┐
│ SELF-DISTILLATION │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 1. Générer N réponses par prompt │
│ 2. Scorer les réponses (reward model) │
│ 3. Garder les meilleures │
│ 4. Fine-tuner sur les meilleures réponses │
│ 5. Répéter │
│ │
│ Iteration 1 ──► Iteration 2 ──► Iteration 3 │
│ Score: 0.6 Score: 0.75 Score: 0.85 │
│ │
└─────────────────────────────────────────────────────────────────┘Constitutional AI Distillation
Intégrer des principes éthiques pendant la distillation.
constitution = [
"Sois honnête et factuel",
"Évite les contenus nuisibles",
"Respecte la vie privée",
]
def constitutional_filter(response):
"""Filtre les réponses selon la constitution"""
for principle in constitution:
if violates(response, principle):
return None
return responsePipeline complet
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import torch
class DistillationPipeline:
def __init__(self, teacher_name, student_name):
self.teacher = AutoModelForCausalLM.from_pretrained(teacher_name)
self.student = AutoModelForCausalLM.from_pretrained(student_name)
self.tokenizer = AutoTokenizer.from_pretrained(student_name)
def generate_synthetic_data(self, prompts, n_samples=10000):
"""Étape 1: Générer données avec teacher"""
data = []
for prompt in prompts[:n_samples]:
with torch.no_grad():
response = self.teacher.generate(
self.tokenizer.encode(prompt, return_tensors="pt"),
max_new_tokens=512,
)
data.append({
"prompt": prompt,
"response": self.tokenizer.decode(response[0]),
})
return data
def train_student(self, data, epochs=3):
"""Étape 2: Entraîner student"""
# ... training loop
pass
def evaluate(self, test_set):
"""Étape 3: Évaluer"""
# ... evaluation
passMétriques d'évaluation
| Métrique | Description | Cible |
|---|---|---|
| Fidelity | Similarité avec teacher | > 90% |
| Compression | Réduction de taille | 5-10x |
| Speedup | Gain de vitesse | 3-5x |
| Task perf | Performance sur tâche | > 95% du teacher |
Coûts et ROI
┌─────────────────────────────────────────────────────────────────┐
│ COÛT DE LA DISTILLATION │
├─────────────────────────────────────────────────────────────────┤
│ │
│ Investissement initial: │
│ ├── Génération dataset: $500-5000 (API calls) │
│ ├── Fine-tuning student: $100-1000 (GPU) │
│ └── Total: ~$1000-6000 │
│ │
│ Économies en production (1M req/mois): │
│ ├── GPT-4: $30,000/mois │
│ ├── Modèle distillé: $3,000/mois │
│ └── Économie: $27,000/mois │
│ │
│ ROI: 1 semaine │
│ │
└─────────────────────────────────────────────────────────────────┘Pour aller plus loin
- DistilBERT Paper - Hugging Face
- Orca Paper - Microsoft
- Phi Technical Report - Microsoft
- Alpaca - Stanford
Quantization et Optimisation d'Inférence
Techniques pour réduire la taille des modèles et accélérer l'inférence - quantization, pruning, compilation et déploiement efficace
Évaluation et Benchmarks de LLM
Méthodes et benchmarks pour évaluer les performances des grands modèles de langage - MMLU, HumanEval, MT-Bench et métriques clés