diff --git a/lifelines/fitters/cox_time_varying_fitter.py b/lifelines/fitters/cox_time_varying_fitter.py index a1089073f..a6fd01bb6 100644 --- a/lifelines/fitters/cox_time_varying_fitter.py +++ b/lifelines/fitters/cox_time_varying_fitter.py @@ -515,72 +515,71 @@ def _get_gradients(X, events, start, stop, weights, beta): # pylint: disable=to for t in unique_death_times: - # I feel like this can be made into some tree-like structure ix = (start < t) & (t <= stop) + # Extract all needed arrays at once to reduce indexing overhead X_at_t = X[ix] weights_at_t = weights[ix] - stops_events_at_t = stop[ix] events_at_t = events[ix] + stops_at_t = stop[ix] - phi_i = weights_at_t * np.exp(np.dot(X_at_t, beta)) - phi_x_i = phi_i[:, None] * X_at_t - phi_x_x_i = np.dot(X_at_t.T, phi_x_i) + # Pre-compute exp(X*beta) to avoid recalculation + exp_xb = np.exp(X_at_t @ beta) + phi_i = weights_at_t * exp_xb + + # Vectorized computation of phi_x_i avoiding broadcasting overhead + phi_x_i = phi_i.reshape(-1, 1) * X_at_t # Calculate sums of Risk set - risk_phi = array_sum_to_scalar(phi_i) - risk_phi_x = matrix_axis_0_sum_to_1d_array(phi_x_i) - risk_phi_x_x = phi_x_x_i + risk_phi = np.sum(phi_i) + risk_phi_x = np.sum(phi_x_i, axis=0) + risk_phi_x_x = X_at_t.T @ phi_x_i # Calculate the sums of Tie set - deaths = events_at_t & (stops_events_at_t == t) + deaths = events_at_t & (stops_at_t == t) - tied_death_counts = array_sum_to_scalar(deaths.astype(int)) # should always at least 1. Why? TODO + tied_death_counts = np.sum(deaths) - xi_deaths = X_at_t[deaths] + # Early exit if no deaths at this time (shouldn't happen but safety check) + if tied_death_counts == 0: + continue - x_death_sum = matrix_axis_0_sum_to_1d_array(weights_at_t[deaths, None] * xi_deaths) + # Optimize death-related calculations using boolean indexing + deaths_weights = weights_at_t[deaths] + X_deaths = X_at_t[deaths] - weight_count = array_sum_to_scalar(weights_at_t[deaths]) + x_death_sum = np.sum(deaths_weights.reshape(-1, 1) * X_deaths, axis=0) + weight_count = np.sum(deaths_weights) weighted_average = weight_count / tied_death_counts - # - # This code is near identical to the _batch algorithm in CoxPHFitter. In fact, see _batch for comments. - # - if tied_death_counts > 1: - # A good explanation for how Efron handles ties. Consider three of five subjects who fail at the time. - # As it is not known a priori that who is the first to fail, so one-third of - # (φ1 + φ2 + φ3) is adjusted from sum_j^{5} φj after one fails. Similarly two-third - # of (φ1 + φ2 + φ3) is adjusted after first two individuals fail, etc. + # Pre-compute tie values to avoid repeated indexing + phi_deaths = phi_i[deaths] + phi_x_deaths = phi_x_i[deaths] - # a lot of this is now in Einstein notation for performance, but see original "expanded" code here - # https://github.com/CamDavidsonPilon/lifelines/blob/e7056e7817272eb5dff5983556954f56c33301b1/lifelines/fitters/cox_time_varying_fitter.py#L458-L490 + tie_phi = np.sum(phi_deaths) + tie_phi_x = np.sum(phi_x_deaths, axis=0) + tie_phi_x_x = X_deaths.T @ (phi_deaths.reshape(-1, 1) * X_deaths) - tie_phi = array_sum_to_scalar(phi_i[deaths]) - tie_phi_x = matrix_axis_0_sum_to_1d_array(phi_x_i[deaths]) - tie_phi_x_x = np.dot(xi_deaths.T, phi_i[deaths, None] * xi_deaths) - - increasing_proportion = np.arange(tied_death_counts) / tied_death_counts - denom = 1.0 / (risk_phi - increasing_proportion * tie_phi) + increasing_proportion = np.arange(tied_death_counts, dtype=np.float64) / tied_death_counts + denom_inv = risk_phi - increasing_proportion * tie_phi + denom = 1.0 / denom_inv numer = risk_phi_x - np.outer(increasing_proportion, tie_phi_x) - a1 = np.einsum("ab, i->ab", risk_phi_x_x, denom) - np.einsum( - "ab, i->ab", tie_phi_x_x, increasing_proportion * denom - ) + # More efficient einsum operations + a1 = risk_phi_x_x * denom.reshape(1, 1, -1).sum(axis=2) - tie_phi_x_x * (increasing_proportion * denom).sum() else: - # no tensors here, but do some casting to make it easier in the converging step next. - denom = 1.0 / np.array([risk_phi]) - numer = risk_phi_x - a1 = risk_phi_x_x * denom + denom = np.array([1.0 / risk_phi]) + numer = risk_phi_x.reshape(1, -1) + a1 = risk_phi_x_x * denom[0] - summand = numer * denom[:, None] - a2 = summand.T.dot(summand) + summand = numer * denom.reshape(-1, 1) + a2 = summand.T @ summand - gradient = gradient + x_death_sum - weighted_average * summand.sum(0) - log_lik = log_lik + np.dot(x_death_sum, beta) + weighted_average * np.log(denom).sum() - hessian = hessian + weighted_average * (a2 - a1) + gradient += x_death_sum - weighted_average * np.sum(summand, axis=0) + log_lik += x_death_sum @ beta + weighted_average * np.sum(np.log(denom)) + hessian += weighted_average * (a2 - a1) return hessian, gradient, log_lik