Source code for pysatl_core.transformations.operators_mixin
"""
Operator mixin for transformation-enabled distributions.
This mixin provides arithmetic operators implemented through
transformation primitives.
"""
from __future__ import annotations
__author__ = "Leonid Elkin"
__copyright__ = "Copyright (c) 2025 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"
from numbers import Real
from types import NotImplementedType
from typing import TYPE_CHECKING, cast
from pysatl_core.types import BinaryOperationName
if TYPE_CHECKING:
from pysatl_core.distributions.distribution import Distribution
[docs]
class TransformationOperatorsMixin:
"""
Mixin adding affine and binary arithmetic operators to distributions.
"""
def _affine_transform(self, *, scale: float, shift: float) -> Distribution:
"""
Apply affine transformation ``Y = scale * X + shift``.
"""
from pysatl_core.distributions.distribution import Distribution
from pysatl_core.transformations.operations.distributions.affine import affine
from pysatl_core.transformations.operations.methods.affine import (
default_affine_transformation_methods,
)
distribution = cast(Distribution, self)
return affine(
distribution,
scale=scale,
shift=shift,
methods=default_affine_transformation_methods(
kind=getattr(distribution.distribution_type, "kind", None),
scale=scale,
),
)
def _binary_transform(
self,
other: Distribution,
*,
operation: BinaryOperationName,
) -> Distribution:
"""
Apply binary transformation between two distributions.
"""
from pysatl_core.distributions.distribution import Distribution
from pysatl_core.transformations.operations.distributions.binary.base import binary
from pysatl_core.transformations.operations.methods.binary.division import (
default_division_binary_transformation_methods,
)
from pysatl_core.transformations.operations.methods.binary.linear import (
default_linear_binary_transformation_methods,
)
from pysatl_core.transformations.operations.methods.binary.multiplication import (
default_multiplication_binary_transformation_methods,
)
distribution = cast(Distribution, self)
kind = getattr(distribution.distribution_type, "kind", None)
if operation in {BinaryOperationName.ADD, BinaryOperationName.SUB}:
methods = default_linear_binary_transformation_methods(kind=kind)
elif operation == BinaryOperationName.MUL:
methods = default_multiplication_binary_transformation_methods(kind=kind)
else:
methods = default_division_binary_transformation_methods(kind=kind)
return binary(
distribution,
other,
operation=operation,
methods=methods,
)
[docs]
def __add__(self, other: object) -> Distribution | NotImplementedType:
"""Return ``self + other`` for scalar or distribution operands."""
from pysatl_core.distributions.distribution import Distribution
if isinstance(other, Real):
return self._affine_transform(scale=1.0, shift=float(other))
if isinstance(other, Distribution):
return self._binary_transform(other, operation=BinaryOperationName.ADD)
return NotImplemented
[docs]
def __radd__(self, other: object) -> Distribution | NotImplementedType:
"""Return ``other + self`` for scalar or distribution operands."""
from pysatl_core.distributions.distribution import Distribution
from pysatl_core.transformations.operations.distributions.binary.base import binary
from pysatl_core.transformations.operations.methods.binary.linear import (
default_linear_binary_transformation_methods,
)
if isinstance(other, Real):
return self._affine_transform(scale=1.0, shift=float(other))
if isinstance(other, Distribution):
kind = getattr(other.distribution_type, "kind", None)
return binary(
other,
cast(Distribution, self),
operation=BinaryOperationName.ADD,
methods=default_linear_binary_transformation_methods(kind=kind),
)
return NotImplemented
[docs]
def __sub__(self, other: object) -> Distribution | NotImplementedType:
"""Return ``self - other`` for scalar or distribution operands."""
from pysatl_core.distributions.distribution import Distribution
if isinstance(other, Real):
return self._affine_transform(scale=1.0, shift=-float(other))
if isinstance(other, Distribution):
return self._binary_transform(other, operation=BinaryOperationName.SUB)
return NotImplemented
[docs]
def __rsub__(self, other: object) -> Distribution | NotImplementedType:
"""Return ``other - self`` for scalar or distribution operands."""
from pysatl_core.distributions.distribution import Distribution
from pysatl_core.transformations.operations.distributions.binary.base import binary
from pysatl_core.transformations.operations.methods.binary.linear import (
default_linear_binary_transformation_methods,
)
if isinstance(other, Real):
return self._affine_transform(scale=-1.0, shift=float(other))
if isinstance(other, Distribution):
kind = getattr(other.distribution_type, "kind", None)
return binary(
other,
cast(Distribution, self),
operation=BinaryOperationName.SUB,
methods=default_linear_binary_transformation_methods(kind=kind),
)
return NotImplemented
[docs]
def __mul__(self, other: object) -> Distribution | NotImplementedType:
"""Return ``self * other`` for scalar or distribution operands."""
from pysatl_core.distributions.distribution import Distribution
if isinstance(other, Real):
return self._affine_transform(scale=float(other), shift=0.0)
if isinstance(other, Distribution):
return self._binary_transform(other, operation=BinaryOperationName.MUL)
return NotImplemented
[docs]
def __rmul__(self, other: object) -> Distribution | NotImplementedType:
"""Return ``other * self`` for scalar or distribution operands."""
from pysatl_core.distributions.distribution import Distribution
from pysatl_core.transformations.operations.distributions.binary.base import binary
from pysatl_core.transformations.operations.methods.binary.multiplication import (
default_multiplication_binary_transformation_methods,
)
if isinstance(other, Real):
return self._affine_transform(scale=float(other), shift=0.0)
if isinstance(other, Distribution):
kind = getattr(other.distribution_type, "kind", None)
return binary(
other,
cast(Distribution, self),
operation=BinaryOperationName.MUL,
methods=default_multiplication_binary_transformation_methods(kind=kind),
)
return NotImplemented
[docs]
def __truediv__(self, other: object) -> Distribution | NotImplementedType:
"""Return ``self / other`` for scalar or distribution operands."""
from pysatl_core.distributions.distribution import Distribution
if isinstance(other, Real):
divisor = float(other)
if divisor == 0.0:
raise ZeroDivisionError("Cannot divide a distribution by zero.")
return self._affine_transform(scale=1.0 / divisor, shift=0.0)
if isinstance(other, Distribution):
return self._binary_transform(other, operation=BinaryOperationName.DIV)
return NotImplemented
[docs]
def __rtruediv__(self, other: object) -> Distribution | NotImplementedType:
"""Return ``other / self`` for distribution operands."""
from pysatl_core.distributions.distribution import Distribution
from pysatl_core.transformations.operations.distributions.binary.base import binary
from pysatl_core.transformations.operations.methods.binary.division import (
default_division_binary_transformation_methods,
)
if isinstance(other, Distribution):
kind = getattr(other.distribution_type, "kind", None)
return binary(
other,
cast(Distribution, self),
operation=BinaryOperationName.DIV,
methods=default_division_binary_transformation_methods(kind=kind),
)
return NotImplemented
[docs]
def __neg__(self) -> Distribution:
"""Return ``-self`` as an affine transformation."""
return self._affine_transform(scale=-1.0, shift=0.0)
__all__ = [
"TransformationOperatorsMixin",
]