Skip to content

Draft Worker API

specsplit.workers.draft.engine

Draft Engine — autoregressive speculative token tree generation.

The DraftEngine wraps a small, fast language model and generates speculative token trees of depth K. These trees are sent to the Target Worker for verification.

Architecture Notes
  • Session-based KV caching: Maintains a dict[session_id, DraftCacheState] mapping that stores KV state per session. This ensures thread-safe operation when the gRPC server handles concurrent requests.
  • Cross-round KV cache reuse is supported: each call checks whether the prompt prefix matches the cached state (by token IDs, not just length) and extends incrementally when possible.
  • Token trees are represented as nested lists of TokenNode-like dicts for easy protobuf conversion.

TokenNode dataclass

In-memory representation of a single node in the draft token tree.

to_dict()

Serialize to a dict (mirrors the protobuf TokenNode message).

DraftCacheState dataclass

Per-session KV cache state for the Draft Engine.

Attributes:

Name Type Description
kv_cache Any

The HuggingFace past_key_values from the last call.

cached_prompt_len int

Number of tokens encoded in the KV cache.

cached_prompt_ids list[int]

Actual token IDs encoded in the KV cache, used to verify prefix match (not just length).

cached_last_logits Tensor | None

Logits at the end of the cached prefix.

DraftEngine

Autoregressive generation engine for the Draft Worker.

This class manages model loading, KV cache state, and speculative tree generation using a real HuggingFace AutoModelForCausalLM.

Supports per-session KV caching for thread-safe concurrent requests. When session_id is provided to generate_draft_tree(), each session gets its own isolated cache state and threading lock.

Parameters:

Name Type Description Default
config DraftWorkerConfig | None

Draft worker configuration (model name, device, etc.).

None

is_loaded property

Whether the model has been loaded.

model_vocab_size property

Vocabulary size exposed by the loaded model or tokenizer.

end_session(session_id)

Terminate a session and free its KV cache.

Parameters:

Name Type Description Default
session_id str

The session to terminate.

required

Returns:

Type Description
bool

True if the session existed and was removed.

generate_draft_tree(prompt_ids, k=None, num_beams=None, temperature=None, session_id=None)

Generate a speculative token tree from the given prompt context.

Performs k steps of autoregressive generation using KV caching. Each "beam" is an independent greedy/sampled chain (no beam-search coupling). The result is a list of root TokenNode objects, each heading a flat chain of depth k.

Parameters:

Name Type Description Default
prompt_ids list[int]

Tokenized prompt (list of vocabulary indices).

required
k int | None

Tree depth (defaults to config.max_draft_tokens).

None
num_beams int | None

Number of independent chains (defaults to config.num_beams).

None
temperature float | None

Sampling temperature (defaults to config.temperature).

None
session_id str | None

Optional session ID for per-session KV cache isolation. Required for thread-safe concurrent access.

None

Returns:

Type Description
list[TokenNode]

A list of root-level TokenNode objects forming the draft

list[TokenNode]

forest.

load_model()

Load the draft model and tokenizer via AutoModelForCausalLM.

reset_cache(session_id=None)

Clear the KV cache (e.g., on prompt change or verification failure).

Parameters:

Name Type Description Default
session_id str | None

If provided, clears only that session's cache. If None, clears the default (singleton) cache.

None

specsplit.workers.draft.service

Draft Worker gRPC service bindings.

Exposes the DraftService gRPC server that wraps the DraftEngine for network-accessible speculative generation.

DraftServiceServicer

Bases: DraftServiceServicer

gRPC servicer implementing the DraftService RPC interface.

Each RPC call is wrapped with a telemetry span for distributed tracing.

Parameters:

Name Type Description Default
engine DraftEngine

The draft generation engine.

required
telemetry TelemetryLogger | None

Optional telemetry logger for span collection.

None

GenerateDrafts(request, context)

Handle a GenerateDrafts RPC call.

Parameters:

Name Type Description Default
request DraftRequest

A DraftRequest protobuf message.

required
context ServicerContext

gRPC server context.

required

Returns:

Type Description
DraftResponse

A DraftResponse protobuf message with the generated tree.

Ping(request, context)

Health check endpoint.

serve(config=None)

Start the Draft Worker gRPC server.

Parameters:

Name Type Description Default
config DraftWorkerConfig | None

Optional configuration override. If None, reads from environment variables.

None