Skip to content

Target Worker API

specsplit.workers.target.engine

Target Engine — session-based KV-cached tree-attention verification.

The TargetEngine wraps the large, accurate language model. Given a draft token tree from the Draft Worker, it performs a forward pass using tree attention to verify which draft tokens are accepted under the target distribution.

Architecture Notes
  • Session-based KV caching: Maintains a dict[session_id, KVCacheState] mapping that stores pre-allocated :class:StaticKVCache per session. This eliminates prompt recomputation and torch.cat reallocation across verification rounds within the same generation session.
  • After verification, rollback_cache uses StaticKVCache.rollback() (O(1) pointer update) for linear chains, or StaticKVCache.compact() for branching trees, to crop the cache to the accepted prefix.
  • Uses greedy verification: for each position in the tree, if argmax(target_logits) matches the drafted token, it is accepted.

CacheDesyncError

Bases: RuntimeError

Raised when target cache length disagrees with orchestrator's expected_prefix_length.

The orchestrator must retry with full prompt_token_ids to rebuild the cache.

VerificationResult dataclass

Result of verifying a draft tree against the target model.

acceptance_rate property

Fraction of draft tokens accepted (0.0-1.0).

Uses num_draft_tokens (the total draft candidates presented) as the denominator for a meaningful acceptance rate.

KVCacheState dataclass

Per-session KV cache state stored on the Target Worker.

Wraps a :class:StaticKVCache providing O(1) rollback and zero-copy slice-assignment appends.

Attributes:

Name Type Description
cache VirtualKVCache | None

The pre-allocated static KV cache for this session, or None if the session was created before the model was loaded.

seq_len int

The current cached sequence length (number of tokens whose KV projections are stored).

next_root_logit Tensor | None

The logit predicting the token directly following the accepted sequence. Maintained to perfectly align the target model upon a cache hit without duplicate forwarding.

TargetEngine

Session-aware tree-attention verification engine for the Target Worker.

Maintains a dictionary of session_id → KVCacheState to reuse KV projections across verification rounds within the same generation session. After each verification, the cache is rolled back to the longest accepted prefix via rollback_cache().

The KV cache uses :class:StaticKVCache — a pre-allocated, pointer- based cache that provides O(1) rollback and zero-copy appends, compared to HuggingFace's default DynamicCache which requires torch.cat on every append.

Parameters:

Name Type Description Default
config TargetWorkerConfig | None

Target worker configuration (model name, device, etc.).

None

active_sessions property

Number of active sessions with cached KV state.

is_loaded property

Whether the model has been loaded.

model_vocab_size property

Vocabulary size exposed by the loaded model or tokenizer.

end_session(session_id)

Terminate a session and free its KV cache from memory.

Thread-safe: acquires the session lock before destroying the cache to ensure no active verify_draft_tree thread is using the tensors being deleted (Issue 10).

Parameters:

Name Type Description Default
session_id str

The session to terminate.

required

Returns:

Type Description
bool

True if the session existed and was removed, False otherwise.

get_or_create_session(session_id)

Retrieve an existing session cache or create a new one.

Parameters:

Name Type Description Default
session_id str

Unique identifier for the generation session.

required

Returns:

Type Description
KVCacheState

A (cache_state, cache_hit) tuple. cache_hit is True

bool

if an existing cache was found.

has_session(session_id)

Return True if a session cache currently exists.

has_usable_session(session_id)

Return True if a session exists and currently stores a usable prefix.

load_model()

Load the target model and tokenizer via AutoModelForCausalLM.

purge_stale_sessions(ttl_seconds=None)

Remove sessions that have not been accessed within the TTL.

Parameters:

Name Type Description Default
ttl_seconds float | None

Override the configured TTL (for testing).

None

Returns:

Type Description
int

Number of purged sessions.

rollback_cache(session_id, accepted_depth, accepted_tree_indices=None, prefix_length=0)

Compact the KV cache to only the accepted prefix + accepted path.

Uses :class:VirtualKVCache for O(1) rollback (linear chains) or compact() (branching trees with non-contiguous accepted positions).

Parameters:

Name Type Description Default
session_id str

The session whose cache to roll back.

required
accepted_depth int

Total number of tokens in the accepted prefix (measured from the start of the full cached sequence). Used for simple O(1) rollback when accepted_tree_indices is None.

required
accepted_tree_indices list[int] | None

BFS indices (0-based into the tree portion of the input) of the accepted path nodes, in order from root to leaf. When provided, the cache is compacted to prefix[:prefix_length] + tree[accepted_tree_indices].

None
prefix_length int

Length of the prompt/prefix portion of the cached sequence. Only used when accepted_tree_indices is provided.

0

Raises:

Type Description
KeyError

If the session does not exist.

ValueError

If accepted_depth exceeds the current cache length.

shutdown()

Stop the background TTL GC thread.

verify_draft_tree(prompt_ids, flat_token_ids, topology_map, flat_log_probs, flat_draft_probs_full, session_id=None, temperature=0.0, expected_prefix_length=0, new_token_ids=None)

Verify a draft token tree against the target model's distribution.

When a session_id is provided, the engine reuses (or creates) a KV cache for that session. After verification, the cache is automatically rolled back to the accepted prefix.

The verification pipeline
  1. Build a tree-attention mask via build_tree_attention().
  2. On cache hits, append any linear new_token_ids delta to the session cache so the target prefix matches the draft context.
  3. Forward pass the tree tokens through the target model with the current cached KV state.
  4. Run verify_greedy_tree() on the tree logits.
  5. Roll back the KV cache to the accepted prefix.

Parameters:

Name Type Description Default
prompt_ids list[int]

Full prompt/prefix token IDs when rebuilding from scratch after a miss or cache desync.

required
flat_token_ids list[int]

Flattened list of draft token IDs.

required
topology_map list[int]

Parent indices for each node in flat_token_ids.

required
flat_log_probs list[float]

Log probabilities for the draft tokens.

required
flat_draft_probs_full Tensor | None

Full draft probabilities if top-k is used.

required
session_id str | None

Optional session ID for KV cache reuse. If None, verification is fully stateless (no caching).

None
new_token_ids list[int] | None

Linear delta tokens to append to the cached prefix before tree verification. Used by the orchestrator to carry forward the previous round's bonus/correction token.

None

Returns:

Type Description
VerificationResult

A VerificationResult with accepted tokens, an optional

VerificationResult

correction token, and a cache_hit flag.

specsplit.workers.target.kv_cache

Pre-allocated Static KV Cache for Disaggregated Speculative Decoding.

Standard HuggingFace past_key_values uses a tuple of per-layer (key, value) tensors that grow via torch.cat at every decoding step. This incurs reallocation + copy overhead that is unacceptable in a latency-sensitive speculative decoding pipeline.

This module provides a :class:StaticKVCache that pre-allocates the full maximum-length buffer once at session start, and then uses simple slice assignment (tensor[:, :, ptr:ptr+n, :] = new_data) for appends and a single integer update for rollbacks.

Performance Properties

  • Append: O(n) in the number of new tokens only (slice write, no reallocation, no torch.cat).
  • Rollback: O(1) — updates the seq_len pointer. Stale data beyond the pointer is ignored and overwritten on the next append.
  • Memory: Fixed at ``2 * num_layers * batch * num_heads * max_seq_len
  • head_dim * sizeof(dtype)`` bytes. No dynamic growth.

Tensor Layout

::

key_cache:   [num_layers, batch, num_heads, max_seq_len, head_dim]
value_cache: [num_layers, batch, num_heads, max_seq_len, head_dim]

The seq_len pointer tracks how many valid positions are stored. Only [:, :, :, :seq_len, :] contains meaningful data.

.. note::

This module is **not yet wired** into :class:`TargetEngine`, which
currently uses HuggingFace's standard ``past_key_values`` tuple
with slice-based rollback.  ``StaticKVCache`` is prepared as a
future performance optimization to eliminate ``torch.cat``
reallocation overhead in the verification hot path.

StaticKVCache

Bases: Cache if Cache is not None else object

Pre-allocated, pointer-based KV cache for a single session.

Allocates fixed-size key and value buffers at initialization. Appends use slice assignment, and rollbacks update a single integer.

Parameters:

Name Type Description Default
num_layers int

Number of transformer layers (attention blocks).

required
num_heads int

Number of attention heads per layer.

required
max_seq_len int

Maximum sequence length to support. The buffers are allocated to this size and never grow.

required
head_dim int

Dimension per attention head.

required
batch_size int

Batch dimension (typically 1 for inference).

1
dtype dtype

Tensor dtype (e.g., torch.float16).

float16
device device | str

Torch device (e.g., "cuda:0").

'cpu'

Example::

cache = StaticKVCache(
    num_layers=32, num_heads=32,
    max_seq_len=2048, head_dim=128,
    dtype=torch.float16, device="cuda:0",
)

# After a forward pass produces new KV projections:
# new_keys shape:   [num_layers, batch, num_heads, new_len, head_dim]
# new_values shape: [num_layers, batch, num_heads, new_len, head_dim]
cache.append(new_keys, new_values)

# After verification, roll back to accepted prefix:
cache.rollback(accepted_length=128)  # O(1)!

# Get KV for a specific layer to pass into model:
k, v = cache.get_kv_for_layer(layer_idx=0)
# k shape: [batch, num_heads, seq_len, head_dim]

is_full property

Whether the cache has reached maximum capacity.

remaining_capacity property

Number of additional positions that can be appended.

seq_len property

Number of valid positions currently stored in the cache.

append(new_keys, new_values)

Insert new KV projections at the current pointer position.

Uses slice assignment (cache[:, :, :, ptr:ptr+n, :] = data) which is an in-place write — no torch.cat, no .clone(), no reallocation.

Parameters:

Name Type Description Default
new_keys Tensor

New key projections. Shape: [num_layers, batch, num_heads, new_len, head_dim].

required
new_values Tensor

New value projections. Shape: [num_layers, batch, num_heads, new_len, head_dim].

required

Raises:

Type Description
ValueError

If the new tokens would exceed max_seq_len.

ValueError

If tensor shapes don't match the cache configuration.

batch_select_indices(indices)

Keep only specified positions (Cache interface, maps to compact).

compact(keep_indices)

Re-order the cache to keep only specified positions.

Used for branching tree KV cache compaction where the accepted path contains non-contiguous BFS positions. Unlike :meth:rollback, this performs actual tensor operations (torch.index_select), but it is still far cheaper than re-computing the full KV cache from scratch.

Parameters:

Name Type Description Default
keep_indices list[int]

Ordered list of cache positions to retain. Must be valid indices in [0, seq_len).

required

Raises:

Type Description
ValueError

If any index is out of range.

crop(max_length)

Crop cache to max_length (Cache interface, maps to rollback).

get_all_kv()

Get valid (key, value) pairs for ALL layers as a HF-compatible tuple.

This format is compatible with HuggingFace's past_key_values argument: a tuple of (key, value) pairs per layer.

Returns:

Type Description
tuple[Tensor, Tensor]

A tuple of length num_layers, where each element is

...

(key, value) with shape [batch, num_heads, seq_len, head_dim].

Note

The returned tensors are views into the pre-allocated cache. Do not modify them in-place unless you intend to update the cache.

get_kv_for_layer(layer_idx)

Get the valid (key, value) tensors for a specific layer.

Returns views (not copies) sliced to the current seq_len. These are suitable for passing directly into a transformer layer's attention computation.

Parameters:

Name Type Description Default
layer_idx int

Index of the transformer layer (0-based).

required

Returns:

Type Description
Tensor

A (key, value) tuple of tensors, each with shape

Tensor

[batch, num_heads, seq_len, head_dim].

Raises:

Type Description
IndexError

If layer_idx is out of range.

get_seq_length(layer_idx=0)

Return the number of cached positions (Cache interface).

reset()

Reset the cache to empty (seq_len = 0).

Like :meth:rollback, this is O(1) — it only resets the pointer. The pre-allocated buffers remain in GPU memory for reuse.

rollback(accepted_length)

Roll back the cache to a given accepted prefix length.

This is an O(1) operation. It simply moves the seq_len pointer backwards. No tensor data is copied or zeroed. Stale data beyond the new pointer is harmless — it will be overwritten by the next :meth:append call, and it is never read because all consumers use [:, :, :, :seq_len, :] slices.

Parameters:

Name Type Description Default
accepted_length int

The number of positions to keep (measured from the start of the sequence). Must satisfy 0 <= accepted_length <= seq_len.

required

Raises:

Type Description
ValueError

If accepted_length is negative or exceeds the current seq_len.

stack_hf_cache(past_key_values) staticmethod

Convert HuggingFace's cache format to 5D tensors for append().

HuggingFace models output past_key_values as::

tuple[  # num_layers
    tuple[
        Tensor[batch, heads, seq, head_dim],  # keys
        Tensor[batch, heads, seq, head_dim],  # values
    ],
    ...
]

This method stacks them into unified 5D tensors::

keys:   [num_layers, batch, heads, seq, head_dim]
values: [num_layers, batch, heads, seq, head_dim]

Parameters:

Name Type Description Default
past_key_values tuple[tuple[Tensor, Tensor], ...]

HuggingFace past_key_values tuple.

required

Returns:

Type Description
Tensor

A (stacked_keys, stacked_values) tuple of 5D tensors

Tensor

suitable for passing to StaticKVCache.append().

update(key_states, value_states, layer_idx, cache_kwargs=None)

Update the cache in-place. Called by HuggingFace attention layers.

Writes key/value states directly into the pre-allocated buffers via slice assignment or index_copy_. No torch.cat, no reallocation.

Parameters:

Name Type Description Default
key_states Tensor

[batch, num_heads, new_len, head_dim]

required
value_states Tensor

Same shape.

required
layer_idx int

Layer index.

required
cache_kwargs dict[str, Any] | None

May contain "cache_position" tensor of write indices.

None

Returns:

Type Description
tuple[Tensor, Tensor]

(keys, values) for this layer, shape [batch, num_heads, seq_len, head_dim].

VirtualKVCache

Bases: StaticKVCache

Wraps StaticKVCache with a virtual index layer.

Instead of physically compacting the buffer on branching rollback, maintains a _l2p mapping (logical-to-physical). Reads use gathered indices; writes use scatter. Compact becomes an O(length) pointer swap.

append(new_keys, new_values)

Append new KV states at the end of the current logical sequence.

Updates the underlying physical buffers using the _l2p mapping and advances _seq_len.

compact(keep_indices)

Compact the logical sequence by remapping _l2p to keep_indices.

keep_indices are logical indices that remain active after compaction.

get_kv_for_layer(layer_idx)

Return (key, value) slices for one layer at the current logical length.

get_physical_indices()

Return physical buffer indices backing the current logical prefix.

update(key_states, value_states, layer_idx, cache_kwargs=None)

Update KV states for a layer and return active slices.

When cache_kwargs["cache_position"] is provided, writes at those logical indices; otherwise it appends at the end of the current logical sequence. The returned tensors correspond to the active prefix.

specsplit.workers.target.tree_attn

Custom Tree Attention Masking for Disaggregated Speculative Decoding.

When the Target Worker receives a flat list of draft token IDs and a "topology map" (a list of parent indices), it must verify the entire tree in a single forward pass. This module builds the custom 2D boolean attention mask and 1D position_ids tensor required to achieve that.

Terminology

  • Topology map: A list of length num_tree_nodes where topology_map[i] is the local index (0-based into the tree nodes array) of node i's parent, or -1 if node i is a root of the tree. This is the standard flat-tree representation used in SpecInfer / Medusa / Eagle.
  • Prefix: The already-processed prompt tokens whose KV projections live in the KV cache. Every tree node attends to the full prefix.
  • Total sequence: [prefix tokens | tree tokens] concatenated.

Attention Rules ~~~~~~~~~~~~~~~ A tree node at position j (0-indexed within the tree) is allowed to attend to: 1. All prefix tokens (positions 0 .. prefix_length-1). 2. Itself. 3. All of its ancestors in the tree (following parent pointers up).

It must NOT attend to siblings, cousins, or any other branch.

Position ID Rules ~~~~~~~~~~~~~~~~~ Siblings at the same depth in the tree represent alternative continuations at the same logical decoding step, so they share the same position_id. Specifically: - Prefix positions: 0, 1, ..., prefix_length - 1 - Tree node positions: prefix_length + depth_of_node

Example

Consider a tree with topology_map = [-1, 0, 0, 1, 2] and prefix_length = 3::

    Prefix: [p0, p1, p2]
    Tree:       t0
               /    \
             t1      t2
             |       |
             t3      t4

Depths:  t0=0, t1=1, t2=1, t3=2, t4=2
Position IDs: [0,1,2,  3,4,4,5,5]
               prefix   tree

Attention mask (tree portion only — prefix columns are all True):
    t0 attends to: prefix + t0
    t1 attends to: prefix + t0, t1
    t2 attends to: prefix + t0, t2
    t3 attends to: prefix + t0, t1, t3
    t4 attends to: prefix + t0, t2, t4

build_tree_attention(topology_map, prefix_length, device='cpu', tree_rows_only=False)

Build a causal tree-attention mask and position IDs tensor.

Parameters:

Name Type Description Default
topology_map list[int]

List of parent indices for each tree node. topology_map[i] = j means node i's parent is node j. topology_map[i] = -1 means node i is a root. Length: num_tree_nodes.

required
prefix_length int

Number of prefix tokens already in the KV cache. These positions are always attended to by every tree node.

required
device device | str

Torch device for the output tensors.

'cpu'
tree_rows_only bool

If True, allocate only [num_tree_nodes, total_len] mask (for cache hit when only tree rows are needed). Avoids O(total_len²) allocation when 99%+ would be discarded by slicing.

False

Returns:

Type Description
Tensor

A tuple (attention_mask, position_ids) where:

Tensor
  • attention_mask: torch.bool tensor of shape [1, 1, Q, total_len] with Q = num_tree_nodes if tree_rows_only else total_len. True = allowed to attend.
tuple[Tensor, Tensor]
  • position_ids: torch.long tensor of shape [1, Q].

Raises:

Type Description
ValueError

If topology_map contains an out-of-range parent index or forms a cycle.

bool_mask_to_float(mask, dtype=torch.float16)

Convert a boolean attention mask to a float mask with -inf masking.

Some model backends (e.g., Flash Attention v2) expect the attention mask as a float tensor where masked positions are -inf and attended positions are 0.0.

Parameters:

Name Type Description Default
mask Tensor

Boolean mask of shape [1, 1, Q, K]. True = attend, False = mask out.

required
dtype dtype

Output dtype (should match model precision).

float16

Returns:

Type Description
Tensor

Float mask of the same shape, with 0.0 for attended and

Tensor

-inf for masked positions.

specsplit.workers.target.service

Target Worker gRPC service bindings.

Exposes the TargetService gRPC server that wraps the TargetEngine for network-accessible tree-attention verification with session-based KV caching.

TargetServiceServicer

Bases: TargetServiceServicer

gRPC servicer implementing the TargetService RPC interface.

Handles VerifyDrafts (with session-based KV caching) and EndSession (for explicit cache cleanup).

Parameters:

Name Type Description Default
engine TargetEngine

The target verification engine.

required
config TargetWorkerConfig | None

Target worker config (for request size limits).

None
telemetry TelemetryLogger | None

Optional telemetry logger for span collection.

None

EndSession(request, context)

Handle an EndSession RPC — release a session's KV cache.

Parameters:

Name Type Description Default
request EndSessionRequest

An EndSessionRequest protobuf message.

required
context ServicerContext

gRPC server context.

required

Returns:

Type Description
EndSessionResponse

An EndSessionResponse protobuf message.

Ping(request, context)

Health check endpoint.

VerifyDrafts(request, context)

Handle a VerifyDrafts RPC call with optional session KV caching.

If request.session_id is non-empty, the engine reuses (or creates) a KV cache for that session, and automatically rolls it back to the accepted prefix after verification.

Parameters:

Name Type Description Default
request VerifyRequest

A VerifyRequest protobuf message.

required
context ServicerContext

gRPC server context.

required

Returns:

Type Description
VerifyResponse

A VerifyResponse protobuf message with verification results.

serve(config=None)

Start the Target Worker gRPC server.

Parameters:

Name Type Description Default
config TargetWorkerConfig | None

Optional configuration override.

None