"""
Characteristic Graph (global) + View (per distribution profile)
This module defines the global characteristic registry and per-distribution views.
The CharacteristicRegistry maintains a directed graph over characteristic names
(PDF, CDF, PPF, PMF, etc.) with nodes and edges guarded by constraints. Each
distribution profile sees a filtered view of this graph based on its specific
features (kind, dimension, etc.).
Core concepts:
- **Nodes**: Characteristics (PDF, CDF, etc.) with presence and definitiveness rules
- **Edges**: Unary computation methods between characteristics
- **Constraints**: Rules that determine when nodes/edges are applicable
- **View**: A filtered subgraph for a specific distribution
- **Definitive characteristics**: Starting points for computations
"""
from __future__ import annotations
__author__ = "Leonid Elkin, Mikhail Mikhailov"
__copyright__ = "Copyright (c) 2025 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"
import warnings
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, ClassVar, Self
from pysatl_core.distributions.registry.constraint import GraphPrimitiveConstraint
from pysatl_core.distributions.registry.graph_primitives import (
DEFAULT_COMPUTATION_KEY,
EdgeMeta,
GraphInvariantError,
)
if TYPE_CHECKING:
from pysatl_core.distributions.computation import ComputationMethod
from pysatl_core.distributions.distribution import Distribution
from pysatl_core.types import GenericCharacteristicName
# --------------------------------------------------------------------------- #
# Registry (singleton)
# --------------------------------------------------------------------------- #
[docs]
class CharacteristicRegistry:
"""
Global characteristic graph with constraint-guarded nodes and edges.
This registry maintains the complete graph of characteristics and computation
methods. It serves as a singleton that can be configured once and then used
to create filtered views for specific distributions.
Invariants (enforced per view):
1. The subgraph induced by definitive characteristics is strongly connected
2. Every non-definitive characteristic is reachable from at least one definitive
3. No path exists from any non-definitive characteristic to any definitive
Methods
-------
add_characteristic(name, is_definitive, presence_constraint=None, definitive_constraint=None)
Declare a characteristic with presence and optional definitiveness rules.
add_computation(method, label=DEFAULT_COMPUTATION_KEY, constraint=None)
Add a unary computation edge between declared nodes.
view(distr)
Create a filtered view for the given distribution.
Notes
-----
- Nodes must be declared before adding computations
- Only unary computations (1 source → 1 target) are supported
- No invariant validation happens during mutation; validation occurs when
creating a view with view()
"""
_instance: ClassVar[Self | None] = None
def __new__(cls) -> Self:
if cls._instance is None:
inst = super().__new__(cls)
cls._instance = inst
return cls._instance
[docs]
def __init__(self) -> None:
if getattr(self, "_initialized", False):
return
# Adjacency: src → dst → label → [EdgeMeta]
self._adj: dict[
GenericCharacteristicName,
dict[GenericCharacteristicName, dict[str, list[EdgeMeta]]],
] = {}
self._all_nodes: set[GenericCharacteristicName] = set()
# Node constraints
self._presence_rules: dict[GenericCharacteristicName, GraphPrimitiveConstraint] = {}
self._def_rules: dict[GenericCharacteristicName, GraphPrimitiveConstraint] = {}
# Label preference for path finding
self.label_preference: tuple[str, ...] = (DEFAULT_COMPUTATION_KEY,)
self._initialized = True
[docs]
def __copy__(self) -> Self:
"""Singleton copy returns the same instance."""
return self
[docs]
def __deepcopy__(self, memo: dict[Any, Any]) -> Self:
"""Singleton deepcopy returns the same instance."""
return self
[docs]
def __reduce__(self) -> tuple[type[Self], tuple[()]]:
"""Ensure pickling preserves singleton semantics."""
return self.__class__, ()
@classmethod
def _reset(cls) -> None:
"""Reset the singleton"""
cls._instance = None
def _add_node(self, node: GenericCharacteristicName) -> None:
"""Insert node into the registry (idempotent)."""
self._adj.setdefault(node, {})
self._all_nodes.add(node)
def _ensure_node(self, node: GenericCharacteristicName) -> bool:
"""Check if a node has been declared via add_characteristic()."""
return node in self._all_nodes
def _add_presence_rule(
self, name: GenericCharacteristicName, constraint: GraphPrimitiveConstraint | None
) -> None:
"""
Register a presence rule for a node.
Warns if the node already has a presence rule.
"""
if name in self._presence_rules:
warnings.warn(
f"Node {name} already has a presence rule. New constraint will be ignored.",
UserWarning,
stacklevel=3,
)
return
self._presence_rules[name] = (
constraint if constraint is not None else GraphPrimitiveConstraint()
)
def _add_definitive_rule(
self, name: GenericCharacteristicName, constraint: GraphPrimitiveConstraint | None
) -> None:
"""
Register a definitiveness rule for a node.
Warns if the node already has a definitiveness rule.
"""
if name in self._def_rules and constraint is not None:
warnings.warn(
f"Node {name} already has a definitiveness rule. New constraint will be ignored.",
UserWarning,
stacklevel=3,
)
return
self._def_rules[name] = constraint if constraint is not None else GraphPrimitiveConstraint()
[docs]
def add_computation(
self,
method: ComputationMethod[Any, Any],
*,
label: str = DEFAULT_COMPUTATION_KEY,
constraint: GraphPrimitiveConstraint | None = None,
) -> None:
"""
Add a labeled unary computation edge.
Parameters
----------
method : ComputationMethod
Computation object with exactly one source and one target.
label : str, default=DEFAULT_COMPUTATION_KEY
Variant label for the edge.
constraint : GraphPrimitiveConstraint, optional
Edge applicability constraint. If None, a pass-through constraint is used.
Raises
------
ValueError
If method is not unary, or source/target nodes are not declared.
Notes
-----
- Multiple edges with different labels can exist between the same nodes
- The first matching edge for each label is kept when creating views
"""
if len(method.sources) != 1:
raise ValueError("Only unary computations are supported (1 source → 1 target).")
src = method.sources[0]
dst = method.target
if not self._ensure_node(src) or not self._ensure_node(dst):
raise ValueError("Source characteristic or destination characteristic is invalid.")
self._adj[src].setdefault(dst, {})
# TODO: We need to be careful here if some constraint more general and with the same label
# than other it can consume it. Actually, the same label methods should not intersect their
# constraints
self._adj[src][dst].setdefault(label, [])
self._adj[src][dst][label].append(
EdgeMeta(
method=method,
constraint=constraint or GraphPrimitiveConstraint(),
)
)
[docs]
def add_characteristic(
self,
name: GenericCharacteristicName,
is_definitive: bool,
*,
presence_constraint: GraphPrimitiveConstraint | None = None,
definitive_constraint: GraphPrimitiveConstraint | None = None,
) -> None:
"""
Declare a characteristic with presence and optional definitiveness rules.
Parameters
----------
name : str
Characteristic name (e.g., "pdf", "cdf").
is_definitive : bool
Whether this characteristic can serve as a starting point for computations.
presence_constraint : GraphPrimitiveConstraint, optional
Constraint determining when this characteristic exists for a distribution.
definitive_constraint : GraphPrimitiveConstraint, optional
Constraint determining when this characteristic is definitive.
Ignored if is_definitive is False.
Notes
-----
- If is_definitive is False but definitive_constraint is provided,
a warning is issued and the constraint is ignored
- Presence constraints are required; without one, the characteristic
will never appear in any view
"""
self._add_node(name)
self._add_presence_rule(name, presence_constraint)
if not is_definitive and definitive_constraint is not None:
warnings.warn(
f"Node {name} is non-definitive but has a definitive constraint. "
"Constraint will be ignored.",
UserWarning,
stacklevel=2,
)
if is_definitive:
self._add_definitive_rule(name, definitive_constraint)
# --------------------------------------------------------------------- #
# Views
# --------------------------------------------------------------------- #
def _compute_present_nodes(self, distr: Distribution) -> set[GenericCharacteristicName]:
"""
Compute characteristics present for the given distribution.
Returns
-------
set of str
Characteristics whose presence constraints allow this distribution.
"""
present: set[GenericCharacteristicName] = set()
for name, constraint in self._presence_rules.items():
if constraint.allows(distr):
present.add(name)
return present
def _compute_definitive_nodes(self, distr: Distribution) -> set[GenericCharacteristicName]:
"""
Compute definitive characteristics for the given distribution.
Returns
-------
set of str
Characteristics whose definitiveness constraints allow this distribution.
"""
definitive: set[GenericCharacteristicName] = set()
for name, constraint in self._def_rules.items():
if constraint.allows(distr):
definitive.add(name)
return definitive
[docs]
def view(self, distr: Distribution) -> RegistryView:
"""
Create a filtered view of the graph for the given distribution.
Parameters
----------
distr : Distribution
Distribution profile to filter for.
Returns
-------
RegistryView
Filtered view containing only applicable nodes and edges.
Notes
-----
1. Filters edges by their constraints
2. Removes edges touching absent nodes
3. Computes definitive nodes from the remaining present nodes
4. Validates graph invariants
"""
# 1) Filter edges by applicability
adj: dict[
GenericCharacteristicName, dict[GenericCharacteristicName, dict[str, EdgeMeta]]
] = {}
for src, d in self._adj.items():
for dst, variants in d.items():
kept: dict[str, EdgeMeta] = {}
for label, metas in variants.items():
for edge in metas:
if edge.constraint.allows(distr):
kept[label] = edge
# TODO: It is possible that there are two edges under the same label
# that fit the same distribution, this should not be the case.
# Taking the first one for now
break
if kept:
adj.setdefault(src, {}).setdefault(dst, {}).update(kept)
# 2) Filter by node presence
present_nodes = self._compute_present_nodes(distr)
if present_nodes:
adj = {
src: {dst: dict(variants) for dst, variants in d.items() if dst in present_nodes}
for src, d in adj.items()
if src in present_nodes
}
# Ensure isolated present nodes are preserved
for node in present_nodes:
adj.setdefault(node, {})
# 3) Compute definitive nodes (must be present)
definitive_nodes = self._compute_definitive_nodes(distr) & present_nodes
return RegistryView(adj, definitive_nodes, present_nodes)
# --------------------------------------------------------------------------- #
# Registry view
# --------------------------------------------------------------------------- #
[docs]
class RegistryView:
"""
Filtered view of the characteristic graph for a specific distribution.
This view contains only the nodes and edges applicable to a particular
distribution profile, with all graph invariants validated.
Parameters
----------
adj : Mapping[src, Mapping[dst, Mapping[label, EdgeMeta]]]
Filtered adjacency preserving label variants.
definitive_nodes : set of str
Definitive characteristics in this view.
present_nodes : set of str
All present characteristics in this view.
Raises
------
GraphInvariantError
If any graph invariant is violated.
Attributes
----------
definitive_characteristics : set of str
Definitive characteristics for this distribution.
all_characteristics : set of str
All present characteristics for this distribution.
"""
[docs]
def __init__(
self,
adj: Mapping[
GenericCharacteristicName,
Mapping[GenericCharacteristicName, Mapping[str, EdgeMeta]],
],
definitive_nodes: set[GenericCharacteristicName],
present_nodes: set[GenericCharacteristicName],
) -> None:
# Deep copy adjacency to ensure immutability
self._adj: dict[
GenericCharacteristicName, dict[GenericCharacteristicName, dict[str, EdgeMeta]]
] = {}
for src, d in adj.items():
self._adj[src] = {dst: dict(variants) for dst, variants in d.items()}
self.definitive_characteristics: set[GenericCharacteristicName] = set(definitive_nodes)
self.all_characteristics: set[GenericCharacteristicName] = set(present_nodes)
# Validate invariants immediately
self._validate_invariants()
@property
def indefinitive_characteristics(self) -> set[GenericCharacteristicName]:
"""
Present but non-definitive characteristics.
Returns
-------
set of str
Characteristics that exist but are not definitive.
"""
return self.all_characteristics - self.definitive_characteristics
[docs]
def successors(
self, v: GenericCharacteristicName
) -> Mapping[GenericCharacteristicName, Mapping[str, EdgeMeta]]:
"""
Get outgoing edges from a characteristic.
Parameters
----------
v : str
Source characteristic.
Returns
-------
Mapping[str, Mapping[str, EdgeMeta]]
Destination → label → edge metadata.
"""
return self._adj.get(v, {})
[docs]
def successors_nodes(self, v: GenericCharacteristicName) -> set[GenericCharacteristicName]:
"""
Get directly reachable characteristics from v.
Parameters
----------
v : str
Source characteristic.
Returns
-------
set of str
Characteristics directly reachable from v.
"""
return set(self._adj.get(v, {}).keys())
[docs]
def predecessors(self, v: GenericCharacteristicName) -> set[GenericCharacteristicName]:
"""
Get characteristics with edges to v.
Parameters
----------
v : str
Destination characteristic.
Returns
-------
set of str
Characteristics that can reach v directly.
"""
res: set[GenericCharacteristicName] = set()
for src, d in self._adj.items():
if v in d and d[v]:
res.add(src)
return res
[docs]
def variants(
self, src: GenericCharacteristicName, dst: GenericCharacteristicName
) -> Mapping[str, EdgeMeta]:
"""
Get all labeled edges between two characteristics.
Parameters
----------
src, dst : str
Edge endpoints.
Returns
-------
Mapping[str, EdgeMeta]
Label → edge metadata mapping.
"""
return self._adj.get(src, {}).get(dst, {})
[docs]
def find_path(
self,
src: GenericCharacteristicName,
dst: GenericCharacteristicName,
*,
prefer_label: str | None = None,
) -> list[Any] | None:
"""
Find a computation path from src to dst using BFS.
Parameters
----------
src, dst : str
Source and destination characteristics.
prefer_label : str, optional
Preferred edge label to use when multiple options exist.
Returns
-------
list of ComputationMethod or None
List of computation methods forming the path, or None if no path exists.
Notes
-----
Label selection priority:
1. prefer_label if present
2. DEFAULT_COMPUTATION_KEY if present
3. Lexicographically smallest label
"""
if src == dst:
return []
visited: set[GenericCharacteristicName] = {src}
parent: dict[GenericCharacteristicName, tuple[GenericCharacteristicName, Any]] = {}
queue: list[GenericCharacteristicName] = [src]
qi = 0
while qi < len(queue):
v = queue[qi]
qi += 1
for w, by_label in self._adj.get(v, {}).items():
if not by_label or w in visited:
continue
method = self._pick_method(by_label, prefer_label)
visited.add(w)
parent[w] = (v, method)
if w == dst:
# Reconstruct path
path: list[Any] = []
cur = dst
while cur != src:
pv, m = parent[cur]
path.append(m)
cur = pv
path.reverse()
return path
queue.append(w)
return None
def _validate_invariants(self) -> None:
"""
Validate all graph invariants.
Raises
------
GraphInvariantError
If any invariant is violated.
"""
if not self._definitive_strongly_connected():
raise GraphInvariantError("Definitive subgraph must be strongly connected.")
if not self._all_indefinitives_reachable_from_definitives():
raise GraphInvariantError(
"Every indefinitive characteristic must be reachable from some definitive."
)
if self._exists_path_from_indefinitive_to_definitive():
raise GraphInvariantError(
"No path from any indefinitive characteristic back to a definitive is allowed."
)
def _definitive_strongly_connected(self) -> bool:
"""
Check if definitive characteristics form a strongly connected subgraph.
Returns
-------
bool
True if strongly connected.
"""
defs = self.definitive_characteristics
if len(defs) <= 1:
return True
start = next(iter(defs))
fwd = self._reachable_from(start, allowed=defs)
if fwd != (defs - {start}):
return False
# Check reverse reachability
seen: set[GenericCharacteristicName] = {start}
stack = [start]
while stack:
v = stack.pop()
for w in self.predecessors(v):
if w in defs and w not in seen:
seen.add(w)
stack.append(w)
return seen == defs
def _all_indefinitives_reachable_from_definitives(self) -> bool:
"""
Check that all non-definitive nodes are reachable from definitive nodes.
Returns
-------
bool
True if all indefinitives are reachable.
"""
indefs = self.indefinitive_characteristics
if not indefs:
return True
total: set[GenericCharacteristicName] = set()
for d in self.definitive_characteristics:
total |= self._reachable_from(d)
return indefs.issubset(total)
def _exists_path_from_indefinitive_to_definitive(self) -> bool:
"""
Check if any non-definitive node can reach a definitive node.
Returns
-------
bool
True if such a path exists (which would violate invariants).
"""
defs = self.definitive_characteristics
return any(self._reachable_from(i) & defs for i in self.indefinitive_characteristics)
def _reachable_from(
self,
src: GenericCharacteristicName,
*,
allowed: set[GenericCharacteristicName] | None = None,
) -> set[GenericCharacteristicName]:
"""
Compute forward reachable nodes from src.
Parameters
----------
src : str
Starting node.
allowed : set of str, optional
Restrict to this set of nodes.
Returns
-------
set of str
Nodes reachable from src (excluding src itself).
"""
if allowed is not None and src not in allowed:
return set()
visited: set[GenericCharacteristicName] = set()
stack = [src]
while stack:
v = stack.pop()
if v in visited:
continue
visited.add(v)
for w in self.successors_nodes(v):
if allowed is not None and w not in allowed:
continue
if w not in visited:
stack.append(w)
visited.discard(src)
return visited
@staticmethod
def _pick_method(
variants: Mapping[str, EdgeMeta],
prefer_label: str | None,
) -> Any:
"""
Select a method from label variants.
Parameters
----------
variants : Mapping[str, EdgeMeta]
Available edge variants.
prefer_label : str, optional
Preferred label.
Returns
-------
Any
Selected computation method.
"""
if prefer_label and prefer_label in variants:
return variants[prefer_label].method
if DEFAULT_COMPUTATION_KEY in variants:
return variants[DEFAULT_COMPUTATION_KEY].method
label = sorted(variants.keys())[0]
return variants[label].method