Source code for rework_pysatl_mpest.core.parameter
"""Provides a descriptor for parameters of custom distributions.
This module contains a `Parameter` descriptor class, which is used to define
and validate parameters in classes inheriting from `ContinuousDistribution`.
It allows you to set invariants for parameter values and handle assignment errors,
as well as to fix parameters from changes."""
__author__ = "Danil Totmyanin"
__copyright__ = "Copyright (c) 2025 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"
from typing import Callable, Union, overload
[docs]
class Parameter:
"""A descriptor for validating and managing distribution parameters.
This class implements the descriptor protocol for managing access to
attributes representing the parameters of a statistical distribution.
It allows you to set the conditions (invariants) that the
parameter value must satisfy.
Parameters
----------
invariant : Callable[[float], bool], optional
A predicate function for validating parameter values.
Defaults to `lambda x: True`.
error_message : str, optional
An error message in case of failed validation.
Defaults to "Parameter value is not valid.".
Attributes
----------
invariant : Callable[[float], bool]
A function that checks if the parameter value is valid.
Returns True if the value is correct, and False otherwise.
error_message : str
An error message that is raised if the invariant is not satisfied.
public_name : str
The name of the attribute as defined in the owner class.
private_name : str
The name of the attribute used to store the value within an instance
of the owner class.
Examples
--------
.. code-block:: python
class NormalDistribution(ContinuousDistribution):
# Location parameter can be any number
loc = Parameter()
# Scale parameter must be a positive number
scale = Parameter(invariant=lambda s: s > 0, error_message="Standard deviation (scale) must be positive.")
"""
def __init__(
self,
invariant: Callable[[float], bool] = lambda x: True,
error_message: str = "Parameter value is not valid.",
):
self.invariant = invariant
self.error_message = error_message
def __set_name__(self, owner: type[object], name: str):
"""Sets the name for the public and private attributes.
This method is automatically called when a descriptor instance is created
in the owner class. It uses the attribute name to create the
public and private names.
Parameters
----------
owner : type[object]
The class that uses the descriptor.
name : str
The attribute name assigned to the descriptor instance.
"""
self.public_name = name
self.private_name = "_" + name
@overload
def __get__(self, instance: None, owner: type[object]) -> "Parameter":
"""If access is via a class, return the descriptor object itself."""
@overload
def __get__(self, instance: object, owner: type[object]) -> float:
"""If access is via an object, return the value."""
def __get__(self, instance: object | None, owner: type[object]) -> Union[float, "Parameter"]:
"""Returns the parameter value or the descriptor itself.
If access is through an instance of the class, it returns the
parameter's value. If access is through the class itself, it returns
the descriptor object.
Parameters
----------
instance : object or None
An instance of the owner class, or `None` if access
is through the class.
owner : type[object]
The owner class.
Returns
-------
float or Parameter
The value of the parameter or the descriptor itself.
"""
if instance is None:
return self
return getattr(instance, self.private_name)
def __set__(self, instance: object, value: float):
"""Sets the parameter value after validation.
Before setting a new value, it checks whether the parameter is
"fixed." Then, it validates the value using the :attr:`invariant` function.
Parameters
----------
instance : object
An instance of the owner class.
value : float
The new value for the parameter.
Raises
------
AttributeError
If an attempt is made to change a "fixed" parameter.
ValueError
If the new value does not pass the :attr:`invariant` check.
"""
if self.public_name in getattr(instance, "_fixed_params", set()):
raise AttributeError(
f"Cannot set '{self.public_name}' for instance of '{type(instance).__name__}' class. "
"This parameter is fixed."
)
if not self.invariant(value):
raise ValueError(f"Invalid value for '{self.public_name}': {self.error_message}")
setattr(instance, self.private_name, value)