Target Worker API
specsplit.workers.target.engine
Target Engine — session-based KV-cached tree-attention verification.
The TargetEngine wraps the large, accurate language model. Given a draft
token tree from the Draft Worker, it performs a forward pass using tree
attention to verify which draft tokens are accepted under the target
distribution.
Architecture Notes
- Session-based KV caching: Maintains a
dict[session_id, KVCacheState]mapping that stores pre-allocated :class:StaticKVCacheper session. This eliminates prompt recomputation andtorch.catreallocation across verification rounds within the same generation session. - After verification,
rollback_cacheusesStaticKVCache.rollback()(O(1) pointer update) for linear chains, orStaticKVCache.compact()for branching trees, to crop the cache to the accepted prefix. - Uses greedy verification: for each position in the tree, if
argmax(target_logits)matches the drafted token, it is accepted.
CacheDesyncError
Bases: RuntimeError
Raised when target cache length disagrees with orchestrator's expected_prefix_length.
The orchestrator must retry with full prompt_token_ids to rebuild the cache.
VerificationResult
dataclass
Result of verifying a draft tree against the target model.
acceptance_rate
property
Fraction of draft tokens accepted (0.0-1.0).
Uses num_draft_tokens (the total draft candidates presented)
as the denominator for a meaningful acceptance rate.
KVCacheState
dataclass
Per-session KV cache state stored on the Target Worker.
Wraps a :class:StaticKVCache providing O(1) rollback and
zero-copy slice-assignment appends.
Attributes:
| Name | Type | Description |
|---|---|---|
cache |
VirtualKVCache | None
|
The pre-allocated static KV cache for this session,
or |
seq_len |
int
|
The current cached sequence length (number of tokens whose KV projections are stored). |
next_root_logit |
Tensor | None
|
The logit predicting the token directly following the accepted sequence. Maintained to perfectly align the target model upon a cache hit without duplicate forwarding. |
TargetEngine
Session-aware tree-attention verification engine for the Target Worker.
Maintains a dictionary of session_id → KVCacheState to reuse KV
projections across verification rounds within the same generation
session. After each verification, the cache is rolled back to the
longest accepted prefix via rollback_cache().
The KV cache uses :class:StaticKVCache — a pre-allocated, pointer-
based cache that provides O(1) rollback and zero-copy appends, compared
to HuggingFace's default DynamicCache which requires torch.cat
on every append.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
TargetWorkerConfig | None
|
Target worker configuration (model name, device, etc.). |
None
|
active_sessions
property
Number of active sessions with cached KV state.
is_loaded
property
Whether the model has been loaded.
model_vocab_size
property
Vocabulary size exposed by the loaded model or tokenizer.
end_session(session_id)
Terminate a session and free its KV cache from memory.
Thread-safe: acquires the session lock before destroying the cache
to ensure no active verify_draft_tree thread is using the
tensors being deleted (Issue 10).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
session_id
|
str
|
The session to terminate. |
required |
Returns:
| Type | Description |
|---|---|
bool
|
True if the session existed and was removed, False otherwise. |
get_or_create_session(session_id)
Retrieve an existing session cache or create a new one.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
session_id
|
str
|
Unique identifier for the generation session. |
required |
Returns:
| Type | Description |
|---|---|
KVCacheState
|
A |
bool
|
if an existing cache was found. |
has_session(session_id)
Return True if a session cache currently exists.
has_usable_session(session_id)
Return True if a session exists and currently stores a usable prefix.
load_model()
Load the target model and tokenizer via AutoModelForCausalLM.
purge_stale_sessions(ttl_seconds=None)
Remove sessions that have not been accessed within the TTL.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ttl_seconds
|
float | None
|
Override the configured TTL (for testing). |
None
|
Returns:
| Type | Description |
|---|---|
int
|
Number of purged sessions. |
rollback_cache(session_id, accepted_depth, accepted_tree_indices=None, prefix_length=0)
Compact the KV cache to only the accepted prefix + accepted path.
Uses :class:VirtualKVCache for O(1) rollback (linear chains)
or compact() (branching trees with non-contiguous accepted
positions).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
session_id
|
str
|
The session whose cache to roll back. |
required |
accepted_depth
|
int
|
Total number of tokens in the accepted prefix
(measured from the start of the full cached sequence).
Used for simple O(1) rollback when |
required |
accepted_tree_indices
|
list[int] | None
|
BFS indices (0-based into the tree
portion of the input) of the accepted path nodes, in
order from root to leaf. When provided, the cache is
compacted to |
None
|
prefix_length
|
int
|
Length of the prompt/prefix portion of the
cached sequence. Only used when |
0
|
Raises:
| Type | Description |
|---|---|
KeyError
|
If the session does not exist. |
ValueError
|
If |
shutdown()
Stop the background TTL GC thread.
verify_draft_tree(prompt_ids, flat_token_ids, topology_map, flat_log_probs, flat_draft_probs_full, session_id=None, temperature=0.0, expected_prefix_length=0, new_token_ids=None)
Verify a draft token tree against the target model's distribution.
When a session_id is provided, the engine reuses (or creates) a
KV cache for that session. After verification, the cache is
automatically rolled back to the accepted prefix.
The verification pipeline
- Build a tree-attention mask via
build_tree_attention(). - On cache hits, append any linear
new_token_idsdelta to the session cache so the target prefix matches the draft context. - Forward pass the tree tokens through the target model with the current cached KV state.
- Run
verify_greedy_tree()on the tree logits. - Roll back the KV cache to the accepted prefix.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prompt_ids
|
list[int]
|
Full prompt/prefix token IDs when rebuilding from scratch after a miss or cache desync. |
required |
flat_token_ids
|
list[int]
|
Flattened list of draft token IDs. |
required |
topology_map
|
list[int]
|
Parent indices for each node in flat_token_ids. |
required |
flat_log_probs
|
list[float]
|
Log probabilities for the draft tokens. |
required |
flat_draft_probs_full
|
Tensor | None
|
Full draft probabilities if top-k is used. |
required |
session_id
|
str | None
|
Optional session ID for KV cache reuse. If |
None
|
new_token_ids
|
list[int] | None
|
Linear delta tokens to append to the cached prefix before tree verification. Used by the orchestrator to carry forward the previous round's bonus/correction token. |
None
|
Returns:
| Type | Description |
|---|---|
VerificationResult
|
A |
VerificationResult
|
correction token, and a |
specsplit.workers.target.kv_cache
Pre-allocated Static KV Cache for Disaggregated Speculative Decoding.
Standard HuggingFace past_key_values uses a tuple of per-layer
(key, value) tensors that grow via torch.cat at every decoding
step. This incurs reallocation + copy overhead that is unacceptable
in a latency-sensitive speculative decoding pipeline.
This module provides a :class:StaticKVCache that pre-allocates the
full maximum-length buffer once at session start, and then uses simple
slice assignment (tensor[:, :, ptr:ptr+n, :] = new_data) for
appends and a single integer update for rollbacks.
Performance Properties
- Append: O(n) in the number of new tokens only (slice write,
no reallocation, no
torch.cat). - Rollback: O(1) — updates the
seq_lenpointer. Stale data beyond the pointer is ignored and overwritten on the next append. - Memory: Fixed at ``2 * num_layers * batch * num_heads * max_seq_len
- head_dim * sizeof(dtype)`` bytes. No dynamic growth.
Tensor Layout
::
key_cache: [num_layers, batch, num_heads, max_seq_len, head_dim]
value_cache: [num_layers, batch, num_heads, max_seq_len, head_dim]
The seq_len pointer tracks how many valid positions are stored.
Only [:, :, :, :seq_len, :] contains meaningful data.
.. note::
This module is **not yet wired** into :class:`TargetEngine`, which
currently uses HuggingFace's standard ``past_key_values`` tuple
with slice-based rollback. ``StaticKVCache`` is prepared as a
future performance optimization to eliminate ``torch.cat``
reallocation overhead in the verification hot path.
StaticKVCache
Bases: Cache if Cache is not None else object
Pre-allocated, pointer-based KV cache for a single session.
Allocates fixed-size key and value buffers at initialization. Appends use slice assignment, and rollbacks update a single integer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_layers
|
int
|
Number of transformer layers (attention blocks). |
required |
num_heads
|
int
|
Number of attention heads per layer. |
required |
max_seq_len
|
int
|
Maximum sequence length to support. The buffers are allocated to this size and never grow. |
required |
head_dim
|
int
|
Dimension per attention head. |
required |
batch_size
|
int
|
Batch dimension (typically 1 for inference). |
1
|
dtype
|
dtype
|
Tensor dtype (e.g., |
float16
|
device
|
device | str
|
Torch device (e.g., |
'cpu'
|
Example::
cache = StaticKVCache(
num_layers=32, num_heads=32,
max_seq_len=2048, head_dim=128,
dtype=torch.float16, device="cuda:0",
)
# After a forward pass produces new KV projections:
# new_keys shape: [num_layers, batch, num_heads, new_len, head_dim]
# new_values shape: [num_layers, batch, num_heads, new_len, head_dim]
cache.append(new_keys, new_values)
# After verification, roll back to accepted prefix:
cache.rollback(accepted_length=128) # O(1)!
# Get KV for a specific layer to pass into model:
k, v = cache.get_kv_for_layer(layer_idx=0)
# k shape: [batch, num_heads, seq_len, head_dim]
is_full
property
Whether the cache has reached maximum capacity.
remaining_capacity
property
Number of additional positions that can be appended.
seq_len
property
Number of valid positions currently stored in the cache.
append(new_keys, new_values)
Insert new KV projections at the current pointer position.
Uses slice assignment (cache[:, :, :, ptr:ptr+n, :] = data)
which is an in-place write — no torch.cat, no
.clone(), no reallocation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
new_keys
|
Tensor
|
New key projections.
Shape: |
required |
new_values
|
Tensor
|
New value projections.
Shape: |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the new tokens would exceed |
ValueError
|
If tensor shapes don't match the cache configuration. |
batch_select_indices(indices)
Keep only specified positions (Cache interface, maps to compact).
compact(keep_indices)
Re-order the cache to keep only specified positions.
Used for branching tree KV cache compaction where the accepted
path contains non-contiguous BFS positions. Unlike
:meth:rollback, this performs actual tensor operations
(torch.index_select), but it is still far cheaper than
re-computing the full KV cache from scratch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
keep_indices
|
list[int]
|
Ordered list of cache positions to retain.
Must be valid indices in |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If any index is out of range. |
crop(max_length)
Crop cache to max_length (Cache interface, maps to rollback).
get_all_kv()
Get valid (key, value) pairs for ALL layers as a HF-compatible tuple.
This format is compatible with HuggingFace's past_key_values
argument: a tuple of (key, value) pairs per layer.
Returns:
| Type | Description |
|---|---|
tuple[Tensor, Tensor]
|
A tuple of length |
...
|
|
Note
The returned tensors are views into the pre-allocated cache. Do not modify them in-place unless you intend to update the cache.
get_kv_for_layer(layer_idx)
Get the valid (key, value) tensors for a specific layer.
Returns views (not copies) sliced to the current seq_len.
These are suitable for passing directly into a transformer layer's
attention computation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
layer_idx
|
int
|
Index of the transformer layer (0-based). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
A |
Tensor
|
|
Raises:
| Type | Description |
|---|---|
IndexError
|
If |
get_seq_length(layer_idx=0)
Return the number of cached positions (Cache interface).
reset()
Reset the cache to empty (seq_len = 0).
Like :meth:rollback, this is O(1) — it only resets the pointer.
The pre-allocated buffers remain in GPU memory for reuse.
rollback(accepted_length)
Roll back the cache to a given accepted prefix length.
This is an O(1) operation. It simply moves the seq_len
pointer backwards. No tensor data is copied or zeroed. Stale
data beyond the new pointer is harmless — it will be overwritten
by the next :meth:append call, and it is never read because
all consumers use [:, :, :, :seq_len, :] slices.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
accepted_length
|
int
|
The number of positions to keep (measured
from the start of the sequence). Must satisfy
|
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
stack_hf_cache(past_key_values)
staticmethod
Convert HuggingFace's cache format to 5D tensors for append().
HuggingFace models output past_key_values as::
tuple[ # num_layers
tuple[
Tensor[batch, heads, seq, head_dim], # keys
Tensor[batch, heads, seq, head_dim], # values
],
...
]
This method stacks them into unified 5D tensors::
keys: [num_layers, batch, heads, seq, head_dim]
values: [num_layers, batch, heads, seq, head_dim]
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
past_key_values
|
tuple[tuple[Tensor, Tensor], ...]
|
HuggingFace |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
A |
Tensor
|
suitable for passing to |
update(key_states, value_states, layer_idx, cache_kwargs=None)
Update the cache in-place. Called by HuggingFace attention layers.
Writes key/value states directly into the pre-allocated buffers via slice assignment or index_copy_. No torch.cat, no reallocation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key_states
|
Tensor
|
[batch, num_heads, new_len, head_dim] |
required |
value_states
|
Tensor
|
Same shape. |
required |
layer_idx
|
int
|
Layer index. |
required |
cache_kwargs
|
dict[str, Any] | None
|
May contain "cache_position" tensor of write indices. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[Tensor, Tensor]
|
(keys, values) for this layer, shape [batch, num_heads, seq_len, head_dim]. |
VirtualKVCache
Bases: StaticKVCache
Wraps StaticKVCache with a virtual index layer.
Instead of physically compacting the buffer on branching rollback,
maintains a _l2p mapping (logical-to-physical). Reads use gathered indices;
writes use scatter. Compact becomes an O(length) pointer swap.
append(new_keys, new_values)
Append new KV states at the end of the current logical sequence.
Updates the underlying physical buffers using the _l2p mapping and
advances _seq_len.
compact(keep_indices)
Compact the logical sequence by remapping _l2p to keep_indices.
keep_indices are logical indices that remain active after compaction.
get_kv_for_layer(layer_idx)
Return (key, value) slices for one layer at the current logical length.
get_physical_indices()
Return physical buffer indices backing the current logical prefix.
update(key_states, value_states, layer_idx, cache_kwargs=None)
Update KV states for a layer and return active slices.
When cache_kwargs["cache_position"] is provided, writes at those
logical indices; otherwise it appends at the end of the current logical
sequence. The returned tensors correspond to the active prefix.
specsplit.workers.target.tree_attn
Custom Tree Attention Masking for Disaggregated Speculative Decoding.
When the Target Worker receives a flat list of draft token IDs and a
"topology map" (a list of parent indices), it must verify the entire tree
in a single forward pass. This module builds the custom 2D boolean
attention mask and 1D position_ids tensor required to achieve that.
Terminology
- Topology map: A list of length
num_tree_nodeswheretopology_map[i]is the local index (0-based into the tree nodes array) of nodei's parent, or-1if nodeiis a root of the tree. This is the standard flat-tree representation used in SpecInfer / Medusa / Eagle. - Prefix: The already-processed prompt tokens whose KV projections live in the KV cache. Every tree node attends to the full prefix.
- Total sequence:
[prefix tokens | tree tokens]concatenated.
Attention Rules
~~~~~~~~~~~~~~~
A tree node at position j (0-indexed within the tree) is allowed to
attend to:
1. All prefix tokens (positions 0 .. prefix_length-1).
2. Itself.
3. All of its ancestors in the tree (following parent pointers up).
It must NOT attend to siblings, cousins, or any other branch.
Position ID Rules
~~~~~~~~~~~~~~~~~
Siblings at the same depth in the tree represent alternative
continuations at the same logical decoding step, so they share the same
position_id. Specifically:
- Prefix positions: 0, 1, ..., prefix_length - 1
- Tree node positions: prefix_length + depth_of_node
Example
Consider a tree with topology_map = [-1, 0, 0, 1, 2] and prefix_length = 3::
Prefix: [p0, p1, p2]
Tree: t0
/ \
t1 t2
| |
t3 t4
Depths: t0=0, t1=1, t2=1, t3=2, t4=2
Position IDs: [0,1,2, 3,4,4,5,5]
prefix tree
Attention mask (tree portion only — prefix columns are all True):
t0 attends to: prefix + t0
t1 attends to: prefix + t0, t1
t2 attends to: prefix + t0, t2
t3 attends to: prefix + t0, t1, t3
t4 attends to: prefix + t0, t2, t4
build_tree_attention(topology_map, prefix_length, device='cpu', tree_rows_only=False)
Build a causal tree-attention mask and position IDs tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
topology_map
|
list[int]
|
List of parent indices for each tree node.
|
required |
prefix_length
|
int
|
Number of prefix tokens already in the KV cache. These positions are always attended to by every tree node. |
required |
device
|
device | str
|
Torch device for the output tensors. |
'cpu'
|
tree_rows_only
|
bool
|
If True, allocate only [num_tree_nodes, total_len] mask (for cache hit when only tree rows are needed). Avoids O(total_len²) allocation when 99%+ would be discarded by slicing. |
False
|
Returns:
| Type | Description |
|---|---|
Tensor
|
A tuple |
Tensor
|
|
tuple[Tensor, Tensor]
|
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
bool_mask_to_float(mask, dtype=torch.float16)
Convert a boolean attention mask to a float mask with -inf masking.
Some model backends (e.g., Flash Attention v2) expect the attention
mask as a float tensor where masked positions are -inf and
attended positions are 0.0.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mask
|
Tensor
|
Boolean mask of shape |
required |
dtype
|
dtype
|
Output dtype (should match model precision). |
float16
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Float mask of the same shape, with |
Tensor
|
|
specsplit.workers.target.service
Target Worker gRPC service bindings.
Exposes the TargetService gRPC server that wraps the TargetEngine
for network-accessible tree-attention verification with session-based
KV caching.
TargetServiceServicer
Bases: TargetServiceServicer
gRPC servicer implementing the TargetService RPC interface.
Handles VerifyDrafts (with session-based KV caching) and
EndSession (for explicit cache cleanup).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
engine
|
TargetEngine
|
The target verification engine. |
required |
config
|
TargetWorkerConfig | None
|
Target worker config (for request size limits). |
None
|
telemetry
|
TelemetryLogger | None
|
Optional telemetry logger for span collection. |
None
|
EndSession(request, context)
Handle an EndSession RPC — release a session's KV cache.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
request
|
EndSessionRequest
|
An |
required |
context
|
ServicerContext
|
gRPC server context. |
required |
Returns:
| Type | Description |
|---|---|
EndSessionResponse
|
An |
Ping(request, context)
Health check endpoint.
VerifyDrafts(request, context)
Handle a VerifyDrafts RPC call with optional session KV caching.
If request.session_id is non-empty, the engine reuses (or creates)
a KV cache for that session, and automatically rolls it back to the
accepted prefix after verification.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
request
|
VerifyRequest
|
A |
required |
context
|
ServicerContext
|
gRPC server context. |
required |
Returns:
| Type | Description |
|---|---|
VerifyResponse
|
A |
serve(config=None)
Start the Target Worker gRPC server.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
TargetWorkerConfig | None
|
Optional configuration override. |
None
|