Seminar Title: Mini-batch Estimation for Deep Cox Models
Abstract:
Powered by modern deep learning techniques, deep Cox neural networks (Cox-NNs) have become widely used in survival analysis, particularly for applications using high-dimensional features such as dynamically predicting time-to-disease onset from time-varying medical images. In practice, Cox-NNs are typically trained using mini-batch stochastic gradient descent (SGD) algorithm, where model parameters are updated iteratively based on incomplete risk set information from small batches of data. However, the Cox partial likelihood relies on the full risk set. This makes the theoretical justification for mini-batch SGD far from obvious and largely unresolved.
In this talk, I will show that mini-batch SGD training for Cox-NNs implicitly optimizes a new objective function that differs from the traditional partial likelihood. This leads to a new estimator for Cox models, namely, the mini-batch maximum partial likelihood estimator (mb-MPLE). I will present the statistical properties of mb-MPLE, reveal their dependence on the batch size, and provide practical guidance for hyper-parameter tuning in mini-batch training for Cox models. Finally, I will illustrate how insights from analyzing mb-MPLE naturally motivates a broader and more flexible rank-based statistical framework for prediction tasks.