Source code for rework_pysatl_mpest.estimators.iterative.steps.expectation_step

"""Provides the Expectation-step for an iterative estimation pipeline."""

__author__ = "Danil Totmyanin"
__copyright__ = "Copyright (c) 2025 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"


import numpy as np
from scipy.special import logsumexp

from ..pipeline_state import PipelineState
from ..pipeline_step import PipelineStep


[docs] class ExpectationStep(PipelineStep): """A pipeline step that performs the Expectation (E-step). This step calculates the responsibility matrix H, where H[i, j] is the posterior probability that the i-th data point belongs to the j-th mixture component. It can perform either a soft (probabilistic) or hard (winner-takes-all) assignment. Parameters ---------- is_soft : bool, optional If True (default), performs a soft assignment where H contains probabilities. If False, performs a hard assignment where each data point is assigned to the single most likely component (i.e., H contains only 0s and 1s). Attributes ---------- is_soft : bool Flag indicating whether to perform soft or hard assignment. Methods ------- .. autosummary:: :toctree: generated/ run """ def __init__(self, is_soft: bool = True): self.is_soft = is_soft @property def available_next_steps(self) -> list[type[PipelineStep]]: """list[type[PipelineStep]]: Defines the valid subsequent steps. Specifies that an :class:`ExpectationStep` must be followed by a :class:`MaximizationStep` to form a standard EM iteration. """ from rework_pysatl_mpest.estimators.iterative.steps.maximization_step import MaximizationStep return [MaximizationStep]
[docs] def run(self, state: PipelineState) -> PipelineState: """Executes the E-step by calculating the responsibility matrix H. This method computes the log-likelihood of each data point under each component, incorporates the component weights, and normalizes to find the posterior probabilities (responsibilities). The resulting matrix H is then stored in the pipeline state. Parameters ---------- state : PipelineState The current state of the pipeline, which must contain the input data X and the current mixture model curr_mixture. Returns ------- PipelineState The updated pipeline state with the H attribute computed and set. """ X, mixture = state.X, state.curr_mixture log_p_xij_matrix = np.array([comp.lpdf(X) for comp in mixture.components]) log_p_xij_matrix = log_p_xij_matrix.T log_weighted_likelihoods = log_p_xij_matrix + mixture.log_weights log_denominator = logsumexp(log_weighted_likelihoods, axis=1, keepdims=True) log_H = log_weighted_likelihoods - log_denominator H_soft = np.exp(log_H) H_soft[np.isnan(H_soft)] = 0.0 if not self.is_soft: n_samples = X.shape[0] H_hard = np.zeros_like(H_soft) max_indices = np.argmax(H_soft, axis=1) H_hard[np.arange(n_samples), max_indices] = 1.0 state.H = H_hard else: state.H = H_soft return state