Skip to content

Core Modules

specsplit.core.config

Pydantic configuration models for SpecSplit services.

All settings can be overridden via environment variables with the SPECSPLIT_ prefix. For example, SPECSPLIT_DRAFT_MODEL_NAME=gpt2 overrides the default draft model name.

Settings can also be loaded from a YAML or JSON config file using :func:load_config_file. Priority (highest → lowest): 1. Constructor kwargs / CLI arguments 2. Environment variables (SPECSPLIT_* prefix) 3. Config file values 4. Field defaults

Example YAML config file::

orchestrator:
  draft_address: "localhost:50051"
  target_address: "localhost:50052"
  max_rounds: 50
  max_output_tokens: 1024
  tokenizer_model: "Qwen/Qwen2.5-7B-Instruct"
draft:
  model_name: "Qwen/Qwen2.5-0.5B-Instruct"
  max_draft_tokens: 5
target:
  model_name: "Qwen/Qwen2.5-7B-Instruct"

Usage::

from specsplit.core.config import OrchestratorConfig, load_config_file

file_cfg = load_config_file("config.yaml")
cfg = OrchestratorConfig(**file_cfg.get("orchestrator", {}))

DraftWorkerConfig

Bases: BaseSettings

Configuration for the Draft Worker (small, fast LLM).

TargetWorkerConfig

Bases: BaseSettings

Configuration for the Target Worker (large, accurate LLM).

OrchestratorConfig

Bases: BaseSettings

Configuration for the Orchestrator (pipeline coordinator).

load_config_file(path)

Load configuration from a YAML or JSON file.

The file should contain top-level keys matching the service names: orchestrator, draft, and/or target. Each key maps to a dict of field names → values.

Parameters:

Name Type Description Default
path str | Path

Path to the config file (.yaml, .yml, or .json).

required

Returns:

Type Description
dict[str, Any]

A dict with keys "orchestrator", "draft", "target"

dict[str, Any]

(any or all may be absent if not specified in the file).

Raises:

Type Description
FileNotFoundError

If the file does not exist.

ValueError

If the file format is unsupported.

Example::

cfg = load_config_file("specsplit.yaml")
orch = OrchestratorConfig(**cfg.get("orchestrator", {}))
draft = DraftWorkerConfig(**cfg.get("draft", {}))

specsplit.core.telemetry

High-precision timing and structured telemetry logging for SpecSplit.

Provides a Stopwatch for nanosecond-precision wall-clock measurement and a TelemetryLogger for emitting structured JSON spans that can be collected for distributed tracing and benchmarking.

Stopwatch dataclass

Nanosecond-precision wall-clock stopwatch with context-manager support.

Usage::

sw = Stopwatch()
sw.start()
# ... do work ...
sw.stop()
print(f"Elapsed: {sw.elapsed_ms:.3f} ms")

Or as a context manager::

with Stopwatch() as sw:
    # ... do work ...
print(f"Elapsed: {sw.elapsed_ms:.3f} ms")

elapsed_ms property

Elapsed time in milliseconds.

elapsed_ns property

Elapsed time in nanoseconds.

elapsed_s property

Elapsed time in seconds.

start()

Begin timing.

stop()

Stop timing and record elapsed nanoseconds.

TelemetrySpan dataclass

A single telemetry span representing a timed operation.

to_dict()

Serialize the span to a dictionary.

TelemetryEvent dataclass

A point-in-time event for reconstructing request/response timelines.

to_dict()

Serialize the event to a dictionary.

TelemetryLogger

Structured telemetry logger that collects spans and exports them as JSON.

Each span is tagged with a unique span_id for distributed tracing. Collected spans can be exported to a JSON file for offline analysis.

Usage::

tlog = TelemetryLogger(service_name="draft-worker")

with tlog.span("generate_draft", tokens_processed=128) as span_id:
    # ... do work ...
    pass

tlog.export("telemetry_output.json")

events property

Return a copy of all recorded events.

spans property

Return a copy of all recorded spans.

export(path)

Export all recorded spans to a JSON file.

Parameters:

Name Type Description Default
path str | Path

Output file path. Parent directories are created if needed.

required

export_csv(path)

Export all recorded spans to a CSV file.

export_prometheus(path)

Export all recorded spans to a Prometheus text-based format.

record_event(event_type, **metadata)

Record a point-in-time event.

record_span(span)

Record a completed span.

reset()

Clear all recorded spans.

span(operation, **metadata)

Create a timed span context manager.

Parameters:

Name Type Description Default
operation str

Name of the operation being timed.

required
**metadata Any

Arbitrary key-value pairs to attach to the span.

{}

Returns:

Type Description
_SpanContext

A context manager that records timing on exit.

get_current_context()

Return the current telemetry span context, if any.

specsplit.core.verification

Verification Mathematics for Disaggregated Speculative Decoding.

This module implements the core acceptance/rejection logic used by the Target Worker to validate draft token trees. We start with strictly greedy decoding (temperature = 0.0) as a correctness baseline before moving to stochastic rejection sampling.

Greedy Verification Algorithm

Given: - draft_tokens[i]: The token ID drafted at tree position i. - target_logits[i]: The target model's logit vector at position i. - topology_map[i]: The parent index of tree position i (-1 for roots).

For each position i, compute argmax(target_logits[i]). A drafted token is accepted if it matches the target's greedy choice.

The algorithm then walks the topology map to find the longest continuous path from a root to a leaf where every node along the path is accepted. The "bonus token" is the target's greedy choice at the first divergence point (or at the accepted leaf, extending the sequence by one).

This entire comparison is done on-device via torch.argmax and boolean indexing — no CPU↔GPU synchronization until the final small result extraction.

VerificationResult dataclass

Unified result for both greedy and stochastic tree verification.

Attributes:

Name Type Description
accepted_leaf_index int

The local tree-node index of the last accepted node on the longest accepted path. If no tokens were accepted, this is -1.

accepted_tokens list[int]

Ordered list of accepted token IDs along the longest accepted path (root → leaf).

bonus_token int

The target model's greedy/sampled choice at the divergence point (i.e., the token that would follow the accepted prefix). This is always produced — it extends the output by one token for free.

accepted_indices list[int]

The local tree-node indices of the accepted nodes, in path order (root → leaf). Useful for KV cache rollback and position tracking.

num_draft_tokens int

Total number of draft tokens in the tree (for computing acceptance rate).

diverged bool

Whether the path ended at a rejection point (True) or was fully accepted to a leaf node (False). Used by the TargetEngine to correctly sample the bonus token from the right logit position.

acceptance_rate property

Fraction of the tree that was accepted (0.0-1.0).

Computed as num_accepted / num_draft_tokens. Note that this measures the path acceptance, not the full tree utilization.

num_accepted property

Number of draft tokens accepted.

verify_greedy_tree(draft_tokens, target_logits, topology_map)

Verify a draft token tree against target logits using greedy decoding.

All comparisons are performed on-device. The only CPU↔GPU sync happens at the very end when extracting the small result lists.

Parameters:

Name Type Description Default
draft_tokens Tensor

Flat tensor of drafted token IDs. Shape: [num_tree_nodes], dtype: torch.long. draft_tokens[i] is the token drafted at tree position i.

required
target_logits Tensor

Target model's logit vectors at each tree position. Shape: [num_tree_nodes, vocab_size], dtype: torch.float*. target_logits[i] corresponds to tree position i.

required
topology_map list[int]

List of parent indices for each tree node. topology_map[i] = j means node i's parent is j. topology_map[i] = -1 means node i is a root. Length: num_tree_nodes.

required

Returns:

Name Type Description
A VerificationResult

class:VerificationResult with the longest accepted path,

VerificationResult

the bonus token, and metadata.

Raises:

Type Description
ValueError

If tensor shapes are inconsistent with the topology map.

Example::

>>> draft_tokens = torch.tensor([42, 17, 99, 8, 55])
>>> # Suppose target argmax at each position is [42, 17, 99, 7, 55]
>>> target_logits = torch.zeros(5, 100)
>>> target_logits[0, 42] = 10.0  # matches draft
>>> target_logits[1, 17] = 10.0  # matches draft
>>> target_logits[2, 99] = 10.0  # matches draft
>>> target_logits[3, 7]  = 10.0  # MISMATCH (draft=8, target=7)
>>> target_logits[4, 55] = 10.0  # matches but parent rejected
>>> topology_map = [-1, 0, 0, 1, 2]  # binary tree
>>> result = verify_greedy_tree(draft_tokens, target_logits, topology_map)
>>> result.accepted_tokens  # Longest: root(42) → left(17), stops at 8≠7
[42, 17]
>>> result.bonus_token  # Target says 7 where draft said 8
7

verify_stochastic_tree(draft_tokens, draft_probs, target_probs, topology_map, draft_probs_full=None)

Perform stochastic verification with rejection sampling over the full tree.

Explores all root-to-leaf paths via DFS. At each node, the draft token is accepted if p_target >= p_draft, otherwise accepted with probability p_target / p_draft. The longest accepted path across all branches is chosen; the bonus (correction) token is sampled from the target distribution at the divergence node (rejection point) or at the accepted leaf.

Parameters:

Name Type Description Default
draft_tokens Tensor

[num_nodes] tensor of drafted token IDs.

required
draft_probs Tensor

[num_nodes] tensor of probabilities the draft assigned to the drafted token at each node (legacy; used when draft_probs_full=None).

required
target_probs Tensor

[num_nodes, vocab_size] tensor of target model probabilities.

required
topology_map list[int]

List mapping node index to parent index (-1 for root).

required
draft_probs_full Tensor | None

Optional [num_nodes, vocab_size] full draft distribution. When present, residual = max(0, P_target - P_draft) for correct sampling.

None

Returns:

Type Description
VerificationResult

VerificationResult with the longest accepted path and sampled bonus token.

Raises:

Type Description
ValueError

If shapes are inconsistent or the tree has no root nodes.

specsplit.core.model_loading

Helpers for loading local HuggingFace causal-LM checkpoints used by workers.

get_checkpoint_dtype(model_name, *, default=None)

Return the preferred compute dtype recorded in a local checkpoint.

get_model_config(model_or_config)

Return a model config object from either a model or config.

get_model_vocab_size(model)

Best-effort vocabulary size for a loaded causal language model.

specsplit.core.cache_utils

Cache format conversion utilities for HuggingFace transformers compatibility.

Provides version-agnostic helpers to convert between legacy tuple format and DynamicCache for causal language models used by SpecSplit.

cache_to_legacy(past_kv)

Convert a cache object to a legacy ((k, v), ...) tuple when possible.

This supports legacy tuples and the newer Cache API via either to_legacy_cache() or .layers.

cache_supports_crop(past_kv)

Return True when the cache can be safely rolled back in-place.

crop_cache(past_kv, max_length)

Crop a cache to max_length tokens.

batch_model_caches(caches)

Combine single-item caches into one batched cache for a forward pass.

slice_batch_item_from_cache(cache, item_idx)

Extract one batch item from a batched cache object.

legacy_to_dynamic_cache(past_kv, config)

Convert legacy (key, value) tuple cache to DynamicCache for model forward.

Version-agnostic: tries from_legacy_cache first, then DynamicCache constructor, finally falls back to returning the legacy tuple (many models accept it).

Parameters:

Name Type Description Default
past_kv tuple[tuple[Tensor, Tensor], ...]

Legacy cache as tuple of (key, value) per layer.

required
config Any

Model config (PreTrainedConfig) for DynamicCache.

required

Returns:

Type Description
Any

DynamicCache instance or the original tuple if conversion fails.

specsplit.core.serialization

Serialization utilities for converting between PyTorch tensors and Python lists.

These helpers are used at the gRPC boundary to convert tensors into protobuf- compatible formats (lists of ints/floats) and back. All conversions are device-aware and preserve dtype where applicable.

tensor_to_token_ids(tensor)

Convert a 1-D tensor of token IDs to a plain Python list.

Parameters:

Name Type Description Default
tensor Tensor

A 1-D torch.LongTensor (or compatible integer dtype) containing vocabulary indices.

required

Returns:

Type Description
list[int]

A Python list of ints suitable for protobuf serialization.

Raises:

Type Description
ValueError

If the tensor has more than one dimension.

Example::

>>> ids = tensor_to_token_ids(torch.tensor([101, 2003, 1037]))
>>> ids
[101, 2003, 1037]

token_ids_to_tensor(ids, device='cpu', dtype=torch.long)

Convert a list of token IDs to a 1-D PyTorch tensor.

Parameters:

Name Type Description Default
ids list[int]

A list of integer vocabulary indices.

required
device str | device

Target device ("cpu", "cuda:0", etc.).

'cpu'
dtype dtype

Desired tensor dtype. Defaults to torch.long.

long

Returns:

Type Description
Tensor

A 1-D tensor on the specified device.

Example::

>>> t = token_ids_to_tensor([101, 2003, 1037], device="cpu")
>>> t.shape
torch.Size([3])

logits_to_probs(logits, temperature=1.0, dim=-1)

Convert raw logits to a probability distribution via softmax.

Applies temperature scaling before softmax. A temperature of 0 is treated as greedy (argmax), returning a one-hot distribution.

Parameters:

Name Type Description Default
logits Tensor

Raw logits from a language model, shape (..., vocab_size).

required
temperature float

Sampling temperature. Values < 1.0 sharpen the distribution; values > 1.0 flatten it.

1.0
dim int

Dimension along which to apply softmax.

-1

Returns:

Type Description
Tensor

Probability tensor of the same shape as logits.

Raises:

Type Description
ValueError

If temperature is negative.