SpecSplit Architecture
Overview
SpecSplit is a disaggregated speculative decoding system that splits LLM inference across two networked GPU workers:
- Draft Worker ("The Hare") — a small, fast model that speculatively generates token trees.
- Target Worker ("The Tortoise") — a large, accurate model that verifies draft trees using tree-attention.
The system is coordinated by an Orchestrator that manages the asynchronous ping-pong loop between workers.
System Diagram
┌─────────────────────────┐
│ Orchestrator │
│ (Pipeline Coordinator) │
└────┬───────────────┬────┘
prompt + │ │ accepted tokens
context │ │ + correction
▼ ▼
┌──────────────────┐ ┌──────────────────┐
│ Draft Worker │ │ Target Worker │
│ │ │ │
│ ┌────────────┐ │ │ ┌────────────┐ │
│ │ Small LLM │ │ │ │ Large LLM │ │
│ │ (e.g. GPT-2)│ │ │ │(e.g. Llama)│ │
│ └────────────┘ │ │ └────────────┘ │
│ │ │ │
│ KV Cache ✓ │ │ Session KV ✓ │
│ Cheap GPU │ │ Expensive GPU │
└──────────────────┘ └──────────────────┘
▲ ▲
│ gRPC (proto3) │
└───────────────────────┘
Data Flow
- User prompt → Orchestrator tokenizes and sends to Draft Worker.
- Draft Worker generates a speculative token tree of depth K using autoregressive sampling with a local KV cache.
- Draft tree + session_id → sent to Target Worker via gRPC.
- Target Worker performs a single batched forward pass with tree attention
to score all candidate paths simultaneously. If a
session_idis provided, the existing KV cache for that session is reused and rolled back to the accepted prefix after verification. - Verification determines the longest accepted path:
- Greedy (temperature = 0): accept when
argmax(p_target) == draft token. - Stochastic (temperature > 0): rejection sampling over all branches via DFS;
at each node accept if
p_target ≥ p_draft, else accept with probabilityp_target / p_draft; the longest accepted path across the tree is chosen. - Accepted tokens + optional correction → returned to Orchestrator.
- Orchestrator appends accepted tokens and loops back to step 2.
- When generation completes, the Orchestrator calls
EndSessionto free the Target Worker's KV cache.
Protocol Design
The gRPC protocol (spec_decoding.proto) defines:
| Service | RPC | Purpose |
|---|---|---|
DraftService |
GenerateDrafts |
Generate speculative token trees |
DraftService |
Ping |
Health check |
TargetService |
VerifyDrafts |
Verify draft trees (with session KV cache) |
TargetService |
EndSession |
Release a session's KV cache |
TargetService |
Ping |
Health check |
Key Messages
TokenNode— recursive tree node:{token_id, log_prob, children[]}TelemetryMetadata— per-RPC timing:{span_id, wall_time_ms, model_time_ms, ...}
Component Responsibilities
| Component | Responsibility |
|---|---|
core/config.py |
Pydantic settings with env var override (SPECSPLIT_*) |
core/serialization.py |
Tensor ↔ list conversion at gRPC boundary |
core/telemetry.py |
Nanosecond-precision timing + JSON span export |
core/verification.py |
Greedy and stochastic tree verification (argmax or rejection sampling + DFS) |
workers/draft/ |
Stateful draft generation with KV cache management |
workers/target/engine.py |
Session-based KV-cached tree-attention verification |
workers/target/tree_attn.py |
Custom tree attention mask + position ID construction |
workers/target/kv_cache.py |
Pre-allocated static KV cache (O(1) rollback) |
workers/orchestrator/client.py |
CLI entry point + synchronous pipeline |
workers/orchestrator/pipeline.py |
Async overlapped draft→verify with speculation |
Design Decisions
-
Disaggregated architecture — Draft and Target workers run on separate machines/GPUs, connected over the network. This allows independent scaling and heterogeneous hardware.
-
Session-based KV caching — The Target Worker maintains a per-session KV cache (
session_id → KVCacheState) to avoid prompt recomputation across verification rounds. Sessions are LRU-evicted atmax_sessions, and caches are freed explicitly via theEndSessionRPC. -
Pre-allocated static KV cache —
StaticKVCacheinkv_cache.pyavoidstorch.catreallocation by pre-allocating key/value buffers and using slice assignment. Rollback is O(1) — a single pointer update. -
Tree-structured speculation — Instead of linear draft sequences, we generate trees (branching factor > 1) to explore multiple hypotheses in parallel, increasing acceptance rates. Custom tree attention masks (
tree_attn.py) ensure each node only attends to its ancestors. -
Greedy verification math —
verification.pyrunstorch.argmaxandtorch.eqon-device, then uses iterative DFS on a small boolean mask to find the longest accepted path. Only the final result is synced to CPU. -
Async overlapped pipeline —
pipeline.pyusesasyncio.gatherto speculatively draft round N+1 while verifying round N. On speculation hit, a full gRPC round-trip is saved. -
Pydantic configuration — All settings are type-safe, validated, and overridable via environment variables for easy deployment configuration.
-
Structured telemetry — Every RPC call generates a span with nanosecond timing, enabling distributed tracing and performance analysis.