Core Modules
specsplit.core.config
Pydantic configuration models for SpecSplit services.
All settings can be overridden via environment variables with the SPECSPLIT_
prefix. For example, SPECSPLIT_DRAFT_MODEL_NAME=gpt2 overrides the default
draft model name.
Settings can also be loaded from a YAML or JSON config file using
:func:load_config_file. Priority (highest → lowest):
1. Constructor kwargs / CLI arguments
2. Environment variables (SPECSPLIT_* prefix)
3. Config file values
4. Field defaults
Example YAML config file::
orchestrator:
draft_address: "localhost:50051"
target_address: "localhost:50052"
max_rounds: 50
max_output_tokens: 1024
tokenizer_model: "Qwen/Qwen2.5-7B-Instruct"
draft:
model_name: "Qwen/Qwen2.5-0.5B-Instruct"
max_draft_tokens: 5
target:
model_name: "Qwen/Qwen2.5-7B-Instruct"
Usage::
from specsplit.core.config import OrchestratorConfig, load_config_file
file_cfg = load_config_file("config.yaml")
cfg = OrchestratorConfig(**file_cfg.get("orchestrator", {}))
DraftWorkerConfig
Bases: BaseSettings
Configuration for the Draft Worker (small, fast LLM).
TargetWorkerConfig
Bases: BaseSettings
Configuration for the Target Worker (large, accurate LLM).
OrchestratorConfig
Bases: BaseSettings
Configuration for the Orchestrator (pipeline coordinator).
load_config_file(path)
Load configuration from a YAML or JSON file.
The file should contain top-level keys matching the service names:
orchestrator, draft, and/or target. Each key maps to a
dict of field names → values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str | Path
|
Path to the config file ( |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dict with keys |
dict[str, Any]
|
(any or all may be absent if not specified in the file). |
Raises:
| Type | Description |
|---|---|
FileNotFoundError
|
If the file does not exist. |
ValueError
|
If the file format is unsupported. |
Example::
cfg = load_config_file("specsplit.yaml")
orch = OrchestratorConfig(**cfg.get("orchestrator", {}))
draft = DraftWorkerConfig(**cfg.get("draft", {}))
specsplit.core.telemetry
High-precision timing and structured telemetry logging for SpecSplit.
Provides a Stopwatch for nanosecond-precision wall-clock measurement and
a TelemetryLogger for emitting structured JSON spans that can be collected
for distributed tracing and benchmarking.
Stopwatch
dataclass
Nanosecond-precision wall-clock stopwatch with context-manager support.
Usage::
sw = Stopwatch()
sw.start()
# ... do work ...
sw.stop()
print(f"Elapsed: {sw.elapsed_ms:.3f} ms")
Or as a context manager::
with Stopwatch() as sw:
# ... do work ...
print(f"Elapsed: {sw.elapsed_ms:.3f} ms")
elapsed_ms
property
Elapsed time in milliseconds.
elapsed_ns
property
Elapsed time in nanoseconds.
elapsed_s
property
Elapsed time in seconds.
start()
Begin timing.
stop()
Stop timing and record elapsed nanoseconds.
TelemetrySpan
dataclass
A single telemetry span representing a timed operation.
to_dict()
Serialize the span to a dictionary.
TelemetryEvent
dataclass
A point-in-time event for reconstructing request/response timelines.
to_dict()
Serialize the event to a dictionary.
TelemetryLogger
Structured telemetry logger that collects spans and exports them as JSON.
Each span is tagged with a unique span_id for distributed tracing.
Collected spans can be exported to a JSON file for offline analysis.
Usage::
tlog = TelemetryLogger(service_name="draft-worker")
with tlog.span("generate_draft", tokens_processed=128) as span_id:
# ... do work ...
pass
tlog.export("telemetry_output.json")
events
property
Return a copy of all recorded events.
spans
property
Return a copy of all recorded spans.
export(path)
Export all recorded spans to a JSON file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str | Path
|
Output file path. Parent directories are created if needed. |
required |
export_csv(path)
Export all recorded spans to a CSV file.
export_prometheus(path)
Export all recorded spans to a Prometheus text-based format.
record_event(event_type, **metadata)
Record a point-in-time event.
record_span(span)
Record a completed span.
reset()
Clear all recorded spans.
span(operation, **metadata)
Create a timed span context manager.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operation
|
str
|
Name of the operation being timed. |
required |
**metadata
|
Any
|
Arbitrary key-value pairs to attach to the span. |
{}
|
Returns:
| Type | Description |
|---|---|
_SpanContext
|
A context manager that records timing on exit. |
get_current_context()
Return the current telemetry span context, if any.
specsplit.core.verification
Verification Mathematics for Disaggregated Speculative Decoding.
This module implements the core acceptance/rejection logic used by the Target Worker to validate draft token trees. We start with strictly greedy decoding (temperature = 0.0) as a correctness baseline before moving to stochastic rejection sampling.
Greedy Verification Algorithm
Given:
- draft_tokens[i]: The token ID drafted at tree position i.
- target_logits[i]: The target model's logit vector at position i.
- topology_map[i]: The parent index of tree position i
(-1 for roots).
For each position i, compute argmax(target_logits[i]). A
drafted token is accepted if it matches the target's greedy choice.
The algorithm then walks the topology map to find the longest continuous path from a root to a leaf where every node along the path is accepted. The "bonus token" is the target's greedy choice at the first divergence point (or at the accepted leaf, extending the sequence by one).
This entire comparison is done on-device via torch.argmax and
boolean indexing — no CPU↔GPU synchronization until the final small
result extraction.
VerificationResult
dataclass
Unified result for both greedy and stochastic tree verification.
Attributes:
| Name | Type | Description |
|---|---|---|
accepted_leaf_index |
int
|
The local tree-node index of the last
accepted node on the longest accepted path. If no tokens
were accepted, this is |
accepted_tokens |
list[int]
|
Ordered list of accepted token IDs along the longest accepted path (root → leaf). |
bonus_token |
int
|
The target model's greedy/sampled choice at the divergence point (i.e., the token that would follow the accepted prefix). This is always produced — it extends the output by one token for free. |
accepted_indices |
list[int]
|
The local tree-node indices of the accepted nodes, in path order (root → leaf). Useful for KV cache rollback and position tracking. |
num_draft_tokens |
int
|
Total number of draft tokens in the tree (for computing acceptance rate). |
diverged |
bool
|
Whether the path ended at a rejection point (True) or was fully accepted to a leaf node (False). Used by the TargetEngine to correctly sample the bonus token from the right logit position. |
acceptance_rate
property
Fraction of the tree that was accepted (0.0-1.0).
Computed as num_accepted / num_draft_tokens. Note that this
measures the path acceptance, not the full tree utilization.
num_accepted
property
Number of draft tokens accepted.
verify_greedy_tree(draft_tokens, target_logits, topology_map)
Verify a draft token tree against target logits using greedy decoding.
All comparisons are performed on-device. The only CPU↔GPU sync happens at the very end when extracting the small result lists.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
draft_tokens
|
Tensor
|
Flat tensor of drafted token IDs.
Shape: |
required |
target_logits
|
Tensor
|
Target model's logit vectors at each tree position.
Shape: |
required |
topology_map
|
list[int]
|
List of parent indices for each tree node.
|
required |
Returns:
| Name | Type | Description |
|---|---|---|
A |
VerificationResult
|
class: |
VerificationResult
|
the bonus token, and metadata. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If tensor shapes are inconsistent with the topology map. |
Example::
>>> draft_tokens = torch.tensor([42, 17, 99, 8, 55])
>>> # Suppose target argmax at each position is [42, 17, 99, 7, 55]
>>> target_logits = torch.zeros(5, 100)
>>> target_logits[0, 42] = 10.0 # matches draft
>>> target_logits[1, 17] = 10.0 # matches draft
>>> target_logits[2, 99] = 10.0 # matches draft
>>> target_logits[3, 7] = 10.0 # MISMATCH (draft=8, target=7)
>>> target_logits[4, 55] = 10.0 # matches but parent rejected
>>> topology_map = [-1, 0, 0, 1, 2] # binary tree
>>> result = verify_greedy_tree(draft_tokens, target_logits, topology_map)
>>> result.accepted_tokens # Longest: root(42) → left(17), stops at 8≠7
[42, 17]
>>> result.bonus_token # Target says 7 where draft said 8
7
verify_stochastic_tree(draft_tokens, draft_probs, target_probs, topology_map, draft_probs_full=None)
Perform stochastic verification with rejection sampling over the full tree.
Explores all root-to-leaf paths via DFS. At each node, the draft token is accepted if p_target >= p_draft, otherwise accepted with probability p_target / p_draft. The longest accepted path across all branches is chosen; the bonus (correction) token is sampled from the target distribution at the divergence node (rejection point) or at the accepted leaf.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
draft_tokens
|
Tensor
|
[num_nodes] tensor of drafted token IDs. |
required |
draft_probs
|
Tensor
|
[num_nodes] tensor of probabilities the draft assigned to the drafted token at each node (legacy; used when draft_probs_full=None). |
required |
target_probs
|
Tensor
|
[num_nodes, vocab_size] tensor of target model probabilities. |
required |
topology_map
|
list[int]
|
List mapping node index to parent index (-1 for root). |
required |
draft_probs_full
|
Tensor | None
|
Optional [num_nodes, vocab_size] full draft distribution. When present, residual = max(0, P_target - P_draft) for correct sampling. |
None
|
Returns:
| Type | Description |
|---|---|
VerificationResult
|
VerificationResult with the longest accepted path and sampled bonus token. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If shapes are inconsistent or the tree has no root nodes. |
specsplit.core.model_loading
Helpers for loading local HuggingFace causal-LM checkpoints used by workers.
get_checkpoint_dtype(model_name, *, default=None)
Return the preferred compute dtype recorded in a local checkpoint.
get_model_config(model_or_config)
Return a model config object from either a model or config.
get_model_vocab_size(model)
Best-effort vocabulary size for a loaded causal language model.
specsplit.core.cache_utils
Cache format conversion utilities for HuggingFace transformers compatibility.
Provides version-agnostic helpers to convert between legacy tuple format and DynamicCache for causal language models used by SpecSplit.
cache_to_legacy(past_kv)
Convert a cache object to a legacy ((k, v), ...) tuple when possible.
This supports legacy tuples and the newer Cache API via either
to_legacy_cache() or .layers.
cache_supports_crop(past_kv)
Return True when the cache can be safely rolled back in-place.
crop_cache(past_kv, max_length)
Crop a cache to max_length tokens.
batch_model_caches(caches)
Combine single-item caches into one batched cache for a forward pass.
slice_batch_item_from_cache(cache, item_idx)
Extract one batch item from a batched cache object.
legacy_to_dynamic_cache(past_kv, config)
Convert legacy (key, value) tuple cache to DynamicCache for model forward.
Version-agnostic: tries from_legacy_cache first, then DynamicCache constructor, finally falls back to returning the legacy tuple (many models accept it).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
past_kv
|
tuple[tuple[Tensor, Tensor], ...]
|
Legacy cache as tuple of (key, value) per layer. |
required |
config
|
Any
|
Model config (PreTrainedConfig) for DynamicCache. |
required |
Returns:
| Type | Description |
|---|---|
Any
|
DynamicCache instance or the original tuple if conversion fails. |
specsplit.core.serialization
Serialization utilities for converting between PyTorch tensors and Python lists.
These helpers are used at the gRPC boundary to convert tensors into protobuf- compatible formats (lists of ints/floats) and back. All conversions are device-aware and preserve dtype where applicable.
tensor_to_token_ids(tensor)
Convert a 1-D tensor of token IDs to a plain Python list.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor
|
Tensor
|
A 1-D |
required |
Returns:
| Type | Description |
|---|---|
list[int]
|
A Python list of ints suitable for protobuf serialization. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the tensor has more than one dimension. |
Example::
>>> ids = tensor_to_token_ids(torch.tensor([101, 2003, 1037]))
>>> ids
[101, 2003, 1037]
token_ids_to_tensor(ids, device='cpu', dtype=torch.long)
Convert a list of token IDs to a 1-D PyTorch tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ids
|
list[int]
|
A list of integer vocabulary indices. |
required |
device
|
str | device
|
Target device ( |
'cpu'
|
dtype
|
dtype
|
Desired tensor dtype. Defaults to |
long
|
Returns:
| Type | Description |
|---|---|
Tensor
|
A 1-D tensor on the specified device. |
Example::
>>> t = token_ids_to_tensor([101, 2003, 1037], device="cpu")
>>> t.shape
torch.Size([3])
logits_to_probs(logits, temperature=1.0, dim=-1)
Convert raw logits to a probability distribution via softmax.
Applies temperature scaling before softmax. A temperature of 0 is treated as greedy (argmax), returning a one-hot distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
logits
|
Tensor
|
Raw logits from a language model, shape |
required |
temperature
|
float
|
Sampling temperature. Values < 1.0 sharpen the distribution; values > 1.0 flatten it. |
1.0
|
dim
|
int
|
Dimension along which to apply softmax. |
-1
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Probability tensor of the same shape as logits. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If temperature is negative. |