Source code for lydata.utils

"""Utility functions and classes."""

import os
import re
from collections.abc import Callable
from dataclasses import dataclass, field
from functools import cmp_to_key
from typing import Any, Literal

import pandas as pd
from github import Auth
from loguru import logger
from pydantic import BaseModel, Field
from roman import fromRoman as roman_to_int  # noqa: N813


[docs] def get_github_auth( token: str | None = None, user: str | None = None, password: str | None = None, ) -> Auth.Auth | None: """Get the GitHub authentication object from arguments or environment variables.""" token = token or os.getenv("GITHUB_TOKEN") user = user or os.getenv("GITHUB_USER") password = password or os.getenv("GITHUB_PASSWORD") if token: logger.debug("Using GITHUB_TOKEN for authentication.") return Auth.Token(token) if user and password: logger.debug("Using GITHUB_USER and GITHUB_PASSWORD for authentication.") return Auth.Login(user, password) logger.info("No authentication provided. Using unauthenticated access.") return None
[docs] def update_and_expand( left: pd.DataFrame, right: pd.DataFrame, **update_kwargs: Any, ) -> pd.DataFrame: """Update ``left`` with values from ``right``, also adding columns from ``right``. The added feature of this function over pandas' :py:meth:`~pandas.DataFrame.update` is that it also adds columns that are present in ``right`` but not in ``left``. Any keyword arguments are also directly passed to the :py:meth:`~pandas.DataFrame.update`. >>> left = pd.DataFrame({"a": [1, 2, None], "b": [3, 4, 5]}) >>> right = pd.DataFrame({"a": [None, 3, 4], "c": [6, 7, 8]}) >>> update_and_expand(left, right) a b c 0 1.0 3 6 1 3.0 4 7 2 4.0 5 8 """ result = left.copy() result.update(right, **update_kwargs) for column in right.columns: if column not in result.columns: result[column] = right[column] return result
[docs] def replace(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame: """Replace all columns in ``left`` with those from ``right``.""" result = left.copy() for column in right.columns: result[column] = right[column] return result
@dataclass class _ColumnSpec: """Class for specifying column names and aggfuncs. This serves a dual purpose: 1. It is a simple container that ties together a short name and a long name. For this we could have used a `namedtuple` as well. 2. Every `_ColumnSpec` is also an aggregation function in itself. This is used in the :py:meth:`~lydata.accessor.LyDataAccessor.stats` method. """ short: str long: tuple[str, str, str] agg_func: str | Callable[[pd.Series], pd.Series] = "value_counts" agg_kwargs: dict[str, Any] = field(default_factory=lambda: {"dropna": False}) def __call__(self, series: pd.Series) -> pd.Series: """Call the aggregation function on the series.""" return series.agg(self.agg_func, **self.agg_kwargs) @dataclass class _ColumnMap: """Class for mapping short and long column names.""" from_short: dict[str, _ColumnSpec] from_long: dict[tuple[str, str, str], _ColumnSpec] def __post_init__(self) -> None: """Check ``from_short`` and ``from_long`` contain same ``_ColumnSpec``.""" for left, right in zip( self.from_short.values(), self.from_long.values(), strict=True ): if left != right: raise ValueError( "`from_short` and `from_long` contain different " "`_ColumnSpec` instances" ) @classmethod def from_list(cls, columns: list[_ColumnSpec]) -> "_ColumnMap": """Create a ColumnMap from a list of ColumnSpecs.""" short = {col.short: col for col in columns} long = {col.long: col for col in columns} return cls(short, long) def __iter__(self): """Iterate over the short names.""" return iter(self.from_short.values())
[docs] def get_default_column_map_old() -> _ColumnMap: """Get the old default column map. This map defines which short column names can be used to access columns in the DataFrames. >>> from lydata import accessor, loader >>> df = next(loader.load_datasets( ... institution="usz", ... repo_name="lycosystem/lydata.private", ... ref="ab04379a36b6946306041d1d38ad7e97df8ee7ba", ... )) >>> df.ly.surgery # doctest: +ELLIPSIS 0 False ... 286 False Name: (patient, #, neck_dissection), Length: 287, dtype: bool >>> df.ly.smoke # doctest: +ELLIPSIS 0 True ... 286 True Name: (patient, #, nicotine_abuse), Length: 287, dtype: bool """ return _ColumnMap.from_list( [ _ColumnSpec("id", ("patient", "#", "id")), _ColumnSpec("institution", ("patient", "#", "institution")), _ColumnSpec("sex", ("patient", "#", "sex")), _ColumnSpec("age", ("patient", "#", "age")), _ColumnSpec("weight", ("patient", "#", "weight")), _ColumnSpec("date", ("patient", "#", "diagnose_date")), _ColumnSpec("surgery", ("patient", "#", "neck_dissection")), _ColumnSpec("hpv", ("patient", "#", "hpv_status")), _ColumnSpec("smoke", ("patient", "#", "nicotine_abuse")), _ColumnSpec("alcohol", ("patient", "#", "alcohol_abuse")), _ColumnSpec("t_stage", ("tumor", "1", "t_stage")), _ColumnSpec("n_stage", ("patient", "#", "n_stage")), _ColumnSpec("m_stage", ("patient", "#", "m_stage")), _ColumnSpec("midext", ("tumor", "1", "extension")), _ColumnSpec("subsite", ("tumor", "1", "subsite")), _ColumnSpec("location", ("tumor", "1", "location")), _ColumnSpec("volume", ("tumor", "1", "volume")), _ColumnSpec("central", ("tumor", "1", "central")), _ColumnSpec("side", ("tumor", "1", "side")), ] )
def _new_from_old(long_name: tuple[str, str, str]) -> tuple[str, str, str]: """Convert an old long key name to a new long key name. >>> _new_from_old(("patient", "#", "neck_dissection")) ('patient', 'core', 'neck_dissection') >>> _new_from_old(("tumor", "1", "t_stage")) ('tumor', 'core', 't_stage') >>> _new_from_old(("a", "b", "c")) ('a', 'b', 'c') """ start, middle, end = long_name if (start == "patient" and middle == "#") or (start == "tumor" and middle == "1"): middle = "core" return (start, middle, end)
[docs] def is_old(dataset: pd.DataFrame) -> bool: """Check if the dataset uses the old column names.""" second_lvl_headers = dataset.columns.get_level_values(1) return "#" in second_lvl_headers or "1" in second_lvl_headers
[docs] def get_default_column_map_new() -> _ColumnMap: """Get the old default column map. This map defines which short column names can be used to access columns in the DataFrames. >>> from lydata import accessor, loader >>> df = next(loader.load_datasets( ... institution="usz", ... repo_name="lycosystem/lydata.private", ... ref="fb55afa26ff78afa78274a86b131fb3014d0ceea", ... )) >>> df.ly.surgery # doctest: +ELLIPSIS 0 False ... 286 False Name: (patient, core, neck_dissection), Length: 287, dtype: bool >>> df.ly.smoke # doctest: +ELLIPSIS 0 True ... 286 True Name: (patient, core, nicotine_abuse), Length: 287, dtype: bool """ return _ColumnMap.from_list( [ _ColumnSpec(cs.short, _new_from_old(cs.long)) for cs in get_default_column_map_old() ] )
[docs] class ModalityConfig(BaseModel): """Define a diagnostic or pathological modality.""" spec: float = Field(ge=0.5, le=1.0, description="Specificity of the modality.") sens: float = Field(ge=0.5, le=1.0, description="Sensitivity of the modality.") kind: Literal["clinical", "pathological"] = Field( default="clinical", description="Clinical modalities cannot detect microscopic disease.", )
[docs] def get_default_modalities() -> dict[str, ModalityConfig]: """Get defaults values for sensitivities and specificities of modalities. Taken from `de Bondt et al. (2007) <https://doi.org/10.1016/j.ejrad.2007.02.037>`_ and `Kyzas et al. (2008) <https://doi.org/10.1093/jnci/djn125>`_. """ return { "CT": ModalityConfig(spec=0.76, sens=0.81), "MRI": ModalityConfig(spec=0.63, sens=0.81), "PET": ModalityConfig(spec=0.86, sens=0.79), "FNA": ModalityConfig(spec=0.98, sens=0.80, kind="pathological"), "diagnostic_consensus": ModalityConfig(spec=0.86, sens=0.81), "pathology": ModalityConfig(spec=1.0, sens=1.0, kind="pathological"), "pCT": ModalityConfig(spec=0.86, sens=0.81), }
def _get_all_true(df: pd.DataFrame) -> pd.Series: """Return a mask with all entries set to ``True``.""" return pd.Series([True] * len(df)) def _get_numeral_with_sub_value(key: str) -> float: """Get the value of a Roman numeral with an optional sublevel. >>> _get_numeral_with_sub_value("I") 1.0 >>> _get_numeral_with_sub_value("IIa") 2.01 >>> _get_numeral_with_sub_value("IXb") 9.02 """ match = re.match(r"([IVXLCDM]+)([a-z]?)", key) if match is None: raise ValueError(f"Invalid Roman numeral with sublevel: {key}") numeral, sublvl = match.groups() base = roman_to_int(numeral) addition = 0.0 if len(sublvl) == 1: addition = "abcdefghijklmnopqrstuvwxyz".index(sublvl) / 100.0 + 0.01 return base + addition def _top_lvl_cmp(left: str, right: str) -> int: """Compare two top-level column names.""" if left == right: return 0 if left == "patient": return -1 if right == "patient": return 1 if left == "tumor": return -1 if right == "tumor": return 1 if left == "max_llh": return -1 if right == "max_llh": return 1 return (left > right) - (left < right) def _mid_lvl_cmp(left: str, right: str) -> int: """Compare two mid-level column names.""" if left == right: return 0 if left == "core": return -1 if right == "core": return 1 return (left > right) - (left < right) def _lnl_cmp(left: str, right: str) -> int: """Compare two roman numeral LNLs.""" try: left_value = _get_numeral_with_sub_value(left) right_value = _get_numeral_with_sub_value(right) return (left_value > right_value) - (left_value < right_value) except ValueError: if "id" in left: return -1 if "id" in right: return 1 return (left > right) - (left < right) def _sort_by( dataset: pd.DataFrame, which: Literal["top", "mid", "lnl"], level: int | None = None, ) -> pd.DataFrame: """Sort the DataFrame columns by the specified level.""" if level is None: level = ["top", "mid", "lnl"].index(which) cmps = { "top": _top_lvl_cmp, "mid": _mid_lvl_cmp, "lnl": _lnl_cmp, } if which not in cmps: raise ValueError(f"Invalid sorting level: {which} ('top', 'mid', or 'lnl').") if level < 0 or level > 2: raise ValueError(f"Invalid level: {level} (must be 0, 1, or 2).") columns = dataset.columns.get_level_values(level).unique() sorted_columns = sorted(columns, key=cmp_to_key(cmps[which])) return dataset.reindex(columns=sorted_columns, level=level) def _sort_all(dataset: pd.DataFrame) -> pd.DataFrame: """Use the custom sorting to sort the DataFrame columns by all levels.""" dataset = _sort_by(dataset, "lnl", level=2) dataset = _sort_by(dataset, "mid", level=1) return _sort_by(dataset, "top", level=0)