Skip to content

Commit

Permalink
added legacy suport
Browse files Browse the repository at this point in the history
  • Loading branch information
aGuyLearning committed Nov 20, 2024
1 parent 95c2b74 commit 6085e96
Showing 1 changed file with 56 additions and 16 deletions.
72 changes: 56 additions & 16 deletions ehrapy/tools/_sa.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Literal

import numpy as np # This package is implicitly used
Expand Down Expand Up @@ -127,6 +128,8 @@ def kmf(
weights: list[float] | None = None,
fit_options: dict | None = None,
censoring: Literal["right", "left"] = "right",
durations: Iterable | None = None,
event_observed: Iterable | None = None,
) -> KaplanMeierFitter:
"""Fit the Kaplan-Meier estimate for the survival function.
Expand All @@ -151,6 +154,9 @@ def kmf(
fit_options: Additional keyword arguments to pass into the estimator.
censoring: 'right' for fitting the model to a right-censored dataset. (default, calls fit).
'left' for fitting the model to a left-censored dataset (calls fit_left_censoring).
durations: length n -- duration (relative to subject's birth) the subject was alive for. (legacy argument, use duration_col instead)
event_observed: True if the death was observed, False if the event was lost (right-censored). Defaults to all True if event_observed is equal to `None`.(this is a legacy argument, use event_col instead)
Returns:
Fitted KaplanMeierFitter.
Expand All @@ -162,22 +168,55 @@ def kmf(
>>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
>>> kmf = ep.tl.kmf(adata, "mort_day_censored", "censor_flg", label="Mortality")
"""

return _univariate_model(
adata,
duration_col,
event_col,
KaplanMeierFitter,
True,
timeline,
entry,
label,
alpha,
ci_labels,
weights,
fit_options,
censoring,
)
# legacy support
if durations is not None:
# legacy warning
warnings.warn(
"The `durations` and `event_observed` arguments are deprecated, please use `duration_col` and `event_col` instead.",
DeprecationWarning,
stacklevel=2,
)
kmf = KaplanMeierFitter()
if censoring == "None" or "right":
kmf.fit(
durations=durations,
event_observed=event_observed,
timeline=timeline,
entry=entry,
label=label,
alpha=alpha,
ci_labels=ci_labels,
weights=weights,
)
elif censoring == "left":
kmf.fit_left_censoring(
durations=durations,
event_observed=event_observed,
timeline=timeline,
entry=entry,
label=label,
alpha=alpha,
ci_labels=ci_labels,
weights=weights,
)

return kmf
else:
return _univariate_model(
adata,
duration_col,
event_col,
KaplanMeierFitter,
True,
timeline,
entry,
label,
alpha,
ci_labels,
weights,
fit_options,
censoring,
)


def test_kmf_logrank(
Expand Down Expand Up @@ -461,6 +500,7 @@ def nelson_aalen(
>>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
>>> naf = ep.tl.nelson_aalen(adata, "mort_day_censored", "censor_flg")
"""

return _univariate_model(
adata,
duration_col,
Expand Down

0 comments on commit 6085e96

Please sign in to comment.