"""
Fitter registry for descriptor lookup and matching.
Provides ``FitterRegistry`` for registering, querying, and prioritising
fitter descriptors by target, sources, and constraint tags.
The module-level ``fitter_registry()`` function returns a process-wide
singleton that is populated lazily on first access with all built-in
descriptors from :mod:`pysatl_core.distributions.computations.continuous`
and :mod:`pysatl_core.distributions.computations.discrete`.
"""
from __future__ import annotations
__author__ = "Irina Sergeeva"
__copyright__ = "Copyright (c) 2025 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"
from functools import lru_cache
from typing import TYPE_CHECKING, Any, ClassVar, Self
if TYPE_CHECKING:
from collections.abc import Sequence
from pysatl_core.distributions.computations.descriptors import FitterDescriptor
from pysatl_core.types import GenericCharacteristicName
[docs]
class FitterRegistry:
"""
Registry that stores fitter descriptors and selects the best match.
This class is a singleton: every call to ``FitterRegistry()`` returns the
same instance. Use ``FitterRegistry._reset()`` in tests to clear state.
Examples
--------
>>> registry = FitterRegistry()
>>> registry.register(some_descriptor)
>>> desc = registry.find("cdf", ["pdf"], required_tags={"continuous", "univariate"})
"""
_instance: ClassVar[FitterRegistry | None] = None
def __new__(cls) -> Self:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance # type: ignore[return-value]
[docs]
def __init__(self) -> None:
if getattr(self, "_initialized", False):
return
self._by_key: dict[
tuple[GenericCharacteristicName, tuple[GenericCharacteristicName, ...]],
list[FitterDescriptor],
] = {}
self._all: list[FitterDescriptor] = []
self._initialized = True
[docs]
def __copy__(self) -> Self:
"""Singleton copy returns the same instance."""
return self
[docs]
def __deepcopy__(self, memo: dict[Any, Any]) -> Self:
"""Singleton deepcopy returns the same instance."""
return self
@classmethod
def _reset(cls) -> None:
"""
Clear the singleton instance (for testing purposes only).
Resets all registered descriptors and allows the next ``FitterRegistry()``
call to create a fresh instance.
"""
cls._instance = None
[docs]
def register(self, descriptor: FitterDescriptor) -> None:
"""
Register a fitter descriptor.
Parameters
----------
descriptor : FitterDescriptor
Fitter to register.
"""
key = (descriptor.target, tuple(descriptor.sources))
self._by_key.setdefault(key, []).append(descriptor)
self._all.append(descriptor)
[docs]
def register_many(self, descriptors: Sequence[FitterDescriptor]) -> None:
"""Register multiple descriptors at once."""
for d in descriptors:
self.register(d)
[docs]
def find(
self,
target: GenericCharacteristicName,
sources: Sequence[GenericCharacteristicName],
*,
required_tags: frozenset[str] | None = None,
) -> FitterDescriptor | None:
"""
Find the first fitter matching the given target and sources.
Parameters
----------
target : str
Target characteristic name.
sources : Sequence[str]
Source characteristic names.
required_tags : frozenset[str] | None
If provided, only fitters whose ``constraint_tags`` are a superset
of *required_tags* are considered.
Returns
-------
FitterDescriptor | None
First matching descriptor, or ``None`` if no match.
"""
key = (target, tuple(sources))
candidates = self._by_key.get(key)
if candidates is None:
return None
for desc in candidates:
if required_tags is not None and not desc.constraint_tags.issuperset(required_tags):
continue
return desc
return None
[docs]
def find_all(
self,
target: GenericCharacteristicName,
sources: Sequence[GenericCharacteristicName],
*,
required_tags: frozenset[str] | None = None,
) -> list[FitterDescriptor]:
"""
Return *all* matching fitters in registration order.
Parameters
----------
target : str
Target characteristic name.
sources : Sequence[str]
Source characteristic names.
required_tags : frozenset[str] | None
Tag filter (same semantics as :meth:`find`).
Returns
-------
list[FitterDescriptor]
Matching descriptors in registration order.
"""
key = (target, tuple(sources))
candidates = self._by_key.get(key, [])
return [
d
for d in candidates
if required_tags is None or d.constraint_tags.issuperset(required_tags)
]
[docs]
def all_descriptors(self) -> list[FitterDescriptor]:
"""Return all registered descriptors in insertion order."""
return list(self._all)
def __len__(self) -> int:
return len(self._all)
def __contains__(self, name: str) -> bool:
return any(d.name == name for d in self._all)
[docs]
@lru_cache(maxsize=1)
def fitter_registry() -> FitterRegistry:
"""
Return the process-wide singleton ``FitterRegistry``, populated lazily.
All built-in continuous and discrete fitter descriptors are created and
registered on the first call. Subsequent calls return the cached instance.
Notes
-----
- Descriptors are **not** created at import time; they are built here on
first access, keeping module-level side-effects to a minimum.
- ``FitterRegistry`` is a singleton; ``FitterRegistry() is fitter_registry()``
is always ``True`` after the first call.
- To reset the singleton (e.g. in tests), call ``reset_fitter_registry()``.
"""
from pysatl_core.distributions.computations.continuous import (
_build_continuous_descriptors,
)
from pysatl_core.distributions.computations.discrete import (
_build_discrete_descriptors,
)
reg = FitterRegistry()
reg.register_many(_build_continuous_descriptors())
reg.register_many(_build_discrete_descriptors())
return reg
[docs]
def reset_fitter_registry() -> None:
"""Reset the cached fitter registry (useful in tests)."""
fitter_registry.cache_clear()
FitterRegistry._reset()
__all__ = [
"FitterRegistry",
"fitter_registry",
"reset_fitter_registry",
]