Source code for lydata.utils

"""Utility functions and classes."""

from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any, Literal

import pandas as pd
from pydantic import BaseModel, Field


@dataclass
class _ColumnSpec:
    """Class for specifying column names and aggfuncs."""

    short: str
    long: tuple[str, str, str]
    agg_func: str | Callable[[pd.Series], pd.Series] = "value_counts"
    agg_kwargs: dict[str, Any] = field(default_factory=lambda: {"dropna": False})

    def __call__(self, series: pd.Series) -> pd.Series:
        """Call the aggregation function on the series."""
        return series.agg(self.agg_func, **self.agg_kwargs)


@dataclass
class _ColumnMap:
    """Class for mapping short and long column names."""

    from_short: dict[str, _ColumnSpec]
    from_long: dict[tuple[str, str, str], _ColumnSpec]

    def __post_init__(self) -> None:
        """Check ``from_short`` and ``from_long`` contain same ``_ColumnSpec``."""
        for left, right in zip(
            self.from_short.values(), self.from_long.values(), strict=True
        ):
            if left != right:
                raise ValueError(
                    "`from_short` and `from_long` contain different "
                    "`_ColumnSpec` instances"
                )

    @classmethod
    def from_list(cls, columns: list[_ColumnSpec]) -> "_ColumnMap":
        """Create a ColumnMap from a list of ColumnSpecs."""
        short = {col.short: col for col in columns}
        long = {col.long: col for col in columns}
        return cls(short, long)

    def __iter__(self):
        """Iterate over the short names."""
        return iter(self.from_short.values())


[docs] def get_default_column_map() -> _ColumnMap: """Get the default column map.""" return _ColumnMap.from_list( [ _ColumnSpec("age", ("patient", "#", "age")), _ColumnSpec("hpv", ("patient", "#", "hpv_status")), _ColumnSpec("smoke", ("patient", "#", "nicotine_abuse")), _ColumnSpec("alcohol", ("patient", "#", "alcohol_abuse")), _ColumnSpec("t_stage", ("tumor", "1", "t_stage")), _ColumnSpec("n_stage", ("patient", "#", "n_stage")), _ColumnSpec("m_stage", ("patient", "#", "m_stage")), _ColumnSpec("midext", ("tumor", "1", "extension")), ] )
[docs] class ModalityConfig(BaseModel): """Define a diagnostic or pathological modality.""" spec: float = Field(ge=0.5, le=1.0, description="Specificity of the modality.") sens: float = Field(ge=0.5, le=1.0, description="Sensitivity of the modality.") kind: Literal["clinical", "pathological"] = Field( default="clinical", description="Clinical modalities cannot detect microscopic disease.", )
[docs] def get_default_modalities() -> dict[str, ModalityConfig]: """Get defaults values for sensitivities and specificities of modalities. Taken from `de Bondt et al. (2007) <https://doi.org/10.1016/j.ejrad.2007.02.037>`_ and `Kyzas et al. (2008) <https://doi.org/10.1093/jnci/djn125>`_. """ return { "CT": ModalityConfig(spec=0.76, sens=0.81), "MRI": ModalityConfig(spec=0.63, sens=0.81), "PET": ModalityConfig(spec=0.86, sens=0.79), "FNA": ModalityConfig(spec=0.98, sens=0.80, kind="pathological"), "diagnostic_consensus": ModalityConfig(spec=0.86, sens=0.81), "pathology": ModalityConfig(spec=1.0, sens=1.0, kind="pathological"), "pCT": ModalityConfig(spec=0.86, sens=0.81), }
[docs] def main() -> None: """Run the main function.""" ...
if __name__ == "__main__": main()