"""
Support Structures for Probability Distributions
This module defines support structures for probability distributions,
including continuous intervals and discrete point sets.
Support defines the set of values where a probability distribution
is defined (non-zero probability for discrete, non-zero density for continuous).
"""
from __future__ import annotations
__author__ = "Leonid Elkin, Mikhail Mikhailov"
__copyright__ = "Copyright (c) 2025 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"
from dataclasses import dataclass
from math import floor
from typing import TYPE_CHECKING, Protocol, cast, overload, runtime_checkable
import numpy as np
from pysatl_core.types import BoolArray, Interval1D, Number, NumericArray
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator
[docs]
@runtime_checkable
class Support(Protocol):
"""
Protocol for distribution support structures.
Support defines the set of values where a distribution is defined.
"""
@overload
def contains(self, x: Number) -> bool: ...
@overload
def contains(self, x: NumericArray) -> BoolArray: ...
[docs]
class ContinuousSupport(Interval1D, Support):
"""
Support for continuous distributions represented as an interval.
This class inherits from Interval1D and implements the Support protocol
for continuous distributions defined on an interval [left, right].
"""
[docs]
@runtime_checkable
class DiscreteSupport(Support, Protocol):
"""
Protocol for discrete distribution supports.
Discrete supports consist of distinct points where the distribution
has non-zero probability mass.
"""
[docs]
def iter_points(self) -> Iterator[Number]:
"""Iterate through all points in the support."""
...
[docs]
def iter_leq(self, x: Number) -> Iterator[Number]:
"""
Iterate through points less than or equal to x.
Parameters
----------
x : Number
Upper bound for points to iterate.
"""
...
[docs]
def prev(self, x: Number) -> Number | None:
"""
Find the largest point strictly less than x.
Parameters
----------
x : Number
Reference point.
Returns
-------
Number or None
Previous point if it exists, None otherwise.
"""
...
[docs]
class ExplicitTableDiscreteSupport(DiscreteSupport):
"""
Discrete support defined by an explicit list of points.
This implementation stores points in a sorted array for efficient
membership testing and iteration.
Parameters
----------
points : Iterable[Number]
Points in the support.
assume_sorted : bool, default=False
If True, assume points are already sorted and unique.
Attributes
----------
_points : numpy.ndarray
Sorted unique points array.
"""
__slots__ = ("_points",)
[docs]
def __init__(self, points: Iterable[Number], assume_sorted: bool = False) -> None:
arr = np.array(points)
if arr.size == 0:
raise ValueError("Points must be non-empty")
if not assume_sorted:
arr.sort()
# Remove duplicates
unique_mask = np.empty(arr.size, dtype=bool)
unique_mask[0] = True
unique_mask[1:] = arr[1:] != arr[:-1]
self._points = arr[unique_mask]
@overload
def contains(self, x: Number) -> bool: ...
@overload
def contains(self, x: NumericArray) -> BoolArray: ...
[docs]
def contains(self, x: Number | NumericArray) -> bool | BoolArray:
"""
Check if point(s) are in the support.
Parameters
----------
x : Number or NumericArray
Point(s) to check.
Returns
-------
bool or BoolArray
True for points in the support, False otherwise.
"""
arr = np.asarray(x)
idx = np.searchsorted(self._points, arr, side="left")
size = self._points.size
in_bounds = (idx >= 0) & (idx < size)
idx_clipped = np.minimum(idx, size - 1)
eq = self._points[idx_clipped] == arr
result = in_bounds & eq
if np.ndim(arr) == 0:
return bool(result)
return cast(BoolArray, result)
[docs]
def __contains__(self, x: object) -> bool:
"""Check if a point is in the support."""
return bool(self.contains(cast(Number, x)))
[docs]
def iter_points(self) -> Iterator[Number]:
"""Iterate through all points in the support."""
return iter(self._points)
[docs]
def iter_leq(self, x: Number) -> Iterator[Number]:
"""
Iterate through points less than or equal to x.
Parameters
----------
x : Number
Upper bound.
"""
return iter(self._points[: np.searchsorted(self._points, x, side="right")])
[docs]
def prev(self, x: Number) -> Number | None:
"""
Find the largest point strictly less than x.
Parameters
----------
x : Number
Reference point.
"""
idx = np.searchsorted(self._points, x, side="left")
if idx == 0:
return None
return cast(Number, self._points[idx - 1])
[docs]
def first(self) -> Number:
"""Get the smallest point in the support."""
return cast(Number, self._points[0])
[docs]
def next(self, current: Number) -> Number | None:
"""
Find the smallest point strictly greater than current.
Parameters
----------
current : Number
Reference point.
"""
idx = np.searchsorted(self._points, current, side="right")
if idx == self._points.size:
return None
return cast(Number, self._points[idx])
@property
def points(self) -> NumericArray:
"""Get a copy of the points array."""
return cast(NumericArray, self._points.copy())
__iter__ = iter_points
[docs]
@dataclass(slots=True)
class IntegerLatticeDiscreteSupport(DiscreteSupport):
"""
Discrete support defined by an integer lattice: {residue + k * modulus}.
Parameters
----------
residue : int
Base value for the lattice.
modulus : int
Step size between lattice points (must be positive).
min_k : int, optional
Minimum k value (inclusive).
max_k : int, optional
Maximum k value (inclusive).
Raises
------
ValueError
If modulus is not positive.
"""
residue: int
modulus: int
min_k: int | None = None
max_k: int | None = None
def __post_init__(self) -> None:
if self.modulus <= 0:
raise ValueError("modulus must be a positive integer.")
@overload
def contains(self, x: Number) -> bool: ...
@overload
def contains(self, x: NumericArray) -> BoolArray: ...
[docs]
def contains(self, x: Number | NumericArray) -> bool | BoolArray:
"""
Check if point(s) are in the integer lattice support.
Points must be integers satisfying: x = residue (mod modulus)
and be within bounds if min_k/max_k are specified.
"""
xf = np.asarray(x, dtype=float)
v = np.floor(xf).astype(int)
is_integer = xf == v
mask = is_integer
if self.min_k is not None:
mask &= v >= self.min_k
if self.max_k is not None:
mask &= v <= self.max_k
step_ok = ((v - self.residue) % self.modulus) == 0
mask &= step_ok
result = mask.astype(bool)
if np.ndim(xf) == 0:
return bool(result)
return cast(BoolArray, result)
[docs]
def __contains__(self, x: object) -> bool:
"""Check if a point is in the integer lattice support."""
return bool(self.contains(cast(Number, x)))
[docs]
def iter_points(self) -> Iterator[int]:
"""
Iterate through all points in the integer lattice support.
Note
------
If bounded from above and unbounded from below then iterates from upper bound in
decreasing order
Raises
------
RuntimeError
If both min_k and max_k are None (unbounded both ways).
"""
first = self.first()
last = self.last()
if first is not None and last is not None and first > last:
return iter(())
if first is not None:
def _gen_lr() -> Iterator[int]:
current = first
while self.max_k is None or current <= self.max_k:
yield current
current += self.modulus
return _gen_lr()
if last is not None:
def _gen_rl() -> Iterator[int]:
current = last
while self.min_k is None or current >= self.min_k:
yield current
current -= self.modulus
return _gen_rl()
raise RuntimeError(
"Cannot iterate points for an unbounded IntegerLatticeDiscreteSupport "
"(both min_k and max_k are None). Provide at least one bound to enable enumeration."
)
[docs]
def iter_leq(self, x: Number) -> Iterator[int]:
"""
Iterate through points less than or equal to x.
Raises
------
RuntimeError
If min_k is None (left-unbounded support).
"""
if self.min_k is None:
raise RuntimeError(
"iter_leq is not supported for left-unbounded IntegerLatticeDiscreteSupport. "
"Provide min_k to enable iter_leq."
)
first = self.first()
if first is None:
return iter(())
threshold = int(floor(float(x)))
if threshold < first:
return iter(())
last = threshold
if self.max_k is not None and last > self.max_k:
last = self.max_k
offset = (last - self.residue) % self.modulus
last = last - offset
if last < first:
return iter(())
def _gen() -> Iterator[int]:
current = first
while current <= last:
yield current
current += self.modulus
return _gen()
[docs]
def prev(self, x: Number) -> int | None:
"""Find the largest point strictly less than x."""
if self.min_k is not None and float(x) <= self.min_k:
return None
target = int(floor(float(x))) - 1
if self.max_k is not None and target > self.max_k:
target = self.max_k
if self.min_k is not None and target < self.min_k:
return None
candidate = self.residue + ((target - self.residue) // self.modulus) * self.modulus
if self.min_k is not None and candidate < self.min_k:
return None
if self.max_k is not None and candidate > self.max_k:
return None
return candidate
[docs]
def first(self) -> int | None:
"""Get the smallest point in the support, or None if unbounded left."""
if self.min_k is None:
return None
first = self.min_k
offset = (first - self.residue) % self.modulus
if offset != 0:
first = first + (self.modulus - offset)
if self.max_k is not None and first > self.max_k:
return None
return first
[docs]
def last(self) -> int | None:
"""Get the largest point in the support, or None if unbounded right."""
if self.max_k is None:
return None
last = self.max_k
offset = (last - self.residue) % self.modulus
last = last - offset
if self.min_k is not None and last < self.min_k:
return None
return last
[docs]
def next(self, current: int) -> int | None:
"""Find the smallest point strictly greater than current."""
nxt = current + self.modulus
if self.max_k is not None and nxt > self.max_k:
return None
if self.min_k is not None and nxt < self.min_k:
return None
return nxt
@property
def is_left_bounded(self) -> bool:
"""Check if the support is bounded on the left."""
return self.min_k is not None
@property
def is_right_bounded(self) -> bool:
"""Check if the support is bounded on the right."""
return self.max_k is not None
__iter__ = iter_points
__all__ = [
# Base support protocol
"Support",
"ContinuousSupport",
# Discrete support protocol and implementations
"DiscreteSupport",
"ExplicitTableDiscreteSupport",
"IntegerLatticeDiscreteSupport",
]