Source code for vivainsights.create_survival

# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

"""
create_survival: Parameterized Kaplan-Meier survival workflow (calc + viz + wrapper).

Design goals
------------
- General-purpose: works with any HR attribute column (segments, org, region, etc.).
- Uses lifelines.KaplanMeierFitter if available; falls back to a NumPy implementation.
- Reuses the figure header styling used in other vivainsights visuals.
- Returns either a plot or a table.

The typical workflow starts with `create_survival_prep()` to convert panel data
into the person-level format expected here.

Example
-------
Single overall curve (no grouping):

>>> import vivainsights as vi
>>> from vivainsights.create_survival import create_survival
>>> from vivainsights.create_survival_prep import create_survival_prep
>>>
>>> pq_data = vi.load_pq_data()
>>> surv_data = create_survival_prep(
...     data=pq_data,
...     metric="Copilot_actions_taken_in_Teams",
... )
>>> fig = create_survival(
...     data=surv_data,
...     time_col="time",
...     event_col="event",
... )

Grouped by HR attribute:

>>> fig = create_survival(
...     data=surv_data,
...     time_col="time",
...     event_col="event",
...     hrvar="Organization",
... )

Table output:

>>> tbl = create_survival(
...     data=surv_data,
...     time_col="time",
...     event_col="event",
...     hrvar="Organization",
...     return_type="table",
... )
"""

from typing import List, Optional, Tuple, Literal, Sequence, Union

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

# ---------- lifelines (optional) ----------
try:
    from lifelines import KaplanMeierFitter
    _HAS_LIFELINES = True
except ImportError:  # pragma: no cover
    _HAS_LIFELINES = False

# ---------- Style / header helpers (aligned with create_radar / Lorenz) ----------
try:
    from vivainsights.color_codes import Colors
    _HIGHLIGHT = Colors.HIGHLIGHT_NEGATIVE.value
except Exception:  # pragma: no cover - fallback
    _HIGHLIGHT = "#fe7f4f"

_TITLE_Y = 0.955
_SUB_Y = 0.915
_RULE_Y = 0.900
_TOP_LIMIT = 0.80

__all__ = ["create_survival_calc", "create_survival_viz", "create_survival"]


def _axes_left(fig) -> float:
    """Return the left x (figure coords) of the leftmost axes; fallback to 0.01."""
    axs = fig.get_axes()
    if axs:
        try:
            return float(min(ax.get_position().x0 for ax in axs))
        except Exception:
            pass
    return 0.01


def _retitle_left(fig, title_text: Optional[str], subtitle_text: Optional[str] = None) -> None:
    """Figure-level title/subtitle aligned to the plotting axes' left edge."""
    for ax in fig.get_axes():
        try:
            ax.set_title("")
        except Exception:
            pass
    if getattr(fig, "_suptitle", None) is not None:
        fig._suptitle.set_visible(False)
    x = _axes_left(fig)
    if title_text:
        fig.text(x, _TITLE_Y, title_text, ha="left", fontsize=13, weight="bold", alpha=.8)
    if subtitle_text:
        fig.text(x, _SUB_Y, subtitle_text, ha="left", fontsize=11, alpha=.8)


def _add_header_decoration(fig, color: str = _HIGHLIGHT, y: float = _RULE_Y) -> None:
    """Colored rule + small box starting approximately at the axes-left anchor."""
    x_start = _axes_left(fig)
    overlay = fig.add_axes([0, 0, 1, 1], frameon=False, zorder=10)
    overlay.set_axis_off()
    overlay.add_line(
        Line2D([x_start, 1.0], [y, y], transform=overlay.transAxes, color=color, linewidth=1.2)
    )
    overlay.add_patch(
        plt.Rectangle(
            (x_start, y),
            0.03,
            -0.015,
            transform=overlay.transAxes,
            facecolor=color,
            linewidth=0,
        )
    )


def _reserve_header_space(fig, top: float = _TOP_LIMIT) -> None:
    """Push axes down so the header never overlaps."""
    try:
        if hasattr(fig, "get_constrained_layout") and fig.get_constrained_layout():
            fig.set_constrained_layout(False)
    except Exception:
        pass
    fig.subplots_adjust(top=top)


# ---------- Event coercion helper ----------
def _coerce_event(x: pd.Series) -> pd.Series:
    """
    Coerce an event indicator column to integer 0/1.

    Accepts:
    - Numeric: any value > 0 is treated as 1 (event occurred).
    - Boolean: True -> 1, False -> 0.
    - String: "true"/"yes"/"1" -> 1; "false"/"no"/"0" -> 0 (case-insensitive).

    Raises
    ------
    ValueError
        If the column contains unrecognised string tokens.
    """
    if pd.api.types.is_bool_dtype(x):
        return x.astype("Int64")

    if pd.api.types.is_numeric_dtype(x):
        na_mask = x.isna()
        result = (x > 0).astype("Int64")
        if na_mask.any():
            result[na_mask] = pd.NA
        return result

    # Object dtype: may hold a mix of numeric values and strings.
    # Try numeric coercion first; parse remaining non-numeric values as tokens.
    numeric = pd.to_numeric(x, errors="coerce")
    if numeric.notna().all():
        return (numeric > 0).astype(int)

    _true_tokens = {"true", "yes", "1"}
    _false_tokens = {"false", "no", "0"}

    result = pd.array([pd.NA] * len(x), dtype="Int64")
    for i, (raw, num) in enumerate(zip(x, numeric)):
        if pd.notna(num):
            result[i] = int(num > 0)
        elif pd.isna(raw):
            result[i] = pd.NA
        else:
            token = str(raw).strip().lower()
            if token in _true_tokens:
                result[i] = 1
            elif token in _false_tokens:
                result[i] = 0
            else:
                raise ValueError(
                    f"Unrecognised event indicator value: {raw!r}. "
                    f"Expected numeric, bool, or one of: "
                    f"{sorted(_true_tokens | _false_tokens)}."
                )
    return pd.Series(result, index=x.index)


# ---------- Pure NumPy Kaplan-Meier (fallback if lifelines absent) ----------
def _km_numpy(
    durations: np.ndarray,
    events: np.ndarray,
    timeline: Optional[np.ndarray] = None,
) -> pd.DataFrame:
    """
    Compute an unweighted Kaplan-Meier curve with tied events handled at the same time.

    Parameters
    ----------
    durations : array-like
        Event or censoring times (>=0).
    events : array-like
        1 = event observed, 0 = censored.
    timeline : array-like, optional
        Times at which to evaluate the survival function. If None, use sorted unique
        durations.

    Returns
    -------
    pd.DataFrame with columns ["time", "at_risk", "events", "survival"].
    """
    durations = np.asarray(durations, dtype=float)
    events = np.asarray(events, dtype=int)
    mask = ~np.isnan(durations) & ~np.isnan(events)
    durations, events = durations[mask], events[mask]

    if timeline is None:
        timeline = np.sort(np.unique(durations))
    else:
        timeline = np.asarray(timeline, dtype=float)

    survival, at_risk_seq, events_seq = [], [], []
    S = 1.0
    for t in timeline:
        at_risk = np.sum(durations >= t)
        d_t = np.sum((durations == t) & (events == 1))
        at_risk_seq.append(int(at_risk))
        events_seq.append(int(d_t))
        if at_risk > 0:
            S *= (1.0 - d_t / at_risk)
        survival.append(S)

    n_total = int(np.sum(mask))
    anchor = pd.DataFrame(
        {"time": [0.0], "at_risk": [n_total], "events": [0], "survival": [1.0]}
    )
    result = pd.DataFrame(
        {
            "time": timeline,
            "at_risk": at_risk_seq,
            "events": events_seq,
            "survival": survival,
        }
    )
    return pd.concat([anchor, result], ignore_index=True)


# =========================
# 1) CALC
# =========================
[docs] def create_survival_calc( data: pd.DataFrame, time_col: str, event_col: str, hrvar: Optional[str] = None, id_col: Optional[str] = "PersonId", mingroup: int = 5, timeline: Optional[Sequence[float]] = None, dropna: bool = True, use_lifelines: bool = True, ) -> Tuple[pd.DataFrame, pd.Series]: """ Name ---- create_survival_calc Description ----------- Compute Kaplan-Meier survival curves per group (segment, org, etc.). Uses lifelines.KaplanMeierFitter when available (and `use_lifelines=True`), otherwise falls back to a simple NumPy implementation. The `event_col` is coerced to integer 0/1 via `_coerce_event`, which accepts numeric (>0 = event), boolean, or string tokens ("true"/"yes"/"1"). Parameters ---------- data : pd.DataFrame Person-level data frame (one row per subject), as produced by `create_survival_prep()`, containing `time_col`, `event_col`, and optionally `hrvar`. time_col : str Column containing durations to event or censoring (numeric, e.g., weeks). event_col : str Event indicator column. Accepts numeric (>0 = event), boolean, or string tokens ("true"/"yes"/"1", "false"/"no"/"0"). hrvar : str or None, default None HR attribute column for grouping. If None, a single overall curve is returned. id_col : str or None, default "PersonId" Unique subject identifier used for `mingroup` counting. If None or not present, the row count per group is used instead. mingroup : int, default 5 Minimum unique subjects required per group; groups with fewer are dropped. timeline : sequence of float, optional Common set of times at which to report survival. If None, per-group unique times are used. dropna : bool, default True Drop rows with NA in required columns before computing curves. use_lifelines : bool, default True If True and lifelines is available, use KaplanMeierFitter; otherwise, use NumPy. Returns ------- survival_long : pd.DataFrame Long-format table with columns [``hrvar`` (or ``"group"`` when ungrouped), ``"time"``, ``"survival"``, ``"at_risk"``, ``"events"``]. counts : pd.Series Number of unique subjects per group (after filtering). """ required = [time_col, event_col] + ([hrvar] if hrvar else []) missing = [c for c in required if (c is not None and c not in data.columns)] if missing: raise KeyError(f"Missing required column(s): {missing}") df = data.copy() # Coerce event column to 0/1 df[event_col] = _coerce_event(df[event_col]) if dropna: df = df.dropna(subset=[c for c in required if c is not None]) else: # Always drop NAs in time/event regardless of dropna setting — # .to_numpy(dtype=int/float) cannot handle pd.NA or NaN. df = df.dropna(subset=[time_col, event_col]) # Group sizes for mingroup if hrvar: if id_col and (id_col in df.columns): counts = df.groupby(hrvar)[id_col].nunique() else: counts = df.groupby(hrvar)[time_col].size() keep = counts[counts >= mingroup].index df = df[df[hrvar].isin(keep)].copy() counts = counts[counts.index.isin(keep)] else: if id_col and id_col in df.columns: counts = pd.Series({"Overall": df[id_col].nunique()}) else: counts = pd.Series({"Overall": len(df)}) # If nothing left, return empty if df.empty: col_name = hrvar if hrvar else "group" empty = pd.DataFrame(columns=[col_name, "time", "survival", "at_risk", "events"]) return empty, counts # Build curves curves = [] if hrvar is None: groups = [None] else: groups = sorted(df[hrvar].astype(str).unique()) for g in groups: if hrvar is None: sdf = df else: sdf = df[df[hrvar].astype(str) == g] if sdf.empty: continue durs = sdf[time_col].to_numpy(dtype=float) evts = sdf[event_col].to_numpy(dtype=int) tl = None if timeline is None else np.asarray(timeline, dtype=float) if use_lifelines and _HAS_LIFELINES: kmf = KaplanMeierFitter() kmf.fit(durations=durs, event_observed=evts, timeline=tl) out = pd.DataFrame( { "time": kmf.survival_function_.index.values.astype(float), "survival": kmf.survival_function_["KM_estimate"].values.astype(float), } ) et = ( kmf.event_table.reset_index() .rename(columns={"observed": "events", "event_at": "time"}) ) out = out.merge(et[["time", "at_risk", "events"]], on="time", how="left") else: out = _km_numpy(durs, evts, tl) out[hrvar if hrvar else "group"] = g if g is not None else "Overall" curves.append(out) survival_long = pd.concat(curves, ignore_index=True) grp_col = hrvar if hrvar else "group" survival_long = survival_long[[grp_col, "time", "survival", "at_risk", "events"]] return survival_long, counts
# ========================= # 2) VIZ # =========================
[docs] def create_survival_viz( data: pd.DataFrame, hrvar: str, figsize: Tuple[float, float] = (8, 6), title: Optional[str] = None, subtitle: Optional[str] = None, caption: Optional[str] = None, linewidth: float = 2.0, ) -> plt.Figure: """ Name ---- create_survival_viz Description ----------- Render Kaplan-Meier survival step curves for each group in `data`. Parameters ---------- data : pd.DataFrame Output of `create_survival_calc`, with at least [hrvar, "time", "survival"]. hrvar : str Column name identifying the groups to plot. figsize : tuple of float, default (8, 6) Matplotlib figure size in inches (width, height). title : str, optional Figure-level title. subtitle : str, optional Smaller line beneath the title. caption : str, optional Small text near the bottom of the figure (e.g., date range). linewidth : float, default 2.0 Line width for the step curves. Returns ------- fig : matplotlib.figure.Figure The constructed matplotlib Figure. """ fig, ax = plt.subplots(figsize=figsize) if data.empty: ax.set_xlabel("Time") ax.set_ylabel("Survival probability") ax.set_ylim(0, 1) else: groups = list(data[hrvar].astype(str).unique()) for g in groups: sdf = data[data[hrvar].astype(str) == g] if sdf.empty: continue x = sdf["time"].to_numpy(dtype=float) y = sdf["survival"].to_numpy(dtype=float) ax.step(x, y, where="post", linewidth=linewidth, label=g) ax.set_xlabel("Time") ax.set_ylabel("Survival probability") ax.set_ylim(0, 1) ax.grid(True, alpha=0.3) # Caption if caption: fig.text(0.01, 0.01, caption, ha="left", va="center", fontsize=9) # Legend (outside on the right) — only when labeled artists exist if ax.get_legend_handles_labels()[0]: ax.legend(loc="upper right", bbox_to_anchor=(1.3, 1.1)) # Layout + header styling plt.tight_layout() _retitle_left(fig, title, subtitle) _add_header_decoration(fig) _reserve_header_space(fig) return fig
# ========================= # 3) WRAPPER # ========================= ReturnType = Literal["plot", "table"]
[docs] def create_survival( data: pd.DataFrame, time_col: str, event_col: str, hrvar: Optional[str] = None, id_col: str = "PersonId", mingroup: int = 5, timeline: Optional[Sequence[float]] = None, dropna: bool = True, use_lifelines: bool = True, return_type: ReturnType = "plot", figsize: Tuple[float, float] = (8, 6), title: Optional[str] = None, subtitle: Optional[str] = None, caption: Optional[str] = None, ) -> Union[plt.Figure, pd.DataFrame]: """ Name ---- create_survival Description ----------- High-level convenience wrapper to compute Kaplan-Meier curves and either: (a) return the long survival table (return_type="table"), or (b) render the survival plot (return_type="plot"). The input `data` should be a person-level data frame (one row per person) as produced by `create_survival_prep()`. Parameters ---------- data : pd.DataFrame Person-level data frame (one row per person), as produced by `create_survival_prep()`, containing at least `time_col` and `event_col`. time_col : str Duration-to-event column. event_col : str Event indicator column. Accepts numeric (>0 = event), boolean, or string tokens ("true"/"yes"/"1", "false"/"no"/"0"). hrvar : str or None, default None HR attribute column for separate survival curves. See "Grouping behavior". id_col : str, default "PersonId" Unique subject identifier for mingroup counting. mingroup : int, default 5 Minimum number of unique subjects per group. timeline : sequence of float, optional Times at which to report survival. dropna : bool, default True Drop rows with NAs in required columns prior to calculation. use_lifelines : bool, default True Use lifelines.KaplanMeierFitter when available. return_type : {"plot","table"}, default "plot" - "plot": return a matplotlib Figure. - "table": return the survival-long DataFrame. figsize : tuple of float, default (8, 6) Figure size in inches (only used when return_type="plot"). title : str, optional Plot title. If None, a default is used. subtitle : str, optional Optional subtitle beneath the title. caption : str, optional Caption text shown at the bottom of the figure. Note: the typical input (output of `create_survival_prep`) contains no date column, so date ranges cannot be extracted automatically. Pass the date range string manually if needed, e.g. via `vi.extract_date_range(raw_data)`. Returns ------- matplotlib.figure.Figure or pd.DataFrame - If return_type="plot": a Figure containing the survival curves. - If return_type="table": the long survival table. """ if return_type not in ("plot", "table"): raise ValueError(f"return_type must be 'plot' or 'table', got {return_type!r}.") df = data.copy() hrvar_for_calc: Optional[str] = hrvar if hrvar is not None and hrvar not in df.columns: raise KeyError(f"hrvar '{hrvar}' not found in data.") # Compute survival curves survival_long, counts = create_survival_calc( data=df, time_col=time_col, event_col=event_col, hrvar=hrvar_for_calc, id_col=id_col, mingroup=mingroup, timeline=timeline, dropna=dropna, use_lifelines=use_lifelines, ) if return_type == "table": return survival_long.reset_index(drop=True) # Titles if title is None: if hrvar_for_calc is None: base_title = "Survival Curve" else: base_title = "Survival Curve by Group" else: base_title = title if subtitle is None: if hrvar_for_calc is None: subtitle_effective = "Kaplan\u2013Meier estimate" else: subtitle_effective = f"Kaplan\u2013Meier estimate by {hrvar_for_calc}" else: subtitle_effective = subtitle hrvar_for_viz = hrvar_for_calc if hrvar_for_calc is not None else "group" fig = create_survival_viz( data=survival_long, hrvar=hrvar_for_viz, figsize=figsize, title=base_title, subtitle=subtitle_effective, caption=caption, ) return fig