HAZARD GRADIENT PENALTY FOR SURVIVAL ANALYSIS

Abstract

Survival analysis appears in various fields such as medicine, economics, engineering, and business. Recent studies showed that the Ordinary Differential Equation (ODE) modeling framework integrates many existing survival models while the framework is flexible and widely applicable. However, naively applying the ODE framework to survival analysis problems may model fiercely changing density function with respect to covariates which may worsen the model's performance. Though we can apply L1 or L2 regularizers to the ODE model, their effect on the ODE modeling framework is barely known. In this paper, we propose hazard gradient penalty (HGP) to enhance the performance of a survival analysis model. Our method imposes constraints on local data points by regularizing the gradient of hazard function with respect to the data point. Our method applies to any survival analysis model including the ODE modeling framework and is easy to implement. We theoretically show that our method is related to minimizing the KL divergence between the density function at a data point and that of the neighborhood points. Experimental results on three public benchmarks show that our approach outperforms other regularization methods.

1. INTRODUCTION

Survival analysis (a.k.a time-to-event modeling) is a branch of statistics that predicts the duration of time until an event occurs (Kleinbaum & Klein, 2012) . Survival analysis appears in various fields such as medicine (Schwab et al., 2021) , economics (Meyer, 1988) , engineering (O'Connor & Kleyner, 2011) , and business (Jing & Smola, 2017; Li et al., 2021) . Due to the presence of rightcensored data, which is data whose event has not occurred yet, survival analysis models require special considerations. Cox proportional hazard model (CoxPH) (Cox, 1972; Katzman et al., 2018) and accelerated time failure model (AFT) (Wei, 1992) are widely used to handle right-censored data. Yet the assumptions made by these models are frequently violated in the real world (Lee et al., 2018; Tang et al., 2022a) . Recent studies showed that the Ordinary Differential Equation (ODE) modeling framework integrates many existing survival analysis models including CoxPH and AFT (Groha et al., 2020; Tang et al., 2022a; b) . They also showed that the ODE modeling framework is flexible and widely applicable. However, naively applying the ODE framework to survival analysis problems may result in wildly oscillating density function that may worsen the model's performance. Regularization techniques that can regularize this undesirable behavior are understudied. Though applying L1 or L2 regularizers to the ODE model is one option, their effects on the ODE modeling framework are barely known. The cluster assumption from semi-supervised learning states that the decision boundaries should not cross high-density regions (Chapelle et al., 2006) . Likewise, survival analysis models need hazard functions that slowly change in high-density regions. Suppose we attempt to predict the time to death of three individuals A, B, and C. Assume the traits of A and B are similar and the traits of B and C are dissimilar. It is natural to expect that the probability distribution of time-to-death of A should be close to that of B while far from that of C. The expectation aligns with the cluster assumption. Explicitly modeling the assumption enhances the performance as long as it holds. In this paper, we propose hazard gradient penalty to make a slowly changing (with respect to covariates) survival analysis model in high-density regions. In a nutshell, the hazard gradient penalty regularizes the gradient of the hazard function with respect to the data point from the real data distribution. Our method has several advantages. 1) The method is computationally efficient. 2) The method is theoretically sound. 3) The method is applicable to any survival analysis model including the ODE modeling framework as long as it models hazard function. 4) It is easy to implement. We theoretically show that our method is related to minimizing the KL divergence between the density function at a data point and that of the neighborhood points of the data point. Experimental results on three public benchmarks show that our approach outperforms other regularization methods.

2. PRELIMINARIES

Survival analysis data comprises of an observed covariate x, a failure event time t, and an event indicator e. If an event is observed, t corresponds to the duration time from the beginning of the follow-up of an individual until the event occurs. In this case, the event indicator e = 1. If an event is unobserved, t corresponds to the duration time from the beginning of follow-up of an individual until the last follow-up. In this case, we cannot know the exact time of the event occur and event indicator e = 0. An individual is said to be right-censored if e = 0. The presence of right-censored data differentiates survival analysis from regression problems. In this paper, we only focus on the single-risk problem where event e is a binary-valued variable. Given a set of triplet D = {(x i , t i , e i )} N i=1 , the goal of survival analysis is to predict the likelihood of an event occur p(t | x) or the survival probability S(t | x). The likelihood and the survival probability have the following relationship: S(t | x) = 1 - t 0 p(τ | x)dτ Modeling p(t | x) or S(t | x) should satisfy the following constraints: p(t | x) > 0, ∞ 0 p(τ | x)dτ = 1 S(0 | x) = 1, lim t→∞ S(t | x) = 0, S(t 1 | x) ≥ S(t 2 | x) if t 1 ≤ t 2 Previous works instead modeled the hazard function (a.k.a conditional failure rate) h(t | x) (Cox, 1972; Katzman et al., 2018; Wei, 1992; Zhong et al., 2021) . h(t | x) := lim ∆t→0 P (t ≤ T < t + ∆t | T ≥ t, x) ∆t = p(t | x) S(t | x) As the hazard function is a probability per unit time, it is unbounded upwards. Hence, the only constraint of the hazard function is that the function is non-negative: h(t | x) ≥ 0

2.1. THE ODE MODELING FRAMEWORK

We can obtain an ODE which explains the relationship between the hazard function and the survival function by putting derivative of equation 1 into equation 2 (Kleinbaum & Klein, 2012) . h(t | x) = p(t | x) S(t | x) = 1 S(t | x) - dS(t | x) dt = - d log S(t | x) dt Starting from initial value log S(0 | x) = 0, we can define log S(t | x) as the solution of the ODE initial value problem where the ODE is defined as equation 3foot_0 . log S(t | x) = log S(0 | x) + t 0 -h(τ | x)dτ = t 0 -h(τ | x)dτ We can train the ODE model by minimizing the negative log-likelihood. L x = -e log p θ (t | x) -(1 -e) log S θ (t | x) (4) = -e (log h θ (t | x) + log S θ (t | x)) -(1 -e) log S θ (t | x) Following Groha et al. (2020) , we update the model parameters using Neural ODEs (Chen et al., 2018) . The hazard function h θ (t | x) is modeled using a neural network followed by the softplus activation function to ensure that the output is always non-negative.

2.2. NEURAL ODES

Neural ODEs model the continuous dynamics of variables (Chen et al., 2018) . Starting from z(0), we can define the output z(T ) to be the solution of the following ordinary differential equation (ODE) initial value problem. dz(t) dt = f (z(t), t, θ), z(T ) = z(0) + T 0 f (z(t), t, θ)dt Naively applying an ODE solver to an ODE initial value problem leads to practical difficulties. An ODE solver builds a big computation graph which incurs high memory cost and additional numerical errors may occur in backpropagation steps. Chen et al. ( 2018) showed that we can obtain the gradients of a scalar-valued loss w.r.t all inputs of any ODE solver with constant memory cost. We can calculate the gradients without backpropagating through the operations of the solver but with another call to an ODE solver.

3. METHODS

Figure 1 : Graphical overview of our proposed method. Our method minimize the hazard gradient penalty E S θ (t|x) ∥∇ x h θ (t | x)∥ 2 and the negative log-likelihood in equation 4 at the same time. Intuitively speaking, we regularize the model so that the hazard function does not vary much when a small noise ν is added or subtracted to the data point x. In section 3.2, we show that minimizing the hazard gradient penalty is connected to minimizing the KL divergence between the density at x and the density at x ′ ∈ B(x, ϵ). In this section, we introduce the hazard gradient penalty and show that it is related to minimizing the KL divergence between the density function at a data point and that of its neighbours. See Figure 1 for the graphical overview of our method. The cluster assumption from semi-supervised learning states that the decision boundaries should not cross high-density regions (Chapelle et al., 2006) . In a similar vein, hazard functions of survival analysis models should change slowly in high-density regions. Consider a case where two data points x 1 , x 2 ∈ R d failed at t 1 , t 2 (t 1 < t 2 ) each. Under the cluster assumption, a point x ′ 2 ∈ B(x 2 , ϵ) 2 should fail at t ′ 2 ≈ t 2 . If p(t | x ′ 2 ) is skewed for some reason and puts high density at t ′ 1 < t 1 , it worsens the model's performance. To evade such situation, p(t | x ′ 2 ) should not deviate too much from p(t | x 2 ). To achieve this, we propose the following regularizer.foot_2  R x = E t∼S θ (t|x) [∥∇ x h θ (t | x)∥ 2 ] (5)

3.1. EFFICIENT SAMPLING FROM THE SURVIVAL DENSITY

The sampling operation t ∼ S(t | x)foot_3 in equation 5 may induce computational overhead. To boost the sampling operation, we use log S θ (t | x) which was computed during the negative log-likelihood calculation in equation 4. Let [t 1 , . . . , t K ] be the union of the time points in minibatch. The time points are sorted in increasing order. The adaptive time stepping in ODE solvers are sensitive to the time interval t K -t 1 rather than the number of time points (Rubanova et al., 2019) . We can access log S θ (t k | x) with negligible overhead as long as t 1 < t k < t K . We sample t k from a categorical distribution whose k-th weight is defined as S θ (t k | x) = exp(- t k 0 h(τ | x)dτ ). We finalize the sampling process by sampling t from the uniform distribution U([t k , t k+1 ]). In this way, we don't have to calculate S(t | x) again for sampling t. See Algorithm 1 in Appendix for the pseudo code of ODE based survival analysis with the hazard gradient penalty.

3.2. CONNECTION TO KL DIVERGENCE

We now show that the hazard gradient penalty in equation 5 is equivalent to minimizing the approximation of the upper bound of the KL divergence between the density function at a data point and that of the neighborhood points of the data point. Henceforth, we denote X by the subset of d-dimensional real space R d . Theorem 1 Suppose the hazard function is strictly positive function for all data point x, x ′ ∈ X . The KL divergence E p(t|x) [log p(t | x) -log p(t | x ′ )] is upper bounded by E p(t|x) ∥log h(t | x) -log h(t | x ′ )∥ 2 + E S(t|x) ∥h(t | x) -h(t | x ′ )∥ 2 (6) To prove Theorem 1, we need the following lemma. Lemma 1 The expectation of survival densities difference under the density is the negative of the expectation of hazard functions difference under the survival density. In other words, E p(t|x) [log S(t | x) -log S(t | x ′ )] = -E S(t|x) [h(t | x) -h(t | x ′ )] for all x, x ′ ∈ X Proof) We use the fact that E S(t|x) (log S(t | x) -log S(t | x ′ )) is constant with respect to t. d dt E S(t|x) (log S(t | x) -log S(t | x ′ )) = d dt S(t | x) (log S(t | x) -log S(t | x ′ )) dt = -p(t | x) (log S(t | x) -log S(t | x ′ )) dt + S(t | x) (-h(t | x) + h(t | x ′ )) dt = 0 Hence, E p(t|x) [log S(t | x) -log S(t | x ′ )] = -E S(t|x) [h(t | x) -h(t | x ′ )] ■ We now go back to Theorem 1 and prove the theorem. E p(t|x) [log p(t | x) -log p(t | x ′ )] = E p(t|x) [log p(t | x) -log p(t | x ′ )] 2 (∵ D KL ≥ 0) = E p(t|x) [log h(t | x) -log h(t | x ′ )] -E p(t|x) [log S(t | x) -log S(t | x ′ )] 2 (∵ equation 2) = E p(t|x) [log h(t | x) -log h(t | x ′ )] + E S(t|x) [h(t | x) -h(t | x ′ )] 2 (∵ Lemma 1) ≤ E p(t|x) [log h(t | x) -log h(t | x ′ )] 2 + E S(t|x) [h(t | x) -h(t | x ′ )] 2 (∵ triangle inequality) ≤ E p(t|x) ∥log h(t | x) -log h(t | x ′ )∥ 2 + E S(t|x) ∥h(t | x) -h(t | x ′ )∥ 2 ■ Theorem 2 An approximation of the upper bound of the KL divergence given in equation 6 is upper bounded by 2ϵE S(t|x ) ∥∇ x h(t | x)∥ 2 if x ′ is in the epsilon ball centered at x, i.e. x ′ ∈ B(x, ϵ). To prove the theorem, we first find the approximation. Lemma 2 2ϵE S(t|x) ∥∇ x h(t | x) T (x ′ -x)∥ 2 is an approximation of the upper bound of the KL divergence which is given in equation 6. E p(t|x) ∥log h(t | x) -log h(t | x ′ )∥ 2 + E S(t|x) ∥h(t | x) -h(t | x ′ )∥ 2 ≈ E p(t|x) ∇ x log h(t | x) T (x ′ -x) 2 + E S(t|x) ∇ x h(t | x) T (x ′ -x) 2 (∵ log h(t | x ′ ) ≈ log h(t | x) + ∇ x log h(t | x) T (x ′ -x) and h(t | x ′ ) ≈ h(t | x) + ∇ x h(t | x) T (x ′ -x)) = E p(t|x) ∇ x h(t | x) T h(t | x) (x ′ -x) 2 + E S(t|x) ∇ x h(t | x) T (x ′ -x) 2 = p(t | x) h(t | x) ∇ x h(t | x) T (x ′ -x) 2 dt + E S(t|x) ∇ x h(t | x) T (x ′ -x) 2 (∵ h(t | x) > 0) = 2E S(t|x) ∇ x h(t | x) T (x ′ -x) 2 ■ Obviously, 2E S(t|x) ∇ x h(t | x) T (x ′ -x) 2 ≤ 2E S(t|x) max x ′ ∈X ∇ x h(t | x) T (x ′ -x) 2 . As we assumed x ′ ∈ B(x, ϵ) in Theorem 2, max x ′ ∈X ∇ x h(t | x) T (x ′ -x) 2 is achieved when x ′ -x = ϵ∇ x h(t | x)/∥∇ x h(t | x)∥ 2 . Hence, 2E S(t|x) ∇ x h(t | x) T (x ′ -x) 2 ≤ 2ϵE S(t|x) ∥∇ x h(t | x)∥ 2 and this concludes the proof. ■ Theorem 2 shows that regularizating the hazard gradient penalty in equation 5 is equivalent to minimizing the approximation of the upper bound of the KL divergence E p(t|x) [log p(t | x) -log p(t | x ′ )]. To incorporate the regularizer into the negative log-likelihood loss, we minimize the Lagrange multiplier defined as the sum of the negative log-likelihood and the hazard gradient penalty regularizer. L = E (x,t,e)∼D [L x + λR x ] ) Here, λ is a coefficient that balances the negative log-likelihood and the regularizer. See Appendix B for the code snippet of our JAX implementation (Bradbury et al., 2018) . Minimizing the hazard gradient penalty in equation 5 has two advantages over minimizing the KL divergence directly: a) computational efficiency and b) reduced burden of hyperparameter tuning. To compute the KL divergence, we first sample x ′ ∈ B(x, ϵ). We then need to compute four values: h(t|x), S(t|x), h(t|x ′ ) and S(t|x ′ ). In this case, we have to compute hazard values of every t ∼ S(t|x). Further, we need one more hazard function integration S(t | x ′ ) = exp(-h(t | x ′ )). On the other hand, regularizing the hazard gradient penalty only need to calculate the gradient of the hazard function. When it comes to regularizing the KL divergence, we have to set the appropriate value of the regularizing coefficient λ ′ and the size of the ball ϵ. On the other hand, if we regularize the hazard gradient penalty, we don't need to tune ϵ as λ in equation 5 incorporates ϵ.

4. EXPERIMENTS

In this section, we experimentally show that the hazard gradient penalty outperforms other regularizers. Further, we check the hyperparameter sensitivity of hazard gradient penalty. Throughout the experiments, we use three public datasets: Study to Understand Prognoses Preferences Outcomes and Risks of Treatment (SUPPORT)foot_4 , the Molecular Taxonomy of Breast Cancer International Consortium (METABRIC)foot_5 , and the Rotterdam tumor bank and German Breast Cancer Study Group (RotGBSG)foot_6 . Table 4 summarizes the statistics of the datasets. See Appendix A for evaluation metrics and experimental details.

4.1. METHODS COMPARED

We compare ODE + HGP with four methods: vanilla ODE, ODE + L1, ODE + L2 ODE + LCI. Vanilla ODE minimizes the expectation of the negative log-likelihood in equation 4. ODE + L1 minimizes the Lagrange multiplier defined as the sum of the expectation of the negative log-likelihood and the L1 penalty term: E (x,t,e)∼D L x + α P p=1 |w p | Here, w p s are model parameters and α is a coefficient that balances the negative log-likelihood and the L1 penalty term. This is an extension of Lasso-Cox (Tibshirani, 1997) to the ODE modeling framework. ODE + L2 minimizes the Lagrange multiplier defined as the sum of the expectation of the negative log-likelihood and the L2 penalty term: E (x,t,e)∼D L x + α P p=1 w 2 p Here, w p s are model parameters and α is a coefficient that balances the negative log-likelihood and the L2 penalty term. This is an extension of Ridge-Cox (Verweij & Van Houwelingen, 1994) to the ODE modeling framework. ODE + LCI minimizes the Lagrange multiplier defined as the sum of the expectation of the negative log-likelihood and the negative of the lower bound of a simplified version of time-dependent C-index. The regularizer is defined as - t N i=1 N j=1 e i I(T i < T j , T i < t)(1 + (log σ(S θ (t | x i ) < S θ (t | x j ))/ log 2) N i=1 N j=1 e i I(T i < T j , T i < t) This is equivalent to time dependent concordance index in Section A.1.1 if we don't take the Kaplan-Meier estimator into account. The regularizer is a reminiscent of the lower bound of C-index (Steck et al., 2007) . Although the lower bound of C-index was originally proposed as a substitute of the negative log-likelihood, Chapfuwa et al. (2018) used the lower bound (Steck et al., 2007) as a regularizer of the AFT model (Wei, 1992) .

4.2. RESULTS

Table 1 shows the mC td , mAU C, and iN BLL scoresfoot_7 . The hazard gradient penalty outperforms other methods across almost all metrics and datasets. The interesting point is that both L1 and L2 penalties do not affect the ODE model's performance in most cases. We speculate that regularizing the weight norm is effective in CoxPH as the model is simple and has a strong assumption that the hazard rate is constant. On the contrary, regularizing the norm of the weight may not be able to affect the ODE model's performance as ODE models are much more complex than CoxPH. Also, the experimental results highlight the possibility that the performance of the survival analysis models is more related to the local information such as the gradient at each data point rather than the global information such as the weight norm of the model. Figure 2 shows that the ODE + HGP effectively regularized the variation of the density with respect to the input while other methods could not. Table 1 also shows that regularizing the lower bound of the C-index is not effective in many cases. We conjecture that the method is ineffective as the ODE modeling framework is flexible and optimizing the negative log-likelihood can discriminate each data point's rank. Furthermore, regularizing the lower bound of the C-index does not harness the information of neighbors of data points. We also compare HGP against a Neural ODEs specific regularizer. See Appendix D for details. No reg. 0.771 ± 0.003 0.810 ± 0.002 0.516 ± 0.015 M = 1 0.775 ± 0.004 0.814 ± 0.002 0.505 ± 0.010 M = 5 0.775 ± 0.004 0.814 ± 0.002 0.506 ± 0.011 M = 10 0.775 ± 0.004 0.815 ± 0.002 0.505 ± 0.009 Table 2 shows the results by varying the number of samples M in the sampling process t ∼ S(t | x) in equation 5. As long as the regularizer is applied, the number of samples M does not affect the performance. Even when M = 1, the regularizer works well. Figure 3 shows the results on SUPPORT and RotGBSG datasets by varying the coefficient λ in equation 7. Since the performance variation by λ is stable, the hyperparameter λ can be tuned without much difficulty in practical setups. 

5. RELATED WORKS

A line of research integrated deep neural networks to CoxPH (Faraggi & Simon, 1995; Katzman et al., 2018) and Extended Hazards (Zhong et al., 2021) for more model flexibility. Another line of research proposed distribution-free survival analysis models via the time domain discretization (Lee et al., 2018) , adversarial learning approach (Chapfuwa et al., 2018) , or derivative-based models (Danks & Yau, 2022) . Previous works (Goldstein et al., 2020; Han et al., 2021) proposed new objectives to optimize Brier score (Graf et al., 1999) , Binomial log-likelihood, or distributional calibration directly. Yet to the best of our knowledge, none of the previous works focused on the effect of gradient penalty on survival analysis models. Previous works proposed L1 and L2 regularization in the survival analysis literature (Tibshirani, 1997; Verweij & Van Houwelingen, 1994) . Those methods regularize the survival analysis models so that the L1 or L2 norm of the model parameters does not increase so much. Our method is different from those methods in that we penalize the norm of the gradient on each local data point. Our method is closely related to semi-supervised learning (Chapelle et al., 2006) . Among many semi-supervised learning methods, our method is germane to virtual adversarial training (Miyato et al., 2018) in that it regularizes function variation between a local data point and its neighbours. However, virtual adversarial training is different from ours in that the method was demonstrated in the classification setting and the output is a discrete distribution. In Generative Adversarial Nets (GANs) literature (Goodfellow et al., 2014) , the gradient penalty had been studied actively. Gulrajani et al. (2017) proposed the gradient penalty to satisfy the 1-Lipschitz function constraint in Kantrovich-Rubinstein duality. Mescheder et al. (2018) proposed the gradient penalty to penalize the discriminator for deviating from the Nash equilibrium. Ours is different from these works in that we propose gradient penalty so that the density at x does not deviate much from that of x's neighborhood points.

6. CONCLUSION

In this paper, we introduced a novel regularizer for survival analysis. Unlike previous methods, we focus on individual local data point rather than global information. We theoretically showed that regularizing the norm of the gradient of hazard function with respect to the data point is related to minimizing the KL divergence between the data point and that of its neighbours. Empirically, we showed that the proposed regularizer outperforms other regularizers and it is not sensitive to hyperparameters. Furthermore, the proposed regularizer is computationally efficient and incurs an ignorable overhead. Nonetheless, as minimizing the proposed regularizer may conflict with optimizing the negative log-likelihood, practitioners should tune the balancing coefficient λ for each dataset. The paper highlights the new possibility that the recent advancements in semi-supervised learning could enhance the performance of survival analysis models.  (t | x) ← log h θ (t | x) + log S θ (t | x) L x ← -e log p θ (t | x) -(1 -e) log S θ (t | x) ▷ Negative log-likelihood sample i 1 , . . . , i M ∼ Categorical(S θ (t 1 | x), . . . , S θ (t K | x)) ▷ t ∼ S θ (t | x) t ′ m ← t im-1 + Uniform(t im-1 , t im ) ▷ t 0 = 0 R x ← 1 M M m=1 ∥∇ x h θ (t ′ m | x)∥ 2 ▷ Hazard gradient penalty θ ← θ -γ∇ θ (L x + λR x ) until Convergence

A EVALUATION METRICS AND EXPERIMENTAL DETAILS

A.1 EVALUATION METRICS Throughout this subsection, we denote Ŝ(t | x) as the estimate of S(t | x), I(•) as the indicator function, (x i , T i , e i ) as the ith covariate, time, event indicator of the dataset, Ĝ(t) as the Kaplan-Meier estimator for censoring distribution (Kaplan & Meier, 1958) , and ω i as 1/ Ĝ(T i ).

A.1.1 TIME DEPENDENT CONCORDANCE INDEX (C td )

The concordance index, or C-index is defined as the proportion of correctly ordered pairs among all comparable pairs. We use time dependent variant of C-index that truncates pairs within the prespecified time point Uno et al. (2011) . The time dependent concordance index at t, C td (t), is defined as N i=1 N j=1 e i { Ĝ(T i )} -2 I(T i < T j , T i < t)I( Ŝ(t | x i ) < Ŝ(t | x j )) N i=1 N j=1 e i { Ĝ(T i )} -2 I(T i < T j , T i < t) To evaluate C td at [t 1 , . . . , t L ] at the same time, we take its mean mC td = 1 L L l=1 C td (t l ). A.1.2 TIME DEPENDENT AREA UNDER CURVE (AUC) is an extension of the ROC-AUC to survival data Hung & Chiang (2010) . It measures how well a model can distinguish individuals who fail before the given time (T i < t) and who fail after the given time (T j > t). The AUC at time t, AU C(t), is defined as N i=1 N j=1 I(T j > t)I(T i ≤ t)ω i I( Ŝ(t | x i ) ≤ Ŝ(t | x j )) ( N i=1 I(T i > t))( N i=1 I(T i ≤ t)ω i ) To evaluate AU C at [t 1 , . . . , t L ] at the same time, we take its mean mAU C = 1 L L l=1 AU C(t l ).

A.1.3 NEGATIVE BINOMIAL LOG-LIKELIHOOD

We can evaluate the negative binomial log-likelihood (NBLL) to measure both discrimination and calibration performance Kvamme et al. (2019) . The negative binomial log-likelihood at t measures how close the survival probability is to 1 if the given data survived at t and how close the survival probability is to 0 if the given data failed before t. The NBLL at t, N BLL(t), is defined as - 1 N N i=1 log(1 -Ŝ(t | x i ))I(T i ≤ t, e i = 1) Ĝ(T i ) + log Ŝ(t | x i )I(T i > t) Ĝ(t) For the convenience of evaluation, we integrate the NBLL, iN BLL = To find the best λ in equation 7, we run experiments with λ = 1, 5, 10, 50 and report the results at λ = 10 as it shows decent performance across all metrics and datasets. We also have to set the number of samples M from the time sampling process t ∼ p θ (t | x) in equation 5. We set M = 5 across all the hazard gradient penalty experiments. 

D COMPARISON AGAINST NEURAL ODES SPECIFIC REGULARIZER

In this section, we compare HGP against Neural ODEs specific regularizer: STEER (Ghosh et al., 2020) . STEER regularizes Neural ODEs by perturbing the final time of the integration. Table 5 compares the performance of HGP against STEER. Overall, HGP outperforms STEER on various setups.

E THE TAKEN TIME AND THE NUMBER OF FUNCTION EVALUATIONS

In this section, we compare HGP against competing regularizers in terms of the taken time and the number of function evaluations on SUPPORT dataset. See Table 6 for the details. Overall, the result aligns with that of Table 3 .



Tang et al. (2022b)'s formulation is slightly different in that their hazard function also depends on the cumulative hazard. To our understanding, depending on cumulative hazard is redundant so we conduct experiments without it. B(x, ϵ) is a ϵ-ball centered at x In practice, we implement h θ (t | x) using a neural network whose input is a combination (concatenation, addition or both) of t and x. Hence we can write h θ (t | x) and h θ (t, x) interchangeably. The gradient ∇xh θ (t, x) is naturally defined and so is ∇xh θ (t | x). S(t | x) is not a valid probability distribution as we cannot guarantee S(t | x)dt = 1. Rigorously, we sample t ∼ s(t | x) where s(t | x) = S(t | x)/ S(t | x)dt. We use t ∼ S(t | x) for notational simplicity. See Appendix C for the existence of s θ (t | x). https://github.com/autonlab/auton-survival/blob/master/dsm/datasets/support2.csv https://github.com/jaredleekatzman/DeepSurv/tree/master/experiments/data/metabric https://github.com/jaredleekatzman/DeepSurv/tree/master/experiments/data/gbsg See Appendix A.1 for the details of mC td , mAU C, and iN BLL. The code will be made publicly available in the near future.



Figure2: A boxplot of the log-likelihood variation with respect to the input perturbation E (x,t)∼De=1 ∥ log p θ (t | x) -log p θ (t | x ′ )∥ 2 on three datasets. We denote D e=1 by the set of uncensored data. We choosex ′ = x + ϵg/∥g∥ 2 where g = ∇ x log p θ (t | x) so that ∥ log p θ (t | x) -log p θ (t | x ′ )∥ 2 is maximized under x ′ ∈ B(x, ϵ) constraint.We set ϵ = 1e -2 across all experiments. The hazard gradient penalty effectively regularizes the variation. As the density of probability distribution p(• | x) should be concentrated at t for uncensored data (x, t), the figures show that the hazard gradient penalty effectively regularizes the KL divergence between the density function at a point x and that of neighborhood points x ′ ∈ B(x, ϵ).

Figure 3: Violin plots of experimental results on SUPPORT and RotGBSG by varying λ. Red, green, and blue denote mCtd, mAU C, and iN BLL. The thickness of a plot denotes the probability density of the results. The hazard gradient penalty may conflict with the negative log-likelihood if we set high λ. The λ that achieves the best scores across all metrics on the RotGBSG dataset could have been acquired between λ = 5 and λ = 10. However, we report the result at λ = 5 on the RotGBSG dataset in Table1cfor consistency.

[unique_idx, jnp.arange(X.shape[0])]

36 hazard_t = hazard_func(t, states_t0, args=params)[:, -1] 37 38 log_prob_t = jnp.log(hazard_t + 1e-6) + log_surv_t 39 40 assert log_surv_t.shape == log_prob_t.shape

Experimental Results on three datasets. Averages and standard deviations of 7 different random seeds for each setting are shown. See Appendix A.2 for the details of experimental setups. The dagger mark indicates that the result is statistically significant (p < 0.05) compared to the result of vanilla ODE.

Experimental Results on SUPPORT ablated in terms of sample size M . We set λ = 10.



Table 1c for consistency. evaluations (NFE) in the training time of ODE + L1 and ODE + L2 are lower than that of vanilla ODE. The decrease in NFE compensates for the overheads of calculating L1 and L2 penalties, which makes the training time of ODE + L1 and ODE + L2 faster than that of vanilla ODE. The same applies to ODE + HGP. Despite the additional sampling process and gradient calculation, the

The taken time (in seconds) and the number of function evaluations (NFE) for each step in the training/evaluation time on RotGBSG. The numbers in parenthesis indicate relative performance against vanilla ODE. HGP incurs negligible overhead (1% slowdown) on training time while gives rise to 10% speedup on evaluation time. See Table6for the taken time and the NFE for each step of regularizers on SUPPORT.

Summary statistics of the datasets used in our experiments. N denotes the number of data points and d denotes the dimension of each data points. Retrieve unique times t 1 , . . . , t K from minibatch. integrate -h θ (t | x) from 0 to t K and store log S θ (t 1 | x), . . . , log S θ (t K | x) log S θ (t | x) ← choose from log S θ (t 1 | x), . . . , log S θ (t K | x) that corresponds to (x, t, e) log p θ

Across all experiments, we use an MLP with two hidden layers where each layer has 64 hidden units. Across all layers, we apply Layer normalizationBa et al. (2016). Instead of naively feeding time t into the neural network, we feed scaled time t = (t -t 2 )/(t 3 -t 1 ) where t 1 , t 2 , and t 3 are first, second, and third quartile of failure event times. We found that this strategy enhances the ODE model's performance and boosts training time. To incorporate time t into the survival analysis model, we project the time into an eight dimensional vector using a single layer MLP and then concatenate it to the input data. The time t is also specified by adding projected output into each layer output. We use the AdamW optimizerLoshchilov & Hutter (2019) and clipped the gradient norm so that it does not exceed 1. We set the learning rate to 0.001. We have implemented the code using JAXBradbury et al. (2018) and Diffrax Kidger (2021) 9 .

To report mC td and mAU C, we calculate C td and AU C at 10%, 20%, . . . , 90% event quantiles and average them. To report iN BLL, we integrate from the minimum time of the test set to the maximum time of the test set. We use scikit-survival Pölsterl (2020) to report mC td and mAU C. We use pycoxKvamme et al. (2019) to report iN BLL. Across all experiments, we run 7 experiments with different seeds and report their mean and the standard deviation. params: optax.Params, X: jnp.ndarray, event: jnp.ndarray, 22 t: jnp.ndarray, timestamp: jnp.ndarray, unique_idx: jnp.ndarray,

Comparison against STEER. (a) mC td (↑) ± 0.002 0.729 ± 0.005 0.746 ± 0.006 STEER 0.810 ± 0.001 0.730 ± 0.006 0.745 ± 0.006 HGP 0.814 ± 0.002 0.732 ± 0.005 0.753 ± 0.005 ± 0.015 0.472 ± 0.005 0.530 ± 0.012 STEER 0.521 ± 0.012 0.475 ± 0.011 0.523 ± 0.006 HGP 0.506 ± 0.011 0.479 ± 0.003 0.530 ± 0.003

The taken time (in seconds) and the number of function evaluations (NFE) for each step in the training/evaluation time on SUPPORT. The numbers in parenthesis indicate relative performance against vanilla ODE. HGP boosts training time (8% faster) and gives rise to 9% speedup on evaluation time.

annex

assert event.shape == log_prob_t.shape In general, we cannot guarantee the existence of probability distributionas the integration S θ (t | x)dt may not exist.To ensure the existence of S θ (t | x)dt, we simply add a constraint: h θ (t | x) ≥ ϵ. Here, ϵ is a very small constant (e.g. ϵ = 1e -8). As ϵ is a very small constant, it has a negligible impact on the algorithm. The constraint can be achieved easily by adding ϵ to the softplus output of the hazard function.

