Source code for rework_pysatl_mpest.estimators.ecm

"""Implements the Expectation-Conditional Maximization (ECM) algorithm.

This module provides the ``ECM`` class, which is a concrete
implementation of the :class:`~rework_pysatl_mpest.estimators.base_estimator.BaseEstimator`.
It uses a pipeline architecture (:class:`~rework_pysatl_mpest.estimators.iterative.pipeline.Pipeline`)
to fit the parameters of a mixture model to data.
"""

__author__ = "Danil Totmyanin"
__copyright__ = "Copyright (c) 2025 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"


from collections.abc import Sequence

from numpy.typing import ArrayLike

from ..core import MixtureModel
from ..optimizers import Optimizer
from .base_estimator import BaseEstimator
from .iterative import (
    Breakpointer,
    ExpectationStep,
    MaximizationStep,
    MaximizationStrategy,
    OptimizationBlock,
    Pipeline,
    Pruner,
)
from .iterative._logger import IterationsHistory


[docs] class ECM(BaseEstimator): """An estimator that implements the Expectation-Conditional Maximization (ECM) algorithm. This class encapsulates the logic for the ECM algorithm, a variant of the classic Expectation-Maximization (EM) algorithm. It constructs and executes a :class:`~.Pipeline` consisting of an expectation step (:class:`~.ExpectationStep`) and a conditional maximization step (:class:`~.MaximizationStep`). The key feature of this ECM implementation is that the maximization (M-step) is partitioned into separate, smaller optimization problems—one for each component in the mixture. For each component, all of its optimizable (non-fixed) parameters are estimated simultaneously by maximizing the Q-function. This component-wise update simplifies the optimization process. The overall fitting process can be customized with stopping criteria, component pruning strategies, and a numerical optimizer. Parameters ---------- breakpointers : Sequence[Breakpointer] A sequence of strategies that define the stopping conditions for the iterative process. pruners : Sequence[Pruner] A sequence of strategies for removing (pruning) components from the mixture model during fitting. optimizer : Optimizer A numerical optimizer instance used in the maximization step to find the optimal parameters. Attributes ---------- breakpointers : list[Breakpointer] The list of objects that determine when the fitting process should terminate. pruners : list[Pruner] The list of objects that may remove components from the mixture during the fitting process. optimizer : Optimizer The numerical optimizer used for parameter estimation. logger : IterationsHistory | None An object that collects information about each iteration. This attribute is only available after the :meth:`fit` method has been called. Accessing it beforehand will raise an :class:`AttributeError`. Methods ------- .. autosummary:: :toctree: generated/ fit """ def __init__(self, breakpointers: Sequence[Breakpointer], pruners: Sequence[Pruner], optimizer: Optimizer) -> None: self.breakpointers = list(breakpointers) self.pruners = list(pruners) self.optimizer = optimizer self._logger: IterationsHistory | None = None @property def logger(self) -> IterationsHistory: """An object that collects information about each iteration. Raises ------ AttributeError If accessed before the `fit` method has been called at least once. """ if self._logger is None: raise AttributeError("Logger is not available. Call the 'fit' method first.") return self._logger
[docs] def fit(self, X: ArrayLike, mixture: MixtureModel, once_in_iterations: int = 1) -> MixtureModel: """Fits the mixture model to the data using the ECM algorithm. This method sets up and runs an iterative pipeline to estimate the parameters of a given mixture model based on the input data. At each iteration, it performs an E-step and an M-step. The process repeats until one of the stopping criteria is met. Parameters ---------- X : ArrayLike The input dataset for fitting the model. mixture : MixtureModel The initial mixture model to be fitted. once_in_iterations : int, optional The logging frequency. A value of `n` means logging occurs every `n` iterations. Defaults to 1. Returns ------- MixtureModel The mixture model with the estimated parameters. """ blocks = [] for i, comp in enumerate(mixture): block = OptimizationBlock(i, comp.params_to_optimize, MaximizationStrategy.QFUNCTION) blocks.append(block) pipeline = Pipeline( [ExpectationStep(), MaximizationStep(blocks, self.optimizer)], self.breakpointers, self.pruners, once_in_iterations, ) result = pipeline.fit(X, mixture) self._logger = pipeline.logger return result