Draft Worker API
specsplit.workers.draft.engine
Draft Engine — autoregressive speculative token tree generation.
The DraftEngine wraps a small, fast language model and generates
speculative token trees of depth K. These trees are sent to the Target
Worker for verification.
Architecture Notes
- Session-based KV caching: Maintains a
dict[session_id, DraftCacheState]mapping that stores KV state per session. This ensures thread-safe operation when the gRPC server handles concurrent requests. - Cross-round KV cache reuse is supported: each call checks whether the prompt prefix matches the cached state (by token IDs, not just length) and extends incrementally when possible.
- Token trees are represented as nested lists of
TokenNode-like dicts for easy protobuf conversion.
TokenNode
dataclass
In-memory representation of a single node in the draft token tree.
to_dict()
Serialize to a dict (mirrors the protobuf TokenNode message).
DraftCacheState
dataclass
Per-session KV cache state for the Draft Engine.
Attributes:
| Name | Type | Description |
|---|---|---|
kv_cache |
Any
|
The HuggingFace |
cached_prompt_len |
int
|
Number of tokens encoded in the KV cache. |
cached_prompt_ids |
list[int]
|
Actual token IDs encoded in the KV cache, used to verify prefix match (not just length). |
cached_last_logits |
Tensor | None
|
Logits at the end of the cached prefix. |
DraftEngine
Autoregressive generation engine for the Draft Worker.
This class manages model loading, KV cache state, and speculative
tree generation using a real HuggingFace AutoModelForCausalLM.
Supports per-session KV caching for thread-safe concurrent requests.
When session_id is provided to generate_draft_tree(), each
session gets its own isolated cache state and threading lock.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DraftWorkerConfig | None
|
Draft worker configuration (model name, device, etc.). |
None
|
is_loaded
property
Whether the model has been loaded.
model_vocab_size
property
Vocabulary size exposed by the loaded model or tokenizer.
end_session(session_id)
Terminate a session and free its KV cache.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
session_id
|
str
|
The session to terminate. |
required |
Returns:
| Type | Description |
|---|---|
bool
|
True if the session existed and was removed. |
generate_draft_tree(prompt_ids, k=None, num_beams=None, temperature=None, session_id=None)
Generate a speculative token tree from the given prompt context.
Performs k steps of autoregressive generation using KV caching.
Each "beam" is an independent greedy/sampled chain (no beam-search
coupling). The result is a list of root TokenNode objects,
each heading a flat chain of depth k.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prompt_ids
|
list[int]
|
Tokenized prompt (list of vocabulary indices). |
required |
k
|
int | None
|
Tree depth (defaults to |
None
|
num_beams
|
int | None
|
Number of independent chains (defaults to
|
None
|
temperature
|
float | None
|
Sampling temperature (defaults to
|
None
|
session_id
|
str | None
|
Optional session ID for per-session KV cache isolation. Required for thread-safe concurrent access. |
None
|
Returns:
| Type | Description |
|---|---|
list[TokenNode]
|
A list of root-level |
list[TokenNode]
|
forest. |
load_model()
Load the draft model and tokenizer via AutoModelForCausalLM.
reset_cache(session_id=None)
Clear the KV cache (e.g., on prompt change or verification failure).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
session_id
|
str | None
|
If provided, clears only that session's cache. If None, clears the default (singleton) cache. |
None
|
specsplit.workers.draft.service
Draft Worker gRPC service bindings.
Exposes the DraftService gRPC server that wraps the DraftEngine
for network-accessible speculative generation.
DraftServiceServicer
Bases: DraftServiceServicer
gRPC servicer implementing the DraftService RPC interface.
Each RPC call is wrapped with a telemetry span for distributed tracing.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
engine
|
DraftEngine
|
The draft generation engine. |
required |
telemetry
|
TelemetryLogger | None
|
Optional telemetry logger for span collection. |
None
|
GenerateDrafts(request, context)
Handle a GenerateDrafts RPC call.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
request
|
DraftRequest
|
A |
required |
context
|
ServicerContext
|
gRPC server context. |
required |
Returns:
| Type | Description |
|---|---|
DraftResponse
|
A |
Ping(request, context)
Health check endpoint.
serve(config=None)
Start the Draft Worker gRPC server.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DraftWorkerConfig | None
|
Optional configuration override. If |
None
|