"""
Computation and Sampling Strategies
This module defines strategies for computing distribution characteristics
and generating random samples.
"""
from __future__ import annotations
__author__ = "Leonid Elkin, Mikhail Mikhailov, Irina Sergeeva"
__copyright__ = "Copyright (c) 2025 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Protocol, cast
from pysatl_core.distributions.computations.options import (
EdgeOptionsDescriptor,
ResolvedEdgeOptions,
_BaseOption,
)
from pysatl_core.distributions.registry import characteristic_registry
from pysatl_core.types import Method, NumericArray
if TYPE_CHECKING:
from collections.abc import Hashable, Mapping
from pysatl_core.distributions.computations.computation import (
AnalyticalComputation,
FittedComputationMethod,
)
from pysatl_core.distributions.computations.options import StepOptions
from pysatl_core.distributions.distribution import Distribution
from pysatl_core.distributions.registry.graph import RegistryView
from pysatl_core.distributions.registry.graph_primitives import (
ComputationEdgeMeta,
EdgeMeta,
)
from pysatl_core.types import GenericCharacteristicName, LabelName
# --------------------------------------------------------------------------- #
# Execution plan (introspection of how a strategy will compute a state)
# --------------------------------------------------------------------------- #
[docs]
@dataclass(frozen=True, slots=True)
class ComputationStep:
"""
One step of a strategy's execution plan for a target characteristic.
Attributes
----------
target : GenericCharacteristicName
Characteristic produced by this step.
sources : tuple[GenericCharacteristicName, ...]
Characteristics consumed by this step.
edge_kind : str
Underlying edge kind: ``"analytical_loop"``,
``"transformation_loop"`` or ``"computation"``.
method_name : str
Human-readable identifier of the underlying method (descriptor
``name`` when available, otherwise ``target``).
options_descriptor : EdgeOptionsDescriptor
Compact descriptor describing which user-supplied options will
be consumed at this step. Empty for self-loop steps.
"""
target: GenericCharacteristicName
sources: tuple[GenericCharacteristicName, ...]
edge_kind: str
method_name: str
options_descriptor: EdgeOptionsDescriptor = field(default_factory=EdgeOptionsDescriptor)
[docs]
@dataclass(frozen=True, slots=True)
class ComputationPlan:
"""
Plan describing how a strategy will compute ``target`` for a distribution.
Attributes
----------
target : GenericCharacteristicName
Characteristic the plan resolves.
source : GenericCharacteristicName
Starting characteristic of the plan (a self-loop characteristic
in :attr:`Distribution.analytical_computations`).
steps : tuple[ComputationStep, ...]
Ordered sequence of steps. For a single-loop plan there is
exactly one step ``source -> source``; for a conversion plan
the first step starts at ``source`` and the last step targets
``target``.
"""
target: GenericCharacteristicName
source: GenericCharacteristicName
steps: tuple[ComputationStep, ...]
[docs]
def required_options(self) -> tuple[str, ...]:
"""
Return the names of all options that may be consumed by the plan.
Useful for users that want to know which keyword arguments are
meaningful for a particular ``query_method`` / ``calculate_characteristic``
call.
"""
seen: dict[str, None] = {}
for step in self.steps:
for opt in step.options_descriptor.options:
seen.setdefault(opt.name, None)
return tuple(seen)
[docs]
def required_characteristic_options(self) -> tuple[str, ...]:
"""
Return the names of all *characteristic* options across all steps.
These are the options that are intrinsic to the characteristic and
should be broadcast to every step that declares them. They also
affect the cache key.
"""
seen: dict[str, None] = {}
for step in self.steps:
for opt in step.options_descriptor.characteristic_options:
seen.setdefault(opt.name, None)
return tuple(seen)
[docs]
def required_computation_options(self) -> tuple[str, ...]:
"""
Return the names of all *computation* options across all steps.
These are fitter-specific options that control numerical algorithms.
They do **not** affect the cache key.
"""
seen: dict[str, None] = {}
for step in self.steps:
for opt in step.options_descriptor.computation_options:
seen.setdefault(opt.name, None)
return tuple(seen)
[docs]
def with_options(self, step_index: int, **kwargs: Any) -> dict[int, ResolvedEdgeOptions]:
"""
Create a :data:`StepOptions` mapping with validated options for one step.
This is the recommended way to build the ``options`` parameter for
:meth:`ComputationStrategy.query_method`. Call it once per step
that needs non-default options and merge the results::
plan = distr.explain_computation_path("ppf")
opts = plan.with_options(0, tol=0.1) | plan.with_options(1, eps=1e-3)
ppf = distr.query_method("ppf", options=opts)
Option values are validated eagerly (type-cast + predicate check)
so errors surface here rather than deep inside the strategy.
Parameters
----------
step_index : int
0-based index into :attr:`steps`.
**kwargs : Any
Option values for that step.
Returns
-------
dict[int, ResolvedEdgeOptions]
A single-entry :data:`StepOptions` mapping that can be
merged with other such mappings via ``|``.
Raises
------
IndexError
If ``step_index`` is out of range.
TypeError
If a value cannot be cast to the declared type.
ValueError
If a value fails the option's validation predicate.
"""
if step_index < 0 or step_index >= len(self.steps):
raise IndexError(
f"step_index {step_index} out of range for plan with {len(self.steps)} steps."
)
resolved = self.steps[step_index].options_descriptor.with_values(**kwargs)
return {step_index: resolved}
# --------------------------------------------------------------------------- #
# Cached plan (internal — keeps actual edge / loop refs alongside ComputationPlan)
# --------------------------------------------------------------------------- #
@dataclass(frozen=True, slots=True)
class _CachedPlan:
"""
Internal companion to :class:`ComputationPlan` that retains references
to the actual graph primitives required for execution.
Attributes
----------
plan : ComputationPlan
Public representation of the plan.
loop_method : Method | None
Loop method when the plan resolves through a single self-loop;
``None`` for multi-edge conversion plans.
edges : tuple[ComputationEdgeMeta, ...]
Conversion edges along the plan. Empty when ``loop_method`` is set.
"""
plan: ComputationPlan
loop_method: Method[Any, Any] | None
edges: tuple[ComputationEdgeMeta, ...]
# --------------------------------------------------------------------------- #
# Helpers
# --------------------------------------------------------------------------- #
def _make_hashable(value: Any) -> Hashable:
"""
Convert a possibly-unhashable option value into a stable hashable key.
Lists / tuples become tuples of recursively hashable items. Dicts
and sets become sorted ``frozenset``s of hashable pairs / items.
Everything that is already hashable is returned as-is. As a last
resort the value is keyed by its ``repr`` so two equal objects
produce the same key while still keeping the cache safe.
"""
if isinstance(value, str | bytes):
return value
if isinstance(value, dict):
return frozenset((k, _make_hashable(v)) for k, v in value.items())
if isinstance(value, list | tuple):
return tuple(_make_hashable(v) for v in value)
if isinstance(value, set | frozenset):
return frozenset(_make_hashable(v) for v in value)
try:
hash(value)
except TypeError:
return repr(value)
return cast("Hashable", value)
def _freeze_options(resolved: Mapping[str, Any]) -> frozenset[tuple[str, Hashable]]:
"""Freeze a resolved-options mapping into a stable hashable key."""
return frozenset((name, _make_hashable(val)) for name, val in resolved.items())
def _resolve_option_group(
options: tuple[_BaseOption, ...],
per_step_values: Mapping[str, Any],
fallback_values: Mapping[str, Any],
*,
option_kind: str,
) -> dict[str, Any]:
"""Resolve one option group with per-step values taking precedence."""
resolved: dict[str, Any] = {}
for option in options:
if option.name in per_step_values:
raw = per_step_values[option.name]
elif option.name in fallback_values:
raw = fallback_values[option.name]
else:
raw = option.default
try:
value = option.type(raw)
except (TypeError, ValueError) as exc:
raise TypeError(
f"{option_kind} option '{option.name}': cannot convert "
f"{raw!r} to {option.type.__name__}"
) from exc
if option.validate is not None and not option.validate(value):
raise ValueError(
f"{option_kind} option '{option.name}': value {value!r} failed validation."
)
resolved[option.name] = value
return resolved
def _resolve_step_options(
edge: ComputationEdgeMeta,
step_idx: int,
step_options: StepOptions | None,
characteristic_options: Mapping[str, Any],
computation_defaults: Mapping[str, Any],
) -> tuple[dict[str, Any], dict[str, Any]]:
"""
Resolve characteristic and computation options for a single edge/step.
Resolution order
----------------
*Characteristic options* (affect cache key, broadcast across all steps):
1. ``step_options[step_idx].values`` for keys that are characteristic options
2. ``characteristic_options`` shared dict
3. Declared ``CharacteristicOption.default``
*Computation options* (do NOT affect cache key, fitter-specific):
1. ``step_options[step_idx].values`` for keys that are computation options
2. ``computation_defaults`` dict
3. Declared ``ComputationOption.default``
Parameters
----------
edge : ComputationEdgeMeta
The edge being resolved.
step_idx : int
0-based step index (used to look up per-step overrides).
step_options : StepOptions | None
Per-step caller overrides (keyed by step index).
characteristic_options : Mapping[str, Any]
Shared characteristic options broadcast to every step.
computation_defaults : Mapping[str, Any]
Strategy/call-level computation defaults (between hardcoded and per-step).
Returns
-------
tuple[dict[str, Any], dict[str, Any]]
``(char_resolved, comp_resolved)`` — characteristic and computation
option dicts respectively.
"""
descriptor = edge.options_descriptor
step_resolved = step_options.get(step_idx) if step_options else None
per_step_values: dict[str, Any] = step_resolved.values if step_resolved is not None else {}
char_resolved = _resolve_option_group(
descriptor.characteristic_options,
per_step_values,
characteristic_options,
option_kind="Characteristic",
)
comp_resolved = _resolve_option_group(
descriptor.computation_options,
per_step_values,
computation_defaults,
option_kind="Computation",
)
return char_resolved, comp_resolved
# --------------------------------------------------------------------------- #
# Strategy protocol & default implementation
# --------------------------------------------------------------------------- #
[docs]
class ComputationStrategy(Protocol):
"""
Protocol for strategies that resolve computation methods for characteristics.
Attributes
----------
enable_caching : bool
Whether to cache fitted computation methods.
"""
self,
state: GenericCharacteristicName,
distr: Distribution,
options: StepOptions | None = None,
*,
characteristic_options: Mapping[str, Any] | None = None,
computation_defaults: Mapping[str, Any] | None = None,
) -> Method[Any, Any]: ...
[docs]
def explain_computation_path(
self, state: GenericCharacteristicName, distr: Distribution
) -> ComputationPlan:
"""
Describe how this strategy will compute ``state`` for ``distr``.
Returns an :class:`ComputationPlan` that lists every step the
strategy will perform along with the option descriptors it will
consult at each step. Implementations are expected to *fix* the
returned plan internally so that a subsequent call to
:meth:`query_method` for the same ``(distr, state)`` follows
the same path even if the strategy makes non-deterministic
choices.
"""
...
[docs]
class DefaultComputationStrategy:
"""
Default strategy for resolving characteristic computation methods.
This strategy first checks for analytical implementations provided by
the distribution. If none exists, it walks the characteristic graph
to find a conversion path from an analytical characteristic to the
target characteristic.
Parameters
----------
enable_caching : bool, default=False
If True, cache fitted conversions to avoid repeated fitting.
computation_defaults : Mapping[str, Any] | None, default=None
Strategy-level defaults for computation options. These sit between
the hardcoded ``ComputationOption.default`` and any per-step caller
override. Resolution order (highest priority first):
1. Per-step caller override (``options`` argument to
:meth:`query_method`).
2. ``computation_defaults`` supplied here.
3. Hardcoded ``ComputationOption.default`` on the descriptor.
Example::
strategy = DefaultComputationStrategy(
enable_caching=True,
computation_defaults={"max_iter": 100, "limit": 50},
)
Attributes
----------
_enable_caching : bool
Whether caching is enabled.
_computation_defaults : dict[str, Any]
Strategy-level computation option defaults.
_cache : dict
Cache of fitted computation methods keyed by
``(distr_id, edge_id, target, frozen_all_options)`` so that
different option sets (both characteristic and computation) produce
independent cache entries. Characteristic options affect the
*meaning* of the result; computation options affect the *accuracy*
of the fitted callable — both must be part of the key.
_path_cache : dict
Cache of resolved execution plans keyed by ``(distr_id, target)``.
Lets repeated ``query_method`` calls reuse the path produced by a
previous ``explain_computation_path`` / ``query_method`` and keeps both methods in
sync for non-deterministic strategies.
_resolving : dict[int, set[str]]
Tracking of currently resolving characteristics to detect cycles.
_char_options_stack : list[dict[str, Any]]
Stack of characteristic-options dicts, one entry per active
query_method call. When a fitter on step N calls
distribution.query_method(intermediate) recursively, the strategy
picks up the characteristic options from the top of this stack so
they are propagated automatically without the fitter needing to
forward them explicitly.
"""
[docs]
def __init__(
self,
enable_caching: bool = False,
computation_defaults: Mapping[str, Any] | None = None,
) -> None:
self._enable_caching = enable_caching
self._computation_defaults: dict[str, Any] = dict(computation_defaults or {})
self._cache: dict[
tuple[int, int, GenericCharacteristicName, frozenset[tuple[str, Hashable]]],
FittedComputationMethod[Any, Any],
] = {}
self._path_cache: dict[tuple[int, GenericCharacteristicName], _CachedPlan] = {}
self._resolving: dict[int, set[GenericCharacteristicName]] = {}
self._char_options_stack: list[dict[str, Any]] = []
@property
def is_caching_enabled(self) -> bool:
return self._enable_caching
# ------------------------------------------------------------------ #
# Cycle detection helpers
# ------------------------------------------------------------------ #
def _push_guard(self, distr: Distribution, state: GenericCharacteristicName) -> None:
"""
Push a characteristic onto the resolution stack to detect cycles.
Raises
------
RuntimeError
If a cycle is detected during resolution.
"""
key = id(distr)
seen = self._resolving.setdefault(key, set())
if state in seen:
raise RuntimeError(
f"Cycle detected while resolving '{state}'. "
"Provide at least one analytical base characteristic in the distribution."
)
seen.add(state)
def _pop_guard(self, distr: Distribution, state: GenericCharacteristicName) -> None:
"""Pop a characteristic from the resolution stack."""
key = id(distr)
seen = self._resolving.get(key)
if seen is not None:
seen.discard(state)
if not seen:
self._resolving.pop(key, None)
# ------------------------------------------------------------------ #
# Method picking helpers
# ------------------------------------------------------------------ #
@staticmethod
def _pick_analytical_method(
state: GenericCharacteristicName,
methods: Mapping[LabelName, AnalyticalComputation[Any, Any]],
) -> AnalyticalComputation[Any, Any]:
"""
Pick the first available analytical method for a characteristic.
Raises
------
RuntimeError
If no labeled analytical methods are available for the characteristic.
"""
try:
return next(iter(methods.values()))
except StopIteration as exc:
raise RuntimeError(
f"Characteristic '{state}' provides no labeled analytical computations."
) from exc
@staticmethod
def _pick_loop_edge(
state: GenericCharacteristicName,
view: RegistryView,
) -> EdgeMeta | None:
"""
Pick the first available self-loop edge for a characteristic in a view.
"""
loops = view.variants(state, state)
if not loops:
return None
return next(iter(loops.values()))
@staticmethod
def _loop_method(loop_edge: EdgeMeta) -> Method[Any, Any]:
"""Return the loop-edge underlying method as a :data:`Method`."""
return cast(Method[Any, Any], loop_edge.method)
# ------------------------------------------------------------------ #
# Plan building (introspection + caching of paths)
# ------------------------------------------------------------------ #
@staticmethod
def _step_for_loop(state: GenericCharacteristicName, loop_edge: EdgeMeta) -> ComputationStep:
return ComputationStep(
target=state,
sources=(state,),
edge_kind=loop_edge.edge_kind(),
method_name=getattr(loop_edge.method, "target", state),
options_descriptor=EdgeOptionsDescriptor(),
)
@staticmethod
def _step_for_edge(edge: ComputationEdgeMeta) -> ComputationStep:
method = edge.method
return ComputationStep(
target=method.target,
sources=tuple(method.sources),
edge_kind=edge.edge_kind(),
method_name=edge.options_descriptor.name or method.target,
options_descriptor=edge.options_descriptor,
)
def _build_plan(self, distr: Distribution, state: GenericCharacteristicName) -> _CachedPlan:
"""
Resolve ``state`` against the registry and build a :class:`_CachedPlan`.
This method does **not** perform any fitting -- it only chooses the
loop / conversion path the strategy will use. The plan is cached
so subsequent ``query_method`` calls for the same ``(distr, state)``
follow exactly the same edges.
"""
cache_key = (id(distr), state)
cached = self._path_cache.get(cache_key)
if cached is not None:
return cached
if not distr.analytical_computations:
raise RuntimeError(
"Distribution provides no analytical computations to ground conversions."
)
registry = characteristic_registry()
# Non-registry characteristics: use the distribution-supplied analytical method.
if state not in registry.declared_characteristics:
if state in distr.analytical_computations:
method = self._pick_analytical_method(state, distr.analytical_computations[state])
step = ComputationStep(
target=state,
sources=(state,),
edge_kind="analytical_loop",
method_name=getattr(method, "target", state),
options_descriptor=EdgeOptionsDescriptor(),
)
plan = _CachedPlan(
plan=ComputationPlan(target=state, source=state, steps=(step,)),
loop_method=cast(Method[Any, Any], method),
edges=(),
)
self._path_cache[cache_key] = plan
return plan
raise RuntimeError(
f"Characteristic '{state}' is not declared in the registry and has no "
"analytical implementation in the distribution."
)
view = registry.view(distr)
# Direct self-loop hit first.
loop_edge = self._pick_loop_edge(state, view)
if loop_edge is not None:
step = self._step_for_loop(state, loop_edge)
plan = _CachedPlan(
plan=ComputationPlan(target=state, source=state, steps=(step,)),
loop_method=self._loop_method(loop_edge),
edges=(),
)
self._path_cache[cache_key] = plan
return plan
# Otherwise: try each loop characteristic as a source.
for src in distr.analytical_computations:
if not view.variants(src, src):
continue
path = view.find_path(src, state)
if not path:
continue
steps = tuple(self._step_for_edge(edge) for edge in path)
plan = _CachedPlan(
plan=ComputationPlan(target=state, source=src, steps=steps),
loop_method=None,
edges=tuple(path),
)
self._path_cache[cache_key] = plan
return plan
raise RuntimeError(
f"No conversion path from any characteristic in analytical_computations to '{state}'."
)
# ------------------------------------------------------------------ #
# Public API
# ------------------------------------------------------------------ #
[docs]
def explain_computation_path(
self, state: GenericCharacteristicName, distr: Distribution
) -> ComputationPlan:
"""
Describe and pin the plan that :meth:`query_method` will follow.
The returned :class:`ComputationPlan` lists every step (loop or
conversion edge) and the :class:`EdgeOptionsDescriptor` consulted at
that step. The plan is cached per ``(distr, state)`` so that a
subsequent :meth:`query_method` call goes through the very same
edges -- this matters for non-deterministic strategy variants
and is also what allows the second call to skip the BFS pass.
"""
return self._build_plan(distr, state).plan
[docs]
def query_method(
self,
state: GenericCharacteristicName,
distr: Distribution,
options: StepOptions | None = None,
*,
characteristic_options: Mapping[str, Any] | None = None,
computation_defaults: Mapping[str, Any] | None = None,
) -> Method[Any, Any]:
"""
Resolve a computation method for the target characteristic.
Resolution order:
1. Cached fitted method (if caching enabled)
2. Analytical implementation for non-registry characteristics
3. First self-loop from the registry view
4. Conversion path from loop characteristics via the graph
Parameters
----------
state : str
Target characteristic name (e.g., "pdf", "cdf").
distr : Distribution
Distribution to compute the characteristic for.
options : StepOptions | None, default=None
Per-step options built via
:meth:`ComputationPlan.with_options`. Each key is a 0-based
step index and each value is a :class:`ResolvedEdgeOptions`
produced by :meth:`EdgeOptionsDescriptor.with_values`.
When ``None``, every edge uses its declared defaults.
characteristic_options : Mapping[str, Any] | None, default=None
Shared characteristic options broadcast to **every step** that
declares a matching :class:`CharacteristicOption`. These are
intrinsic to the characteristic (e.g. ``eps``, ``x0`` for PPF)
and affect the *meaning* of the result. Per-step overrides in
``options`` take precedence over this dict; the dict takes
precedence over the hardcoded ``CharacteristicOption.default``.
computation_defaults : Mapping[str, Any] | None, default=None
Per-call computation option defaults. These override the
strategy-level ``computation_defaults`` set at construction time
and the hardcoded ``ComputationOption.default``, but are
overridden by per-step values in ``options``.
Returns
-------
Method
Callable that computes the characteristic.
Raises
------
RuntimeError
If no analytical base exists, no conversion path is found,
or a cycle is detected.
"""
cached_plan = self._build_plan(distr, state)
if cached_plan.loop_method is not None:
return cached_plan.loop_method
# Merge computation defaults: call-level overrides strategy-level.
effective_comp_defaults: dict[str, Any] = dict(self._computation_defaults)
if computation_defaults:
effective_comp_defaults.update(computation_defaults)
inherited_char_options: dict[str, Any] = dict(
self._char_options_stack[-1] if self._char_options_stack else {}
)
if characteristic_options:
inherited_char_options.update(characteristic_options)
effective_char_options: Mapping[str, Any] = inherited_char_options
self._push_guard(distr, state)
self._char_options_stack.append(dict(effective_char_options))
injected_keys: list[tuple[int, GenericCharacteristicName]] = []
try:
last_fitted: FittedComputationMethod[Any, Any] | None = None
for step_idx, edge in enumerate(cached_plan.edges):
method = edge.method
char_resolved, comp_resolved = _resolve_step_options(
edge,
step_idx,
options,
effective_char_options,
effective_comp_defaults,
)
all_resolved = {**char_resolved, **comp_resolved}
cache_key = (
id(distr),
id(edge),
method.target,
_freeze_options(all_resolved),
)
cached_fitted: FittedComputationMethod[Any, Any] | None = None
if self._enable_caching:
cached_fitted = self._cache.get(cache_key)
if cached_fitted is not None:
fitted = cached_fitted
else:
fitted = method.prepare(distr, **all_resolved)
if self._enable_caching and method.cacheable:
self._cache[cache_key] = fitted
last_fitted = fitted
# Expose the fitted result as a loop plan for the intermediate
# target so that fitters on subsequent edges can retrieve it
# via distribution.query_method(method.target).
intermediate_key = (id(distr), method.target)
self._path_cache[intermediate_key] = _CachedPlan(
plan=ComputationPlan(
target=method.target,
source=method.target,
steps=(
ComputationStep(
target=method.target,
sources=(method.target,),
edge_kind="analytical_loop",
method_name=method.target,
),
),
),
loop_method=fitted,
edges=(),
)
injected_keys.append(intermediate_key)
if last_fitted is None:
raise RuntimeError(f"Empty path when resolving '{state}'.")
return last_fitted
finally:
# Remove the temporary loop plans injected for intermediate targets.
# Placed in finally to ensure cleanup even if fitting raises.
for key in injected_keys:
self._path_cache.pop(key, None)
self._char_options_stack.pop()
self._pop_guard(distr, state)
[docs]
class SamplingStrategy(Protocol):
"""Protocol for strategies that generate samples from distributions."""
[docs]
def sample(self, n: int, distr: Distribution, **options: Any) -> NumericArray: ...