Source code for pysatl_core.transformations.distribution

"""
Base classes for transformed distributions.

This module introduces the first architectural layer for derived
probability distributions produced by transformations. The goal is to
keep them fully compatible with the existing :class:`Distribution`
protocol and computation graph while still preserving transformation
metadata.
"""

from __future__ import annotations

__author__ = "Leonid Elkin"
__copyright__ = "Copyright (c) 2025 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"

from abc import abstractmethod
from collections.abc import Callable, Mapping
from typing import TYPE_CHECKING, Any

from pysatl_core.distributions.computations.computation import AnalyticalComputation
from pysatl_core.distributions.distribution import _KEEP, Distribution
from pysatl_core.distributions.strategies import (
    ComputationStrategy,
    SamplingStrategy,
)
from pysatl_core.transformations.lightweight_distribution import LightweightDistribution
from pysatl_core.transformations.operators_mixin import TransformationOperatorsMixin
from pysatl_core.transformations.transformation_method import TransformationMethod
from pysatl_core.types import (
    DistributionType,
    GenericCharacteristicName,
    LabelName,
    ParentRole,
    TransformationMethodSpecsMap,
    TransformationName,
)

if TYPE_CHECKING:
    from pysatl_core.distributions.support import Support
    from pysatl_core.transformations.approximations.approximation import (
        CharacteristicApproximationMethod,
    )


[docs] class DerivedDistribution(TransformationOperatorsMixin, Distribution): """ Base class for distributions obtained from one or more parents. Parameters ---------- distribution_type : DistributionType Type descriptor of the derived distribution. bases : Mapping[ParentRole, Distribution] Parent distributions participating in the transformation. Internally, they are stored as lightweight snapshots to avoid retaining full parent distribution objects. analytical_computations : Mapping[ GenericCharacteristicName, ( AnalyticalComputation[Any, Any] | Mapping[LabelName, AnalyticalComputation[Any, Any]] ), ] Derived characteristic methods exposed by the transformation. Presence here means that at least one ancestor in the derivation chain is analytical. transformation_name : TransformationName Logical name of the transformation. support : Support | None, optional Support of the transformed distribution. sampling_strategy : SamplingStrategy | None, optional Strategy used to generate random samples. computation_strategy : ComputationStrategy | None, optional Strategy used to resolve characteristics. loop_analytical_flags : Mapping[ GenericCharacteristicName, Mapping[LabelName, bool], ] | None, optional Optional graph flags for loop analytical status. A loop has ``is_analytical=True`` only when all required ancestors of that characteristic are analytical. """
[docs] def __init__( self, *, distribution_type: DistributionType, bases: Mapping[ParentRole, Distribution], analytical_computations: Mapping[ GenericCharacteristicName, (AnalyticalComputation[Any, Any] | Mapping[LabelName, AnalyticalComputation[Any, Any]]), ], transformation_name: TransformationName, support: Support | None = None, sampling_strategy: SamplingStrategy | None = None, computation_strategy: ComputationStrategy | None = None, loop_analytical_flags: ( Mapping[GenericCharacteristicName, Mapping[LabelName, bool]] | None ) = None, ) -> None: self._loop_analytical_flags = { characteristic_name: dict(flags_by_label) for characteristic_name, flags_by_label in (loop_analytical_flags or {}).items() } super().__init__( distribution_type=distribution_type, analytical_computations=analytical_computations, support=support, sampling_strategy=sampling_strategy, computation_strategy=computation_strategy, ) self._bases = { role: LightweightDistribution.from_distribution(base) for role, base in bases.items() } self._transformation_name = transformation_name
@property def bases(self) -> Mapping[ParentRole, Distribution]: """Get parent distributions grouped by their logical roles.""" return self._bases @property @abstractmethod def parent_roles(self) -> tuple[ParentRole, ...]: """Get deterministic parent roles used by this transformed distribution.""" @property def transformation_name(self) -> TransformationName: """Get the logical name of the transformation.""" return self._transformation_name
[docs] def loop_is_analytical( self, characteristic_name: GenericCharacteristicName, label_name: LabelName, ) -> bool: """ Return transformation-aware analytical flag for a loop method. The method returns ``True`` only when all required predecessors are analytical. A method can still be present in ``analytical_computations`` when this returns ``False``. """ return self._loop_analytical_flags.get(characteristic_name, {}).get(label_name, True)
@staticmethod def _resolve_transformation_methods( *, methods: TransformationMethodSpecsMap | None, default_methods: TransformationMethodSpecsMap, ) -> TransformationMethodSpecsMap: """Resolve transformation methods from user input or defaults.""" return default_methods if methods is None else methods def _build_transformation_analytical_computations( self, *, transformation_name: TransformationName, bases: Mapping[ParentRole, LightweightDistribution], methods: TransformationMethodSpecsMap, source_validator: Callable[[ParentRole, GenericCharacteristicName], bool] | None = None, ) -> tuple[ Mapping[GenericCharacteristicName, Mapping[LabelName, TransformationMethod[Any, Any]]], Mapping[GenericCharacteristicName, Mapping[LabelName, bool]], ]: """ Build analytical computations from transformation method specifications. """ computations: dict[ GenericCharacteristicName, dict[LabelName, TransformationMethod[Any, Any]] ] = {} loop_analytical_flags: dict[GenericCharacteristicName, dict[LabelName, bool]] = {} for target, labeled_specs in methods.items(): for label, (source_requirements_resolver, evaluator) in labeled_specs.items(): source_requirements = ( source_requirements_resolver(self) if callable(source_requirements_resolver) else source_requirements_resolver ) if source_validator is not None and any( not source_validator(role, characteristic) for role, characteristics in source_requirements.items() for characteristic in characteristics ): continue try: method = TransformationMethod( target=target, transformation=transformation_name, bases=bases, source_requirements=source_requirements, evaluator=evaluator, owner=self, ) except ValueError: continue computations.setdefault(target, {})[label] = method loop_analytical_flags.setdefault(target, {})[label] = method.is_analytical return computations, loop_analytical_flags
[docs] def approximate( self, methods: Mapping[GenericCharacteristicName, CharacteristicApproximationMethod], **options: Any, ) -> ApproximatedDistribution: """ Approximate selected characteristics of the current derivation. Parameters ---------- methods : Mapping[GenericCharacteristicName, CharacteristicApproximationMethod] Mapping from characteristic names to characteristic-level approximation methods. **options : Any Extra options forwarded to each approximation method. Returns ------- ApproximatedDistribution Distribution with materialized approximations for selected characteristics. """ if not methods: raise ValueError("At least one characteristic approximation method must be provided.") analytical_computations: dict[ GenericCharacteristicName, AnalyticalComputation[Any, Any] ] = {} for characteristic_name, method in methods.items(): computation = method.approximate( self, **options, ) if computation.target != characteristic_name: raise ValueError( "Approximation method returned computation for a mismatched " f"target: expected '{characteristic_name}', got '{computation.target}'." ) analytical_computations[characteristic_name] = computation return ApproximatedDistribution( distribution_type=self.distribution_type, analytical_computations=analytical_computations, support=self.support, sampling_strategy=self.sampling_strategy, computation_strategy=self.computation_strategy, )
@abstractmethod def _clone_with_strategies( self, *, sampling_strategy: SamplingStrategy | None | object = _KEEP, computation_strategy: ComputationStrategy | None | object = _KEEP, ) -> DerivedDistribution: """ Return a copy of the derived distribution with updated strategies. Concrete subclasses must preserve their own transformation parameters while applying strategy overrides. """ _ = sampling_strategy, computation_strategy raise NotImplementedError
[docs] class ApproximatedDistribution(DerivedDistribution): """ Derived distribution whose analytical computations were materialized by an external approximator. Parameters ---------- distribution_type : DistributionType Type descriptor of the approximated distribution. analytical_computations : Mapping[ GenericCharacteristicName, ( AnalyticalComputation[Any, Any] | Mapping[LabelName, AnalyticalComputation[Any, Any]] ), ] Materialized methods produced by the approximator. They are exposed in ``analytical_computations`` for strategy resolution, but are never treated as fully analytical. support : Support | None, optional Support of the approximated distribution. sampling_strategy : SamplingStrategy | None, optional Sampling strategy to expose. computation_strategy : ComputationStrategy | None, optional Characteristic resolution strategy. """
[docs] def __init__( self, *, distribution_type: DistributionType, analytical_computations: Mapping[ GenericCharacteristicName, (AnalyticalComputation[Any, Any] | Mapping[LabelName, AnalyticalComputation[Any, Any]]), ], support: Support | None = None, sampling_strategy: SamplingStrategy | None = None, computation_strategy: ComputationStrategy | None = None, ) -> None: super().__init__( distribution_type=distribution_type, bases={}, analytical_computations=analytical_computations, transformation_name=TransformationName.APPROXIMATION, support=support, sampling_strategy=sampling_strategy, computation_strategy=computation_strategy, loop_analytical_flags={}, ) self._loop_analytical_flags = self._build_non_analytical_loop_flags( self.analytical_computations )
@staticmethod def _build_non_analytical_loop_flags( analytical_computations: Mapping[ GenericCharacteristicName, Mapping[LabelName, AnalyticalComputation[Any, Any]], ], ) -> dict[GenericCharacteristicName, dict[LabelName, bool]]: """Build loop flags where every approximation loop is non-analytical.""" return { characteristic_name: dict.fromkeys(labeled_methods, False) for characteristic_name, labeled_methods in analytical_computations.items() } @property def parent_roles(self) -> tuple[ParentRole, ...]: """Return empty parent roles for approximation-based derived distributions.""" return () def _clone_with_strategies( self, *, sampling_strategy: SamplingStrategy | None | object = _KEEP, computation_strategy: ComputationStrategy | None | object = _KEEP, ) -> ApproximatedDistribution: """Return a copy of the approximated distribution with updated strategies.""" return ApproximatedDistribution( distribution_type=self.distribution_type, analytical_computations=self.analytical_computations, support=self.support, sampling_strategy=self._new_sampling_strategy(sampling_strategy), computation_strategy=self._new_computation_strategy(computation_strategy), )
__all__ = [ "ApproximatedDistribution", "DerivedDistribution", "LightweightDistribution", ]