"""
Computation and Sampling Strategies
This module defines strategies for computing distribution characteristics
and generating random samples.
"""
from __future__ import annotations
__author__ = "Leonid Elkin, Mikhail Mikhailov"
__copyright__ = "Copyright (c) 2025 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"
from typing import TYPE_CHECKING, Protocol, cast
import numpy as np
from pysatl_core.distributions.registry import characteristic_registry
from pysatl_core.types import CharacteristicName, NumericArray
if TYPE_CHECKING:
from typing import Any
from pysatl_core.distributions.computation import AnalyticalComputation, FittedComputationMethod
from pysatl_core.distributions.distribution import Distribution
from pysatl_core.types import GenericCharacteristicName
type Method[In, Out] = AnalyticalComputation[In, Out] | FittedComputationMethod[In, Out]
[docs]
class ComputationStrategy[In, Out](Protocol):
"""
Protocol for strategies that resolve computation methods for characteristics.
Attributes
----------
enable_caching : bool
Whether to cache fitted computation methods.
"""
enable_caching: bool
self, state: GenericCharacteristicName, distr: Distribution, **options: Any
) -> Method[In, Out]: ...
[docs]
class DefaultComputationStrategy[In, Out]:
"""
Default strategy for resolving characteristic computation methods.
This strategy first checks for analytical implementations provided by
the distribution. If none exists, it walks the characteristic graph
to find a conversion path from an analytical characteristic to the
target characteristic.
Parameters
----------
enable_caching : bool, default=False
If True, cache fitted conversions to avoid repeated fitting.
Attributes
----------
enable_caching : bool
Whether caching is enabled.
_cache : dict[str, FittedComputationMethod]
Cache of fitted computation methods.
_resolving : dict[int, set[str]]
Tracking of currently resolving characteristics to detect cycles.
"""
[docs]
def __init__(self, enable_caching: bool = False) -> None:
self.enable_caching = enable_caching
self._cache: dict[GenericCharacteristicName, FittedComputationMethod[In, Out]] = {}
self._resolving: dict[int, set[GenericCharacteristicName]] = {}
def _push_guard(self, distr: Distribution, state: GenericCharacteristicName) -> None:
"""
Push a characteristic onto the resolution stack to detect cycles.
Raises
------
RuntimeError
If a cycle is detected during resolution.
"""
key = id(distr)
seen = self._resolving.setdefault(key, set())
if state in seen:
raise RuntimeError(
f"Cycle detected while resolving '{state}'. "
"Provide at least one analytical base characteristic in the distribution."
)
seen.add(state)
def _pop_guard(self, distr: Distribution, state: GenericCharacteristicName) -> None:
"""Pop a characteristic from the resolution stack."""
key = id(distr)
seen = self._resolving.get(key)
if seen is not None:
seen.discard(state)
if not seen:
self._resolving.pop(key, None)
[docs]
def query_method(
self, state: GenericCharacteristicName, distr: Distribution, **options: Any
) -> Method[In, Out]:
"""
Resolve a computation method for the target characteristic.
Resolution order:
1. Analytical implementation from the distribution
2. Cached fitted method (if caching enabled)
3. Conversion path from an analytical characteristic via the graph
Parameters
----------
state : str
Target characteristic name (e.g., "pdf", "cdf").
distr : Distribution
Distribution to compute the characteristic for.
**options : Any
Additional options passed to fitters.
Returns
-------
Method
Callable that computes the characteristic.
Raises
------
RuntimeError
If no analytical base exists, no conversion path is found,
or a cycle is detected.
"""
# 1. Check for analytical implementation
if state in distr.analytical_computations:
return distr.analytical_computations[state]
# 2. Check cache if enabled
if self.enable_caching:
cached = self._cache.get(state)
if cached is not None:
return cached
# 3. Require at least one analytical characteristic
if not distr.analytical_computations:
raise RuntimeError(
"Distribution provides no analytical computations to ground conversions."
)
# 4. Get filtered graph view for this distribution
reg = characteristic_registry().view(distr)
self._push_guard(distr, state)
try:
# 5. Try each analytical characteristic as a source
for src in distr.analytical_computations:
if src == state:
return distr.analytical_computations[src]
# Find conversion path in the graph
path = reg.find_path(src, state)
if not path:
continue
# Fit each edge along the path
last_fitted: FittedComputationMethod[In, Out] | None = None
for edge in path:
fitted = edge.fit(distr, **options)
if self.enable_caching:
self._cache[edge.target] = fitted
last_fitted = fitted
if last_fitted is None:
raise RuntimeError(f"Empty path when resolving '{state}' from '{src}'.")
return last_fitted
raise RuntimeError(
f"No conversion path from any analytical characteristic to '{state}'."
)
finally:
self._pop_guard(distr, state)
[docs]
class SamplingStrategy(Protocol):
"""Protocol for strategies that generate samples from distributions."""
[docs]
def sample(self, n: int, distr: Distribution, **options: Any) -> NumericArray: ...
[docs]
class DefaultSamplingUnivariateStrategy(SamplingStrategy):
"""
Default univariate sampler based on inverse transform sampling.
This strategy generates samples by applying the PPF (inverse CDF)
to uniformly distributed random variables.
Notes
-----
- Requires the distribution to provide a PPF computation method.
- Assumes that the PPF follows NumPy semantics (vectorized evaluation).
- Graph-derived PPFs (scalar-only) are currently not supported.
- Returns a NumPy array containing the generated samples.
"""
[docs]
def sample(self, n: int, distr: Distribution, **options: Any) -> NumericArray:
"""
Generate samples from the distribution.
Parameters
----------
n : int
Number of samples to generate.
distr : Distribution
Distribution to sample from.
**options : Any
Additional options forwarded to the PPF computation.
Returns
-------
NumericArray
NumPy array containing ``n`` generated samples.
The exact array shape depends on the distribution and sampling strategy.
"""
ppf = distr.query_method(CharacteristicName.PPF, **options)
rng = np.random.default_rng()
U = rng.random(n)
# TODO: Now it will be based on the fact that the characteristic
# has NumPy semantics (It is much more faster), that is,
# it will not work with the graph computed characteristics currently.
samples = ppf(U)
return cast(NumericArray, samples)