1 Introduction

Alzheimer’s disease (AD), the most common type of dementia, is a slowly progressive neurodegenerative disorder and leads to memory loss increased and cognitive function reduced. An accurate prediction of related cognitive decline is crucial and would facilitate optimal decision-making for clinicians and patients. Many researchers have studied the cognitive progression by brain magnetic resonance imaging (MRI), and their works demonstrate great potentials to use MRI biomarkers to predict cognitive decline presymptomatically in a sufficiently rapid and rigorous manner.

Traditional machine learning algorithms have been widely applied to AD progression modeling. Prior researches either predict the target clinical scores at the isolated single time point [8] or develop joint analysis schemes on multiple time points [13]. Recently, deep neural networks have brought breakthroughs in modeling AD progression. Suk et al. [9] proposed deep multi-task neural network while Wang et al. [11] modeled the problem via a recurrent neural network (RNN). Although the RNN learns the data in time order, it does not take into account that the longitudinal image data are obtained in a continuous sequence rather than a uniform batch (e.g., the MR images are taken at different time points, not at only one time point). It motivates us to develop a system to mimic how doctors monitor and prognosticate the AD progression.

Fig. 1.
figure 1

Overview of proposed continual longitudinal feature learning framework.

In real-world disease diagnosis applications, different batches of data arrive periodically (e.g., monthly, seasonally, or yearly) with the data distribution changing over time rather than all data coming together. This presents an opportunity for continual learning, whose primary goal is to learn consecutive tasks without forgetting the knowledge learned in the past (e.g., with less longitudinal data) and leverage the previous knowledge to achieve artificial general intelligence. A straightforward way is to fine-tune the deep model for every new data set; however, this can cause “catastrophic forgetting” – a phenomenon where training a model with new tasks interferes the previously learned old knowledge – leading to performance degradation or even overwriting of the old knowledge by the new ones or the model fails to adapt new tasks, bias towards the old knowledge. Many approaches [4, 5] have been proposed to overcome the “catastrophic forgetting”, however, none of the existing methods consider the discriminative weight subset by incorporating inherent correlations between old tasks and new tasks. Besides, it is important to respect the valuable temporal information from the longitudinal data coming in time order (e.g., patient’s 3-month MR image comes in front of its 12-month MR image). We therefore design a novel algorithm termed Deep Multi-order Preserving Weight Consolidation (DMoPWC) to continually learn the time-order of longitudinal sequence without losing statistical power on less longitudinal data and ensure that the old and new tasks correlation is respected. Figure 1 shows the overview of the DMoPWC.

The key contributions of this work can be summarized in threefold. Firstly, we formulate the AD progression in a continual learning manner which respects the longitudinal data sets coming in sequence and ensures equally prediction accuracy for future visits. To the best of our knowledge, it is the first learning model which models disease progression in a continually sequential manner and accumulate the knowledge for predicting future cognitive decline. Secondly, to overcome “catastrophic forgetting” for the old learned time points’ information, we propose a novel DMoPWC—it considers the discriminative weight subset by incorporating inherent correlations between old and new time points’ information and learns the task-specific patient’s information from the new time point. Thirdly, we consider time order knowledge to guarantee features at the certain time point to be temporally ahead of those of succeeding time points. Our extensive experimental results show the superiority of the proposed algorithm.

2 Method

2.1 Problem Definition and Preliminaries

We define the problem as follows—there will be an unknown number of MR images belonging to different tasks (time points) with unknown distributions, arriving in sequence. The task can be a single task or multiple tasks (e.g., patients’ images from a single time point or multiple time points). Our goal is to learn a deep model in such a continual learning scenario without “catastrophic forgetting”. At the testing time, the task at time point t will be given and we aim to test the future clinical scores for time point \(t\,+\,1\). Given a sequence of T tasks, task at time point \(t = 1, 2, \cdots , T\) with \(N_t\) images comes with dataset \(\mathbf {D}_t = \{\mathbf {x}_i^t, y_i^t\}_{i=1}^{N_t}\). Specifically, for task t, \(y_i^t\) is the ground truth of the clinical scores for the i-th subject \(\mathbf {x}_i^t \in \mathbb {R}^{p}\) at time point t. We denote the training data matrix by \(\mathbf {X}^t\) for \(\mathbf {D}_t\), i.e., \(\mathbf {X}^t = (\mathbf {x}^t_1, \cdots , \mathbf {x}^t_{N_t}).\) When the dataset of time point t comes, all the previous training time points’ datasets \(\mathbf {D}_1, \cdots , \mathbf {D}_{t-1}\) are not available any more, but the deep model parameters with L layers \(\theta ^{t-1} = \{\theta _l^{t-1}\}_{l=1}^L\) can be accessed. The problem at time point t when given data \(\mathbf {D}_t\) can be defined as follows:

$$\begin{aligned} \min _{\theta ^t} \mathcal{L}_t (\theta ^t |\theta ^{t- 1}, \mathbf {D}_t) +\lambda \varOmega (\theta ^t),\ t = 1, \cdots , T \end{aligned}$$
(1)

where \(\mathcal {L}_t\) is the loss function of solving \(\theta ^t\), and \(\theta ^t\) is the model parameters for time point t. \(\varOmega (\cdot )\) includes one or more sparsity-inducing norms and \(\lambda \) is a non-negative parameter.

Elastic Weight Consolidation (EWC) [4] is proposed to solve the above problem (1) that consists of a quadratic penalty on the difference between the parameter \(\theta ^t\) and \(\theta ^{t-1}\) to slow down the “catastrophic forgetting” for previously learned time point information. The posterior distribution \(p(\theta ^t|\mathbf {D}_t)\) is used to describe the problem by the Bayes’ rule, \( \log p(\theta ^t |\mathbf {D}_t) = \log p(\mathbf {D}_t |\theta ^t) + \log p(\theta ^t | \mathbf {D}_{t-1}) - \log p(\mathbf {D}_t), \) where the posterior probability \(\log p(\theta ^t | \mathbf {D}_{t - 1})\) embeds all the information from task \(t-1\). EWC approximates it as a Gaussian distribution with mean of parameter \(\bar{\theta }^{t-1}\) and a diagonal matrix I of the Fisher Information matrix \(\mathbb {F}\). The Fisher information matrix \(\mathbb {F}\) is computed by

$$\begin{aligned} \mathbb {F}_i^t = I(\theta ^t)_{ii} = E_x\left[ \left( \frac{\partial }{\partial \theta ^{t}_i}\log p(\mathbf {D}_t|\theta ^t)\right) ^2|\theta ^t\right] . \end{aligned}$$
(2)

Therefore, the problem of EWC at time point t can be rewritten as follows:

$$\begin{aligned} \min _{\theta ^t} \quad \mathcal {L}_t(\theta ^t) + \frac{\lambda _1}{2}\sum _{i}\mathbb {F}^{t-1}_{i}(\theta ^t_{i} - \bar{\theta }^{t-1}_i)^2, \end{aligned}$$
(3)

where \(\lambda _1\) denotes how important the time point \(t-1\) data is compared to time point t data and i labels each weight (layer) of the parameter \(\theta \).

Fig. 2.
figure 2

Graphical illustration of the proposed Deep Multi-order Preserving Weight Consolidation (DMoPWC). DMoPWC first learns a model on baseline data (blue), then updates it after observing 6-month data (yellow) and finally updates the updated model after learning 12-month data (green). The thicker red arrow denotes larger time-order penalty on later time point. DMoPWC can keep most previously learned knowledge comparing with EWC and fine-tuning. (Color figure online)

2.2 Multi-order Preserving Weight Consolidation

The main problem of Eq. (3) is that it only enforces time point t data close to time point \(t-1\) data. This will ignore the patient’s inherent correlations within time point t and the same patient’s information between previous time point knowledge and time point t and such relationship might potentially help improve the statistical power and overcome “catastrophic forgetting” on the previously learned time points’ information. Learning multiple related time points’ data jointly can improve performance relative to learn each time point data separately [1] when the two time points’ data are related. One appealing property of the \(l_{2, 1}\)-norm regularization is that it shares similar parameter sparsity patterns among multiple different tasks. Therefore, a new formulation, Eq. (4), may improve the ability of overcoming “catastrophic forgetting” from multiple time points and enforce the sparsity over features for multiple subjects simultaneously,

$$\begin{aligned} \begin{aligned} \min _{\theta ^t} \mathcal {L}_t(\theta ^t) + \frac{\lambda _1}{2}\sum _{i}\mathbb {F}^{t-1}_{i}(\theta ^{t}_i - \bar{\theta }^{t-1}_i)^2 + \lambda _2\sum _i||\theta ^t_i||_{2, 1}, \end{aligned} \end{aligned}$$
(4)

where \(\lambda _2\) is the non-negative regularization parameter and \(||\theta ^t_i||_{2, 1}=\sum _{j}||\theta _{i,j}^t||_2\) is to learn the related representations and j presents j-th subject (row).

Specifically, we further consider some critical parameters which have better representation power to a subset of the specific time point. It has been shown that \(l_1\) sparse norm [6] can identify informative longitudinal phenotypic biomarkers that are related to pathological changes of AD in brain image analysis. To this end, we propose to learn the discriminative new task-specific parameters while learning task relatedness among multi-time points multiple subjects and the objective function for time point t becomes:

$$\begin{aligned} \min _{\theta ^t} \quad \mathcal {L}_t(\theta ^t) + \frac{\lambda _1}{2}\sum _{i}\mathbb {F}^{t-1}_{i}(\theta ^{t}_i - \bar{\theta }^{t-1}_i)^2 + \lambda _2\sum _i||\theta ^{t}_i||_{2,1}+\lambda _3||\theta ^t||_1 , \end{aligned}$$
(5)

where \(\lambda _3\) is the non-negative regularization parameter. Equation (5) studies the discriminative task-specific weight subset with inherent correlations among multi-time points multiple subjects while keeping previously learned knowledge via weight consolidation.

How to utilize the time ordering imaging information remains an open problem. We thus introduce a novel time-order preserving criteria to enrich Eq. (5), which is to prevent the time point t information \(\theta ^{t}\) from being temporally in front of the time point \(t-1\) information of \(\theta ^{t-1}\). For instance, for longitudinal data, we know that 3-month visit is behind baseline visit and 12-month visit is behind 3-month visit and baseline visit (See Fig. 1). In other words, the model observes the same temporal order as the input longitudinal time series. Thus, we introduce the expression, \(w^t ||\theta ^t\, - \,\theta ^{t-1}||_2^2,\) where \(w^t\) represents the temporal order weight function for time point t. Therefore, \(w^{t-1}\theta ^{t-1} < w^t\theta ^t\) represents the approximated temporal order of the time point t. In this work, we choose a simple element-wise linear form of the weight function \(\mathbf {W}\) to reflect the longitudinal time ordering information as \( \mathbf {W} = [\frac{1}{T}, \frac{2}{T}, \cdots , \frac{t}{T}, \cdots , \frac{T-1}{T}, 1]\).

Therefore, the final objective function of the proposed Deep Multi-order Preserving Weight Consolidation (DMoPWC) will become

$$\begin{aligned} \min _{\theta ^t} \ \mathcal {L}_t(\theta ^t) + \frac{\lambda _1}{2}\sum _{i}\mathbb {F}^{t-1}_{i}(\theta ^{t}_i - \bar{\theta }^{t-1}_i)^2 + \lambda _2\sum _i||\theta ^{t}_i||_{2,1}+\lambda _3||\theta ^t||_1 + \lambda _4w^t ||\theta ^t - \theta ^{t-1}||_2^2, \end{aligned}$$
(6)

where \(\lambda _4\) is a non-negative parameter. Figure 2 shows the geometric illustration of DMoP WC, it shows that our model can learn the most common sub-area (three colors’ overlapping area) among three time points’ data and preserve time-order in sequence comparing with EWC (two colors’ overlapping area) and fine-tuning. The left figure in Fig. 2 illustrates our model has the same model size across multi-time points learning.

figure a

3 Experiments

Datasets. We evaluate our DMoPWC algorithm on the entire ADNI-1 cohort [3]. We study seven time points structural MR images. Responses are MMSE and ADAS-Cog scores, coming from seven different time points: baseline, M06, M12, M18, M24, M36 and M48. The sample sizes corresponding to seven time points are 837, 733, 728, 326, 641, 454 and 251. Specifically, we remove 25 subjects without MMSE and ADAS-cog from the baseline data and we use 812 subjects instead. The hippocampal surface multivariate morphometry statistics (MMS) [12] are utilized as learning features, consist of surface multivariate tensor-based morphometry, which is computed from the conformal grid and describes surface deformation on a local surface region, and radial distance, which measures the surface deformation along the surface normal direction. We use FIRSTFootnote 1 to segment hippocampi from MR images and follow the same protocol as Shi et al. [7] and extract vertex-wise hippocampal morphometry features on every pair of hippocampal surfaces. As a result, each subject \(\mathbf {x}_i^t\) has \(p=120,000\) features in total. In the prediction, we use the current time point data to predict future clinical score, e.g., we study baseline MR images and predict 12-month MMSE/ADAS-cog.

Network Settings. We use a two-layer fully-connected neural network of 100-100 units with ReLU activations as our initial network. All comparison algorithms are trained on a single Nvidia TITAN X GPU. All models and algorithms are implemented using TensorflowFootnote 2 library.Footnote 3

Hyperparameter Settings. All hyper-parameters in DMoPWC are optimized using grid-search and the best results for each model are reported. The SGD optimizer is used with a learning rate of 0.001 and we set batch size of 256 with 1400 iterations, \(\lambda _1 = 15\), \(\lambda _2 = 0.0001\), \(\lambda _3 = 0.15\) and \(\lambda _4 = 0.5\) on MMSE and \(\lambda _1 = 13\), \(\lambda _2 = 0.015\), \(\lambda _3 = 0.00001\) and \(\lambda _4 = 0.1\) on ADAS-cog. We use 200 subjects to compute Fisher \(\mathbb {F}^t\).

Evaluation Methods. In order to evaluate the proposed model, we randomly split the data into training and testing sets using a 9:1 ratio and repeat this procedure 20 times to avoid data bias. We report the mean and standard deviation of these 20 different splits. Lastly, we evaluate the regression performance by using weighted Correlation (wCC), Pearson Correlation Coefficient (PCC) for overall measures and root Mean Square Error (rMSE) for task-specific measures. The three measures are defined as \(wCC=\frac{\sum _{t=1}^TCorr(\mathbf {Y}_t, \hat{\mathbf {Y}_t})N_t}{\sum _{t=1}^T N_t}\), \(PCC = \frac{\sum _{i=1}^{N_{total}}(\mathbf {Y}_i-\mathbf {\bar{Y}})(\mathbf {\hat{Y}}_i-\mathbf {\bar{\hat{Y}}})}{\sqrt{\sum _i(\mathbf {Y}_i-\mathbf {\bar{Y}})^2}\sqrt{\sum _i(\mathbf {\hat{Y}}_i-\mathbf {\bar{\hat{Y}}})^2}}\) and \(rMSE=\sqrt{\frac{||\mathbf {Y}_t - \hat{\mathbf {Y}_t}||_2^2}{N_t}}\), where Corr is the correlation coefficient between two vectors and \(N_t\) is the number of subjects of task t. \(\mathbf {Y}_t\) and \(\hat{\mathbf {Y}_t}\) are the ground truth of targets and the corresponding prediction at time point t while \(\mathbf {\bar{Y}}\) and \(\mathbf {\bar{\hat{Y}}}\) are the mean value of \(\mathbf {Y}\) and \(\mathbf {\hat{Y}}\), respectively.

Comparison Methods. We compare our algorithm with three groups of methods: single-task regression methods: (1) LASSO [10] and (2) Ridge regression [2]; multi-task regression methods: (1) L21: multi-task \(\ell _{2, 1}\)-norm regularization with least square loss [6]. (2) cFSGL: convex fused sparse group Lasso [13]. (3) MSMT: multi-soure multi-target dictionary learning [12]; deep learning methods: (1) SN: Fine-Tuning (No penalty). (2) EWC [4]: elastic weight consolidation. (3) DMoPWC: the proposed algorithm. For linear regression methods, cross-validation is used to select the model parameters in the training data and the same training datasets with MMS features are used to predict its two time points later clinical scores. We use the same patch set of [12] as input. For deep learning methods, we use the same settings for the initial model and the sequential data is used to predict its future two-time-points later clinical scores.

Performance Comparisons. We report the results of DMoPWC and other methods on the prediction model of MMSE on ADNI-I dataset in Table 1. The proposed approach DMoPWC outperforms single-task and multi-task regression methods. For multi-task methods, we observe that dictionary learning method obtains better results than others. We also notice that the deep learning models strongly improve the prediction results over linear regressions. The proposed DMoPWC has better performance than SN because the retraining of SN does not consider the knowledge of previous time points. SN has better performance than EWC and DMoPWC on M12 due to random initialization of the weights of deep neural networks on baseline data, but M12 values of three methods are really close comparing with other time points. However, EWC has the worse performance than SN on most time points while DMoPWC significantly improve EWC because it studies the time-order information along with common weight subset and discriminative new time point features while keeping the old time points’ knowledge. Moreover, DMoPWC can make equally prediction accuracy no matter how many visits the patients have made because of the online continual learning process.

Table 1. The prediction results of MMSE on ADNI-I dataset.
Table 2. The prediction results of ADAS-cog on ADNI-I dataset.

We follow the same experimental settings in the MMSE study and explore the prediction model by ADAS-cog scores and report the performance in Table 2. We can observe that the best performance of predicting scores of ADAS-Cog is achieved by DMoPWC in three time points. MSMT has smallest rMSE on M18 and M48 because of the fluctuation of scores when the available amount of data becomes less. However, after DMoPWC dealing with temporary sequence information, the results are more linear, reasonable and accurate on all time points. We also find out that the proposed DMoPWC has much more improvement on M24 and M36 than MMSE prediction. Since we keep the previous time points’ knowledge, the later time points do not have bias comparing with linear regression algorithms.

Fig. 3.
figure 3

Comparisons on time-order preserving term of rMSE performance on ADNI-I Dataset.

Effect of Time-Order Perserving Term. We compare the effectiveness of the time-order preserving term against the DMoPWC without order-preserving (oP) term. Figure 3 shows the comparison results. DMoPWC achieves better rMSE performance than DMoPWC w/o oP, which demonstrates DMoPWC further improves the results by considering the time order smoothness problem in longitudinal dataset, especially DMoPWC significantly improves the result of M48. This may be due to the baseline data has less correlation with later time points’ data and DMoPWC w/o oP assumes each time point has the same correlation for the later time points and the results show the oP term offers a unique perspective on prognosis with longitudinal data.

Fig. 4.
figure 4

Comparisons of rMSE performance on MMSE and ADAS-cog when learn data in batch and sequential mode.

Comparisons of Learning Data in Batch Mode and Sequential Mode. We study the difference between learning longitudinal data in batch mode and sequential mode in Fig. 4 in terms of rMSE on MMSE and ADS-cog. We can observe that the performance of learning sequential longitudinal data is better than learning all data in batch. It may be partly due to the fact that the model will keep the previous time points’ knowledge and learn the new time point information to improve future results when we learn the longitudinal data via continual learning. However, learning all images together ignores the relationship of early time points and cannot leverage such knowledge to boost the later time points’ prediction results.

Future Works. In the future, we will investigate DMoPWC with few-shot training data to further improve the performance of current continual learning framework.