Source code for rework_pysatl_mpest.estimators.iterative.steps.expectation_step
"""Provides the Expectation-step for an iterative estimation pipeline."""
__author__ = "Danil Totmyanin"
__copyright__ = "Copyright (c) 2025 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"
import numpy as np
from scipy.special import logsumexp
from ..pipeline_state import PipelineState
from ..pipeline_step import PipelineStep
[docs]
class ExpectationStep(PipelineStep):
"""A pipeline step that performs the Expectation (E-step).
This step calculates the responsibility matrix H, where H[i, j] is
the posterior probability that the i-th data point belongs to the j-th
mixture component. It can perform either a soft (probabilistic) or hard
(winner-takes-all) assignment.
Parameters
----------
is_soft : bool, optional
If True (default), performs a soft assignment where H contains probabilities.
If False, performs a hard assignment where each data point
is assigned to the single most likely component (i.e., H contains
only 0s and 1s).
Attributes
----------
is_soft : bool
Flag indicating whether to perform soft or hard assignment.
Methods
-------
.. autosummary::
:toctree: generated/
run
"""
def __init__(self, is_soft: bool = True):
self.is_soft = is_soft
@property
def available_next_steps(self) -> list[type[PipelineStep]]:
"""list[type[PipelineStep]]: Defines the valid subsequent steps.
Specifies that an :class:`ExpectationStep` must be followed by a
:class:`MaximizationStep` to form a standard EM iteration.
"""
from rework_pysatl_mpest.estimators.iterative.steps.maximization_step import MaximizationStep
return [MaximizationStep]
[docs]
def run(self, state: PipelineState) -> PipelineState:
"""Executes the E-step by calculating the responsibility matrix H.
This method computes the log-likelihood of each data point under each
component, incorporates the component weights, and normalizes to find
the posterior probabilities (responsibilities). The resulting matrix H
is then stored in the pipeline state.
Parameters
----------
state : PipelineState
The current state of the pipeline, which must contain the input
data X and the current mixture model curr_mixture.
Returns
-------
PipelineState
The updated pipeline state with the H attribute computed and set.
"""
X, mixture = state.X, state.curr_mixture
log_p_xij_matrix = np.array([comp.lpdf(X) for comp in mixture.components])
log_p_xij_matrix = log_p_xij_matrix.T
log_weighted_likelihoods = log_p_xij_matrix + mixture.log_weights
log_denominator = logsumexp(log_weighted_likelihoods, axis=1, keepdims=True)
log_H = log_weighted_likelihoods - log_denominator
H_soft = np.exp(log_H)
H_soft[np.isnan(H_soft)] = 0.0
if not self.is_soft:
n_samples = X.shape[0]
H_hard = np.zeros_like(H_soft)
max_indices = np.argmax(H_soft, axis=1)
H_hard[np.arange(n_samples), max_indices] = 1.0
state.H = H_hard
else:
state.H = H_soft
return state