Orchestrator API
specsplit.workers.orchestrator.pipeline
Overlapped Async Orchestrator for Disaggregated Speculative Decoding.
The key insight of speculative decoding is that latency is dominated by network round-trips (Draft → Orchestrator → Target → Orchestrator) and the Target model's forward pass. This module hides that latency by overlapping operations:
While TargetService.VerifyDrafts(Tree N) is in flight:
→ Speculatively fire DraftService.GenerateDrafts(Tree N+1)
assuming the longest branch of Tree N will be accepted.
If the assumption is correct (high acceptance), Tree N+1 is ready immediately. If the assumption is wrong (rejection), Tree N+1 is discarded and we re-draft from the corrected context.
Pipeline Architecture
::
┌─────────────┐ Tree N ┌─────────────┐
│ Orchestrator │─────────────▶│ Target Svc │ ← VerifyDrafts(N)
│ │ │ │
│ (async) │ Tree N+1 ┌─────────────┐
│ │─────────────▶│ Draft Svc │ ← GenerateDrafts(N+1)
│ │ │ (specul.) │ (runs concurrently!)
└──────────────┘ └──────────────┘
Time ──▶
│ verify(N) ████████████████████│
│ draft(N+1) ████████│ │ ← overlapped!
│ │ verify(N+1) ██████████████│
│ │ draft(N+2) ██████│ │
The draft(N+1) call runs during verify(N), saving one full
round-trip per iteration when the speculation is correct.
PipelineResult
dataclass
Final result of the speculative decoding pipeline.
Attributes:
| Name | Type | Description |
|---|---|---|
output_tokens |
list[int]
|
The generated token IDs (excluding the prompt). |
total_rounds |
int
|
Number of draft→verify rounds executed. |
acceptance_rate |
float
|
Overall acceptance rate across all rounds. |
speculation_hit_rate |
float
|
Fraction of N+1 speculations that were correct. |
wall_time_ms |
float
|
Total wall-clock time. |
telemetry |
list[dict[str, Any]]
|
Collected telemetry spans. |
SpeculativeState
dataclass
Tracks the current state of the speculative pipeline.
Attributes:
| Name | Type | Description |
|---|---|---|
generated_tokens |
list[int]
|
All tokens generated so far (accepted + bonus). |
prompt_ids |
list[int]
|
The original prompt token IDs. |
total_rounds |
int
|
Number of verify rounds completed. |
total_accepted |
int
|
Total draft tokens accepted across all rounds. |
total_path_depth |
int
|
Sum of longest-path depths across all rounds (denominator for acceptance rate). |
total_tree_nodes |
int
|
Total draft tree nodes generated (for diagnostics). |
speculation_hits |
int
|
Number of times N+1 speculation was correct. |
speculation_misses |
int
|
Number of times N+1 was discarded. |
is_finished |
bool
|
Whether generation has reached a stop condition. |
run_speculative_loop_async(draft_stub, target_stub, prompt_ids, config=None, session_id='default', eos_token_id=None, vocab_bridge=None, telemetry=None)
async
Run the overlapped speculative decoding loop.
This is the main entry point for the async pipeline. It generates
tokens by repeatedly:
1. Firing VerifyDrafts(Tree N) as an async task.
2. Speculatively firing GenerateDrafts(Tree N+1) assuming
the longest branch of Tree N is accepted.
3. Awaiting verification.
4. If speculation was correct: use Tree N+1 directly.
5. If speculation was wrong: discard Tree N+1, flush Draft cache,
re-draft from corrected context.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
draft_stub
|
Any
|
gRPC stub for the Draft Service (sync or async). |
required |
target_stub
|
Any
|
gRPC stub for the Target Service (sync or async). |
required |
prompt_ids
|
list[int]
|
Tokenized input prompt. |
required |
config
|
OrchestratorConfig | None
|
Pipeline configuration (defaults, timeouts, limits). |
None
|
session_id
|
str
|
Session ID for KV cache reuse on the Target Worker. |
'default'
|
eos_token_id
|
int | None
|
Token ID that signals end-of-generation. |
None
|
vocab_bridge
|
Any | None
|
Optional VocabBridge mapping token IDs if draft/target differ. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
A |
PipelineResult
|
class: |
PipelineResult
|
and timing information. |
Example::
import asyncio
result = asyncio.run(run_speculative_loop_async(
draft_stub=draft_channel,
target_stub=target_channel,
prompt_ids=[101, 2003, 1037],
))
print(f"Generated {len(result.output_tokens)} tokens "
f"in {result.total_rounds} rounds "
f"({result.acceptance_rate:.1%} accepted)")
specsplit.workers.orchestrator.client
Orchestrator — manages the async draft→verify ping-pong pipeline.
The Orchestrator is the user-facing entry point. It sends prompts to the
Draft Worker, forwards the resulting token trees to the Target Worker for
verification, and iterates until the maximum output length or round limit
is reached.
Usage::
python -m specsplit.workers.orchestrator.client --prompt "Once upon a time"
Orchestrator
Manages the speculative decoding pipeline between Draft and Target workers.
The orchestrator runs a loop
- Send prompt context to Draft Worker → receive draft tree.
- Forward draft tree to Target Worker → receive accepted tokens.
- Append accepted tokens to the output.
- If a correction token was sampled, append it and reset draft cache.
- Repeat until
max_output_tokensormax_roundsis reached.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
OrchestratorConfig | None
|
Orchestrator configuration (addresses, timeouts, limits). |
None
|
model_name
|
str
|
HuggingFace model name for the tokenizer. Defaults to
|
'gpt2'
|
chat_session()
Create a new stateful ConversationSession.
close()
async
Close gRPC channels and release resources.
connect()
Establish async gRPC channels to Draft and Target workers.
export_telemetry(path=None)
Export the most recent run report to a JSON file.
run(prompt)
Run the pipeline and return the generated text.
Thin wrapper around :meth:run_with_result_sync for callers that only
need the output string.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prompt
|
str
|
The user's input text prompt. |
required |
Returns:
| Type | Description |
|---|---|
str
|
The generated output text. |
run_with_result(prompt, session_id=None)
async
Run the full speculative decoding pipeline for a given prompt.
Tokenizes the prompt, executes the async speculative loop over gRPC, and decodes the resulting tokens back to a string.
Issue 7: Generates a unique session ID per call when KV caching
is enabled, preventing cross-prompt cache pollution. Sends
EndSession RPC in a finally block to prevent leaks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prompt
|
str
|
The user's input text prompt. |
required |
session_id
|
str | None
|
Optional caller-supplied session ID. If not provided and KV caching is enabled, a unique ID is generated automatically. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[str, PipelineResult]
|
A tuple of (generated output text, PipelineResult with full metrics). |
run_with_result_sync(prompt)
Synchronous wrapper around :meth:run_with_result.
Creates a new event loop and runs the async method to completion. Use this from non-async callers (CLI, benchmarks, etc.).
When called from an async context (FastAPI, Jupyter), runs the pipeline in a dedicated thread with its own loop to avoid "event loop already running" deadlock.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prompt
|
str
|
The user's input text prompt. |
required |
Returns:
| Type | Description |
|---|---|
tuple[str, PipelineResult]
|
A tuple of (generated output text, PipelineResult with full metrics). |
ConversationSession
A stateful conversation session for multi-turn interactions.
Maintains accumulated token IDs across multiple generate() turns
to avoid O(n^2) re-tokenization. Automatically manages the session ID
and KV cache cleanup on the Target Worker via context manager.
end()
End the session and explicitly flush the Target Worker's KV cache.
generate(user_prompt)
Sync generation for the next turn in the conversation.
generate_async(user_prompt)
async
Async generation for the next turn in the conversation.
main()
CLI entry point for the orchestrator.