Source code for vivainsights.create_survival_prep

# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

"""
create_survival_prep: Convert Standard Person Query panel data into person-level
survival analysis format (time-to-event + event indicator).

This is typically used as the first step before calling `create_survival()`.

Example
-------
>>> import vivainsights as vi
>>> from vivainsights.create_survival_prep import create_survival_prep
>>>
>>> pq_data = vi.load_pq_data()
>>> surv_data = create_survival_prep(
...     data=pq_data,
...     metric="Copilot_actions_taken_in_Teams",
... )
>>> surv_data.head()

With a custom event condition and HR attribute:

>>> surv_data = create_survival_prep(
...     data=pq_data,
...     metric="Copilot_actions_taken_in_Teams",
...     event_condition=lambda x: x >= 10,
...     hrvar="LevelDesignation",
... )

Pass the output directly to create_survival:

>>> import vivainsights as vi
>>> pq_data = vi.load_pq_data()
>>> surv_data = vi.create_survival_prep(pq_data, metric="Copilot_actions_taken_in_Teams")
>>> fig = vi.create_survival(surv_data, time_col="time", event_col="event", hrvar="Organization")
"""

from typing import Callable, Optional

import pandas as pd

__all__ = ["create_survival_prep"]


[docs] def create_survival_prep( data: pd.DataFrame, metric: str, event_condition: Callable[[pd.Series], pd.Series] = lambda x: x > 0, hrvar: Optional[str] = "Organization", id_col: str = "PersonId", date_col: str = "MetricDate", ) -> pd.DataFrame: """ Name ---- create_survival_prep Description ----------- Convert a Standard Person Query panel dataset (multiple rows per person, one per period/week) into a person-level survival analysis table suitable for use with `create_survival()` or `create_survival_calc()`. For each person the function determines: - **time**: the number of observed periods until the event first occurred, or the total number of periods observed if the event never occurred (censored). - **event**: 1 if ``event_condition`` was satisfied in at least one period, 0 if censored (condition never met within the observation window). The HR attribute column (`hrvar`) is carried through using the most recently observed value per person (last row after sorting by `date_col`). Parameters ---------- data : pd.DataFrame Standard Person Query panel data. One row per person per period. metric : str Numeric metric column to evaluate against `event_condition`. event_condition : callable, default ``lambda x: x > 0`` A function that accepts a pandas Series of metric values and returns a boolean Series. The event is considered to have occurred at the first period where this condition is True. Examples: - ``lambda x: x > 0`` — any non-zero activity (default) - ``lambda x: x >= 10`` — at least 10 actions in a period hrvar : str or None, default "Organization" HR attribute column to carry through into the output (most recent observed value per person). Set to None to omit. id_col : str, default "PersonId" Column uniquely identifying each person. date_col : str, default "MetricDate" Date/period column used to sort rows chronologically before computing the time-to-event. If absent, the existing row order is preserved. Returns ------- pd.DataFrame One row per person with columns: - ``id_col`` (e.g. "PersonId") - ``"time"`` — periods until event, or total observed periods if censored - ``"event"`` — 1 (event occurred) or 0 (censored) - ``hrvar`` column, if supplied and present in ``data`` Raises ------ KeyError If `metric` or `id_col` is not found in `data`. ValueError If `event_condition` does not return a boolean-compatible Series. Notes ----- This function mirrors ``create_survival_prep()`` in the R vivainsights package. The typical workflow is:: surv_data = create_survival_prep(pq_data, metric="Copilot_actions_taken_in_Teams") fig = create_survival(surv_data, time_col="time", event_col="event") Examples -------- >>> import vivainsights as vi >>> pq_data = vi.load_pq_data() >>> surv_data = create_survival_prep( ... data=pq_data, ... metric="Copilot_actions_taken_in_Teams", ... ) >>> surv_data.head() """ # Validate required columns missing = [c for c in [id_col, metric] if c not in data.columns] if missing: raise KeyError(f"Missing required column(s): {missing}") # Validate event_condition try: test_result = event_condition(data[metric].iloc[:1]) if not hasattr(test_result, "dtype"): raise TypeError("event_condition must return a pandas Series.") bool(test_result.iloc[0]) except (TypeError, AttributeError) as exc: raise ValueError( "`event_condition` must accept a pandas Series and return a boolean-compatible " f"Series. Got error: {exc}" ) from exc # Sort chronologically if date column exists df = data.copy() if date_col in df.columns: df = df.sort_values([id_col, date_col]).reset_index(drop=True) # Apply event_condition df["_event_flag"] = event_condition(df[metric]).astype(bool) # 1-based chronological period number within each person df["_period"] = df.groupby(id_col, sort=False).cumcount() + 1 # Whether the event ever occurred per person event_series = df.groupby(id_col, sort=False)["_event_flag"].any().astype(int) # First period where event occurred (present only for people who had an event) first_event_period = ( df.loc[df["_event_flag"]] .groupby(id_col, sort=False)["_period"] .min() ) # Total observed periods per person (used for censored observations) total_periods = df.groupby(id_col, sort=False)["_period"].max() # time = first event period if event occurred, else total observed periods time_series = first_event_period.combine_first(total_periods).astype(int) result = pd.DataFrame({"time": time_series, "event": event_series}) result.index.name = id_col result = result.reset_index() # Carry through hrvar: most recent non-null value per person if hrvar and hrvar in df.columns: hrvar_series = ( df.dropna(subset=[hrvar]) .groupby(id_col, sort=False)[hrvar] .last() .rename(hrvar) ) result = result.merge(hrvar_series, on=id_col, how="left") # Column ordering: id, time, event, [hrvar] cols = [id_col, "time", "event"] if hrvar and hrvar in result.columns: cols.append(hrvar) return result[cols].reset_index(drop=True)