I'm confused how to interpret the output of .predict
from a fitted CoxnetSurvivalAnalysis
model in scikit-survival. I've read through the notebook Intro to Survival Analysis in scikit-survival and the API reference, but can't find an explanation. Below is a minimal example of what leads to my confusion:
import pandas as pd
from sksurv.datasets import load_veterans_lung_cancer
from sksurv.linear_model import CoxnetSurvivalAnalysis
# load data
data_X, data_y = load_veterans_lung_cancer()
# one-hot-encode categorical columns in X
categorical_cols = ['Celltype', 'Prior_therapy', 'Treatment']
X = data_X.copy()
for c in categorical_cols:
dummy_matrix = pd.get_dummies(X[c], prefix=c, drop_first=False)
X = pd.concat([X, dummy_matrix], axis=1).drop(c, axis=1)
# display final X to fit Cox Elastic Net model on
del data_X
print(X.head(3))
so here's the X going into the model:
Age_in_years Celltype Karnofsky_score Months_from_Diagnosis \
0 69.0 squamous 60.0 7.0
1 64.0 squamous 70.0 5.0
2 38.0 squamous 60.0 3.0
Prior_therapy Treatment
0 no standard
1 yes standard
2 no standard
...moving on to fitting model and generating predictions:
# Fit Model
coxnet = CoxnetSurvivalAnalysis()
coxnet.fit(X, data_y)
# What are these predictions?
preds = coxnet.predict(X)
preds
has same number of records as X
, but their values are wayyy different than the values in data_y
, even when predicted on the same data they were fit on.
print(preds.mean())
print(data_y['Survival_in_days'].mean())
output:
-0.044114643249153422
121.62773722627738
So what exactly are preds
? Clearly .predict
means something pretty different here than in scikit-learn, but I can't figure out what. The API Reference says it returns "The predicted decision function," but what does that mean? And how do I get to the predicted estimate in months yhat
for a given X
? I'm new to survival analysis so I'm obviously missing something.