Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 40 additions & 41 deletions lifelines/fitters/cox_time_varying_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,72 +515,71 @@ def _get_gradients(X, events, start, stop, weights, beta): # pylint: disable=to

for t in unique_death_times:

# I feel like this can be made into some tree-like structure
ix = (start < t) & (t <= stop)

# Extract all needed arrays at once to reduce indexing overhead
X_at_t = X[ix]
weights_at_t = weights[ix]
stops_events_at_t = stop[ix]
events_at_t = events[ix]
stops_at_t = stop[ix]

phi_i = weights_at_t * np.exp(np.dot(X_at_t, beta))
phi_x_i = phi_i[:, None] * X_at_t
phi_x_x_i = np.dot(X_at_t.T, phi_x_i)
# Pre-compute exp(X*beta) to avoid recalculation
exp_xb = np.exp(X_at_t @ beta)
phi_i = weights_at_t * exp_xb

# Vectorized computation of phi_x_i avoiding broadcasting overhead
phi_x_i = phi_i.reshape(-1, 1) * X_at_t

# Calculate sums of Risk set
risk_phi = array_sum_to_scalar(phi_i)
risk_phi_x = matrix_axis_0_sum_to_1d_array(phi_x_i)
risk_phi_x_x = phi_x_x_i
risk_phi = np.sum(phi_i)
risk_phi_x = np.sum(phi_x_i, axis=0)
risk_phi_x_x = X_at_t.T @ phi_x_i

# Calculate the sums of Tie set
deaths = events_at_t & (stops_events_at_t == t)
deaths = events_at_t & (stops_at_t == t)

tied_death_counts = array_sum_to_scalar(deaths.astype(int)) # should always at least 1. Why? TODO
tied_death_counts = np.sum(deaths)

xi_deaths = X_at_t[deaths]
# Early exit if no deaths at this time (shouldn't happen but safety check)
if tied_death_counts == 0:
continue

x_death_sum = matrix_axis_0_sum_to_1d_array(weights_at_t[deaths, None] * xi_deaths)
# Optimize death-related calculations using boolean indexing
deaths_weights = weights_at_t[deaths]
X_deaths = X_at_t[deaths]

weight_count = array_sum_to_scalar(weights_at_t[deaths])
x_death_sum = np.sum(deaths_weights.reshape(-1, 1) * X_deaths, axis=0)
weight_count = np.sum(deaths_weights)
weighted_average = weight_count / tied_death_counts

#
# This code is near identical to the _batch algorithm in CoxPHFitter. In fact, see _batch for comments.
#

if tied_death_counts > 1:

# A good explanation for how Efron handles ties. Consider three of five subjects who fail at the time.
# As it is not known a priori that who is the first to fail, so one-third of
# (φ1 + φ2 + φ3) is adjusted from sum_j^{5} φj after one fails. Similarly two-third
# of (φ1 + φ2 + φ3) is adjusted after first two individuals fail, etc.
# Pre-compute tie values to avoid repeated indexing
phi_deaths = phi_i[deaths]
phi_x_deaths = phi_x_i[deaths]

# a lot of this is now in Einstein notation for performance, but see original "expanded" code here
# https://github.com/CamDavidsonPilon/lifelines/blob/e7056e7817272eb5dff5983556954f56c33301b1/lifelines/fitters/cox_time_varying_fitter.py#L458-L490
tie_phi = np.sum(phi_deaths)
tie_phi_x = np.sum(phi_x_deaths, axis=0)
tie_phi_x_x = X_deaths.T @ (phi_deaths.reshape(-1, 1) * X_deaths)

tie_phi = array_sum_to_scalar(phi_i[deaths])
tie_phi_x = matrix_axis_0_sum_to_1d_array(phi_x_i[deaths])
tie_phi_x_x = np.dot(xi_deaths.T, phi_i[deaths, None] * xi_deaths)

increasing_proportion = np.arange(tied_death_counts) / tied_death_counts
denom = 1.0 / (risk_phi - increasing_proportion * tie_phi)
increasing_proportion = np.arange(tied_death_counts, dtype=np.float64) / tied_death_counts
denom_inv = risk_phi - increasing_proportion * tie_phi
denom = 1.0 / denom_inv
numer = risk_phi_x - np.outer(increasing_proportion, tie_phi_x)

a1 = np.einsum("ab, i->ab", risk_phi_x_x, denom) - np.einsum(
"ab, i->ab", tie_phi_x_x, increasing_proportion * denom
)
# More efficient einsum operations
a1 = risk_phi_x_x * denom.reshape(1, 1, -1).sum(axis=2) - tie_phi_x_x * (increasing_proportion * denom).sum()
else:
# no tensors here, but do some casting to make it easier in the converging step next.
denom = 1.0 / np.array([risk_phi])
numer = risk_phi_x
a1 = risk_phi_x_x * denom
denom = np.array([1.0 / risk_phi])
numer = risk_phi_x.reshape(1, -1)
a1 = risk_phi_x_x * denom[0]

summand = numer * denom[:, None]
a2 = summand.T.dot(summand)
summand = numer * denom.reshape(-1, 1)
a2 = summand.T @ summand

gradient = gradient + x_death_sum - weighted_average * summand.sum(0)
log_lik = log_lik + np.dot(x_death_sum, beta) + weighted_average * np.log(denom).sum()
hessian = hessian + weighted_average * (a2 - a1)
gradient += x_death_sum - weighted_average * np.sum(summand, axis=0)
log_lik += x_death_sum @ beta + weighted_average * np.sum(np.log(denom))
hessian += weighted_average * (a2 - a1)

return hessian, gradient, log_lik

Expand Down
Loading