Source code for pysatl_core.families.parametric_family

"""
Parametric family definitions and management infrastructure.

This module contains the main class for defining parametric families of
distributions, including support for multiple parameterizations, distribution
characteristics, sampling strategies, and computation methods.
"""

from __future__ import annotations

__author__ = "Leonid Elkin, Mikhail Mikhailov, Fedor Myznikov"
__copyright__ = "Copyright (c) 2025 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"

import inspect
from collections.abc import Mapping
from functools import partial
from typing import TYPE_CHECKING, Any, cast, dataclass_transform

from pysatl_core.distributions.computation import AnalyticalComputation
from pysatl_core.families.distribution import ParametricFamilyDistribution
from pysatl_core.types import (
    DEFAULT_ANALYTICAL_COMPUTATION_LABEL,
    ComputationFunc,
    DistributionType,
)

if TYPE_CHECKING:
    from collections.abc import Callable

    from pysatl_core.distributions.strategies import ComputationStrategy, SamplingStrategy
    from pysatl_core.distributions.support import Support
    from pysatl_core.families.parametrizations import (
        Parametrization,
    )
    from pysatl_core.types import (
        GenericCharacteristicName,
        LabelName,
        ParametrizationName,
    )

    type SupportArg = Callable[[Parametrization], Support | None] | None
    type SupportResolver = Callable[[Parametrization], Support | None]
    type LabeledCharacteristicProvider = (
        Mapping[LabelName, ParametricFamilyCharacteristic[Any, Any]]
        | ParametricFamilyCharacteristic[Any, Any]
    )
    type CharacteristicProvider = (
        Mapping[ParametrizationName, LabeledCharacteristicProvider]
        | ParametricFamilyCharacteristic[Any, Any]
    )
    type CharacteristicsMap = Mapping[GenericCharacteristicName, CharacteristicProvider]
    type NonParametrizedCharacteristic[In, Out] = Callable[[], Out]
    type ParametricFamilyCharacteristic[In, Out] = (
        NonParametrizedCharacteristic[In, Out] | ParametrizedCharacteristic[In, Out]
    )


type ParametrizedCharacteristic[In, Out] = (
    Callable[[Parametrization, In], Out] | Callable[[Parametrization], Out]
)


[docs] class ParametricFamily: """ A family of distributions with multiple parametrizations. Represents a parametric family of distributions (e.g., normal, lognormal) that can be parameterized in different ways. Manages parametrizations, distribution characteristics, and provides factory methods for creating distribution instances. Parameters ---------- name : str Name of the distribution family. distr_type : DistributionType or Callable[[Parametrization], DistributionType] Distribution type or function that infers type from base parametrization. distr_parametrizations : list[ParametrizationName] List of parametrization names (first is base parametrization). distr_characteristics : CharacteristicsMap Mapping from characteristic names to analytical provider callables. Each provider callable may accept a parametrization instance as the first argument. The remaining signature is characteristic-specific: - nullary characteristics (e.g., mean, var): provider(params, **kwargs) -> Any - pointwise characteristics (e.g., pdf, cdf, ppf): provider(params, x, **kwargs) -> Any Providers are grouped by parametrization and may define multiple labeled methods. If a single callable is provided, it is treated as the base-parametrization method under ``DEFAULT_ANALYTICAL_COMPUTATION_LABEL``. support_by_parametrization : Callable or None, optional Function that returns support for given parameters. """
[docs] def __init__( self, name: str, distr_type: DistributionType | Callable[[Parametrization], DistributionType], distr_parametrizations: list[ParametrizationName], distr_characteristics: CharacteristicsMap, support_by_parametrization: SupportArg = None, ): if not distr_parametrizations: raise ValueError( "distr_parametrizations must be non-empty (base parametrization is required)." ) self._name = name # Ordered names; the first one is the base parametrization name self.parametrization_names = distr_parametrizations self.base_parametrization_name = self.parametrization_names[0] self._distr_type: Callable[[Parametrization], DistributionType] = ( (lambda params: distr_type) if isinstance(distr_type, DistributionType) else distr_type ) self._support_resolver: SupportResolver = support_by_parametrization or (lambda _p: None) # Runtime registry of parametrization classes self._parametrizations: dict[ParametrizationName, type[Parametrization]] = {} def _normalize_labeled_provider( characteristic_name: GenericCharacteristicName, parametrization_name: ParametrizationName, provider: LabeledCharacteristicProvider, ) -> dict[LabelName, ParametricFamilyCharacteristic[Any, Any]]: normalized = ( dict(provider) if isinstance(provider, Mapping) else {DEFAULT_ANALYTICAL_COMPUTATION_LABEL: provider} ) if not normalized: raise ValueError( f"Characteristic '{characteristic_name}' has no labeled providers for " f"parametrization '{parametrization_name}'." ) return normalized def _normalize_characteristic( characteristic_name: GenericCharacteristicName, value: CharacteristicProvider, ) -> dict[ParametrizationName, dict[LabelName, ParametricFamilyCharacteristic[Any, Any]]]: if not isinstance(value, Mapping): base_name = self.base_parametrization_name return { base_name: _normalize_labeled_provider(characteristic_name, base_name, value) } normalized_by_parametrization: dict[ ParametrizationName, dict[LabelName, ParametricFamilyCharacteristic[Any, Any]] ] = {} for parametrization_name, provider in value.items(): normalized_by_parametrization[parametrization_name] = _normalize_labeled_provider( characteristic_name, parametrization_name, provider, ) return normalized_by_parametrization self.distr_characteristics: dict[ GenericCharacteristicName, dict[ParametrizationName, dict[LabelName, ParametricFamilyCharacteristic[Any, Any]]], ] = { characteristic_name: _normalize_characteristic(characteristic_name, provider) for characteristic_name, provider in distr_characteristics.items() } # Validate characteristic providers valid_names = set(self.parametrization_names) for char_name, forms in self.distr_characteristics.items(): if not forms: raise ValueError(f"Characteristic '{char_name}' has no providers.") unknown = set(forms) - valid_names if unknown: raise ValueError( f"Characteristic '{char_name}' has providers for unknown parametrizations: " f"{sorted(unknown)}." ) # Precompute analytical plan: for each parametrization pick provider (self or base) self._analytical_plan: dict[ ParametrizationName, dict[GenericCharacteristicName, ParametrizationName] ] = {} base = self.base_parametrization_name for pname in self.parametrization_names: plan: dict[GenericCharacteristicName, ParametrizationName] = {} for characteristic, forms in self.distr_characteristics.items(): if pname in forms: plan[characteristic] = pname elif base in forms: plan[characteristic] = base self._analytical_plan[pname] = plan
@property def name(self) -> str: """Get the family name.""" return self._name @property def parametrizations(self) -> dict[ParametrizationName, type[Parametrization]]: """Get mapping from parametrization names to classes.""" return self._parametrizations @property def base(self) -> type[Parametrization]: """ Get the base parametrization class. Raises ------ ValueError If base parametrization is not registered. """ try: return self._parametrizations[self.base_parametrization_name] except KeyError as exc: raise ValueError( f"Base parametrization '{self.base_parametrization_name}' is not registered." ) from exc @property def support_resolver(self) -> SupportResolver: """Support resolver callable.""" return self._support_resolver
[docs] def register_parametrization( self, name: ParametrizationName, parametrization_class: type[Parametrization], ) -> None: """ Register a parametrization class. Parameters ---------- name : ParametrizationName Unique parametrization name. parametrization_class : type[Parametrization] Parametrization class to register. Raises ------ ValueError If name is already registered. """ if name in self._parametrizations: raise ValueError(f"Parametrization '{name}' is already registered.") self._parametrizations[name] = parametrization_class
[docs] def get_parametrization(self, name: ParametrizationName) -> type[Parametrization]: """ Fetch a parametrization class by name. Raises ------ KeyError If name is not registered. """ return self._parametrizations[name]
[docs] def to_base(self, parameters: Parametrization) -> Parametrization: """ Convert parameters to the base parametrization. Parameters ---------- parameters : Parametrization Parameters in any parametrization. Returns ------- Parametrization Equivalent parameters in base parametrization. """ if parameters.name == self.base_parametrization_name: return parameters return parameters.transform_to_base_parametrization()
@staticmethod def _bind_parametrization[In, Out]( func: ParametricFamilyCharacteristic[In, Out], params_obj: Parametrization ) -> ComputationFunc[In, Out]: """Bind ``params_obj`` to ``func`` only when ``func`` can accept positional arguments. This allows parametrization-independent analytical providers to be written without a dummy first argument (e.g. ``def skew_func()`` or ``def kurt_func(*, excess=False)``), while still supporting the usual ``def f(parameters, ...)`` style. It means that we will always make any other analytical_computation params like ``excess`` as keyword-only """ sig = inspect.signature(func) params = list(sig.parameters.values()) accepts_first_positional = bool(params) and params[0].kind in ( inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, ) return cast( ComputationFunc[In, Out], partial(cast(ParametrizedCharacteristic[In, Out], func), params_obj) if accepts_first_positional else func, ) def _build_analytical_computations( self, parameters: Parametrization ) -> dict[GenericCharacteristicName, dict[LabelName, AnalyticalComputation[Any, Any]]]: """ Build analytical computations for given parameters. Uses precomputed provider plan for efficient computation. """ plan = self._analytical_plan.get(parameters.name, {}) result: dict[ GenericCharacteristicName, dict[LabelName, AnalyticalComputation[Any, Any]] ] = {} base_params: Parametrization | None = None for characteristic, provider_name in plan.items(): if provider_name == parameters.name: params_obj = parameters else: base_params = base_params or self.to_base(parameters) params_obj = base_params labeled_providers = self.distr_characteristics[characteristic][provider_name] result[characteristic] = { label_name: AnalyticalComputation( target=characteristic, func=self._bind_parametrization(func_factory, params_obj), ) for label_name, func_factory in labeled_providers.items() } return result
[docs] def distribution( self, parametrization_name: ParametrizationName | None = None, sampling_strategy: SamplingStrategy | None = None, computation_strategy: ComputationStrategy | None = None, **parameters_values: Any, ) -> ParametricFamilyDistribution: """ Create a distribution instance with given parameters. Parameters ---------- parametrization_name : ParametrizationName | None, optional Name of parametrization to use (defaults to base). sampling_strategy : SamplingStrategy Strategy for generating random samples. Such an object is unique for each distribution. computation_strategy : ComputationStrategy Strategy for computing characteristics and conversions. Such an object is unique for each distribution. **parameters_values Parameter values for the distribution. Returns ------- ParametricFamilyDistribution Distribution instance with specified parameters. Raises ------ KeyError If parametrization name is not registered. ValueError If parameters don't satisfy constraints. """ parametrization_class = ( self.base if parametrization_name is None else self._parametrizations[parametrization_name] ) parameters = parametrization_class(**parameters_values) parameters.validate() base_parameters = self.to_base(parameters) distribution_type = self._distr_type(base_parameters) analytical_computations = self._build_analytical_computations(parameters) return ParametricFamilyDistribution( family_name=self.name, distribution_type=distribution_type, analytical_computations=analytical_computations, parametrization=parameters, support=self.support_resolver(parameters), sampling_strategy=sampling_strategy, computation_strategy=computation_strategy, )
[docs] @dataclass_transform() def parametrization( self, *, name: ParametrizationName ) -> Callable[[type[Parametrization]], type[Parametrization]]: """ Create a class decorator that registers a parametrization. If you want to use this syntax and so that Mypy doesn't swear, you should mark your class as a dataclass. At the moment, Mypy cannot identify dataclass_transform if the decorator is a class method. Parameters ---------- name : str Name of the parametrization. Returns ------- Callable[[type[Parametrization]], type[Parametrization]] Class decorator for registering parametrizations. """ from pysatl_core.families.parametrizations import parametrization as _param_deco return _param_deco(family=self, name=name)
__call__ = distribution