Source code for pysatl_core.distributions.strategies

"""
Computation and Sampling Strategies

This module defines strategies for computing distribution characteristics
and generating random samples.
"""

from __future__ import annotations

__author__ = "Leonid Elkin, Mikhail Mikhailov"
__copyright__ = "Copyright (c) 2025 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"

from typing import TYPE_CHECKING, Any, Protocol, cast

from pysatl_core.distributions.registry import characteristic_registry
from pysatl_core.types import Method, NumericArray

if TYPE_CHECKING:
    from collections.abc import Mapping

    from pysatl_core.distributions.computation import (
        AnalyticalComputation,
        FittedComputationMethod,
    )
    from pysatl_core.distributions.distribution import Distribution
    from pysatl_core.distributions.registry.graph import RegistryView
    from pysatl_core.types import GenericCharacteristicName, LabelName


[docs] class ComputationStrategy(Protocol): """ Protocol for strategies that resolve computation methods for characteristics. Attributes ---------- enable_caching : bool Whether to cache fitted computation methods. """
[docs] def query_method(
self, state: GenericCharacteristicName, distr: Distribution, **options: Any ) -> Method[Any, Any]: ...
[docs] class DefaultComputationStrategy: """ Default strategy for resolving characteristic computation methods. This strategy first checks for analytical implementations provided by the distribution. If none exists, it walks the characteristic graph to find a conversion path from an analytical characteristic to the target characteristic. Parameters ---------- enable_caching : bool, default=False If True, cache fitted conversions to avoid repeated fitting. Attributes ---------- _enable_caching : bool Whether caching is enabled. _cache : dict[str, FittedComputationMethod] Cache of fitted computation methods. _resolving : dict[int, set[str]] Tracking of currently resolving characteristics to detect cycles. """
[docs] def __init__(self, enable_caching: bool = False) -> None: self._enable_caching = enable_caching self._cache: dict[GenericCharacteristicName, FittedComputationMethod[Any, Any]] = {} self._resolving: dict[int, set[GenericCharacteristicName]] = {}
@property def is_caching_enabled(self) -> bool: return self._enable_caching def _push_guard(self, distr: Distribution, state: GenericCharacteristicName) -> None: """ Push a characteristic onto the resolution stack to detect cycles. Raises ------ RuntimeError If a cycle is detected during resolution. """ key = id(distr) seen = self._resolving.setdefault(key, set()) if state in seen: raise RuntimeError( f"Cycle detected while resolving '{state}'. " "Provide at least one analytical base characteristic in the distribution." ) seen.add(state) def _pop_guard(self, distr: Distribution, state: GenericCharacteristicName) -> None: """Pop a characteristic from the resolution stack.""" key = id(distr) seen = self._resolving.get(key) if seen is not None: seen.discard(state) if not seen: self._resolving.pop(key, None) @staticmethod def _pick_analytical_method( state: GenericCharacteristicName, methods: Mapping[LabelName, AnalyticalComputation[Any, Any]], ) -> AnalyticalComputation[Any, Any]: """ Pick the first available analytical method for a characteristic. Raises ------ RuntimeError If no labeled analytical methods are available for the characteristic. """ try: return next(iter(methods.values())) except StopIteration as exc: raise RuntimeError( f"Characteristic '{state}' provides no labeled analytical computations." ) from exc @staticmethod def _pick_analytical_loop_method( state: GenericCharacteristicName, view: RegistryView, ) -> Method[Any, Any] | None: """ Pick the first analytical self-loop method for a characteristic in a view. """ loops = view.analytical_variants(state) if not loops: return None return cast(Method[Any, Any], next(iter(loops.values())).method)
[docs] def query_method( self, state: GenericCharacteristicName, distr: Distribution, **options: Any ) -> Method[Any, Any]: """ Resolve a computation method for the target characteristic. Resolution order: 1. Cached fitted method (if caching enabled) 2. Analytical implementation for non-registry characteristics 3. Analytical self-loop from the registry view 4. Conversion path from analytical-loop characteristics via the graph Parameters ---------- state : str Target characteristic name (e.g., "pdf", "cdf"). distr : Distribution Distribution to compute the characteristic for. **options : Any Additional options passed to fitters. Returns ------- Method Callable that computes the characteristic. Raises ------ RuntimeError If no analytical base exists, no conversion path is found, or a cycle is detected. """ # 1. Check cache if enabled if self._enable_caching: cached = self._cache.get(state) if cached is not None: return cached # 2. Require at least one analytical characteristic if not distr.analytical_computations: raise RuntimeError( "Distribution provides no analytical computations to ground conversions." ) # 3. Non-registry characteristics are resolved directly. # It covers the situation where user is providing their analytical computation which isn't # in the graph registry = characteristic_registry() if state not in registry.declared_characteristics: if state in distr.analytical_computations: return self._pick_analytical_method(state, distr.analytical_computations[state]) raise RuntimeError( f"Characteristic '{state}' is not declared in the registry and has no " "analytical implementation in the distribution." ) # 4. Get filtered graph view for this distribution. view = registry.view(distr) self._push_guard(distr, state) try: loop_method = self._pick_analytical_loop_method(state, view) if loop_method is not None: return loop_method # 5. Try each analytical-loop characteristic as a source for src in distr.analytical_computations: if not view.analytical_variants(src): continue # Find conversion path in the graph path = view.find_path(src, state) if not path: continue # Fit each edge along the path last_fitted: FittedComputationMethod[Any, Any] | None = None for edge in path: fitted = edge.prepare(distr, **options) if self._enable_caching and edge.cacheable: self._cache[edge.target] = fitted last_fitted = fitted if last_fitted is None: raise RuntimeError(f"Empty path when resolving '{state}' from '{src}'.") return last_fitted raise RuntimeError( f"No conversion path from any analytical characteristic to '{state}'." ) finally: self._pop_guard(distr, state)
[docs] class SamplingStrategy(Protocol): """Protocol for strategies that generate samples from distributions."""
[docs] def sample(self, n: int, distr: Distribution, **options: Any) -> NumericArray: ...