Source code for lydata.augmentor

"""Provides functions for augmenting and enhancing the lyDATA tables.

This module does the heavy lifting of inferring the most likely true involvment based
on several - possibly conflicting - diagnoses and their sensitivities and
specificities. It also resolves the sub- and super-level involvement information,
e.g. if a sublevel is involved, the superlevel is also involved, and vice-versa.

All this is achieved in the :py:func:`combine_and_augment_levels` function, which is
also used by the :py:meth:`~lydata.accessor.LyDataAccessor.combine`,
:py:meth:`~lydata.accessor.LyDataAccessor.augment`, and
:py:meth:`~lydata.accessor.LyDataAccessor.enhance` methods of the
:py:class:`~lydata.accessor.LyDataAccessor` class.
"""

from collections.abc import Mapping, Sequence
from itertools import product
from typing import Literal

import numpy as np
import pandas as pd

from lydata.utils import _sort_by


def _keep_only_involvement(table: pd.DataFrame) -> pd.DataFrame:
    """Keep only the involvement information under ``"ipsi"`` and ``"contra"``.

    >>> table = pd.DataFrame({
    ...     ("ipsi", "I"): [True, False, None],
    ...     ("contra", "II"): [False, True, None],
    ...     ("foo", "bar"): [1, 2, 3],
    ... })
    >>> _keep_only_involvement(table)
        ipsi contra
           I     II
    0   True  False
    1  False   True
    2   None   None
    """
    return table.filter(regex=r"(ipsi|contra)", axis="columns")


def _align_tables(tables: Sequence[pd.DataFrame]) -> list[pd.DataFrame]:
    """Align all columns in the sequence of ``tables``.

    >>> one = pd.DataFrame({
    ...     ("x", "a"): [1, 2],
    ...     ("x", "b"): [3, 4],
    ...     ("y", "c"): [5, 6],
    ...     ("y", "b"): [19, 120],
    ... })
    >>> two = pd.DataFrame({
    ...     ("y", "c"): [91, 10],
    ...     ("y", "b"): [9, 10],
    ...     ("x", "a"): [7, 8],
    ... })
    >>> three = pd.DataFrame({
    ...     ("x", "c"): [71, 81],
    ...     ("y", "b"): [5, 6],
    ...     ("x", "a"): [5, 61],
    ... })
    >>> aligned = _align_tables([one, two, three])
    >>> aligned[0]  # doctest: +NORMALIZE_WHITESPACE
       x           y
       a  b   c    b  c
    0  1  3 NaN   19  5
    1  2  4 NaN  120  6
    >>> aligned[1]  # doctest: +NORMALIZE_WHITESPACE
       x           y
       a   b   c   b   c
    0  7 NaN NaN   9  91
    1  8 NaN NaN  10  10
    >>> aligned[2]  # doctest: +NORMALIZE_WHITESPACE
        x          y
        a   b   c  b   c
    0   5 NaN  71  5 NaN
    1  61 NaN  81  6 NaN
    """
    if len(tables) == 0:
        return []

    all_columns = tables[0].columns
    for table in tables[1:]:
        all_columns = all_columns.union(table.columns)

    return [table.reindex(columns=all_columns) for table in tables]


def _convert_to_float_matrix(diagnoses: Sequence[pd.DataFrame]) -> np.ndarray:
    """Convert a sequence of ``diagnoses`` to a 3D float matrix.

    >>> one = pd.DataFrame({"a": [1, None], "b": [3, 4]})
    >>> two = pd.DataFrame({"a": [5, 6], "b": [7, None]})
    >>> _convert_to_float_matrix([one, two])  # doctest: +NORMALIZE_WHITESPACE
    array([[[ 1.,  3.],
            [nan,  4.]],
           [[ 5.,  7.],
            [ 6., nan]]])
    """
    matrix = np.array(diagnoses)
    matrix[pd.isna(matrix)] = np.nan
    return np.astype(matrix, float)


def _compute_likelihoods(
    diagnosis_matrix: np.ndarray,
    sensitivities: np.ndarray,
    specificities: np.ndarray,
    method: Literal["max_llh", "rank"],
) -> tuple[np.ndarray, np.ndarray]:
    """Compute the likelihoods of true/false diagnoses using the given ``method``.

    The ``diagnosis_matrix`` is a 3D array of shape ``(n_modalities, n_patients,
    n_levels)``. It should contain ``1.0`` where the diagnosis was positive and ``0.0``
    where it was negative. It may also contain ``np.nan``.

    The ``sensitivities`` and ``specificities`` are 1D arrays of shape
    ``(n_modalities,)``. When choosing the ``method="max_llh"``, the likelihood of each
    diagnosis is combined into one likelihood for each patient and level. With
    ``method="rank"``, the likelihoods are computed for the most trustworthy diagnosis.

    Returns the likelihoods of true and false diagnoses as two separate arrays.
    """
    true_pos = sensitivities[:, None, None] * diagnosis_matrix
    false_neg = (1 - sensitivities[:, None, None]) * (1 - diagnosis_matrix)
    true_neg = specificities[:, None, None] * (1 - diagnosis_matrix)
    false_pos = (1 - specificities[:, None, None]) * diagnosis_matrix

    if method not in {"max_llh", "rank"}:
        raise ValueError(f"Unknown method {method}")

    agg_func = np.nanprod if method == "max_llh" else np.nanmax
    true_llh = agg_func(true_pos + false_neg, axis=0)
    false_llh = agg_func(true_neg + false_pos, axis=0)
    return true_llh, false_llh


def _compute_involved_probs(
    diagnosis_matrix: np.ndarray,
    sensitivities: np.ndarray,
    specificities: np.ndarray,
    method: Literal["max_llh", "rank"],
) -> np.ndarray:
    """Compute the probabilities of involvement for each diagnosis."""
    true_llhs, false_llhs = _compute_likelihoods(
        diagnosis_matrix=diagnosis_matrix,
        sensitivities=sensitivities,
        specificities=specificities,
        method=method,
    )
    return true_llhs / (true_llhs + false_llhs)


[docs] def combine_and_augment_levels( diagnoses: Sequence[pd.DataFrame], specificities: Sequence[float], sensitivities: Sequence[float], method: Literal["max_llh", "rank"] = "max_llh", sides: Sequence[Literal["ipsi", "contra"]] | None = None, subdivisions: Mapping[str, Sequence[str]] | None = None, ) -> pd.DataFrame: """Combine ``diagnoses`` and add sub-/superlevel involvement info. Different diagnostic modalities may conflict with each other, e.g. on MRI an LNL may look metastatic, while FNA finds no malignancy. This function combines available diagnoses based on their ``sensitivities`` and ``specificities`` into a sort of consensus. When choosing the ``method="max_llh"``, the most likely/ probable diagnosis is chosen. If ``method="rank"``, the single most trustworthy diagnosis is kept. Additionally, the function may add and resolve sub- and superlevel involvement information. For example, some datasets report the overall involvement in LNL II, while others differentiate between sublevels IIa and IIb. Now, if IIa harbors disease, that means that the overall involvement in II is also true. By specifying ``subdivisions``, the function consistently updates these super- and sublevel involvement patterns. The returned :py:class:`~pandas.DataFrame` has a two-level multi-index: One level for each of the ``sides`` and the second level for the involvement levels. This means it i in the same format as the stack of input ``diagnoses``. See the accessor methods ``:py:meth:`~lydata.accessor.LyDataAccessor.augment`` and ``:py:meth:`~lydata.accessor.LyDataAccessor.combine`` for some examples. """ diagnoses = [_keep_only_involvement(table) for table in diagnoses] diagnoses = _align_tables(diagnoses) matrix = _convert_to_float_matrix(diagnoses) all_nan_mask = np.all(np.isnan(matrix), axis=0) involved_probs = _compute_involved_probs( diagnosis_matrix=matrix, sensitivities=np.array(sensitivities), specificities=np.array(specificities), method=method, ) combined = np.astype(involved_probs >= 0.5, object) combined[all_nan_mask] = None combined = pd.DataFrame(combined, columns=diagnoses[0].columns) healthy_probs = 1.0 - involved_probs involved_probs[all_nan_mask] = np.nan involved_probs = pd.DataFrame(involved_probs, columns=diagnoses[0].columns) healthy_probs[all_nan_mask] = np.nan healthy_probs = pd.DataFrame(healthy_probs, columns=diagnoses[0].columns) if sides is None: sides = ["ipsi", "contra"] if subdivisions is None: subdivisions = { "I": ["a", "b"], "II": ["a", "b"], "V": ["a", "b"], } for side, (superlvl, subids) in product(sides, subdivisions.items()): if side not in combined.columns: continue superlvl_col = (side, superlvl) sublvls = [superlvl + subid for subid in subids] sublvl_cols = [(side, sublvl) for sublvl in sublvls] if set([superlvl] + sublvls).isdisjoint(set(combined[side].columns)): continue for lvl in [superlvl] + sublvls: combined[(side, lvl)] = combined.get((side, lvl), [None] * len(combined)) nans = [np.nan] * len(combined) involved_probs[(side, lvl)] = involved_probs.get((side, lvl), nans) healthy_probs[(side, lvl)] = healthy_probs.get((side, lvl), nans) is_super_unknown = combined[superlvl_col].isna() is_super_healthy = combined[superlvl_col] == False is_super_involved = combined[superlvl_col] == True is_any_sub_involved = combined[sublvl_cols].any(axis=1) is_one_sub_unknown = combined[sublvl_cols].isna().sum(axis=1) == 1 are_all_subs_healthy = (combined[sublvl_cols] == False).all(axis=1) are_all_subs_unknown = combined[sublvl_cols].isna().all(axis=1) # Superlvl unknown => no conflict, use sublvl info combined.loc[is_super_unknown & is_any_sub_involved, superlvl_col] = True combined.loc[is_super_unknown & are_all_subs_healthy, superlvl_col] = False # No sublvl involved => no conflict, use superlvl info combined.loc[~is_any_sub_involved & is_super_healthy, sublvl_cols] = False # Conflicts # 1) Subs override superlvl super_healthy_prob_from_subs = np.nanprod(healthy_probs[sublvl_cols], axis=1) super_involved_prob_from_subs = 1.0 - super_healthy_prob_from_subs do_subs_determine_super_healthy = ( is_super_involved & ~are_all_subs_unknown & (super_healthy_prob_from_subs > involved_probs[superlvl_col]) ) combined.loc[do_subs_determine_super_healthy, superlvl_col] = False do_subs_determine_super_involved = ( is_super_healthy & ~are_all_subs_unknown & (super_involved_prob_from_subs > healthy_probs[superlvl_col]) ) combined.loc[do_subs_determine_super_involved, superlvl_col] = True # 2) Superlvl overrides subs does_super_determine_all_subs_healthy = ( is_any_sub_involved & is_super_healthy & (healthy_probs[superlvl_col] > super_involved_prob_from_subs) ) combined.loc[does_super_determine_all_subs_healthy, sublvl_cols] = False does_super_determine_subs_unknown = ( are_all_subs_healthy & is_super_involved & (involved_probs[superlvl_col] > super_healthy_prob_from_subs) ) combined.loc[does_super_determine_subs_unknown, sublvl_cols] = None for sublvl in sublvls: sublvl_col = (side, sublvl) is_sub_unknown = combined[sublvl_col].isna() does_super_determine_unknown_sub_involved = ( is_super_involved & is_sub_unknown & is_one_sub_unknown & ~is_any_sub_involved & (involved_probs[superlvl_col] > super_healthy_prob_from_subs) ) # The above combination of conditions means that the current `sublvl` is # unknown, while all others are healthy, while the superlvl is involved. # Then below, we change the sublvl to involved. combined.loc[does_super_determine_unknown_sub_involved, sublvl_col] = True combined = _sort_by(combined, which="lnl", level=1) return _sort_by(combined, which="mid", level=0)