Clinically-Inspired Multi-Agent Transformers for Disease Trajectory Forecasting From Multimodal Data

Deep neural networks are often applied to medical images to automate the problem of medical diagnosis. However, a more clinically relevant question that practitioners usually face is how to predict the future trajectory of a disease. Current methods for prognosis or disease trajectory forecasting often require domain knowledge and are complicated to apply. In this paper, we formulate the prognosis prediction problem as a one-to-many prediction problem. Inspired by a clinical decision-making process with two agents–a radiologist and a general practitioner – we predict prognosis with two transformer-based components that share information with each other. The first transformer in this framework aims to analyze the imaging data, and the second one leverages its internal states as inputs, also fusing them with auxiliary clinical data. The temporal nature of the problem is modeled within the transformer states, allowing us to treat the forecasting problem as a multi-task classification, for which we propose a novel loss. We show the effectiveness of our approach in predicting the development of structural knee osteoarthritis changes and forecasting Alzheimer’s disease clinical status directly from raw multi-modal data. The proposed method outperforms multiple state-of-the-art baselines with respect to performance and calibration, both of which are needed for real-world applications. An open-source implementation of our method is made publicly available at https://github.com/Oulu-IMEDS/CLIMATv2.


I. INTRODUCTION
R ECENT developments in Machine Learning (ML) sug- gest that it is soon to be tightly integrated into many fields, including healthcare [1], [2].One particular subfield of ML -Deep Learning (DL) has advanced the most, as it opened the possibility to make predictions from high-dimensional data.In medicine, this impacted the field of radiology, in which highly trained human readers identify pathologies in medical images.The full clinical pipeline, however, aims to assess the the slowing down of the disease, for example using behavioral interventions [13].
Individuals with AD have difficulties with reading, learning, and even performing daily activities.AD is fatally progressive and caused more than 120, 000 deaths in the United States in 2019; however, no effective cure for it has been made available [8].The benefits of early AD diagnosis are similar to OA -the progression of the disease can be delayed, and patients may be assigned relevant care in a timely manner [14].
In both of the aforementioned fields -OA and AD, there is a lack of studies on prognosis prediction.From an ML perspective, a more conventional setup is to predict whether the patient has the disease at present or a specific point of time in the future [15], [16], [17], [18], [19], [20], [21].However, prognosis prediction aims to answer whether and how the disease would evolve over time.Furthermore, in a real-life situation, the treating physician makes the prognosis while interacting with a radiologist or other stakeholders who can provide information (e.g.blood tests or radiology reports) about the patient's condition.This also largely differentiates the diagnostic task from predicting a prognosis.
In this paper, we present an extended version of our earlier work on automatic DTF [22], where we proposed a Clinically-Inspired Multi-Agent Transformers (CLIMAT) framework, aiming to mimic the interaction process between a general practitioner / treating physician1 and a radiologist.In our system, a radiologist module, consisting of a feature extractor (convolutional neural network; CNN) and a transformer, analyses the input imaging data and then provides an output state of the transformer representing a radiology report to the general practitioner -corresponding module (purely transformerbased).The latter fuses this information with auxiliary patient data, and makes the prognosis prediction.We graphically illustrate the described idea in Figure 1.
Compared to the conference version [22], we have enhanced our framework, such that the module corresponding to the general practitioner does not only perform prognosis, but is also encouraged to make diagnostic predictions consistent with a radiologist module.The earlier version of CLIMAT relies on a simplifying assumption in relation to the independence between the diagnostic label task and non-imaging data.The introduced update helps the framework to expand out of the knee osteoarthritis domain, and be more realistic, thereby allowing our method to be applied in fields where diagnosis could rely on both imaging and non-imaging data.Moreover, we equip the framework with a new loss -Calibrated Loss based on Upper Bound (CLUB) -that aims to maintain the performance while improving the calibration of the framework's predictions.Finally, we have also expanded the application of our framework to the case of AD.
To summarize, our contributions are the following: 1) We propose CLIMATv2, a clinically-inspired transformer-based framework that can learn to forecast disease severity from multi-modal data in an end-to-end manner.The main novelty of our approach is the incorporation of prior knowledge of the decision-making process into the model design.2) We derive the CLUB loss, an upper bound on a temperature-scaled cross-entropy (TCE), and apply it to the DTF problem we have at hand.Experimentally, we show that CLUB provides better calibration and yields similar or better balanced accuracy than the competitive baselines.3) From a clinical perspective, our results show the feasibility to perform fine-grained prognosis of knee OA and AD directly from raw multi-modal 2D and 3D data.

II. RELATED WORK
1) Knee osteoarthritis prognosis: The attention of the literature has gradually been shifting from diagnosing the current OA severity of a knee to predicting whether degenerative changes will happen within a specified time frame.While some studies [15], [16], [17] aimed to predict whether knee OA progresses within a specified duration, others [18], [19] tried to predict if a patient will undergo a total knee replacement (TKR) surgery at some point in the future.However, the common problem of the aforementioned studies is that the scope of knee OA progression is limited to a single period of time or outcome, which substantially differentiates our work from the prior art.
2) Alzheimer's disease prognosis: Compared to the field of OA, a variety of approaches have been proposed to process longitudinal data in the AD field.Lu et al. [21] utilized a fullyconnected network (FCN) to predict AD progression within a time frame of 3 years from magnetic resonance imaging (MRI) and fluorodeoxyglucose positron emission tomography (FDG-PET) scans.Ghazi et al. [23] and Jung et al. [20] used different long-short-term memory (LSTM)-based models to predict AD clinical statuses from scalar MRI biomarkers.
Albright et al. [24] took into account various combinations of scalar measures and clinical variables to predict changes in AD statuses using FCNs and recurrent neural networks (RNN).In contrast to the prior art relying on either raw imaging data or scalar measures, our method enables learning from raw imaging scans, imaging-based measurements, and other scalar variables simultaneously.Additionally, whereas FCN and sequential networks were widely used in the literature, we propose to use a transformer-based framework to perform the AD clinical status prognosis task.Furthermore, we use FCN, two well-known sequential models -gated recurrent unit (GRU) and LSTM -as our reference approaches.
3) Transformers for vision tasks: Although originally developed in the field of natural language processing [25], [26], transformer-based architectures have recently been applied also in vision tasks.Dosovitskiy et al. [27] pioneered the use of transformer-based architectures without a CNN for image classification problems.Girdhar et al. [28] and Arnab et al. [29] studied the same family of architectures to perform video recognition tasks.However, Hassani et al. [30] pointed out such pure transformers require a significantly large amount of imaging data to perform well.The reason is that transformers do not have well-informed inductive biases, which are strengths of CNNs.Thus, our method relies on [30] due to medium dataset sizes.
4) Multimodal data processing with transformers.: Transformers have been empirically robust in learning various categories of tasks from sequential data such as text or tabular data [25], [31].However, in medical imaging, it is common to acquire multiple modalities comprising both raw images (e.g.plain radiographs, MRI, or PET scans) and tabular data, which are challenging for a single transformer.Recent work has shown that multiple transformers are needed to for such multiple modalities [32].Therefore, similar to our previous version [22], this study adapts the idea of using multiple transformers in our framework to perform DTF from multiple modalities.

A. The CLIMAT framework: a conceptual overview
As mentioned earlier, we base our framework on multi-agent decision-making processes in a clinical setting.In many applications, this can be considered information passing between two agents -a radiologist and a general practitioner [33].While the radiologist specializing in imaging diagnosis is in charge of producing radiology reports, the general practitioner relies on various modalities including the radiologic findings to forecast the severity of a certain disease.We model such collaboration by the concept presented in Figure 1.Specifically, the radiologist analyzes a medical image x 0 (e.g.radiograph or PET image) of a patient to provide an interpretation with rich visual description and annotations, allowing the diagnosis of the current stage y R 0 of the disease.Subsequently, the general practitioner relies on (i) the clinical data m 0 (e.g.questionnaires or symptomatic assessments) with a further interpretation if needed, (ii) the provided radiology report, and (iii) the referenced diagnosis of the radiologist y R 0 to predict the course of the disease y 0:T .
We implement the concept proposed above in the CLIMATv2 framework (see Figure 4 and Section III-B).CLI-MATv2 comprises three primary transformer-based blocks2 namely Radiologist (R), Context (C), and General Practitioner (P).Firstly, assume that we obtain visual features learned from the imaging data x 0 .Then, the block R acts as the radiologist to perform visual reasoning from the visual features and predict the current stage ŷR 0 of a disease.The other two blocks are responsible for context extracting and prognosis predicting.As such, the block C aims to extract a context embedding from clinical variables m 0 .Subsequently, the block P utilizes the combination of the context embedding and the output states of the block R to forecast the disease trajectory ŷ0:T .
In this work, we have two major upgrades to CLI-MATv1 [22].Firstly, we do not assume anymore that y 0 and m 0 are independent, as this does not hold in many medical imaging domains, e.g. for OA [34].Namely, in the current version of CLIMAT, both the blocks R and P have now been allowed to make diagnosis predictions simultaneously, making sure that the learned embeddings contain information on y 0 .Furthermore, we encourage their predictions to be consistent with the final module of our model.Secondly, besides performance, in this work, we take into account model calibration, which allows us to gain better insights into the reliability of models' predictions [35].To facilitate better calibration within our proposed framework, we propose a novel loss, called CLUB, presented in Section III-C.

B. Technical realization 1) Transformer:
A transformer encoder comprises a stack of L multi-head self-attention layers, whose input is a sequence of vectors {s i } N i=1 where s i ∈ R 1×C , and C is the feature size.As such, a transformer is formulated as [25] where E [CLS] is a learnable token, E [P OS] is a learnable positional embedding, and h represents features extracted from the last layer.MLP is a multi-layer perceptron (i.e. a fullyconnected network), LN is the layer normalization [36], and MSA(•) is a multi-head self-attention (MSA) layer [25].The self-attention mechanism relies on the learning of query, key, value parameter matrices, denoted by W Q l , W K l , and W V l with l = 1, . . ., L, respectively.Initially, we simultaneously set Q 0 , K 0 , and V 0 to h 0 defined in Eq. (1).When iterating through layers l = 1, . . ., L, we update the states as follows Finally, the self-attention is established thanks to the scaled dot-product function applied to Q l , K l , and V l , and defined as where d k is the feature dimension of Q l .In essence, Q l K ⊺ l represents the association between all pairs of queries and keys.The normalization based on d k is critical to address the case where the magnitude of entries in Q l K ⊺ l is too large.The essential part that produces the attention is the utilization of softmax, which allows for the creation of a normalized heatmap over the association of Q l and K l .Subsequently, by adding more sets of learnable weights W Q l , W K l , and W V l , we can obtain MSA by concatenating different output heads of attention.Precisely, the MSA mechanism is formulated as follows where H is the number of heads, and W O l represents learning parameters associated with the H output heads.
The three main blocks in our framework are transformerbased networks (see Figure 4).While the blocks R and C have only 1 [CLS] token, the block P can include K [CLS] tokens to allow for multi-target predictions.The hyperparameter K is introduced in the block P to ensure that there are enough output heads for multi-task predictions.We typically set K to 1 or T + 1.In the case K = T + 1, each output head has a corresponding [CLS] token.
2) Multimodal feature extraction: Our framework is able to handle multimodal imaging and non-imaging data.As input data can be clinical variables, raw images (i.e.2D or 3D images), and biomarkers extracted by human experts or specialized software, we have distinct feature extraction modules for different input formats.Specifically, we use the feed-forward network (FFN), 2D-CNN, and 3D-CNN-based architectures for scalar or 1D inputs, 2D, and 3D images, respectively.As such, we pre-define common feature lengths C X and C M for all imaging and non-imaging embeddings, respectively.Each FFN-based feature extractor consists of a linear layer, a GELU activation [37], layer normalization [36], and has an output shape of 1×C X or 1×C M depending on the type of input data.In the CNN-based modules, we first unroll their output feature maps into sequences of feature vectors per image super-pixel or super-voxel, then linearly project them into a C X -dimensional space.
3) Radiologist module: The Radiologist block is a transformer network with L R layers and is responsible for processing all imaging features previously extracted in Section III-B.2.For the input data preparation, we concatenate all features of different imaging modalities to form a sequence of length N that contains C X -dimensional image representations.Subsequently, we propagate this sequence through the transformer R. To this end, the visual embedding hR ∈ R (N +1)×C X produced by its last layer serves two purposes: representing radiology reports and visual features for diagnosis predictions.For the former, we subsequently combine hR with non-imaging embeddings to constitute inputs for the General Practitioner block (see Section III-B.5).For the latter, following a common practice in [38], [39], [40], we perform an average pooling onto hR to generate a C X -dimensional vector.Afterward, we pass the resulting vector through an FFN comprised of a linear layer, a GELU activation [37], and a layer normalization [36] to predict the current stage y R 0 of the disorder (see Figure 4).4) Clinical context embedding module: Here, we aim to mimic the comprehension of a general practitioner over different clinical modalities (e.g.questionnaires, extra tests, and risk factors).As such, we take a single [CLS] embedding followed by M clinical vector representations extracted in Section III-B.2 to form the input sequence for the Context block (see Figure 4).The underlying architecture of the block is a transformer-based network.After passing the input sequence through the transformer C with L C layers, we merely use the first feature vector h0 C of the last feature maps h L C as a common contextual token representing all the non-imaging modalities.
5) General Practitioner module: As soon as the contextual token of length C M is acquired from the Context block, we concatenate N + 1 copies of the token h0 C into the last states hR of the transformer R to generate a sequence of N + 1 mixed feature vectors with a feature size of C X + C M .We then process the obtained sequence using Eq. ( 1) to have the sequence of (K + N + 1) feature vectors.Here, we utilize the third transformer-based module to simulate the analysis of the general practitioner over all sources of data for prognosis predictions.Specifically, after passing the input sequence through the transformer P, we utilize the first T + 1 vector representations of the last layer to forecast the disease severity trajectory (ŷ 0 , . . ., ŷT ).Predicting disease severity at each time point requires a common or distinct FFN, which comprises a layer normalization followed by two fully connected layers separated by a GELU activation [37].Compared to CLIMATv1 [22], we aim to optimize not only the performance but also the calibration of our model's predictions.As CLIMATv2 simultaneously predicts a sequence of T + 1 targets with different difficulties, we treat it as a multitask predictive model.The temporal information here is contained within the transformer states.Inspired by [41], to harmonize all the tasks, we propose the CLUB loss (abbreviated from Calibrated Loss Upper Bound).However, unlike [41], which relies on the 'not always true' assumption that

C. Calibrated Loss based on Upper Bound for Multi-Task
, where σ is a noise factor and f c ′ (.) is the c ′ -th element of the logits produced by a parametric function f , we theoretically derive CLUB as an upper bound of temperature-scaled cross-entropy (CE) loss.
Consider the t-th task with t ∈ {0 . . .T }, let denote predicted logits of CLI-MATv2 (i.e. an output of the transformer P) on task t, where N t c is the number of classes of the t-th target.Let g t = (g t,1 , . . ., g t,N t c ) ⊺ = exp(f t ).Similar to [41], we model the affection of noise σ t onto the prediction of y t in the scaled form SOFTMAX 1 σ 2 t +ε f t , where ε ∈ R + is needed to ensure the scaled softmax to be valid for all σ t ∈ R. For convenience, we temporarily eliminate the t index from all notations.By denoting τ = 1 σ 2 +ε ∈ R + , we rewrite the scaled softmax as where c, c ′ are class indices, and τ ∈ R + is a noise factor.
Without the loss of generality, c is assumed to be the ground truth class of a certain input x. τ is the inverse temperature that can smoothen (τ ≤ 1) or sharpen (τ > 1) predicted probabilities.Here, one can observe that ( c ′ g τ c ′ ) 1 τ can be seen as an absolutely homogeneous function or an ℓ τ -norm ∥g∥ τ in a Lebesgue space, when τ belongs to (0, 1) or [1, ∞), respectively.Therefore, a TCE loss can be formulated as where c is the true class.When τ = 1, the TCE loss becomes the vanilla CE loss For the purpose of improving calibration, we are interested in the case of τ ∈ (0, 1] [35], allowing us to apply the reverse Hölder's inequality to have ∥g∥ τ ≤ N (1−τ )/τ c ∥g∥ 1 .Then, we can derive an upper bound of L TCE , called the CLUB loss, as where the equality holds if and only if τ = 1.Unlike L TCE , our CLUB loss directly depends on ∥g∥ 1 rather than ∥g∥ τ .Eq. (10) indicates that L CLUB is a convex combination between the CE loss (9) and log N c , which takes into account the task complexity in terms of the number of classes.
2) Performance and calibration optimization: In our setting, we consider each τ t associated with task t as a learnable parameter.As the model's parameters θ and τ t 's are independent, we can respectively derive the gradients of L CLUB (t) w.r.t.θ and τ t 's as follows where L CE (t) and L CLUB (t) are the CE and CLUB losses on the t-th task, respectively.Whereas the optimization w.r.t.θ essentially aims to improve the performance of our model, learning τ t 's directly impacts its calibration quality.Eqs.(10) and (11) indicate that τ t 's can be seen as learnable coefficients of different tasks.
To effectively constrain τ t ≤ 1 and avoid a trivial solution where ∀t ∈ {0, . . ., T }, τ t = 1, we constrain the learnable parameters {τ t } T t=0 using Algorithm 1. Specifically, Line 1 guarantees that ρ t 's are valid for any σ t 's.Lines 2 to 4 prevent all the τ t 's from converging to the obvious value 1. Lines 5 and 6 re-scales τ t 's such that merely ones with the maximum values become 1.This last step is necessary to avoid τ t 's values being small inversely proportionally to the number of tasks.

D. Multi-Task Learning for Disease Trajectory Forecasting
In practice, it is highly common to have data not fully annotated.Thus, our framework should allow for handling missing targets by design.As such, our multi-task loss can tackle such an impaired condition with ease by using an indicator function to mask out targets without annotation.Formally, we minimize the following prognosis forecasting loss where I t is an indicator function for task t.
While the radiologist has strong expertise in imaging diagnosis, in relation to prognosis, the general practitioner has more advantages due to the access to multimodal data, such as the patient's background.On the other hand, general practitioners are also able to assess images to some extent.We incorporate the corresponding prior into our learning framework by enforcing consistency in predictions between the two agents: where f R 0 and f 0 indicate logits of the blocks R and P for diagnosis predictions, respectively.It is worth noting that while L prog operates solely on annotated targets, L cons optimizes all targets.
To optimize the whole framework, we minimize the final loss L as follows where λ ∈ R + is a consistency regularization coefficient.

A. Data
In this study, we conducted experiments on two public datasets for knee OA and AD.The overall description and subject selection of the two datasets and corresponding tasks can be seen in Figure 5 and Table I.The details of data pre-processing and prognosis prediction tasks are presented as follows.1) Knee OA structural prognosis prediction: We conducted experiments on the Osteoarthritis Initiative (OAI) cohort, publicly available at https://nda.nih.gov/oai/.4, 796 participants from 45 to 79 years old participated in the OAI cohort, which consisted of a baseline, and follow-up visits after 12, 18, 24, 30, 36, 48, 60, 72, 84, 96, 108, 120, and 132 months.In the present study, we used all knee images that were acquired with large imaging cohorts: the baseline, and the 12, 24, 36, 48, 72, and 96-month follow-ups.
As the OAI dataset includes data from five acquisition centers, we used data from 4 centers for training and validation, and considered data from the left-out one as an independent test set.On the former set, we performed a 5-fold crossvalidation strategy.
Following [15], [42], we utilized the BoneFinder tool [43] to extract a pair of knees regions from each bilateral radiograph, and pre-process each of them.Subsequently, we resized each pre-processed image to 256 × 256 pixels (pixel spacing of 0.5mm), and horizontally flipped it if that image corresponds to a right knee.
We utilized the Kellgren-Lawrence (KL) as well as OARSI grading systems to assess knee OA severity.The KL system classifies knee OA into 5 levels from 0 to 4, proportional to the OA severity increase.The OARSI system consists of 6 sub-systems -namely lateral/medial joint space (JSL/JSM), osteophytes in the lateral/medial side of the femur (OSFL/OSFM), and osteophytes in the lateral/medial side of the tibia (OSTL/OSTM).And according to that the furthest targets in KL, JSL, and JSM were 8 years from the baseline while it was 4 years for the other grading aspects.
Regarding the KL grading system, we grouped KL-0 and KL-1 into the same class as they are clinically similar, and added TKR knees as the fifth class.As a result, there were 5 classes in KL, and there were 4 severity levels in each of the OARSI sub-systems.Following [15], [22], we utilized age, sex, body mass index (BMI), history of injury, history of surgery, and total Western Ontario and McMaster Universities Arthritis Index (WOMAC) as clinical variables.We quantized the continuous variables, and presented each of them by a 4-element one-hot vector depending on the relative position of its value in the interval created by the minimum and the maximum.
For clinical relevance, we did not perform knee OA prognosis predictions on knees that underwent TKR or were diagnosed with the highest grade in any OARSI sub-system.In addition, we ignored one single entry whose pair of knees were improperly localized from its lateral radiograph by the BoneFinder tool.To have more training samples, we generated multiple entries from the longitudinal record of each participant by considering imaging and non-imaging data at different follow-up visits (except for the last one) as additional inputs.
2) AD clinical status prognosis prediction: We applied our framework to forecast the Alzheimer's disease (AD) clinical status from multi-modal data on the Alzheimer's Disease Neuroimaging Initiative (ADNI) cohort, which is available at https://ida.loni.usc.edu.The recruitment was done at 57 sites around America and Canada, and there were 2, 577 male and female participants from 55 to 90 enrolled in the cohort.The participants underwent a series of tests such as clinical evaluation, neuropsychological tests, genetic testing, lumbar puncture, MRI, and PET imaging at a baseline and follow-up visits at 1, 2, and 4-year periods.
In this study, we used raw FDG-PET scans, MRI measures, cognitive tests, clinical history, and risk factors as predictor variables.The raw FDG-PET scans were pre-processed by the dataset owner, and were then standardized to voxel dimensions of 160 × 160 × 160 (1.5 × 1.5 × 1.5mm 3 voxel spacing) using the NiBabel library [44].To be in line with the OAI dataset, we applied the same technique to convert scalar inputs to onehot encoding vectors with a length of 4. In querying subjects, while we only selected entries whose raw FDG-PET scans were available, the other input variables were allowed to be missing.
Our objective was to forecast the AD clinical statuses of participants' brains -cognitively normal (CN), mild cognitive impairment (MCI) or probable AD -in the next 4 years.Since the amount of the queried data was substantially limited (see Table I), we sampled entries from follow-up examinations to increase the amount of training data, and performed 10-fold cross-validation on this task.

B. Experimental Setup 1) Implementation details:
We trained and evaluated our method and the reference approaches using V100 Nvidia GPUs.Each experimental setting was performed on a single GPU with 12GB.We implemented all the methods using the PyTorch framework [45], and trained each of them with the same set of configurations and hyperparameters.For each problem, we used the Adam optimizer [46].The learning rates of 1e−4 and 1e−5 were set for the OA and AD-related tasks, respectively.
To extract visual representations of 2D images, we utilized the ResNet18 architecture [47] whose weights were pretrained on the ImageNet dataset [48].We used a batch size of 128 for the knee OA experiments.Regarding 3D images, we chose the 3D-ShuffleNet2 architecture because it was well-balanced between efficiency and performance as shown in [49], which allowed us to train each model with a batch size of 36 on a single consumer-level GPU.We utilized 3D-ShuffleNet2's weights previously pretrained on the Kinetics-600 dataset [50].Moreover, we used a common feature extraction architecture with a linear layer, a ReLU activation, and a layer normalization [36] for all scalar numerical and categorical inputs.We provide the detailed description of the input variables in Tables II and III.
2) Baselines: For fair comparisons, our baselines were models that had the same feature extraction modules for multi-modal data, as described in Section IV-B.1, but utilized different architectures to perform discrete time series forecasting.As such, we compared our method to baselines with the forecasting module using fully-connected network (FCN), GRU [51], LSTM [52], multi-modal transformer (MMTF) [31], Reformer [53], Informer [54], Autoformer [55], or CLIMAT [22].While FCN, MMTF, Reformer, Informer, Autoformer, and CLIMAT are parallel models, GRU and LSTM are sequential approaches.Among the transformerbased methods both versions of CLIMAT have a modular structure of transformers rather than using a flat structure as in MMTF, Reformer, Informer, and Autoformer.
3) Metrics: As data from both OAI and ADNI were imbalanced, balanced accuracy (BA) [56] was a must metric in our experiments.As there were only 3 classes in the AD clinical status prognosis prediction task, we also utilized the one-vsone multi-class area under the ROC Curve (mAUC-ROC) [57] as another metric.To quantitatively measure calibration, we used expected calibration error (ECE) [58], [35].We reported means and standard errors of each metric computed over 5 runs with different random seeds.
To perform analyses of the statistical significance of our results, we utilized the two-sided Wilcoxon signed-rank test to validate the advantage of our method compared to each baseline [59].We equally split the test set into 20 subsets without overlapping patients.For such a subset, we computed metrics averaged over 5 random seeds per method.The statistical testing was done patient-wise by comparing our method with every baseline individually.In the case of the OAI dataset, for all patients, we did two rounds of hypothesis testing: one for the left and one for the right knee, respectively.Subsequently, we applied the Bonferroni correction to adjust the significance thresholds for multiple comparisons (p = 0.025 due to two knees per patient) [60].C. Ablation studies 1) Overview: We conducted a thorough ablation study to investigate the effects of different components in our CLI-MATv2 architecture on the OAI dataset.The empirical results are presented in Table IV and summarized in the following subsections.
2) Effect of the transformer P's depth: Firstly, we searched for an optimal depth of the transformer P. The results show that the transformer P with a depth of 4 provides the best performance, yielding 0.2 % gain in averaged BA compared to depths of 2 and 4. The average BA (over 4 years) indicates a substantial boost in performance.We, therefore, use the depth of 4 for the transformer P in the sequel.
3) Effect of the number of [CLS] embeddings and FFNs in the transformer P: Then, we simultaneously validated two components: using single or multiple [CLS] embeddings, and using common or separate FFN in the transformer P. Of 4 combinations of settings, the quantitative results suggest that the transformer should have 9 individual [CLS] embeddings, each of which corresponds to an output head, and merely use one common FFN to make predictions at different time points.
4) Effect of the consistency term: To validate the necessity of the L cons term, we conducted an experiment on a set of λ values {0, 0.25, 0.5, 0.75, 1}.The empirical evidence in Table IV shows that a λ of 0.5 resulted in the best performance, which was 0.7% higher than the setting without L cons .We further validated the effects of the consistency term on other knee OA grading criteria as well as the AD status forecasting task.The empirical results in Table V consistently demonstrate that the term L cons has a positive impact on performance, albeit with the trade-off of calibration.A consistency coefficient λ of 0.5 is the most optimal setting in terms of performance across the tasks.Specifically, we observed BA gains of 1.7%, 0.6%, and 0.5% with tradeoff ECEs of 0.4%, 0.8%, and 0.4% for JSL, JSM, and AD, respectively.5) Average pooling for image representation: In contrast to the previous version, we adopted a conventional approach used in prior studies [38], [39], [40], which involves performing an average pooling over the output sequence of the Radiologist block to constitute an imaging feature vector for diagnosis prediction ŷR 0 .According to Table IV, such an approach results in a gain of 1.1% BA compared to the baseline, which solely utilized the first vector of the sequence generated by the block R.

6) Multimodal channel-wise concatenation:
We conducted an ablation study on the combination of multimodal embeddings.As such, we compared our channel-wise approach to a sequence-wise baseline that simply concatenates imaging embeddings and a projected version of non-imaging ones.For the baseline, we utilized a linear projection layer to ensure that imaging and non-imaging embeddings are in the same C Xdimensional space.We reported the K-fold cross-validation results in Table VI.On the knee OA-related tasks, our approach tends to have positive benefits on both performance and calibration.Specifically, the performance gains were 2.3%, 2.0%, and 0.5% for KL, JSL, and JSM, respectively.Except for JSL with an increase of 0.1% ECE, the approach results in calibration improvements of 2.5% and 3.2 for KL and JSM, respectively.On the AD-related task, the channel-wise approach leads to improvements of 1.1% BA and 0.2% ECE.
7) Effectiveness of CLUB loss: We compared the CLUB loss to CE itself, multi-task loss (MTL) [41], focal loss (FL) [61], and adaptive focal loss (AFL) [62].Whereas the first two baselines and our loss are based on CE loss, the remaining ones are related to FL.In Figure 6, we graphically visualize the trade-off between performance and calibration, in which the best in both aspects are expected to locate close to the top-left corners.We observe that our model trained with FLrelated losses was substantially worse calibrated compared to the settings with any CE-based loss.Among the losses based on CE, the proposed CLUB helped our model to achieve the best ECEs in all three OA grading systems with insubstantial drops in performance.
2) Alzheimer's disease status prognosis prediction: We reported the quantitative results in Table VII.Regarding performance, both the CLIMAT methods achieved the best performances across the prediction targets, in which CLIMATv2 was top-1 at the first 2 years in both BA and mROCAUC.Compared to the transformer-based baseline MMTF, our method outperformed by 2.2%, 1.8%, and 2.2% BAs at years 1, 2, and 4, respectively.In calibration, CLIMATv2 yielded substantially lower ECEs than all the references at every prediction target.That observation was supported by the statistical test results in Table VII.

E. Attention maps over multiple modalities
The self-attention mechanism of the transformers in CLI-MATv2 allowed us to visualize attention maps over imaging and non-imaging modalities when our model made a prediction at a specific target.Specifically, we used , where Q L , K L are query and key matrices of the last layer L, respectively, and d k is the feature dimension of the key matrix, as attention maps [25].While we utilized the softmax output corresponding to h0 C in the transformer F for clinical variables, we took the softmax output in computing ht P with t = 0, . . ., T in the transformer P to visualize attention maps on imaging modalities.Here, we set t = 1, corresponding to the forecast of a disease severity 1 year from the baseline.
1) Knee OA structural prognosis: In Figure 9, we visualized attention maps over different input modalities across 7 grading criteria.As such, in Figure 9a, we displayed a healthy knee at the baseline overlaid by 7 corresponding saliency maps.For differentiation, we also provided colored ellipses.Figure 9b shows the heatmap over the 6 clinical variables on each grading criterion.Values on each row sum up to 1.In this particular case, we observe that the model has paid the most attention to the intercondylar notch, together with BMI and WOMAC [63].
2) AD clinical status prognosis: As imaging data consisted of 3D FDG-PET scans as well as the other imaging measurements, we had to separate them into Figures 10a and 10b.We can observe that an attention sphere locates around the posterior cingulate cortex, the inferior frontal gyrus, and the middle gyrus [64].Figure 10b shows accumulated attention weights corresponding to the FDG-PET feature vectors alongside ones of the other imaging measurements.The reason that imaging variables were assigned a substantially higher importance is that the number of the 3D visual embeddings was dominant compared to the others (i.e.125 versus 6).In Figure 10c, high attention can be observed on the percent forgetting score of the Rey Auditory Verbal Learning Test (RAVLT), RAVLT immediate, the AD assessment score 11-item (ADAS11), Clinical Dementia Rating Scale-Sum of Boxes (CDRSB), Mini-Mental State Exam (MMSE), and Functional Activities Questionnaire (FAQ).

V. CONCLUSIONS
In this paper, we proposed a novel general-purpose transformer-based method to forecast the trajectory of a disease's stage from multimodel data.We applied our method to two real-world applications, that are related to OA and AD.Our framework provides tools to integrate multi-modal data and has interpretation capabilities through self-attention.
In comparison with the prior version, CLIMATv2 has two primary upgrades.First, we have eliminated the assumption of independence between non-imaging data m 0 and diagnostic predictions y 0 used in CLIMATv1 [22] since it does not hold not only in OA and AD, but also in other diseases.Specifically, Liu et al. [34] provided empirical evidence of the benefit of the inclusion of non-imaging data in the knee OA grading task.The study conducted by Bird et al. [65] indicated a link between human genes and AD while Li et al. [66] showed that a blood test can detect the existence of amyloid-beta plaques in the human brain, which is strongly associated with AD status.Second, we have proposed the CLUB loss, which allowed us to optimize for both performance and calibration.
There are some limitations in this study, which are worth mentioning.First, we used common DL architectures as imaging and non-imaging feature extractors.While such a standardized procedure resulted in fair comparisons, better results could have been obtained with e.g.Neural Architecture Search methods [67].Furthermore, a wider range of DL modules could have been considered, but this could substantially increase the use of computing resources.Specifically, to obtain results in this work, it required roughly 400 GPU hours for experiments in Table VII and 525 GPU hours in Figure 7 for every method, respectively.
The second limitation of the present study, is that attention maps produced by transformers act as human-friendly signals of our model, and should be carefully used in practice with expert knowledge in the domain.Transformers may highlight areas not associated with the body part, which can be seen in Figure 11 as well as in other studies [68], [69], [70].
To conclude, to our knowledge, this is not only the first study in the realm of OA, but also the first work on AD clinical status prognosis prediction from the multi-modal setup that includes raw 3D scans and scalar variables.The developed method can be of interest to other fields, where forecasting of calibrated disease trajectory is of interest.An implementation of our method is made publicly available at https://github.com/Oulu-IMEDS/CLIMATv2.

VI. ACKNOWLEDGEMENT
The OAI is a public-private partnership comprised of five contracts (N01-AR- The authors wish to acknowledge CSC -IT Center for Science, Finland, for generous computational resources.

Fig. 2 .
Fig. 2. Radiographs of a patient with knee OA progressed in 8 years.The orange arrow indicates joint space narrowing.The disease progressed from Kellgren-Lawrence (KL) grade 0 at the baseline (BL) to 3 in 6 years.At the 8th year, the patient underwent a total knee replacement (TKR) surgery.

Fig. 3 .
Fig. 3.The three projections of a 3D FDG-PET scan, which is converted to the jet colormap for demonstration purposes.The red regions are associated with high concentrations of the FDG radioactive tracer in the brain.

Fig. 4 .
Fig. 4. The CLIMAT framework (best viewed in color).There are N and M input imaging and non-imaging feature vectors, respectively.The first feature vector h0 C of the last layer of the transformer C is appended to every output vector of hR to form the input for the transformer P. All the blue blocks are transformer-based networks.[CLS] and [P OS] embeddings are in white and orange, respectively.

Fig. 6 .
Fig. 6.Performance and calibration comparisons between CLUB and other baselines.All the measures are on the medial side.The losses can be categorized into groups: (1) FL and AFL, and (2) CE, MTL, and CLUB, which are on cross-entropy and focal loss, respectively.

Fig. 7 .
Fig. 7.  Performance comparisons between our CLIMAT models and other baselines on the knee osteoarthritis prognosis task via different types of grading (means and standard errors over 5 random seeds).* and * * indicate the statistically significant differences between CLIMATv2 compared to each baseline via Wilcoxon signed-rank tests (p < 0.05 and p < 0.001, respectively).As the statistical tests were conducted on both knees, p-value thresholds were adjusted to 0.025 and 0.0005, respectively.

TABLE I DATASET
STATISTICS.SUBJECTS ARE PATIENT KNEE JOINTS, AND PATIENT BRAINS FOR OAI, AND ADNI, RESPECTIVELY.

TABLE II INPUT
VARIABLES FOR FORECASTING KNEE OA SEVERITY GRADES.

TABLE IV HYPERPARAMETER
AND MODEL SELECTION BASED ON CV PERFORMANCES ON THE KL-BASED KNEE OA PROGNOSIS PREDICTION TASK.BA * INDICATES THE AVERAGES OF BAS OF THE TARGETS AT THE BASELINE AND THE FIRST 4 YEARS.

TABLE V EFFECT
OF THE CONSISTENCY TERM ON PERFORMANCE AND CALIBRATION (K-FOLD CROSS-VALIDATION).REPORTED RESULTS AREAVERAGES OF BAS AND ECES OVER THE FIRST 4 YEARS.

TABLE VI ABLATION
STUDY ON IMAGING AND NON-IMAGING COMBINATION WITH K-FOLD CROSS-VALIDATION (K = 5 AND K = 10 FOR OAI AND ADNI, RESPECTIVELY).CHANNEL-WISE APPROACH (OURS) IS COMPARED TO THE SEQUENCE-WISE APPROACH, CONCATENATING IMAGING EMBEDDINGS PRODUCED BY THE BLOCK R WITH PROJECTED NON-IMAGING EMBEDDINGS OUTPUTTED BY THE BLOCK C. REPORTED RESULTS ARE AVERAGED BAS AND ECES OVER THE FIRST 4 YEARS.

TABLE VII CV
PERFORMANCE AND CALIBRATION COMPARISONS ON THE ADNI DATA (MEAN AND STANDARD ERRORS OVER 5 RANDOM SEEDS).THE BEST PERFORMANCES WITH AND WITHOUT SUBSTANTIAL DIFFERENCES ARE INDICATED BY BOLD AND UNDERLINED VALUES, RESPECTIVELY.THE SUBSTANTIAL IMPROVEMENT IS DETERMINED BY WHETHER THE BEST PERFORMANCE OVERLAPS WITH ANY OTHER METHOD'S.* AND * * INDICATE THE STATISTICALLY SIGNIFICANT DIFFERENCES BETWEEN CLIMATV2 VS.EACH BASELINE VIA WILCOXON SIGNED-RANK TESTS (P < 0.05 AND P < 0.001, RESPECTIVELY).
2-2258; N01-AR-2-2259; N01-AR-2-2260; N01-AR-2-2261; N01-AR-2-2262) funded by the National Institutes of Health, a branch of the Department of Health and Human Services, and conducted by the OAI Study Investigators.Private funding partners include Merck Research Laboratories; Novartis Pharmaceuticals Corporation, GlaxoSmithKline; and Pfizer, Inc.Private sector funding for the OAI is managed by the Foundation for the National Institutes of Health.Data collection and sharing for this project was funded by the Alzheimer's Disease Neuroimaging Initiative (ADNI) (National Institutes of Health Grant U01 AG024904) and DOD ADNI (Department of Defense award number W81XWH-12-2-0012).ADNI is funded by the National Institute on Aging, the National Institute of Biomedical Imaging and Bioengineering, and through generous contributions from the following: AbbVie, Alzheimer's Association; Alzheimer's Drug Discovery Foundation; Araclon Biotech; BioClinica, Inc.; Biogen; Bristol-Myers Squibb Company; CereSpir, Inc.; Cogstate; Eisai Inc.; Elan Pharmaceuticals, Inc.; Eli Lilly and Company; EuroImmun; F. Hoffmann-La Roche Ltd and its affiliated company Genentech, Inc.; Fujirebio; GE Healthcare; IXICO Ltd.; Janssen Alzheimer Immunotherapy Research & Development, LLC.; Johnson & Johnson Pharmaceutical Research & Development LLC.; Lumosity; Lundbeck; Merck & Co., Inc.; Meso Scale Diagnostics, LLC.; NeuroRx Research; Neurotrack Technologies; Novartis Pharmaceuticals Corporation; Pfizer Inc.; Piramal Imaging; Servier; Takeda Pharmaceutical Company; and Transition Therapeutics.The Canadian Institutes of Health Research is providing funds to support ADNI clinical sites in Canada.Private sector contributions are facilitated by the Foundation for the National Institutes of Health (www.fnih.org).The grantee organization is the Northern California Institute for Research and Education, and the study is coordinated by the Alzheimer's Therapeutic Research Institute at the University of Southern California.ADNI data are disseminated by the Laboratory for Neuro Imaging at the University of Southern California.