Source code for pysatl_core.families.parametric_family

"""
Parametric family definitions and management infrastructure.

This module contains the main class for defining parametric families of
distributions, including support for multiple parameterizations, distribution
characteristics, sampling strategies, and computation methods.
"""

from __future__ import annotations

__author__ = "Leonid Elkin, Mikhail Mikhailov, Fedor Myznikov"
__copyright__ = "Copyright (c) 2025 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"


from functools import partial
from typing import TYPE_CHECKING, dataclass_transform

from pysatl_core.distributions.computation import AnalyticalComputation
from pysatl_core.distributions.strategies import (
    DefaultComputationStrategy,
    DefaultSamplingUnivariateStrategy,
)
from pysatl_core.families.distribution import ParametricFamilyDistribution
from pysatl_core.types import DistributionType

if TYPE_CHECKING:
    from collections.abc import Callable
    from typing import Any

    from pysatl_core.distributions.strategies import ComputationStrategy, SamplingStrategy
    from pysatl_core.distributions.support import Support
    from pysatl_core.families.parametrizations import (
        Parametrization,
    )
    from pysatl_core.types import (
        GenericCharacteristicName,
        ParametrizationName,
    )

    type ParametrizedFunction = Callable[[Parametrization, Any], Any]
    type SupportArg = Callable[[Parametrization], Support | None] | None
    type SupportResolver = Callable[[Parametrization], Support | None]


[docs] class ParametricFamily: """ A family of distributions with multiple parametrizations. Represents a parametric family of distributions (e.g., normal, lognormal) that can be parameterized in different ways. Manages parametrizations, distribution characteristics, and provides factory methods for creating distribution instances. Parameters ---------- name : str Name of the distribution family. distr_type : DistributionType or Callable[[Parametrization], DistributionType] Distribution type or function that infers type from base parametrization. distr_parametrizations : list[ParametrizationName] List of parametrization names (first is base parametrization). distr_characteristics : dict[str, dict[str, Callable] or Callable] Mapping from characteristic names to computation functions. Single functions are treated as defined for the base parametrization. sampling_strategy : SamplingStrategy, optional Strategy for sampling from distributions. computation_strategy : ComputationStrategy, optional Strategy for computing distribution characteristics. support_by_parametrization : Callable or None, optional Function that returns support for given parameters. """
[docs] def __init__( self, name: str, distr_type: DistributionType | Callable[[Parametrization], DistributionType], distr_parametrizations: list[ParametrizationName], distr_characteristics: dict[ GenericCharacteristicName, dict[ParametrizationName, ParametrizedFunction] | ParametrizedFunction, ], sampling_strategy: SamplingStrategy | None = None, computation_strategy: ComputationStrategy[Any, Any] | None = None, support_by_parametrization: SupportArg = None, ): self._name = name self._distr_type: Callable[[Parametrization], DistributionType] = ( (lambda params: distr_type) if isinstance(distr_type, DistributionType) else distr_type ) self.computation_strategy = ( DefaultComputationStrategy() if computation_strategy is None else computation_strategy ) if support_by_parametrization is None: self._support_resolver: SupportResolver self._support_resolver = lambda _params: None else: self._support_resolver = support_by_parametrization # Ordered names; the first one is the base parametrization name self.parametrization_names: list[ParametrizationName] = distr_parametrizations self.base_parametrization_name: ParametrizationName = self.parametrization_names[0] # Runtime registry of parametrization classes self._parametrizations: dict[ParametrizationName, type[Parametrization]] = {} self.sampling_strategy = ( DefaultSamplingUnivariateStrategy() if sampling_strategy is None else sampling_strategy ) def _process_char_val( value: dict[ParametrizationName, ParametrizedFunction] | ParametrizedFunction, ) -> dict[ParametrizationName, ParametrizedFunction]: return value if isinstance(value, dict) else {self.parametrization_names[0]: value} self.distr_characteristics: dict[ GenericCharacteristicName, dict[ParametrizationName, ParametrizedFunction] ] = {key: _process_char_val(val) for key, val in distr_characteristics.items()} # Precompute analytical plan self._analytical_plan: dict[ ParametrizationName, dict[GenericCharacteristicName, ParametrizationName] ] = {} base_name = self.base_parametrization_name for pname in self.parametrization_names: plan_for_p: dict[GenericCharacteristicName, ParametrizationName] = {} for characteristic, forms in self.distr_characteristics.items(): if pname in forms: plan_for_p[characteristic] = pname elif base_name in forms: plan_for_p[characteristic] = base_name self._analytical_plan[pname] = plan_for_p
@property def name(self) -> str: """Get the family name.""" return self._name @property def parametrizations(self) -> dict[ParametrizationName, type[Parametrization]]: """Get mapping from parametrization names to classes.""" return self._parametrizations @property def base(self) -> type[Parametrization]: """ Get the base parametrization class. Raises ------ ValueError If base parametrization is not registered. """ try: return self._parametrizations[self.base_parametrization_name] except KeyError as exc: raise ValueError( f"Base parametrization '{self.base_parametrization_name}' is not registered." ) from exc @property def support_resolver(self) -> SupportResolver: """Get the support resolver function.""" return self._support_resolver
[docs] def register_parametrization( self, name: ParametrizationName, parametrization_class: type[Parametrization], ) -> None: """ Register a parametrization class. Parameters ---------- name : ParametrizationName Unique parametrization name. parametrization_class : type[Parametrization] Parametrization class to register. Raises ------ ValueError If name is already registered. """ if name in self._parametrizations: raise ValueError(f"Parametrization '{name}' is already registered.") self._parametrizations[name] = parametrization_class
[docs] def get_parametrization(self, name: ParametrizationName) -> type[Parametrization]: """ Fetch a parametrization class by name. Raises ------ KeyError If name is not registered. """ return self._parametrizations[name]
[docs] def to_base(self, parameters: Parametrization) -> Parametrization: """ Convert parameters to the base parametrization. Parameters ---------- parameters : Parametrization Parameters in any parametrization. Returns ------- Parametrization Equivalent parameters in base parametrization. """ if parameters.name == self.base_parametrization_name: return parameters return parameters.transform_to_base_parametrization()
def _build_analytical_computations( self, parameters: Parametrization ) -> dict[GenericCharacteristicName, AnalyticalComputation[Any, Any]]: """ Build analytical computations for given parameters. Uses precomputed provider plan for efficient computation. """ plan = self._analytical_plan.get(parameters.name, {}) result: dict[GenericCharacteristicName, AnalyticalComputation[Any, Any]] = {} base_params: Parametrization | None = None for characteristic, provider_name in plan.items(): if provider_name == parameters.name: params_obj = parameters else: if base_params is None: base_params = self.to_base(parameters) params_obj = base_params func_factory = self.distr_characteristics[characteristic][provider_name] result[characteristic] = AnalyticalComputation( target=characteristic, func=partial(func_factory, params_obj), ) return result
[docs] def distribution( self, parametrization_name: str | None = None, **parameters_values: Any, ) -> ParametricFamilyDistribution: """ Create a distribution instance with given parameters. Parameters ---------- parametrization_name : str, optional Name of parametrization to use (defaults to base). **parameters_values Parameter values for the distribution. Returns ------- ParametricFamilyDistribution Distribution instance with specified parameters. Raises ------ KeyError If parametrization name is not registered. ValueError If parameters don't satisfy constraints. """ if parametrization_name is None: parametrization_class = self.base else: parametrization_class = self._parametrizations[parametrization_name] parameters = parametrization_class(**parameters_values) parameters.validate() base_parameters = self.to_base(parameters) distribution_type = self._distr_type(base_parameters) return ParametricFamilyDistribution( self.name, distribution_type, parameters, self.support_resolver(parameters) )
[docs] @dataclass_transform() def parametrization( self, *, name: str ) -> Callable[[type[Parametrization]], type[Parametrization]]: """ Create a class decorator that registers a parametrization. If you want to use this syntax and so that Mypy doesn't swear, you should mark your class as a dataclass. At the moment, Mypy cannot identify dataclass_transform if the decorator is a class method. Parameters ---------- name : str Name of the parametrization. Returns ------- Callable[[type[Parametrization]], type[Parametrization]] Class decorator for registering parametrizations. """ from pysatl_core.families.parametrizations import parametrization as _param_deco return _param_deco(family=self, name=name)
__call__ = distribution