Wiki IA
LLM et IA Générative

Quantization et Optimisation d'Inférence

Techniques pour réduire la taille des modèles et accélérer l'inférence - quantization, pruning, compilation et déploiement efficace

Quantization et Optimisation d'Inférence

L'optimisation d'inférence permet de déployer des LLM sur des ressources limitées tout en maintenant des performances acceptables. Ces techniques sont essentielles pour le déploiement en production.

Pourquoi optimiser ?

┌─────────────────────────────────────────────────────────────────┐
│                    PROBLÈME DES LLM                              │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  LLaMA 70B en FP16:                                             │
│  ├── Poids: 140 GB de VRAM                                      │
│  ├── Matériel: 2x A100 80GB minimum                             │
│  ├── Coût: ~$3/heure cloud                                      │
│  └── Latence: élevée                                            │
│                                                                  │
│  LLaMA 70B quantifié 4-bit:                                     │
│  ├── Poids: ~35 GB de VRAM                                      │
│  ├── Matériel: 1x A100 ou RTX 4090                              │
│  ├── Coût: ~$1/heure cloud                                      │
│  └── Latence: réduite de 2-3x                                   │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

Quantization

La quantization réduit la précision des poids du modèle (de FP16 à INT8 ou INT4).

Types de précision

FormatBitsTaille/paramUsage
FP32324 bytesEntraînement
FP16162 bytesInférence standard
BF16162 bytesEntraînement/Inférence
INT881 byteInférence optimisée
INT440.5 byteInférence agressive

Méthodes de quantization

1. Post-Training Quantization (PTQ)

Quantification après l'entraînement, sans réentraînement.

# Exemple avec bitsandbytes
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

# Configuration 4-bit
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",  # NormalFloat4
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,  # Nested quantization
)

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-70b-hf",
    quantization_config=bnb_config,
    device_map="auto",
)

2. Quantization-Aware Training (QAT)

Entraînement avec quantification simulée pour de meilleurs résultats.

# Plus coûteux mais meilleure qualité
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./qat-model",
    # Simule la quantification pendant l'entraînement
    fp16=True,
    gradient_checkpointing=True,
)

Formats populaires

GGUF (llama.cpp)

Format optimisé pour CPU et GPU grand public.

# Télécharger un modèle GGUF
huggingface-cli download TheBloke/Llama-2-7B-GGUF \
    llama-2-7b.Q4_K_M.gguf

# Inférence avec llama.cpp
./main -m llama-2-7b.Q4_K_M.gguf \
    -p "Explique la quantization:" \
    -n 256

Nomenclature GGUF:

  • Q4_0: 4-bit basique
  • Q4_K_M: 4-bit avec K-quants, qualité moyenne
  • Q5_K_M: 5-bit, meilleur compromis
  • Q8_0: 8-bit, proche de FP16

GPTQ

Quantification 4-bit avec calibration sur dataset.

from transformers import AutoModelForCausalLM, GPTQConfig

gptq_config = GPTQConfig(
    bits=4,
    dataset="c4",  # Dataset de calibration
    tokenizer=tokenizer,
)

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    quantization_config=gptq_config,
)

AWQ (Activation-aware Weight Quantization)

Préserve les poids importants basé sur les activations.

from awq import AutoAWQForCausalLM

model = AutoAWQForCausalLM.from_quantized(
    "TheBloke/Llama-2-7B-AWQ",
    fuse_layers=True,  # Fusion de layers pour vitesse
)

Comparaison des méthodes

MéthodeQualitéVitesseFacilitéGPU requis
GGUF Q4★★★☆★★★★★★★★★Non
GPTQ★★★★★★★★★★★☆Oui
AWQ★★★★★★★★★★★★★☆Oui
bitsandbytes★★★☆★★★☆★★★★★Oui

Pruning

Le pruning supprime les connexions ou neurones peu importants.

AVANT PRUNING                    APRÈS PRUNING (50%)

●───●───●───●                   ●───●   ●───●
│ ╲ │ ╱ │ ╲ │                   │   │   │   │
●───●───●───●        ───►       ●   ●───●   ●
│ ╱ │ ╲ │ ╱ │                       │ ╲ │
●───●───●───●                   ●   ●───●   ●

100% connexions                  50% connexions

Types de pruning

import torch.nn.utils.prune as prune

# Pruning non-structuré (poids individuels)
prune.l1_unstructured(model.layer, name='weight', amount=0.3)

# Pruning structuré (neurones entiers)
prune.ln_structured(model.layer, name='weight', amount=0.3, n=2, dim=0)

Compilation et optimisation runtime

TensorRT (NVIDIA)

import tensorrt as trt
import torch_tensorrt

# Compiler pour TensorRT
trt_model = torch_tensorrt.compile(
    model,
    inputs=[torch_tensorrt.Input(shape=[1, 512])],
    enabled_precisions={torch.float16},
)

vLLM

Inférence optimisée avec PagedAttention.

from vllm import LLM, SamplingParams

llm = LLM(
    model="meta-llama/Llama-2-7b-hf",
    tensor_parallel_size=1,  # Nombre de GPUs
    quantization="awq",      # Quantization intégrée
)

sampling_params = SamplingParams(
    temperature=0.7,
    max_tokens=256,
)

outputs = llm.generate(["Explique:"], sampling_params)

Text Generation Inference (TGI)

# Démarrer un serveur TGI
docker run --gpus all -p 8080:80 \
    ghcr.io/huggingface/text-generation-inference:latest \
    --model-id meta-llama/Llama-2-7b-hf \
    --quantize bitsandbytes-nf4

Techniques complémentaires

Flash Attention

Optimise le calcul de l'attention avec moins de mémoire.

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
)

KV-Cache Optimization

# Sliding window attention (Mistral)
# Limite la mémoire du KV-cache
config = MistralConfig(
    sliding_window=4096,  # Fenêtre de contexte glissante
)

Speculative Decoding

Utilise un petit modèle pour proposer des tokens, validés par le grand modèle.

┌─────────────────────────────────────────────────────────────────┐
│                   SPECULATIVE DECODING                          │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  1. Draft model (petit) génère N tokens rapidement              │
│  2. Target model (grand) valide en parallèle                    │
│  3. Si validés → acceptés, sinon → regenerate                   │
│                                                                  │
│  Speedup: 2-3x sur génération longue                            │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

Benchmarks et métriques

Métriques clés

MétriqueDescriptionCible
PerplexityQualité du modèle< +5% vs FP16
Tokens/secVitesse de générationMaximiser
VRAMMémoire GPU utiliséeMinimiser
TTFTTime to First Token< 100ms

Exemple de benchmark

import time

def benchmark_model(model, tokenizer, prompt, n_tokens=100):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    start = time.time()
    outputs = model.generate(
        **inputs,
        max_new_tokens=n_tokens,
        do_sample=False,
    )
    elapsed = time.time() - start

    return {
        "tokens_per_sec": n_tokens / elapsed,
        "time_to_first_token": elapsed / n_tokens,  # Approximation
        "vram_gb": torch.cuda.max_memory_allocated() / 1e9,
    }

Récapitulatif

┌─────────────────────────────────────────────────────────────────┐
│                 OPTIMISATION D'INFÉRENCE                        │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  RÉDUCTION TAILLE          ACCÉLÉRATION         DÉPLOIEMENT     │
│  ┌───────────────┐        ┌─────────────┐      ┌────────────┐  │
│  │ Quantization  │        │ Flash Attn  │      │ vLLM       │  │
│  │ • INT8/INT4   │        │ • -50% mem  │      │ TGI        │  │
│  │ • GGUF/GPTQ   │        │ • +2x speed │      │ TensorRT   │  │
│  │ • AWQ         │        │             │      │            │  │
│  └───────────────┘        └─────────────┘      └────────────┘  │
│         │                        │                    │         │
│  ┌───────────────┐        ┌─────────────┐      ┌────────────┐  │
│  │ Pruning       │        │ Speculative │      │ Batching   │  │
│  │ • 30-50% less │        │ Decoding    │      │ continu    │  │
│  │ • Structured  │        │ • 2-3x      │      │            │  │
│  └───────────────┘        └─────────────┘      └────────────┘  │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

Pour aller plus loin

On this page