Source code for rework_pysatl_mpest.estimators.iterative.pipeline

"""Provides a configurable, iterative estimator for mixture models.

This module defines the `Pipeline` class, which orchestrates an iterative
estimation process by executing a sequence of processing steps. It allows for
the flexible construction of algorithms like Expectation-Maximization (EM) by
combining different steps (:class:`PipelineStep`), stopping criteria (:class:`Breakpointer`),
and component pruning strategies (:class:`Pruner`).
"""

__author__ = "Danil Totmyanin"
__copyright__ = "Copyright (c) 2025 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"

import warnings
from collections.abc import Sequence
from copy import copy

import numpy as np
from numpy.typing import ArrayLike

from ...core import MixtureModel
from ..base_estimator import BaseEstimator
from ._logger import IterationRecord, IterationsHistory
from .breakpointer import Breakpointer
from .pipeline_state import PipelineState
from .pipeline_step import PipelineStep
from .pruner import Pruner


[docs] class Pipeline(BaseEstimator): """An estimator that fits a mixture model via a configurable iterative process. The pipeline executes a sequence of defined steps in a loop. After each full sequence of steps, pruning strategies are applied, and stopping conditions are checked. The loop continues until a breakpointer signals to stop. This allows for building complex, multi-stage estimation algorithms, such as variants of the EM algorithm. The core components used for configuration are: - :class:`.PipelineStep` - :class:`Breakpointer` - :class:`Pruner` Parameters ---------- steps : Sequence[PipelineStep] An ordered sequence of steps to be executed in each iteration of the pipeline. breakpointers : Sequence[Breakpointer] A sequence of strategies that define the stopping conditions for the iterative process. This list cannot be empty. pruners : Sequence[Pruner] | None, optional A sequence of strategies for removing components from the mixture model during fitting. Defaults to None, meaning no pruning is performed. once_in_iterations: int, optional The logging frequency. A value of `n` means logging occurs every `n` iterations. Defaults to 1 (log every iteration). Attributes ---------- steps : list[PipelineStep] The ordered list of operations to be performed in each iteration. breakpointers : list[Breakpointer] The list of objects that determine when the fitting process should terminate. pruners : list[Pruner] The list of objects that may remove components from the mixture during the fitting process. logger : IterationsHistory object that collects comprehensive information about each iteration of a :class:`Pipeline` estimator. Raises ------ ValueError If the sequence of :attr:`steps` is empty or invalid (i.e., a step is followed by a step not listed in its :attr:`available_next_steps`), or if the sequence of :attr:`breakpointers` is empty. Methods ------- .. autosummary:: :toctree: generated/ fit """ logger: IterationsHistory def __init__( self, steps: Sequence[PipelineStep], breakpointers: Sequence[Breakpointer], pruners: Sequence[Pruner] | None = None, once_in_iterations: int = 1, ): self._validate_steps(list(steps)) if not breakpointers: raise ValueError( "The 'breakpointers' list cannot be empty. " "At least one stopping criterion must be provided to prevent an infinite loop." ) self.breakpointers = list(breakpointers) self.pruners = list(pruners) if pruners else [] # self.pruners will always be list self.steps = list(steps) self.logger = IterationsHistory(once_in_iterations) def _validate_steps(self, steps: list[PipelineStep]): """Validates the sequence of pipeline steps. Checks if each step in the pipeline can legally be followed by the next one, based on the :attr:`available_next_steps` property of each step. It also checks that the pipeline can be run in a loop (the last step must be compatible with the first step). Parameters ---------- steps : list[PipelineStep] The sequence of steps to validate. Raises ------ ValueError If the :attr:`steps` list is empty or if the pipeline configuration is invalid, meaning a step is followed by an incompatible one. """ if not steps: raise ValueError("The 'steps' list cannot be empty for a Pipeline.") for i in range(-1, len(steps) - 1, 1): curr_step, next_step = steps[i], steps[i + 1] available_steps = tuple(curr_step.available_next_steps) if not isinstance(next_step, available_steps): raise ValueError( f"Wrong pipeline configuration. Step '{curr_step}' have" f"available next steps:'{curr_step.available_next_steps}', but got '{next_step}'" )
[docs] def fit(self, X: ArrayLike, mixture: MixtureModel) -> MixtureModel: """Fits the mixture model to the data using the configured pipeline. This method initializes the pipeline's state and runs the main loop. The loop consists of executing all :attr:`steps` in order, followed by all :attr:`pruners`. This cycle repeats until any `breakpointers` indicate that the process should stop. Parameters ---------- X : ArrayLike The input data sample. mixture : MixtureModel The initial mixture model to be fitted. An internal copy of this model will be modified throughout the process. Returns ------- MixtureModel The fitted mixture model after the pipeline has converged or been stopped. """ X = np.asarray(X, dtype=np.float64) copied_mixture = copy(mixture) # Copy to avoid modifying the original object state = PipelineState(X, None, None, copied_mixture, None) while True: # Updating the state before starting an iteration state.prev_mixture = copy(state.curr_mixture) # Performing steps for step in self.steps: result_state = step.run(state) if result_state.error: if len(self.logger) > 0: self.logger[-1].error = result_state.error else: self.logger.log( IterationRecord( self.logger._counter, result_state.curr_mixture, result_state.X, result_state.H, None, result_state.error, ) ) warnings.warn( f"Pipeline fitting stopped prematurely due to an error in step " f"'{step.__class__.__name__}': {state.error}", RuntimeWarning, ) return result_state.curr_mixture state = result_state # Pruning for pruner in self.pruners: state = pruner.prune(state) # Log self.logger.log( IterationRecord(self.logger._counter, state.curr_mixture, state.X, state.H, self.pruners, state.error) ) # Checking stopping criteria if any(breakpointer.check(state) for breakpointer in self.breakpointers): break return state.curr_mixture