A Self-Attention-Guided 3D Deep Residual Network With Big Transfer to Predict Local Failure in Brain Metastasis After Radiotherapy Using Multi-Channel MRI

A noticeable proportion of larger brain metastases (BMs) are not locally controlled after stereotactic radiotherapy, and it may take months before local progression is apparent on standard follow-up imaging. This work proposes and investigates new explainable deep-learning models to predict the radiotherapy outcome for BM. A novel self-attention-guided 3D residual network is introduced for predicting the outcome of local failure (LF) after radiotherapy using the baseline treatment-planning MRI. The 3D self-attention modules facilitate capturing long-range intra/inter slice dependencies which are often overlooked by convolution layers. The proposed model was compared to a vanilla 3D residual network and 3D residual network with CBAM attention in terms of performance in outcome prediction. A training recipe was adapted for the outcome prediction models during pretraining and training the down-stream task based on the recently proposed big transfer principles. A novel 3D visualization module was coupled with the model to demonstrate the impact of various intra/peri-lesion regions on volumetric multi-channel MRI upon the network’s prediction. The proposed self-attention-guided 3D residual network outperforms the vanilla residual network and the residual network with CBAM attention in accuracy, F1-score, and AUC. The visualization results show the importance of peri-lesional characteristics on treatment-planning MRI in predicting local outcome after radiotherapy. This study demonstrates the potential of self-attention-guided deep-learning features derived from volumetric MRI in radiotherapy outcome prediction for BM. The insights obtained via the developed visualization module for individual lesions can possibly be applied during radiotherapy planning to decrease the chance of LF.

( ) = ( ) + . Simply, by setting ( ) = 0, ( ) becomes the identity function. Figure 1.a shows the architecture of the 3D residual network applied in this study. The most important part of the architecture is the stack of residual blocks and skip connections to preserve the gradient.
Further to preserving gradient, another reason that skip connections prove useful is the fact that the learned features correlate to lower semantic information retrieved from the input in prior levels.
That information become too abstract if the skip connections are not utilized in this architecture.

CBAM: Convolutional Block Attention Module
When it comes to feed-forward convolutional neural networks, CBAM is a simple yet effective attention module. Given an intermediate feature map, the module progressively infers attention maps along two different dimensions, i.e., channel and spatial, and then multiplies the attention maps by the input feature tensors to perform adaptive feature refinement on the intermediate feature tensors. The fact that CBAM is a lightweight and universal module means that it can be smoothly integrated into any CNN architecture with minimal overhead and that it is trainable from start to finish alongside the base CNNs.
As mentioned above, CBAM consists of two sequential separate attention mechanisms, channel attention, and spatial attention. Because each channel of a feature tensor may be considered as a feature detector, channel attention is focused on 'what' is significant in the context of an input image when using feature tensors. Channel attention begins by aggregating spatial information from the feature tensor using both average-pooling and max-pooling processes, resulting in the generation of two separate spatial context descriptors for each feature map: and , which denote average-pooled features and max-pooled features, respectively.
Afterwards, both descriptors are forwarded to a shared network, which generates the channel attention map ∈ 1×1×1× , where C is the number of channels. The shared network is made up of a multi-layer perceptron (MLP) with one hidden layer. Following the application of the shared network to each descriptor, the resulting feature tensors are combined by applying element-wise summing to form a single feature tensor. To summarize, the channel attention module is computed as ( ) = ( ( ( )) + ( ( ))), where denotes the sigmoid function.
The next step is to compute spatial attention. In order to build a spatial attention map, the spatial attention module uses the inter-spatial relationship between features. At the other end of the spectrum from channel attention, spatial attention focuses on 'where' informative features are located in the image, and it is considered a complement to the channel attention. Using averagepooling and max-pooling operations along the channel axis, the spatial attention map can be calculated. A convolution layer is applied to the concatenated feature descriptor in order to construct a spatial attention map based on it, i.e., ( ) ∈ × × . The channel information in the feature tensor is aggregated via the use of two pooling processes, resulting in the generation of two 3D maps: ∈ × × ×1 and ∈ × × ×1 , that denote the average-pooled features over the channel, and the max-pooled features, respectively. Formally, ( ) = ( 7×7×7 ([  ;  ])), where is the sigmoid function and 7×7×7 is the convolution function with the filter size of 7 × 7 × 7.
The channel and spatial modules are applied to the intermediate feature maps sequentially and output the refined features. Figure 1.e shows the architecture of CBAM. In our proposed architecture, the CBAM block was added right before the 3D average pooling layer to filter out irrelevant information and focus on important details for classification. Figure 1.b shows the proposed architecture augmented with CBAM.

The Self-attention Module
A self-attention module is defined as a tensor mapping that transforms the input tensor to a query, a key, and a value tensor. The key and value are learned features extracted by convolution blocks, and the query determines which values to focus on for the learning process. The role of 3D convolution blocks (1 × 1 × 1 convolutions) before the key, query, and value is to perform linear transformations on the input feature tensors. The key, query, and value vectors are denoted by The 3D self-attention module is added to the architecture after each residual block to ensure deriving long-range dependencies along with the convolution layers that mostly capture local features and dependencies.

3D Visualization
In order to provide explainability to the model, we proposed a framework for creating 3D heatmaps ii. The generated point cloud is then normalized (each , , are normalized between 0 and 1) iii. A desired surface within the MRI volume is specified. Since the impact of all voxels within the MRI volume is estimated and color-coded, it is possible to explore and visualize any desired areas throughout the MRI (or the volumetric ROI), for a comprehensive understanding of how different intra-and peri-lesional regions contribute to the network's decisions in therapy outcome prediction.
iv. Since the number of points in the point cloud is limited, in order to improve the quality of the final 3D heatmap, interpolation is performed to generate new random points with the constraint of being on the specified surface.
v. Using a k-nearest neighbor algorithm the 10 closest points to each newly generated point are identified and their average color code is assigned it.
vi. The interpolated point cloud is used to calculate and generate surface normal orientations at each point required for surface reconstruction. Calculating normal orientations was done using a minimum spanning tree.
vii.Once the normal orientations are calculated, using the Poisson surface reconstruction techniques a smooth surface is generated showing important regions contributing to network decisions. Figure S1 shows the overall procedure for generating a desired 3D surface from the initial point cloud. Figure S1 -The procedure for creating explorable 3D brain models and visualization heatmaps from a set of individual slices. (a) initially, the coordinates ( , , ) of each voxel center in the MRI volume is determined. Since the number of slices is often limited, the resulting point cloud consists of multiple clusters with the same and different and which is visually undesirable. To mitigate this issue, the points between slices are randomly interpolated, (b) the point cloud after the inter-slice interpolation, (c) the 3D brain model after assigning an intensity to each point in the point cloud and surface reconstruction, (d) applying the same procedure to generate a color-coded 3D visualization heatmap of importance for a lesion within the brain.