# Building a Mini vLLM from Scratch: A Deep Dive into LLM Inference Optimization Ever wondered what happens under the hood when you run an LLM inference engine like vLLM? I sure did. So I built **nano-vllm** - a minimalistic, educational implementation of a high-performance LLM inference engine from scratch. Think of it as "vLLM for dummies" (myself included). This blog post walks through what I learned and how each optimization technique works. Buckle up - we're going deep! ## The Problem: Why LLM Inference is Tricky Here's the thing about running LLMs: it's not just about doing matrix multiplications. The naive approach of running one request at a time wastes an embarrassing amount of GPU memory and compute. Let me explain why. When an LLM generates text, it works in two phases: 1. **Prefill**: Process the entire prompt at once (compute-bound) 2. **Decode**: Generate tokens one at a time (memory-bound) The decode phase is where things get interesting. For each new token, the model needs to look at ALL previous tokens through the attention mechanism. Without caching, you'd recompute the same thing over and over. Enter the KV cache. But here's the kicker: pre-allocating memory for the KV cache based on the maximum possible sequence length is wildly wasteful. If you have a max length of 2048 tokens but your actual sequence is only 100 tokens, you're wasting 95% of that memory! This is the problem vLLM solved with PagedAttention, and that's what I implemented in nano-vllm. ## Architecture Overview Here's how nano-vllm is organized: ``` nano_vllm/ ├── engine.py # Main inference engine ├── config.py # Model configuration ├── cache.py # KV cache implementations ├── sampler.py # Token sampling ├── core/ │ ├── sequence.py # Request tracking │ ├── scheduler.py # Batch scheduling with priorities │ ├── block.py # Memory blocks for PagedAttention │ └── block_manager.py # Block allocation (like an OS memory manager) ├── attention/ │ ├── paged_attention.py # PagedAttention kernel │ └── flash_attention.py # FlashAttention integration ├── speculative/ │ └── speculative_decoding.py # Draft model speculation ├── educational/ # Learn-by-watching modes │ ├── narrator.py # Plain-English explanations │ ├── xray.py # Tensor visualizations │ └── dashboard.py # Live terminal UI └── model/ ├── loader.py # HuggingFace model loading └── llama.py # Llama implementation (RMSNorm, RoPE, GQA, SwiGLU) ``` Now let's dive into each major optimization! --- ## 1. PagedAttention: The Heart of vLLM ### What's the Problem? Traditional KV cache allocation is like reserving an entire movie theater for one person "just in case" they bring 1999 friends. Wasteful, right? In the naive approach, you pre-allocate a contiguous chunk of memory for each sequence based on the maximum possible length. This leads to: - **Memory fragmentation**: Different sequences finish at different times, leaving holes - **Memory waste**: Most sequences never reach max length - **Limited batch size**: Can't fit as many requests in GPU memory ### How PagedAttention Solves It PagedAttention borrows ideas from operating system virtual memory. Instead of contiguous allocation, it divides the KV cache into fixed-size **blocks** (like memory pages): ```python # From core/block.py @dataclass class Block: """A fixed-size chunk of KV cache memory. Each block stores KV states for `block_size` tokens. """ block_id: int block_size: int = 16 # 16 tokens per block ref_count: int = 1 # For sharing (prefix caching) prefix_hash: Optional[int] = None ``` Each sequence gets a **BlockTable** - a mapping from logical positions to physical blocks: ```python # From core/block.py @dataclass class BlockTable: """Maps logical positions to physical blocks. Like a page table in virtual memory: - Token at position p is in logical block: p // block_size - Slot within block: p % block_size - Physical block: block_ids[p // block_size] Example (block_size=16, sequence has 35 tokens): block_table.block_ids = [5, 12, 3] # 3 physical blocks Token 0-15 -> block 5 Token 16-31 -> block 12 Token 32-34 -> block 3 (slots 0-2) """ block_ids: List[int] block_size: int = 16 ``` The **BlockManager** handles allocation like an OS manages memory: ```python # From core/block_manager.py class BlockManager: """Manages allocation of KV cache blocks. Uses a simple free list (stack) for O(1) allocation/deallocation. """ def allocate_block(self) -> int: if not self.free_blocks: raise RuntimeError("Out of KV cache blocks!") return self.free_blocks.pop() def free_block(self, block_id: int) -> None: block = self.blocks[block_id] if block.decrement_ref() <= 0: self.free_blocks.append(block_id) ``` ### The Paged Attention Computation When computing attention, we gather K and V from non-contiguous blocks: ```python # From attention/paged_attention.py def paged_attention( query: torch.Tensor, key_cache: torch.Tensor, # [num_blocks, block_size, num_kv_heads, head_dim] value_cache: torch.Tensor, block_tables: List[BlockTable], context_lens: List[int], block_size: int, num_kv_heads: int, ) -> torch.Tensor: # Gather from blocks for each sequence for batch_idx in range(batch_size): block_table = block_tables[batch_idx] for pos in range(context_len): logical_block = pos // block_size slot_in_block = pos % block_size physical_block = block_table.block_ids[logical_block] # Copy from cache gathered_keys[batch_idx, :, pos, :] = key_cache[physical_block, slot_in_block] gathered_values[batch_idx, :, pos, :] = value_cache[physical_block, slot_in_block] # Standard attention computation attn_weights = torch.matmul(query, gathered_keys.transpose(-2, -1)) * scale # ... apply masking, softmax, and compute output ``` ### Why This Matters PagedAttention enables: - **Near-zero memory waste**: Only allocate what you need - **Memory sharing**: Common prefixes can share blocks (prefix caching) - **Higher throughput**: Fit more requests in memory = more parallelism --- ## 2. Continuous Batching: No More Waiting Around ### The Traditional Batching Problem Old-school batching waits for ALL sequences in a batch to finish before starting new ones. If you have: - Request A: 50 tokens to generate - Request B: 5 tokens to generate Request B finishes fast but has to wait for Request A. The GPU sits idle! ### Continuous Batching Solution nano-vllm schedules at **iteration granularity**: - New requests join mid-batch - Completed requests leave immediately - GPU stays busy Here's the scheduler in action: ```python # From core/scheduler.py class Scheduler: """Manages sequences through their lifecycle: - WAITING: In queue - RUNNING: Being processed - SWAPPED: Preempted - FINISHED: Done """ def schedule(self) -> SchedulerOutputs: outputs = SchedulerOutputs() # 1. Handle preemption if high-priority request waiting if self.enable_preemption and self.block_manager: self._handle_preemption(outputs) # 2. Continue running sequences (decode) for seq in self.running: if seq.is_chunked_prefill(): outputs.chunked_prefill_sequences.append(seq) else: outputs.decode_sequences.append(seq) # 3. Admit new sequences from waiting queue while can_admit_more(): seq = self._pop_waiting() seq.status = SequenceStatus.RUNNING outputs.prefill_sequences.append(seq) return outputs ``` The engine processes these in one iteration: ```python # From engine.py def step(self) -> List[GenerationOutput]: """One iteration of continuous batching.""" scheduler_outputs = self.scheduler.schedule() # Process chunked prefills for seq, num_tokens in zip(chunked_prefill_seqs, chunked_prefill_tokens): self._run_chunked_prefill(seq, num_tokens) # Process full prefills (new sequences) for seq in prefill_sequences: self._run_prefill(seq) # Process decodes (batched together!) if decode_sequences: self._run_decode(decode_sequences) # Return completed sequences return newly_finished ``` --- ## 3. Priority Scheduling & Preemption Sometimes you want VIP treatment for certain requests. nano-vllm supports: ### Priority-Based Scheduling Requests have priorities. Higher priority = processed first: ```python # From core/scheduler.py def _get_priority_key(self, seq: Sequence) -> Tuple[int, float, int]: """Priority key for heap ordering. Lower tuple = higher priority.""" # Negate priority so higher values come first return (-seq.priority, seq.arrival_time, seq.seq_id) # Using a heap for O(log n) scheduling heapq.heappush(self._waiting_heap, (priority_key, sequence)) ``` ### Preemption: Kicking Out Low-Priority Requests When a high-priority request arrives but there's no memory, we can **preempt** low-priority running requests: ```python # From core/scheduler.py def _handle_preemption(self, outputs): """Preempt low-priority sequences for high-priority waiting.""" highest_waiting = self._peek_waiting() while not self.block_manager.can_allocate(blocks_needed) and self.running: # Find lowest priority running sequence lowest_running = min(self.running, key=lambda s: s.priority) if highest_waiting.priority > lowest_running.priority: # Preempt! Free blocks and reset for recompute self.running.remove(lowest_running) self.block_manager.free_sequence_blocks(lowest_running.block_table) lowest_running.reset_for_recompute() self._push_waiting(lowest_running) ``` The preempted sequence goes back to waiting and will be re-prefilled later. It's recompute-based preemption (vs. swapping to CPU memory) - simpler and works well in practice. --- ## 4. Prefix Caching: Share Common Prefixes Many requests start with the same system prompt. Why recompute the same KV cache? ### How It Works Blocks are hashed based on their token content AND their position in the sequence: ```python # From core/block.py def hash_token_block(token_ids: Tuple[int, ...], parent_hash: Optional[int] = None) -> int: """Cumulative hash including the entire prefix chain. This ensures blocks are only shared when the ENTIRE prefix matches. """ if parent_hash is None: return hash(token_ids) return hash((parent_hash, token_ids)) ``` When a new sequence arrives, we check if its prefix blocks already exist: ```python # From core/block_manager.py def allocate_blocks_with_prefix_caching(self, token_ids: List[int]): """Allocate blocks, reusing cached prefix blocks when possible.""" parent_hash = None for block_idx in range(num_full_blocks): block_tokens = tuple(token_ids[start:end]) cache_key = (parent_hash, block_tokens) if cache_key in self.prefix_cache: # Cache hit! Reuse existing block cached_block_id = self.prefix_cache[cache_key] self.blocks[cached_block_id].increment_ref() # Reference counting block_table.append_block(cached_block_id) else: # Cache miss - allocate new block block_id = self.allocate_block() self.prefix_cache[cache_key] = block_id block_table.append_block(block_id) parent_hash = self.blocks[block_id].prefix_hash return block_table, shared_prefix_len ``` Reference counting ensures blocks aren't freed while still in use by other sequences. --- ## 5. Chunked Prefill: Don't Block on Long Prompts A long prompt (say, 4000 tokens) can block the entire batch during prefill. Chunked prefill breaks it into smaller pieces: ```python # From engine.py def _run_chunked_prefill_paged(self, seq: Sequence, num_tokens: int): """Process a chunk of prompt tokens.""" start_pos = seq.num_prefilled_tokens end_pos = start_pos + num_tokens chunk_tokens = seq.prompt_token_ids[start_pos:end_pos] # Allocate blocks for this chunk # ... # Forward pass for this chunk only logits = self.model(input_ids, block_kv_cache=..., start_positions=[start_pos]) # Update progress seq.num_prefilled_tokens = end_pos # Only sample when ALL prompt tokens are processed if seq.num_prefilled_tokens >= len(seq.prompt_token_ids): next_token = self.sampler.sample(logits) seq.append_token(next_token.item()) ``` The scheduler controls how many tokens to prefill per iteration: ```python # max_prefill_tokens limits compute per iteration if prompt_len <= prefill_budget: outputs.prefill_sequences.append(seq) # Full prefill else: outputs.chunked_prefill_sequences.append(seq) # Partial outputs.chunked_prefill_tokens.append(prefill_budget) ``` --- ## 6. FlashAttention: Memory-Efficient Attention Standard attention materializes the full N×N attention matrix. For a 2048-token sequence, that's 4 million elements! FlashAttention uses **tiling** to avoid this. ### Integration in nano-vllm ```python # From attention/flash_attention.py def flash_attention(query, key, value, causal=True): """Use FlashAttention for O(N) memory instead of O(N^2).""" # FlashAttention expects: [batch, seq_len, num_heads, head_dim] query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) output = flash_attn_func(query, key, value, causal=causal) return output.transpose(1, 2) # Unified interface with fallback def attention(query, key, value, use_flash_attn=True, causal=True): if use_flash_attn and FLASH_ATTN_AVAILABLE: return flash_attention(query, key, value, causal) # Fallback to PyTorch SDPA (also optimized!) return F.scaled_dot_product_attention(query, key, value, is_causal=causal) ``` FlashAttention is used in the model's attention layers: ```python # From model/llama.py class LlamaAttention(nn.Module): def __init__(self, config, layer_idx, use_flash_attn=True): self.use_flash_attn = use_flash_attn and is_flash_attn_available() def forward(self, hidden_states, ...): # ... compute Q, K, V and apply RoPE ... # Use unified attention (FlashAttention if available) attn_output = unified_attention( query=query_states, key=key_states, value=value_states, use_flash_attn=self.use_flash_attn, causal=True, ) ``` --- ## 7. Speculative Decoding: Draft and Verify Decoding is slow because we generate one token at a time. What if we could generate multiple tokens per forward pass of the big model? ### The Idea 1. Use a small, fast **draft model** to generate K candidate tokens 2. The big **target model** verifies all K+1 positions in ONE forward pass 3. Accept tokens that match, reject and resample where they don't ```python # From speculative/speculative_decoding.py def _speculative_step(self, current_ids, target_kv_cache, draft_kv_cache, remaining_tokens): """One speculative decoding step.""" K = self.config.num_speculative_tokens # Step 1: Generate K draft tokens (cheap!) draft_tokens, draft_probs = self._generate_draft_tokens(current_ids, draft_kv_cache, K) # Step 2: Verify with target model (ONE forward pass for K+1 tokens!) verify_ids = [[current_ids[-1]] + draft_tokens] target_logits = self.target_model(verify_ids, kv_cache=target_kv_cache) target_probs = F.softmax(target_logits, dim=-1) # Step 3: Accept/reject using rejection sampling accepted_tokens = [] for i, draft_token in enumerate(draft_tokens): target_prob = target_probs[0, i, draft_token].item() draft_prob = draft_probs[i] # Accept if target prob >= draft prob (maintains target distribution!) acceptance_prob = min(1.0, target_prob / draft_prob) if random() < acceptance_prob: accepted_tokens.append(draft_token) else: # Resample from adjusted distribution resampled = sample_from_adjusted(target_probs[0, i], draft_prob, draft_token) accepted_tokens.append(resampled) break # Stop after first rejection # If all accepted, sample one bonus token! if len(accepted_tokens) == len(draft_tokens): bonus_token = sample(target_probs[0, -1]) accepted_tokens.append(bonus_token) return accepted_tokens ``` ### The Magic: No Quality Loss This is rejection sampling - it mathematically guarantees the output distribution is identical to the target model. We're not approximating anything! The speedup depends on: - Draft model speed (should be ~10x faster than target) - Acceptance rate (higher = more tokens per target forward pass) - Value of K (more speculation = more potential gain) --- ## 8. The Llama Model Implementation nano-vllm includes a from-scratch Llama implementation with all the modern bells and whistles: ### RMSNorm (Instead of LayerNorm) ```python # From model/llama.py class RMSNorm(nn.Module): """Root Mean Square Normalization - simpler than LayerNorm.""" def forward(self, x): rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) return x / rms * self.weight ``` ### Rotary Position Embeddings (RoPE) ```python # From model/llama.py def apply_rotary_pos_emb(q, k, cos, sin): """Encode positions by rotating Q and K vectors. The rotation formula: q_rotated = q * cos + rotate_half(q) * sin This lets the model learn relative positions through dot products. """ q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed ``` ### Grouped Query Attention (GQA) ```python # From model/llama.py class LlamaAttention(nn.Module): """GQA: Fewer KV heads than Q heads, saving memory.""" def __init__(self, config): self.num_heads = config.num_attention_heads # e.g., 32 self.num_kv_heads = config.num_key_value_heads # e.g., 8 self.num_kv_groups = self.num_heads // self.num_kv_heads # = 4 # Q projection is larger than K,V projections self.q_proj = nn.Linear(hidden, num_heads * head_dim) self.k_proj = nn.Linear(hidden, num_kv_heads * head_dim) # Smaller! self.v_proj = nn.Linear(hidden, num_kv_heads * head_dim) ``` ### SwiGLU MLP ```python # From model/llama.py class LlamaMLP(nn.Module): """SwiGLU: output = down(silu(gate(x)) * up(x))""" def forward(self, x): gate = F.silu(self.gate_proj(x)) # Swish activation up = self.up_proj(x) return self.down_proj(gate * up) # Gated linear unit ``` --- ## 9. Educational Modes: Learn by Watching One of my favorite features! nano-vllm includes educational modes that explain what's happening during inference: ### Narrator Mode Provides plain-English commentary like watching surgery with an expert: ```bash python -m nano_vllm.cli --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ --prompt "The capital of France is" --narrate ``` Output: ``` ═══════════════════════════════════════════════════════════════════ INFERENCE ANATOMY - Educational Mode ═══════════════════════════════════════════════════════════════════ Prompt: "The capital of France is" Model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 ═════ ACT 1: TOKENIZATION ═════ Converting your prompt into numbers the model understands... "The capital of France is" ↓ Tokenizer (BPE algorithm) [The] [capital] [of] [France] [is] → [450, 7483, 310, 3444, 338] ═════ ACT 2: PREFILL PHASE ═════ The model reads your entire prompt at once... Processing 5 tokens through 22 layers ✓ Parallel computation (all tokens at once) ✓ Building the KV cache ═════ ACT 3: DECODE PHASE ═════ Now generating one token at a time... Step 1: Predicting token #6 │ Top 5 predictions: │ Paris ████████████████████ 82.3% │ the ███ 7.1% │ located ██ 4.2% └── Sampled: "Paris" (82.3%) ``` ### X-Ray Mode Shows tensor shapes and mathematical operations: ```bash python -m nano_vllm.cli --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ --prompt "Hello" --xray ``` ### Dashboard Mode A live terminal UI showing real-time progress (requires `rich`): ```bash python -m nano_vllm.cli --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ --prompt "Hello" --dashboard ``` ### Interactive Tutorial Step-by-step learning experience: ```bash python -m nano_vllm.cli --tutorial ``` --- ## Quick Start ### Installation ```bash pip install -e . # Optional: FlashAttention (for faster inference) pip install flash-attn --no-build-isolation ``` ### Basic Usage ```bash # Single prompt python -m nano_vllm.cli --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ --prompt "Hello, world" # Multiple prompts (continuous batching) python -m nano_vllm.cli --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ --prompt "The capital of France is" \ --prompt "The largest planet is" \ --prompt "Python is a" # Priority scheduling python -m nano_vllm.cli --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ --prompt "Low priority task" --priority 1 \ --prompt "High priority task" --priority 10 # Speculative decoding python -m nano_vllm.speculative.cli \ --target-model TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ --draft-model TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ --prompt "The future of AI is" \ --num-speculative-tokens 5 ``` ### Python API ```python from nano_vllm.engine import LLMEngine engine = LLMEngine( model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0", use_paged_attention=True, enable_prefix_caching=True, use_flash_attn=True, ) # Single generation output = engine.generate("What is machine learning?", max_tokens=100) # Batch generation with priorities engine.add_request("Prompt 1", max_tokens=50, priority=1) engine.add_request("Prompt 2", max_tokens=50, priority=10) # Higher priority outputs = engine.run_to_completion() ``` --- ## What I Learned Building nano-vllm taught me: 1. **Memory is the bottleneck** - Most LLM inference optimizations are about memory, not compute 2. **OS concepts apply** - PagedAttention is literally virtual memory for KV cache 3. **Batching is nuanced** - Continuous batching is way more complex than "put things in a batch" 4. **Speculation is powerful** - Getting multiple tokens per expensive forward pass is huge 5. **The devil is in the details** - Causal masking, position IDs, reference counting... endless edge cases ## What's Next Still on the TODO list: - [ ] CUDA graphs for reduced kernel launch overhead - [ ] Tensor parallelism for multi-GPU - [ ] Quantization (AWQ, GPTQ) - [ ] OpenAI-compatible API server ## Resources - [vLLM Paper](https://arxiv.org/abs/2309.06180) - [vLLM Repository](https://github.com/vllm-project/vllm) - [PagedAttention Blog](https://blog.vllm.ai/2023/06/20/vllm.html) - [FlashAttention Paper](https://arxiv.org/abs/2205.14135) - [Speculative Decoding Paper](https://arxiv.org/abs/2211.17192) --- If you found this helpful, give the repo a star! And if you spot bugs or have suggestions, PRs are welcome. Happy inferencing!