Files
Computer-Fundamentals/machine-learning/production-topics/24.distributed-training-inference.md
T
tarun-elango 62197e52c0 ml
Co-authored-by: Copilot <copilot@github.com>
2026-04-30 19:59:29 -04:00

1412 lines
44 KiB
Markdown

# Distributed Training & Inference
Serving large models efficiently.
This handbook is written for engineers who want to understand how distributed training and distributed inference actually work in production systems. The goal is not to memorize vocabulary. The goal is to build the mental models needed to design systems, diagnose bottlenecks, make tradeoffs, and avoid expensive mistakes.
## Table of Contents
- [1. Why This Topic Exists](#1-why-this-topic-exists)
- [2. First-Principles Mental Model](#2-first-principles-mental-model)
- [3. Hardware and System Foundations](#3-hardware-and-system-foundations)
- [4. Distributed Training from First Principles](#4-distributed-training-from-first-principles)
- [5. Parallelism Strategies](#5-parallelism-strategies)
- [6. Training System Design in Production](#6-training-system-design-in-production)
- [7. Distributed Inference from First Principles](#7-distributed-inference-from-first-principles)
- [8. Serving Large Models Efficiently](#8-serving-large-models-efficiently)
- [9. Tradeoffs and Decision-Making](#9-tradeoffs-and-decision-making)
- [10. Common Mistakes Engineers Make](#10-common-mistakes-engineers-make)
- [11. Debugging and Troubleshooting](#11-debugging-and-troubleshooting)
- [12. Best Practices](#12-best-practices)
- [13. Production Scenarios and Use Cases](#13-production-scenarios-and-use-cases)
- [14. Interview-Level Understanding](#14-interview-level-understanding)
- [15. Implementation Patterns and Tooling Landscape](#15-implementation-patterns-and-tooling-landscape)
- [16. A Practical Design Walkthrough](#16-a-practical-design-walkthrough)
- [17. Failure Cases and How to Avoid Them](#17-failure-cases-and-how-to-avoid-them)
- [18. Quick Reference Checklist](#18-quick-reference-checklist)
- [19. Final Mental Model](#19-final-mental-model)
---
## 1. Why This Topic Exists
Modern models are large for two related reasons:
1. They have many parameters.
2. They process large amounts of data and context.
That creates three immediate engineering problems:
- **The model may not fit on one device.**
- **The training job may take too long on one device.**
- **The serving system may not meet latency or cost targets if every request is handled naively.**
Distributed systems solve these problems by spreading work across multiple devices, multiple machines, or both. But distributing work does not magically make things fast. It adds overhead, synchronization, scheduling complexity, network traffic, and failure modes.
The core engineering challenge is this:
> How do you split work across hardware so that the useful compute gained is larger than the coordination cost introduced?
That single question connects almost every design decision in this handbook.
---
## 2. First-Principles Mental Model
Before looking at techniques, start from the physical constraints.
### 2.1 The four limiting resources
Every large-model training or serving system is constrained by some combination of:
- **Compute**: FLOPs available on GPUs or accelerators
- **Memory capacity**: whether weights, activations, optimizer state, and KV cache fit at all
- **Memory bandwidth**: how fast data can be moved between HBM and compute units
- **Communication bandwidth and latency**: how fast devices can exchange tensors
If you forget one of these, you will make the wrong optimization.
Examples:
- A model may fit in GPU memory, but still run slowly because HBM bandwidth is the real bottleneck.
- Training may scale well on 8 GPUs in one node, but scale poorly across 64 GPUs because inter-node communication dominates.
- Inference may show low GPU utilization, but still have poor user latency because requests are waiting in a queue.
### 2.2 A useful performance equation
For training, a simplified view is:
```text
step_time ~= input_time + forward_backward_compute + communication_time + optimizer_time + idle_time
```
For serving, a simplified view is:
```text
request_latency ~= queue_time + prefill_time + decode_time + network_time
```
Most production optimization work is about shrinking one of these terms without increasing another term too much.
### 2.3 Why distributed systems are hard
Single-device programming mostly asks: "How do I use one machine efficiently?"
Distributed systems ask two harder questions:
1. **How do I divide the work?**
2. **How do I keep the divided pieces coordinated correctly and efficiently?**
That second question is where reality hits:
- gradients must be synchronized
- parameters may need to be gathered or sharded
- pipeline stages may wait for one another
- requests must be routed to the right worker
- failures on one node can stall the whole job
The best distributed design is usually not the one with the most sophisticated diagram. It is the one that minimizes coordination on the critical path.
---
## 3. Hardware and System Foundations
You cannot reason well about distributed ML without understanding the hardware stack.
### 3.1 The practical hierarchy
At a high level:
- **GPU compute units** perform matrix math
- **HBM** stores active tensors at very high bandwidth
- **Local interconnects** like NVLink move data between GPUs in the same node
- **PCIe** connects GPUs, CPUs, NICs, and storage devices
- **NICs and network fabric** like InfiniBand or Ethernet move data across nodes
- **CPU memory and storage** stage data, checkpoints, logs, and datasets
The further data travels, the more expensive it usually becomes.
That leads to a rule worth memorizing:
> Keep the hottest data as local as possible, and move it as infrequently as possible.
### 3.2 Why intra-node and inter-node matter so much
Two GPUs inside one server often communicate much faster than two GPUs on different servers. That is why many systems try to keep tightly coupled communication inside a node when possible.
Practical implication:
- Tensor parallelism often prefers GPUs with fast local links.
- Cross-node tensor parallelism can work, but it can become communication-heavy.
- Data parallelism is often easier to scale across nodes because replicas mostly compute independently until synchronization points.
### 3.3 Memory categories that matter in practice
During training, memory is consumed by:
- model weights
- activations
- gradients
- optimizer state
- temporary buffers used by kernels and communication libraries
During inference, memory is consumed by:
- model weights
- runtime workspaces
- KV cache
- batching overhead
The KV cache is often the hidden reason an apparently "small enough" model still cannot serve enough concurrent users.
### 3.4 A systems view of the stack
```mermaid
flowchart TB
subgraph Storage
DS[Datasets]
CKPT[Checkpoints]
end
subgraph CPU_Node
CPU[CPU]
RAM[System RAM]
NIC[NIC]
end
subgraph GPU_Node
G0[GPU 0]
G1[GPU 1]
G2[GPU 2]
G3[GPU 3]
end
DS --> CPU
CKPT --> CPU
CPU <--> RAM
CPU <--> G0
CPU <--> G1
CPU <--> G2
CPU <--> G3
G0 <--> G1
G1 <--> G2
G2 <--> G3
NIC <--> CPU
NIC <--> NIC2[Remote NIC]
```
Interpretation:
- Reading from storage is relatively slow.
- CPU staging matters more than many beginners expect.
- Local GPU-GPU links are precious.
- Network fabric decides whether multi-node scaling is efficient or painful.
---
## 4. Distributed Training from First Principles
Start with one training step on one GPU.
### 4.1 Single-device training step
For each batch:
1. Load input data.
2. Run forward pass.
3. Compute loss.
4. Run backward pass to compute gradients.
5. Update parameters with the optimizer.
On one GPU, this is conceptually simple because all tensors live in one place.
### 4.2 What changes when you distribute training
Once work is split, some combination of the following must happen:
- data is split across replicas
- parameters are split across devices
- layers are split across stages
- gradients are synchronized
- optimizer state is sharded or replicated
- activations or parameters are moved between devices
This means distributed training is always a balance between:
- **parallel useful compute**
- **extra communication and coordination**
### 4.3 The most important concept: synchronous data parallel training
In synchronous data parallelism, each replica holds the same model weights but processes a different microbatch.
At the end of backward pass, gradients are combined across replicas so every replica applies the same update.
Why this works:
- Each replica computes an estimate of the gradient from its own data slice.
- Combining those gradients approximates the gradient of the larger global batch.
- If all replicas apply the same aggregated gradient, model parameters stay identical.
### 4.4 Step-by-step data parallel training flow
```mermaid
sequenceDiagram
participant Loader as Data Loader
participant R0 as Rank 0
participant R1 as Rank 1
participant R2 as Rank 2
participant R3 as Rank 3
Loader->>R0: Microbatch A
Loader->>R1: Microbatch B
Loader->>R2: Microbatch C
Loader->>R3: Microbatch D
R0->>R0: Forward + backward
R1->>R1: Forward + backward
R2->>R2: Forward + backward
R3->>R3: Forward + backward
R0-->>R1: Gradient sync
R1-->>R2: Gradient sync
R2-->>R3: Gradient sync
R3-->>R0: Gradient sync
R0->>R0: Optimizer step
R1->>R1: Optimizer step
R2->>R2: Optimizer step
R3->>R3: Optimizer step
```
Important intuition:
- Data parallelism is attractive because most compute is independent.
- The main cost is synchronization.
- If synchronization becomes too expensive, scaling efficiency drops.
### 4.5 Global batch size and gradient accumulation
One of the most common sources of confusion is batch terminology.
```text
global_batch = microbatch_per_device * gradient_accumulation_steps * data_parallel_replicas
```
Example:
- microbatch per GPU = 4
- gradient accumulation steps = 8
- data parallel replicas = 16
Then:
```text
global_batch = 4 * 8 * 16 = 512
```
Why accumulation exists:
- You may want a large effective batch for optimizer stability or throughput.
- But the full batch may not fit in memory at once.
- So you process several microbatches and delay the optimizer update.
Common mistake:
- Engineers increase data parallel replicas and forget to adjust learning rate, warmup, or optimizer settings for the new global batch.
### 4.6 Communication primitives you must understand
Distributed training relies on a small set of collective operations.
| Primitive | What it does | Common use |
| --- | --- | --- |
| Broadcast | One rank sends a tensor to all others | Initial parameter sync |
| All-reduce | Sum or combine tensors and distribute result to all | Gradient synchronization |
| Reduce-scatter | Reduce tensors then shard the result across ranks | Sharded gradient handling |
| All-gather | Each rank shares its shard and all ranks reconstruct full tensor | Parameter gathering in sharded training |
| Send/recv | Point-to-point transfer | Pipeline parallel activation transfer |
A professional mental model:
- **All-reduce** is not a magical "speed up" primitive.
- It is the price you pay to keep replicas mathematically consistent.
### 4.7 Communication algorithms: ring vs tree intuition
You do not need to memorize implementation details, but you should understand the tradeoff.
- **Ring-style collectives** use bandwidth well and are common for large tensors.
- **Tree-style collectives** can reduce latency for smaller messages or specific topologies.
Real-world point:
The best algorithm depends on tensor size, topology, library implementation, and network health. Engineers often assume the math is the hard part. In production, topology and transport are often the hard part.
---
## 5. Parallelism Strategies
There is no single best strategy. Parallelism methods solve different bottlenecks.
### 5.1 Data parallelism
### How it works
Each worker has a full copy of the model and processes different data.
### Why it is useful
- simple mental model
- widely supported by frameworks
- works well when the model fits on each device
### Where it breaks down
- model no longer fits on one GPU
- gradient synchronization becomes expensive at large scale
- optimizer state replication wastes memory
### Best use
- small to medium models
- fine-tuning jobs
- multi-node scaling when replicas can stay mostly independent
### 5.2 Tensor parallelism
Tensor parallelism splits the computation of a single layer across multiple devices.
Example intuition:
- A large matrix multiplication is partitioned across GPUs.
- Each GPU computes part of the result.
- Partial results are combined through communication.
Why it exists:
- some layers are simply too large to fit on one GPU
- even if they fit, splitting may increase throughput for very large models
Tradeoff:
- compute is parallelized
- but each layer now depends on frequent communication
Important practical insight:
Tensor parallelism usually works best when participating GPUs are connected by very fast links. If tensor-parallel ranks are spread across slow network boundaries, communication can dominate each forward pass.
### 5.3 Pipeline parallelism
Pipeline parallelism splits model layers into stages.
Example:
- GPUs 0-1 hold early layers
- GPUs 2-3 hold middle layers
- GPUs 4-5 hold later layers
- GPUs 6-7 hold final layers
Microbatches are streamed through these stages like an assembly line.
Why it helps:
- enables models larger than one device
- reduces full-model replication
Main problem:
- pipeline bubbles
A bubble is idle time when some pipeline stages are waiting instead of computing.
Practical lesson:
- Balanced stage partitioning matters.
- Uneven layer cost produces idle GPUs.
- Choosing the number of microbatches is part of pipeline tuning, not just a training hyperparameter decision.
### 5.4 Pipeline schedule intuition
The common 1F1B idea means one forward pass and one backward pass are interleaved across stages after warmup.
Why it helps:
- reduces memory pressure compared with running all forward passes first
- improves utilization compared with a naive schedule
Why it is still hard:
- stage balance is difficult
- debugging stalls becomes harder
- recomputation and activation movement complicate memory reasoning
### 5.5 Sequence or context parallelism
When sequence length becomes large, you can shard work along the sequence dimension rather than only along layers or weights.
This matters for:
- long-context LLM training
- attention-heavy workloads
- models where activation memory grows strongly with sequence length
Practical point:
Long-context training often becomes an activation and communication problem before it becomes a pure parameter-storage problem.
### 5.6 Expert parallelism for mixture-of-experts models
In MoE systems, only a subset of experts is activated per token.
That changes the distributed design:
- tokens are routed to selected experts
- expert weights may be distributed across devices
- load balancing becomes critical
Common mistake:
- Engineers think MoE is automatically cheaper because not all parameters are used every time.
- In reality, token routing, imbalance, and communication can erase a lot of that theoretical savings.
### 5.7 ZeRO, FSDP, and sharded training
These methods reduce memory duplication by sharding some combination of:
- parameters
- gradients
- optimizer state
### Why sharding matters
In classic data parallel training, each replica may keep:
- full weights
- full gradients
- full optimizer state
That becomes expensive quickly, especially with Adam-style optimizers.
### The intuition behind FSDP-like systems
Instead of keeping the entire model replicated all the time, the system:
1. gathers the parameters needed for a layer
2. computes that layer
3. discards or reshares what is no longer needed
4. reduce-scatters gradients instead of fully replicating them
This saves memory, but it increases communication and runtime complexity.
Practical takeaway:
- Sharding trades memory savings for communication overhead.
- It is often the right trade for large models.
- It is not free.
### 5.8 Activation checkpointing and recomputation
Activation checkpointing saves memory by not storing every intermediate activation from forward pass. During backward pass, missing activations are recomputed.
This is one of the cleanest examples of a deliberate systems tradeoff:
- save memory
- pay extra compute
It is extremely common in large-model training because memory is often the first wall you hit.
### 5.9 Which parallelism should you choose?
```mermaid
flowchart TD
A[Does model fit on one GPU?] -->|Yes| B[Start with data parallelism]
A -->|No| C[Need model sharding]
C --> D[Is layer-wise communication cheap within node?]
D -->|Yes| E[Consider tensor parallelism]
D -->|No| F[Consider pipeline or FSDP]
E --> G[Need more memory savings?]
F --> G
G -->|Yes| H[Add sharding or offload]
G -->|No| I[Optimize batch size and overlap]
H --> J[Validate communication bottlenecks]
I --> J
```
Real engineering answer:
- Start with the simplest strategy that fits memory and time constraints.
- Add complexity only when a concrete bottleneck forces it.
---
## 6. Training System Design in Production
Parallelism strategy is only part of the job. A production training system includes data, scheduling, checkpointing, and observability.
### 6.1 Input pipeline matters more than many teams expect
Fast GPUs do not help if they are waiting for data.
Typical failure modes:
- slow dataset reads from remote storage
- expensive per-sample preprocessing on CPU
- poor shuffling implementation
- worker imbalance
- serialization overhead
Symptoms:
- low GPU utilization even though the model is correct
- step-time variance not explained by compute
- busy CPUs and idle GPUs
Practical fixes:
- pre-tokenize where possible
- cache or stage hot data locally
- increase dataloader worker efficiency carefully
- profile CPU pipeline separately from GPU kernels
### 6.2 Checkpointing is a systems feature, not a training afterthought
At scale, training jobs will fail. Nodes reboot. Networks flap. Schedulers preempt jobs. If checkpointing is weak, you lose days.
Good checkpoint design considers:
- checkpoint frequency
- write bandwidth
- sharded checkpoint format
- restore speed
- compatibility across parallelism layouts
Common mistake:
- Teams only ask whether checkpoints are being written.
- They do not test restore time, partial restore, or resuming after topology changes.
### 6.3 Fault tolerance and elasticity
Questions to ask:
- What happens if one worker disappears?
- Does the entire job restart?
- Can ranks be reformed?
- Can checkpoints resume onto a different world size?
At small scale, restart-all may be acceptable.
At large scale, restart-all can be painfully expensive.
### 6.4 Observability for training
A serious training stack needs visibility into:
- step time
- data loading time
- communication time
- GPU memory usage
- GPU utilization
- network throughput
- gradient norm health
- loss curves
- checkpoint duration
The mistake is collecting only final training loss. That is like trying to debug a distributed database by reading only the last line of the logs.
### 6.5 Cluster topology awareness
Do not treat a 64-GPU cluster as a flat pool if the topology is hierarchical.
Examples:
- 8 GPUs inside one node with fast links
- nodes connected over a slower fabric
This affects:
- rank placement
- tensor parallel group assignment
- data parallel group assignment
- communication library performance
Professional habit:
- Always know which communication stays within a node and which crosses nodes.
### 6.6 Training architecture view
```mermaid
flowchart LR
Data[Dataset Storage] --> Prep[Preprocessing and Tokenization]
Prep --> Loader[Distributed Data Loader]
Loader --> Trainer[Trainer Orchestrator]
Trainer --> W0[Worker Group 0]
Trainer --> W1[Worker Group 1]
Trainer --> W2[Worker Group 2]
W0 <--> W1
W1 <--> W2
W2 <--> W0
W0 --> CKPT[Sharded Checkpoints]
W1 --> CKPT
W2 --> CKPT
Trainer --> Obs[Metrics Logs Traces]
W0 --> Obs
W1 --> Obs
W2 --> Obs
```
---
## 7. Distributed Inference from First Principles
Training optimizes a model. Inference sells the product.
In production, serving often becomes more operationally complex than training because the workload is dynamic, latency-sensitive, and user-facing.
### 7.1 The two phases of autoregressive inference
For LLM-style serving, inference has two very different phases:
1. **Prefill**: process the prompt and build attention state
2. **Decode**: generate tokens one step at a time
These phases behave differently.
### Prefill
- more parallel work per request
- often compute-heavy
- benefits from large batch processing
### Decode
- one token at a time
- repeatedly reads model weights and KV cache
- often memory-bandwidth-limited
- sensitive to scheduling efficiency
This distinction explains many serving architectures.
### 7.2 Step-by-step: one generated token
For a request already in decode phase:
1. Read current token and request state.
2. Read relevant weights.
3. Read KV cache from previous tokens.
4. Run attention and MLP layers.
5. Produce logits.
6. Sample or select next token.
7. Append new K and V entries to cache.
8. Repeat until stop condition.
The key insight:
- decode is not just "small forward passes"
- it is repeated stateful execution with strong memory pressure
### 7.3 Throughput vs latency in serving
Serving optimization is always balancing:
- **low latency for one request**
- **high throughput across many requests**
- **acceptable cost per token**
If you batch aggressively, throughput may improve but single-request latency may worsen.
If you prioritize every request immediately, latency may improve for one user but GPU efficiency may collapse.
This is not a bug. It is the central tradeoff.
### 7.4 Why KV cache matters so much
The KV cache stores attention history so the model does not recompute everything from scratch for each generated token.
Approximate memory intuition:
```text
KV_cache_bytes ~= batch_size * context_tokens * 2 * num_layers * kv_heads * head_dim * bytes_per_element
```
What this means in practice:
- long prompts are expensive
- high concurrency is expensive
- large models are expensive
- the limiting resource in serving may be KV cache memory before raw compute becomes the bottleneck
### 7.5 Why naive batching is insufficient
If you only batch requests that arrive at exactly the same time, you waste GPU capacity.
Real systems use schedulers that:
- admit new requests into active batches
- handle different prompt lengths
- manage different generation lengths
- keep GPU work dense while respecting latency targets
That is why continuous batching became so important in modern LLM serving.
---
## 8. Serving Large Models Efficiently
This section is the operational core of large-model inference.
### 8.1 Model placement and partitioning
Questions to answer first:
- Can the model fit on one GPU?
- If yes, do we still want multiple GPUs for throughput?
- If no, do we shard by tensor, by pipeline stage, or both?
Simple rule:
- If a model fits and latency matters, keeping it local is often best.
- If it does not fit, choose the least communication-heavy sharding that meets the SLO.
### 8.2 Tensor parallel serving
Tensor parallelism during inference splits layer computation across GPUs.
Why teams use it:
- lets large models fit
- can increase throughput for very large layers
Why teams regret bad configurations:
- each token step may require inter-GPU communication
- if GPUs span slow links, token latency becomes unstable
Engineering heuristic:
- prefer tensor-parallel groups that stay within fast local topology when possible
### 8.3 Pipeline parallel serving
This is used when layers are partitioned across devices or nodes.
Pros:
- supports very large models
- can reduce per-device memory pressure
Cons:
- increased end-to-end coordination
- bubbles and stage imbalance
- more complex scheduler interactions during dynamic serving
Pipeline parallelism is often harder to operate for latency-sensitive user traffic than people expect.
### 8.4 Continuous batching
Continuous batching means the server does not wait for a fixed batch window and run it to completion as a rigid unit. Instead, requests can be inserted and retired as decoding progresses.
Why it helps:
- improves GPU occupancy
- handles heterogeneous request lengths better
- increases tokens per second under real mixed traffic
What makes it hard:
- scheduler complexity
- cache bookkeeping
- fairness between short and long requests
- handling cancellation and timeouts cleanly
### 8.5 Paged attention and memory-efficient KV cache management
A large practical problem in serving is memory fragmentation and inefficient cache layout.
Paged attention-style approaches treat KV cache more like a managed memory system instead of a giant contiguous buffer per request.
Why that matters:
- requests vary in length
- requests end at different times
- naive allocation wastes memory
- fragmentation reduces effective capacity
Production takeaway:
- Good cache management can significantly improve concurrency without changing the model itself.
### 8.6 Prefix caching and prompt reuse
If many requests share the same prompt prefix, the server may be able to reuse precomputed state.
Useful for:
- system prompts
- repeated instruction templates
- enterprise workflows with common context wrappers
- agent systems with repeated scaffolding prompts
The win:
- lower prefill cost
- better latency
- lower GPU time per request
The caution:
- cache invalidation and key correctness matter
- tokenization consistency matters
- multi-tenant isolation matters
### 8.7 Speculative decoding
Speculative decoding uses a smaller draft model or heuristic process to propose tokens, then a larger model verifies them.
Why it can help:
- the expensive model does not have to generate every token from scratch in the naive way
- overall decode throughput can improve
Why it does not always help:
- verification overhead exists
- mismatch between draft and target model can reduce acceptance rate
- infrastructure complexity increases
Good engineering question:
- What is the measured token acceptance rate and net speedup under real traffic, not just synthetic benchmarks?
### 8.8 Quantization in serving
Quantization reduces precision of weights and sometimes activations or KV cache.
Why it is so important for serving:
- lower memory footprint
- potentially higher effective throughput
- ability to fit larger models on fewer devices
Tradeoffs:
- possible quality loss
- kernel compatibility issues
- calibration and runtime implementation quality matter
Real-world pattern:
- Many serving stacks use quantization first because it attacks the memory problem directly.
- Teams should still validate downstream quality, not just perplexity or benchmark scores.
### 8.9 Disaggregated prefill and decode
Because prefill and decode stress the system differently, some architectures separate them.
Example:
- one pool optimized for prompt ingestion and heavy prefill
- another pool optimized for memory-efficient token generation
Why teams do this:
- better resource specialization
- better handling of mixed workloads
- improved scheduling under prompt-length variability
Why it is hard:
- request state transfer
- KV cache movement or reconstruction
- system complexity and debugging overhead
### 8.10 Multi-tenant serving and LoRA multiplexing
In enterprise systems, you may serve:
- one base model
- many tenant-specific adapters
- multiple QoS tiers
- a mixture of interactive and batch traffic
Then the serving problem becomes partly a scheduling and memory residency problem.
Questions that matter:
- Which adapters stay resident?
- Which are loaded on demand?
- How do you isolate noisy neighbors?
- How do you charge cost fairly across tenants?
### 8.11 Autoscaling and admission control
Autoscaling large-model serving is not like autoscaling stateless web servers.
Why:
- model load time is expensive
- warmup matters
- GPU availability is constrained
- queue growth can happen faster than new replicas become useful
That is why good serving stacks usually combine:
- predictive scaling
- warm pools
- queue-aware admission control
- request prioritization
The hard truth:
- Sometimes rejecting or degrading low-priority work is the correct engineering choice.
### 8.12 Serving architecture diagram
```mermaid
flowchart LR
Client[Client Applications] --> GW[API Gateway]
GW --> Router[Request Router]
Router --> Sched[Continuous Batch Scheduler]
Sched --> Prefill[Prefill Workers]
Sched --> Decode[Decode Workers]
Prefill <--> Cache[Prefix and KV Cache Layer]
Decode <--> Cache
Decode --> TP0[Tensor Parallel Shard 0]
Decode --> TP1[Tensor Parallel Shard 1]
TP0 <--> TP1
Decode --> Out[Token Stream Output]
Router --> Obs[Metrics Traces Logs]
Sched --> Obs
Decode --> Obs
```
### 8.13 Prefill-decode request flow
```mermaid
sequenceDiagram
participant U as User
participant R as Router
participant P as Prefill Pool
participant C as Cache
participant D as Decode Pool
U->>R: Prompt request
R->>P: Route prompt
P->>C: Build KV state
P->>D: Hand off request state
loop Each generated token
D->>C: Read and extend KV state
D->>U: Stream token
end
```
---
## 9. Tradeoffs and Decision-Making
The following tradeoffs appear repeatedly in real systems.
### 9.1 Memory vs communication
- Sharding saves memory but increases communication.
- Replication reduces communication during compute but wastes memory.
Engineering question:
- Is memory your hard wall, or is communication already your dominant cost?
### 9.2 Throughput vs latency
- Larger batches improve throughput.
- Smaller batches often improve latency.
Engineering question:
- Is this workload interactive, offline batch, or mixed?
### 9.3 Simplicity vs peak efficiency
- Simpler systems are easier to debug and operate.
- Complex systems may win benchmarks but lose reliability.
Engineering question:
- Is the incremental gain worth the new operational surface area?
### 9.4 Cost vs quality
- Smaller or quantized models are cheaper.
- Larger or higher-precision models may improve quality.
Engineering question:
- Which quality metric actually matters for the product?
### 9.5 A practical decision table
| Scenario | Usually reasonable starting point | Watch out for |
| --- | --- | --- |
| 7B model, low-latency chat | Single GPU or small tensor-parallel group, continuous batching, prefix cache | Queue growth and cache fragmentation |
| 70B model, interactive serving | Tensor parallel within node, quantization, strong scheduler, possible prefill/decode split | Cross-node latency spikes |
| Large-scale pretraining | Hybrid DP + TP + PP + sharding, checkpointing, topology-aware placement | Communication overhead and restart cost |
| Fine-tuning with limited budget | Data parallel or FSDP, gradient accumulation, checkpointing | Silent batch-size and optimizer misconfiguration |
| MoE model serving | Expert-aware routing and load balancing | Expert hotspots and communication bursts |
---
## 10. Common Mistakes Engineers Make
### 10.1 Training mistakes
- Scaling data parallelism without rethinking global batch and learning rate
- Looking only at average step time instead of step-time distribution
- Assuming low GPU utilization always means weak kernels instead of data starvation
- Ignoring topology when assigning ranks
- Treating checkpoint writing as sufficient without testing restore
- Using very complex hybrid parallelism before measuring simpler baselines
### 10.2 Inference mistakes
- Optimizing only for raw tokens per second while ignoring p99 latency
- Underestimating KV cache memory and fragmentation
- Sharding across slow links and then wondering why single-token latency is unstable
- Treating prompt-heavy and decode-heavy traffic as the same workload
- Autoscaling too late because model startup time was ignored
- Measuring only synthetic prompts instead of real production request mixes
### 10.3 Organizational mistakes
- ML teams and infra teams optimizing different metrics
- No ownership for scheduler behavior or queueing policy
- No common dashboard that combines model, system, and product metrics
---
## 11. Debugging and Troubleshooting
The best debugging mindset is to isolate whether the bottleneck is:
- compute
- memory capacity
- memory bandwidth
- communication
- data pipeline
- scheduling
- software correctness
### 11.1 Training troubleshooting table
| Symptom | Likely causes | What to check first |
| --- | --- | --- |
| Poor scaling beyond one node | Communication dominates, bad topology placement, network issues | NCCL or collectives timing, rank mapping, network counters |
| GPU OOM | Activations, optimizer state, fragmentation, batch too large | Activation size, optimizer config, checkpointing, sharding |
| Step-time spikes | Checkpoint writes, data stalls, stragglers, retries | Correlate metrics with storage, dataloader, network events |
| Low GPU utilization | Input pipeline bottleneck, CPU preprocessing, synchronization waits | Data loader timing, CPU saturation, comm overlap |
| Divergence after scaling | Effective batch changed, lr schedule mismatch, numerical instability | Global batch math, optimizer hyperparameters, precision settings |
### 11.2 Inference troubleshooting table
| Symptom | Likely causes | What to check first |
| --- | --- | --- |
| High p99 latency | Queueing, bad batch policy, long prompts, cache pressure | Queue time breakdown, prompt length distribution, scheduler logs |
| Low throughput | Small effective batches, memory stalls, poor request mix | Batch occupancy, decode efficiency, GPU memory bandwidth clues |
| OOM under traffic bursts | KV cache growth, fragmentation, too many long requests | Active context lengths, cache allocator stats, eviction policy |
| Unstable token latency | Cross-node shard communication, noisy neighbors, scheduler churn | Rank placement, interconnect traffic, tenant isolation |
| Long cold starts | Slow weight loading, slow graph compilation, no warm pool | Model load path, image caching, startup profiling |
### 11.3 A practical debugging flow
```mermaid
flowchart TD
A[Latency or throughput problem] --> B{Is queue time large?}
B -->|Yes| C[Inspect admission control, batching, autoscaling]
B -->|No| D{Is GPU memory near limit?}
D -->|Yes| E[Inspect KV cache, fragmentation, model placement]
D -->|No| F{Is inter-device communication high?}
F -->|Yes| G[Inspect sharding strategy and topology]
F -->|No| H[Inspect kernels, input pipeline, and software overhead]
C --> I[Validate real traffic mix]
E --> I
G --> I
H --> I
```
### 11.4 Debugging principles that save time
1. Break latency into named components before optimizing anything.
2. Compare one-node behavior to multi-node behavior to isolate communication cost.
3. Reproduce with controlled prompt lengths and batch sizes.
4. Separate cold-start problems from steady-state problems.
5. Validate the math of your parallel configuration before blaming the kernel.
---
## 12. Best Practices
### 12.1 Training best practices
- Start with a simple baseline and measure it carefully.
- Make topology-aware rank assignments.
- Keep strong observability for compute, memory, and network.
- Use sharded checkpoints at scale.
- Test resume paths regularly, not only during emergencies.
- Treat data pipeline profiling as first-class work.
- Document the exact formula for global batch and optimizer semantics.
### 12.2 Inference best practices
- Measure p50, p95, p99, and queue time separately.
- Profile prompt length and output length distributions from real traffic.
- Budget explicitly for KV cache, not just weights.
- Use continuous batching when request heterogeneity is high.
- Keep hot models and hot adapters warm when possible.
- Prefer simple placement that respects hardware topology.
- Use admission control to protect latency SLOs.
### 12.3 Software-hardware co-design best practices
- Match communication-heavy parallelism to fast interconnect domains.
- Match memory-saving techniques to the actual memory bottleneck.
- Match scheduler policy to workload shape, not benchmark mythology.
- Match autoscaling policy to model warmup time and queue dynamics.
---
## 13. Production Scenarios and Use Cases
### 13.1 Foundation model pretraining
Characteristics:
- massive datasets
- long-running jobs
- expensive failures
- hybrid parallelism almost always required
Primary concerns:
- throughput
- checkpoint resilience
- cluster efficiency
- numerical stability
### 13.2 Enterprise chat assistant
Characteristics:
- user-facing latency requirements
- bursty demand
- repeated prompt prefixes
- strict cost control
Primary concerns:
- p99 latency
- prompt caching
- continuous batching
- safe autoscaling
### 13.3 Retrieval-augmented generation
Characteristics:
- prompt lengths can vary widely
- retrieval latency interacts with model latency
- repeated contexts may be cacheable
Primary concerns:
- long-prefill cost
- prompt assembly efficiency
- request orchestration across systems
### 13.4 Batch offline generation
Characteristics:
- less sensitive to latency
- throughput and cost dominate
- easier to batch aggressively
Primary concerns:
- dense batching
- job scheduling
- checkpoint and retry policy for long-running batches
### 13.5 Edge or constrained deployment
Characteristics:
- limited memory
- limited power
- often weaker interconnects or none at all
Primary concerns:
- quantization
- distillation
- smaller context windows
- simpler runtime stack
---
## 14. Interview-Level Understanding
These are the kinds of questions that expose whether someone really understands the topic.
### 14.1 Why does data parallel training stop scaling perfectly?
Because each replica still has to synchronize gradients. As you add more workers, the amount of local compute per worker may shrink while synchronization and coordination remain. Eventually communication overhead, stragglers, and input pipeline inefficiency dominate.
### 14.2 Why is inference decode often memory-bound?
Because each generated token repeatedly reads model weights and KV cache while doing relatively limited work per token compared with prefill. The repeated state access and bandwidth demand can become the bottleneck even when raw compute capacity exists.
### 14.3 When would you prefer FSDP over pure tensor parallelism?
When memory replication is the primary issue and you want to reduce replicated parameters, gradients, or optimizer state. FSDP-like approaches can let larger models train without fully replicating everything, though the trade is more communication and runtime complexity.
### 14.4 Why is topology awareness important?
Because not all communication links are equal. If a communication-heavy parallel group crosses a slow boundary, performance can collapse even if total GPU count looks sufficient on paper.
### 14.5 Why is continuous batching useful for LLM serving?
Because requests arrive at different times and have different lengths. Continuous batching keeps the GPU work denser than rigid static batching while still allowing requests to enter and leave the active set dynamically.
### 14.6 What is a good engineering answer to "How do we make it faster?"
First ask: faster in what sense?
- lower single-request latency?
- higher total throughput?
- lower cost per token?
- faster training wall-clock time?
Without a metric, optimization work becomes noise.
---
## 15. Implementation Patterns and Tooling Landscape
You should know the categories even if the exact tool choice varies by company.
### 15.1 Training stack categories
- framework layer: PyTorch, JAX
- distributed runtime: DDP, FSDP, DeepSpeed, Megatron-style stacks
- communication backend: NCCL and related collectives libraries
- orchestration: Kubernetes, Slurm, Ray, custom schedulers
- storage: object stores, distributed filesystems, checkpoint services
- observability: metrics, logs, traces, profilers
### 15.2 Serving stack categories
- model servers: Triton, vLLM, TGI, TensorRT-LLM, custom runtimes
- orchestration: Kubernetes, Ray Serve, custom service meshes
- routing and queueing: API gateways, schedulers, admission control layers
- optimization layers: quantization, cache systems, speculative decoding, adapter routing
The point is not to memorize vendor names. The point is to understand what architectural role each layer plays.
---
## 16. A Practical Design Walkthrough
Suppose you must serve a 70B-class chat model with the following goals:
- interactive latency target
- strong concurrency during business hours
- prompt lengths vary widely
- budget is limited
Reasoning process:
1. Check whether the model fits on one GPU in your target precision. It likely does not.
2. Choose a tensor-parallel configuration that stays inside fast local topology as much as possible.
3. Estimate KV cache memory under expected concurrency and context lengths.
4. Add quantization if quality is acceptable and it meaningfully increases concurrency.
5. Use continuous batching because request lengths are heterogeneous.
6. Consider prefix caching if prompts share a system template.
7. Measure queue time separately from prefill and decode.
8. Add admission control before traffic spikes force pathological queue growth.
What not to do:
- Jump immediately to a very complex multi-node sharding design without measuring a simpler within-node baseline.
- Benchmark only one fixed prompt length and claim the system is production-ready.
---
## 17. Failure Cases and How to Avoid Them
### 17.1 Distributed training failure cases
- **Collective hangs**: one rank diverges in control flow or crashes before a synchronization point.
- **Checkpoint corruption or unusable format**: writes succeed, but restore fails or is too slow.
- **Scaling cliff**: training is fast up to one node, then efficiency collapses across nodes.
- **Silent batch change**: training configuration changes effective global batch and invalidates optimizer tuning.
How to avoid them:
- validate rank consistency and barriers carefully
- test restore paths routinely
- benchmark one-node and multi-node separately
- document and verify batch math in config review
### 17.2 Serving failure cases
- **Memory collapse under burst traffic**: KV cache grows faster than expected.
- **Latency tail explosion**: queueing and long prompts starve short requests.
- **Cold-start storms**: autoscaler adds replicas too late and all of them are still loading weights.
- **Noisy-neighbor behavior**: one tenant or workload shape hurts everyone else.
How to avoid them:
- enforce admission control and max context policies
- separate workload classes when needed
- maintain warm capacity for hot paths
- monitor tenant-level and request-class-level metrics
---
## 18. Quick Reference Checklist
Before designing a distributed training system, ask:
- Does the model fit on one GPU?
- Is memory or wall-clock time the main problem?
- Which tensors are replicated today?
- Which communication happens every step?
- What is the global batch exactly?
- How fast can the job recover from failure?
Before designing a large-model serving system, ask:
- Is the workload interactive, batch, or mixed?
- What are the real prompt and output length distributions?
- How much memory is reserved for KV cache?
- What is the queueing policy?
- What is the cold-start time?
- Which metric defines success: latency, throughput, or cost?
---
## 19. Final Mental Model
Distributed training and inference are not primarily about "using more GPUs." They are about matching the structure of the workload to the structure of the hardware.
If you remember only a few ideas, remember these:
1. **Every optimization is a tradeoff between compute, memory, communication, and operational complexity.**
2. **Topology matters.** Fast local links and slow remote links are not interchangeable.
3. **Serving large models is often dominated by scheduling and memory management, not just raw math throughput.**
4. **The right design depends on the workload shape, not on what looked best in someone else's benchmark.**
5. **The best engineers decompose latency and step time into parts before they try to optimize anything.**
When your mental model is strong, new tooling and new frameworks become much easier to evaluate. The names change. The constraints do not.