TeTrIS: Template Transformer Networks for Image Segmentation With Shape Priors

In this paper, we introduce and compare different approaches for incorporating shape prior information into neural network-based image segmentation. Specifically, we introduce the concept of template transformer networks, where a shape template is deformed to match the underlying structure of interest through an end-to-end trained spatial transformer network. This has the advantage of explicitly enforcing shape priors, and this is free of discretization artifacts by providing a soft partial volume segmentation. We also introduce a simple yet effective way of incorporating priors in the state-of-the-art pixel-wise binary classification methods such as fully convolutional networks and U-net. Here, the template shape is given as an additional input channel, incorporating this information significantly reduces false positives. We report results on synthetic data and sub-voxel segmentation of coronary lumen structures in cardiac computed tomography showing the benefit of incorporating priors in neural network-based image segmentation.

was combined with anatomical atlases to perform segmentations of different organs [3]- [5]. However these methods require either an image-to-image or image-to-segmentation likelihood function to drive atlas matching or alignment of the deformation model. Statistical methods such as active shape models [6] have been explored extensively with the difficulty of constructing shape models in the first place which are then often limited in their expressiveness due to the underlying manifold learning method (linear or non-linear principal component analysis).
State-of-the-art neural network based segmentation models [7]- [10] typically optimize pixel-wise loss functions such as mean squared error or cross entropy, and more recently differentiable Dice [11]. These objective functions do not take explicit priors into consideration during training. Nevertheless, smoothness priors can be enforced during test time by using conditional random fields or similar post-processing techniques. More recent work has shown improved results by directly incorporating shape constraints into their learning algorithm rather than applying them as post-processing such as in [12]- [14], where priors are learnt to regularize neural network embeddings during training. While this can lead to networks that favor plausible segmentations, there is no guarantee that the outputs adhere to desired shape constraints, such as a single connected component or a closed surface.

A. Contributions
In this paper we introduce a new neural network model based on template deformations which utilizes spatial transformer networks [15]. Our model leverages the representational power of neural networks while explicitly enforcing shape constraints in segmentations by restricting the model to perform segmentation through deformations of a given shape prior. We call this Template Transformer Networks for Image Segmentation (TETRIS). As with template deformations, our method produces anatomically plausible results by regularizing the deformation field. This also avoids discretization artifacts as we do not restrict the network to make pixel-wise classifications. By using a neural network that is trained to align the shape prior to the structure of interest visible in the input image there is no need for a hand crafted (intensity-based) imageto-segmentation registration measure as with other template deformation models. To the best of our knowledge, this is the first full 3D neural network-based image segmentation through This work is licensed under a Creative Commons Attribution 4.0 License. For more information, see http://creativecommons.org/licenses/by/4.0/ Fig. 1. Schematic to illustrate the differences between traditional, pixel-wise segmentation models, the naive way of incorporating priors through additional input, and our TETRIS model which produces a set of parameters for a transformation. By restricting the output space of the network to only deformations of the prior, we obtain guarantees on topology.
registration method combining deep convolutional neural nets with spatial transformers.
Another contribution of this paper is the demonstration that state-of-the-art segmentation algorithms can be easily extended to incorporate implicit shape priors by providing a shape template as additional input during training. To the best of our knowledge, this simple yet effective enhancement has not been considered in the past. Our benchmarking shows that this can lead to a significant increase in segmentation accuracy, a high level graphical overview of these methods are given in Fig. 1.
We present promising results on coronary artery segmentation from cardiac computed tomography which further strengthens the case for the use of priors in medical image segmentation with deep neural nets. Experimentally we show that all methods which utilize prior information are able to consistently improve the cross entropy score of segmentation, and that our method is able to retain a singly connected component segmentation. Our quantitative results show the varying strengths and weakness of the two introduced methods. We also present qualitative results on synthetic examples to demonstrate the effects of out-of-sample data and how this affects neural network segmentations.

B. Related Work
Our proposed model lies in the intersection of machine learning based image registration and segmentation, and the incorporation of shape priors into neural networks. The following discusses the most related work but due to space limitations and the large amount of work in these fields, this cannot provide a comprehensive overview.
1) Atlas and Registration Based Segmentation: Atlas based segmentation algorithms [16] are among the most popular methods and rely on two key components, an intensitybased image-to-image matching term L η and a set of training examples, i.e. the atlases, with corresponding labels. During testing, images can be compared to examples in the atlas dataset using L η and the label mask of the most similar atlases are selected as candidate segmentations. This concept can be extended to employ patch based techniques or advanced label fusion procedures and is robust when label boundaries occur in homogeneous regions. However, this method often offers a coarse segmentation, which may lack precision and can be refined using linear and nonlinear registration. This refinement can be done in either image or segmentation space [17]. The authors of [18] proposed a combination of an image-to-image and segmentation-to-segmentation likelihood function, using a Lagrange multiplier to weight the contribution of each term. If an image-to-segmentation likelihood function is used then this approach is better referred to as template registration [19].
2) Statistical Shape Models: Active shape models as introduced in [20] explicitly model shape based on training examples. By discretizing the k dimensional shapes using n control points, they create point distribution models using an ellipsoid prior. Principle modes of variation can then be found using principle component Analysis (PCA) [21]. New shapes can be represented as a linear weighting of these components. Additionally, by restricting the model to use t principle components where t < kn and restricting the range of values each linear weighting can have, a valid shape space is produced. Active appearance models [22] build on this technique and jointly model appearance together with shape. These models have been use widely in the medical imaging community to perform segmentation [6]. However, such models are heavily biased by the distribution of the training set used to build them.
3) Network Based Image Registration: Traditional registration algorithms take two images, a moving M and fixed F and perform registration by iteratively updating some parameterized transformation T θ which maps image grid locations to each other, such that some loss function L η (M • T θ , F ) is minimized, where bespoke parameters θ are found for a given pair of images during test time. Optimization of the algorithm can be considered as optimizing η, some parameterization of the loss function which results in the 'best' registrations (such as the values of Lagrange multipliers) or by optimizing the choice of T (the transformation family expressible).
The key difference between neural network based image registration and traditional, iterative registration algorithms is that the loss function is only computed during training for neural networks. The parameters of the neural network implicitly encode what transformation, conditioned on the input, is needed to register the image with minimal cost instead of repeatedly calculating a loss to iteratively update the parameters θ .
Recent works on neural network based image registration fall into two major categories, the first treats network based registration as a regression problem on a given ground truth deformation field such as in [23]- [27]. These methods, unlike ours, can be used as fast approximations to other registration models. The second group of methods learn the deformation field implicitly while optimizing a downstream task. For example, [28] combine momentum-parametrization for LDDMM shooting [29] and neural networks to learn an end-to-end model for registration. Reinforcement learning approaches have also been used to perform image registration [30]. They treat the registration problem as an iterative update of four translation and two rotation parameters so do not handle free form deformations.
Related network-based registration methods [31] and [32] use 2D spatial transformer networks to embed the deformation model into a neural network pipeline in order to learn the registration model end-to-end. The latter of which uses a FlowNet architecture [24]. This work was built on in [33] which performs full unsupervised 3D registration. However, unlike our model, all of these methods perform image-toimage registration rather than template deformation for a downstream segmentation task. The authors of [34], [35] begin to investigate template deformations but do not investigate this to its full 3D potential.
4) Shape Priors in Neural Networks: Finally we discuss methods which incorporate shape priors into neural networks. Though conditional random fields are considered smoothness priors, they do assist in providing shape consistency in segmentations. CRFs are incorporated into the training process in [12] by casting CRFs as recurrent neural networks. This allows the segmentation and refinement model to be trained end-to-end. Adversarial training was used in [36] as a means of learning such regularization without the need of an explicit model whilst still being able to train end-to-end. A discriminator network was used to distinguish segmentations from a network and ground truth segmentations, this training process encourages the network segmentations to look more plausible. An interleaving process was proposed in [37] where iterative training of a neural network and CRF refinement was performed inspired by the grab cut [38] method, though this model was not trained end-to-end. More recent work has shown improved results by directly incorporating principle component based shape constraints into their learning algorithm. Building on the work of active shape models [20], the authors of [13] use a PCA layer embedded in the neural network to restrict its output space to be weightings of the principle components, this was extended to a probabilistic model in [39]. Another approach proposed in [14] exploits the fact that autoencoders are able to capture a low dimensional representation of the shapes of segmentation maps. This encoding is then used at training time to constrain the outputs of a segmentation network to be close to this low dimensional manifold via adversarial training. The latter two methods utilize anatomical consistency across subjects.
5) Spatial Transformer Networks: Our work builds heavily on spatial transformer networks (STNs) [15] which we describe below. STNs are a neural network model that, conditioned on some input I returns θ for some parameterized transformation model T θ (G). That is θ = f ψ (I ) where f ψ is a neural network, itself parameterized by ψ. Once we have θ , we are able to differentiably re-sample our image I to V using T θ , as with image registration, here the image itself is re-sampled. The STN model then passes the re-sampled image V to another neural network, g ξ (V ) which performs some down stream task. In [15], they utilize this powerful model to train g ξ , which performs a classification task and simultaneous train f ψ a deformation model that makes the down stream task easier via a combination of rescaling, region of interest extraction and rotation of the input images.
During training, the loss is calculated on the down stream task only, for classification this could be the cross entropy loss between the predicted class produced by g ξ (V ) and the true class. Since the neural network g ξ is a differentiable function and sampling from I to V is also differentiable we are able to train both tasks end to end. Inherently a spatial transformer is performing deformations that assist the down stream task, as opposed to having a loss calculated directly on the task of deforming. This can be considered an implicit registration step, where the registration is autonomously discovered by the network for optimal downstream performance. We do not have to decide what kind of deformation will be good for the task, though we do need to specify the family of deformations T . During test time, unlike iterative registration models, no loss value needs to be calculated. We simply need to perform a forward pass through the network to get both the deformation and the class prediction

II. TEMPLATE TRANSFORMER NETWORKS
Traditional template deformation models require the definition of an image-to-segmentation matching function as an approximation or surrogate to the actual segmentation objective. Iterative optimization is then used to incrementally update the transformation parameters in order to maximize agreement between a template and the image to be segmented. In contrast, our method makes use of network based registration, which only requires the computation of a corresponding loss function (equivalent to the matching function) during training time. This important difference means we no longer need to approximate our actual segmentation function via an intensitybased surrogate and can directly optimize for the task at hand.
We introduce a novel template deformation model that exploits the power of neural network-based registration. Our end-to-end model takes a shape prior in the form of a partial volume image (PVI) and an image as input to a neural network which learns to deform the input prior so as to produce an accurate segmentation of the input image. This is done by implicitly estimating a deformation field so as to maximize template alignment corresponding to optimal segmentation accuracy. We provide a detailed description of the main steps below. An overview of our method is shown in Fig. 2. In the following subsections (II-B, II-C and II-D) we describe in detail how we perform deformations, how we regularize our deformation field and how we handle large volume sizes.

A. Obtaining Shape Templates
Shape priors can be utilized in neural networks in various forms such as in the form of level sets, PVIs, binary masks or as shape parameters (e.g., mesh control points). In this work we focus on the use of a deformation model, conditioned on a shape prior to deform a PVI into another PVI. Our shape prior itself in this particular case is also a PVI but we emphasize this is not a necessity and richer priors such as statistical appearance models can also be used. As template transfer networks predict a transformation instead of a pointwise segmentation map, they lend themselves naturally to the ability of using other geometric representations for the priors such as mesh-based models. Shape priors can be generally obtained via manual, semi-automatic and automatic methods and the exact mechanism is application specific. We will later discuss one particular approach for obtaining shape priors for the application of coronary artery segmentation.

B. Deformation Model
To deform a template, consider some input image I , shape prior U , ground truth segmentation T all of size H × W × D. We have a sampling scheme (or deformation model) T θ (G) where G is considered a standard co-ordinate grid, and loss function L η . T θ (G) is a function on (x t i , y t i , z t i ), grid coordinates in our target space, that maps to (x s i , y s i , z s i ) co-ordinates in our original source space, where we index voxel locations by i ∈ [1, . . . , H W D ] for notational simplicity. Given this we can define V , a re-sampling of a prior U , based on T θ as where k is any sampling kernel with parameters . For image interpolation we use a trilinear kernel to prevent re-sampled pixel values from being extrapolated to outside of the original intensity domain, that is We choose T θ to be a free form deformation, i.e. θ is a three dimensional vector field. For notational simplicity we define the sampling grid function as ⎡ If T θ is a free form deformation which is not in the same resolution as the target image we are required to re-sample the deformation field. Potentially using a different set of sampling kernels k with it's own parameters . We choose to use B-Spline interpolation to ensure smooth fields [40], utilizing the Catmull-Rom solution to the interpolation problem [41].
Our method takes inspiration from STNs by using a neural network f ψ (I, U ), which is conditioned on both the input image I and the shape prior U , to produce parameters θ of the B-Spline deformation model T θ . We can then perform a deformation of the prior U , calculate a segmentation loss and update the parameters ψ of our network.
By combining template deformation with neural networks, we mitigate the key problem with traditional template deformation models, that being the need to hand craft a good image to segmentation alignment function. The source of this problem, as with any registration technique, lies in the fact that a loss calculation must be made during test time to update the deformation field parameters θ . By utilizing STNs to produce θ during test time and instead updating a neural network f ψ during training, we can train a registration model with the true segmentation loss function (based on alignment between prior and reference segmentation) avoiding the need for surrogate functions at test time.
The template deformation model is network agnostic so any neural network can be used. We choose a simple feed forward network architecture with convolutions and max pooling to produce a deformation field which we use in the STN to deform the prior. Full details of which are provided in Figs. 3 and 12.
Our method is able to take any shape prior and deform it with sub-pixel accuracy, unlike other neural network based segmentation algorithms which typically treat segmentation as pixel-wise classification. Since our model smoothly deforms a prior, we are able to produce partial volume segmentations, reducing discretization artifacts in final segmentation maps. We provide experiments on both partial volume data as well as voxel-wise classification results.

C. Field Regularization
Due to the ill-posed nature of registration problems, it is common to constrain deformation fields by adding a regularization term to the optimization problem that favors some desired property, such as locally smooth deformations, or an l2 penalty on the vector field itself to favor minimum displacement solutions. We investigate two regularization terms and L l2 penalizes the l2-norm of the field and L smooth penalizes the sum of squared second order derivatives.

D. Field Aggregation
To deal with the size of the data and the memory restrictions of modern graphics processing units we do inference on a patch basis, we collect control points across patches and aggregate them before re-sampling using B-Spline interpolation. This combined with only using valid padding prevents ill-posed boundary conditions across the image. This also allows us to perform inference on variable size volumes with consistent control point spacing without modification to the neural network.

III. ILLUSTRATIVE EXAMPLE
As a proof of concept, we present qualitative results on the effects of corruption in the data as these are not easily quantifiable. To investigate how incorporating a prior into a neural network can help when corruption is present, we create a toy dataset of 1500 randomly deformed P's, B's and R's for training and two hand crafted test images which we provide qualitative results for. We then train a deformation model to deform the prior (the letter that was originally deformed) to match the deformed letter. Additionally, we train a normal convolution neural network to predict the deformed letter on a pixel-wise level. Both the TETRIS model and the convolutional neural network are conditioned on the prior and the target. As is expected, when the image signal is strong,  the network learns to rely heavily on the image signal and ignores the prior, this can be seen in Fig. 4. We trained both models on only uncorrupted deformations to see how each model can handle an out-of-sample test case. We see that the vanilla CNN learns to completely ignore prior information, so when inferring on corrupt data, it is not able to extrapolate, unlike the TETRIS model. By restricting our model's output space to be within the range of deformations of the prior we are able, even in the presence of corruption to produce plausible results, consistent with our prior.

IV. SYNTHETIC EXPERIMENTS
We argue that for a CNN to handle such corruption it would need to be present in the training set, Fig. 5 shows the effects of having an increased amount of corrupted data in the training set. We construct a secondary dataset where corruption is more easily generated which consists of 1000 randomly deformed discs, where corruption is in the form of smaller discs being cut from the main central disc and smaller peripheral discs being placed around the main central disc and the set is split in half for training and testing. We trained the models with 0%, 5%, 10% and 15% of the training set consisting of corrupted examples. As more corruption is present in the training set, the better the standard CNN model is able to handle them during test time as expected. Though the artifacts that occur are not topologically as plausible as those produced by our TETRIS model, which is reflected in the high dice scores but also high Hausdorff distances.

V. CORONARY ARTERY SEGMENTATION EXPERIMENTS
In this paper we focus on the application of vessel segmentation where ambiguities arise from the functional distinction between veins and arteries, which may have similar image features. This has lead some methods to approach the problem as a multi stage process, first centerlines are extracted [42], then the vessels are segmented [43]. Shape priors can be enforced once good candidate centerlines have been extracted by treating the segmentation task as a wall distance estimation task. By utilizing curved planar reformation [44] the segmentation problem can be cast as a wall distance regression from the centerline and topology can be guaranteed.
We train our network on a set of 274 annotated cardiac CT volumes with 0.5 millimeter isotropic spacing and reserving 138 volumes for validation and an additional 136 for testing. The ground truth labels obtained through manual expert segmentation are in the form of partial volumes.

A. Generating Priors for Coronary Arteries
To generate a shape prior for coronary artery segmentation, we first extract out a centerline using a semi-automatic method which consists of a Random Forest voxel-wise classification, a Dijkstras shortest path based tree extraction and finally a human review and correction step to correct outliers. The centerline is converted to a 3D volumetric representation by creating a tube around it with a fixed radius of 1 mm in a partial volume image, since the centerline exists in arbitrary space rather than voxel space. More examples of coronary centerline extraction method in cardiac CTA can be found in [42]. Fig. 6 is a volume rendering illustrating the difference between the prior and the ground truth segmentation in an example vessel. More generally, priors can be extracted from sources such as automated algorithms, weak labels, human expert knowledge or population based statistics and will inherently be application specific.

B. Model Details and Baselines
We use two baseline models to compare the three models we present i) the residual fully convolutional network (FCN) and ii) a residual U-net architecture utilizing the implementations from [45] using residual blocks from [46]. Details of the architectures can be found in Fig. 3, where the building blocks are described in Fig. 12.
We also present results on naively incorporating shape priors into these state-of-the-art models. We do this by feeding the networks two channels of input, the image to be segmented and the prior that we have of the image at that location. This alternative method is a very simple extension of existing stateof-the-art pixel-wise approaches, computationally cheap and easy to implement. The shape prior, in this case, acts as a kind of initialization for the network's output.

C. Training Details
For all models we use the same patch extraction parameters, during training we dynamically extract 32 patches from each volume and randomly shuffle them into a buffer of 512 patches. Patches are extracted if they are near the centerline, biasing the sampling around the vessel. We use a batch size of 8 for all models and train them using the Adam optimizer [47] while exponentially decaying the learning rate. The learning rate at step i is as defined as where our initial learning rate l 0 = 1 · 10 −5 , decay rate r = 0.99, decay step s = 500 and where regularization is used, we weight it by 5 · 10 −6 .
We pretrain our baseline models using a weighted cross entropy function as defined in Equation 7, where p is our target distribution, q is our candidate distribution and w is a weighting factor. By setting w > 1, we bias the loss term to penalize false negatives. This is beneficial as voxels containing the vessel interior are sparse in any given patch. Penalizing false negatives more prevents the network from predicting all voxels as background voxels during the initial stages of optimization, a trivial local optimum. Note this does not need to be done with our TETRIS model as the network is already biased towards the identity transform thanks to our regularization term which favors a smooth deformation field. For our experiments we set w = 2 and pretrain our non-TETRIS models for 1000 iterations.
We fine tune the models using normal, un-weighted, cross entropy for a further 5000 iterations.

D. Results
Results are presented on a test set of 136 cases, for the task of partial volume estimation we use the cross entropy as a measure of accuracy, we can see from Table II that incorporating shape priors into state of the to state-of-theart neural network segmentation models significantly improves results. For comparison we include results on using the identity function on the prior, i.e. naively taking the shape prior as the segmentation.
We provide box plots of the results in Fig. 9 for a more fine grained break down of the results, where we have plotted the cross entropy on a log scale. Though our model with l2 field regularizations performs the best, there is no significant difference between the methods, exhibiting the expressiveness of a deformation model despite constraining the output space to be a deformation of the prior.  Both our U-Net (with prior) model and our proposed TETRIS model are able to consistently produce singly connected components without post processing by incorporating prior information, further demonstrating the potential of our approachs. However the FCN (with prior) model often does not capture these higher order requirements, even with prior information, we believe this is due to the inherent multi-scale nature of the U-Net architecture. TETRIS shines when using metrics that take into account partial volumes, however our CNNs with priors added as input channels work consistently well when the goal is pixel-wise binary segmentation. We also note a drastic reduction of trainable parameters by a factor of ten for TETRIS compared to U-Net and FCN, indicating a better balance between performance and model complexity.
To obtain the number of connected components, we threshold the partial volume segmentations at 0.5 and perform a 26-connected component analysis. Ideally, all segmentations should have only one connected component. We note that TETRIS without regularization may result in discontinuous segmentations, but did not find this to be the case in practice. We notice no major difference between penalizing the field with an l2 penalty or by the sum of second order derivatives. Fig. 10 shows an example case where neural networks are not able to recover the vessel in the image without prior information, where as all three our models are able to fall back on the prior when the image signal may be weak.
We further investigated the use of more complex models for TETRIS but found that convergence became slow and often resulted in similar validation scores, hence our choice for a simple TETRIS model. We notice that the models also have different strengths and weaknesses, as mentioned previously, the U-Net (with prior) model is better than the FCN (with Fig. 9. Cross entropy for partial volume estimation of test cases for the different methods investigated, clearly demonstrating the benefits of incorporating prior information and the ability of a deformation model to perform just as well, if not better than an standard neural network which naively incorporates prior information. prior) model at capturing global consistency of shape but in regions where contrast is low, our TETRIS model produces smoother more accurate segmentations as seen in Fig. 7.
Using a deformation model does have caveats, in Fig. 8 we see a trifurcation region where TETRIS over-segments and both the U-Net/FCN with prior under-segment. The resolution of the field and the penalty applied to large deformations prevents our model from doing well in such regions.
In summary, we should highlight the advantages of the template transformer based networks over point-wise segmentation models such as U-net and FCN, as it might not be apparent from the segmentation scores. Although, U-Net performs best on Dice, TETRIS performs better on crossentropy assessing the agreement for the soft, partial volume predictions. Additionally, the template transformer networks can provide guarantees on the resulting shape while both U-Net and FCN do not. This can be important in applications where the segmentations are used for downstream tasks such as shape analysis or blood flow calculation. One other important benefit of the TETRIS model (although not explored in this work) is the ability to incorporate a variety of shape priors such as mesh-based representations or probabilistic shape and appearance models (e.g. a mean and variance image).

VI. DISCUSSION
We introduced Template Transformer Networks for image segmentation which are able to deform shape priors into segmentations. This work builds on template deformations by no longer requiring the need for hand-crafted image to segmentation cost functions and makes use of Spatial Transformer Networks for differentiable end-to-end learning. Our method is competitive with state of the art segmentation algorithms while being able to guarantee topological constraints.
Our work is a proof of concept which relied on a simple architecture that can be easily extended. Though our model is restricted in the sense that it can only perform deformations of a prior, we argue this can be an advantage where shape guarantees are important. Arguably, our model strikes a better balance between performance and model complexity due to a significantly fewer number of trainable parameters.
We consider the prior extraction beyond the scope of this work, but we believe that it is a critical part of not only our method, but all methods which require a prior. In problems where no sensible priors can used template based methods are likely not suitable. Though not explored in this work, our approach lends itself to the incorporation of much richer priors, such as probabilistic shape priors and other geometric representations such as meshes or point distribution models.
Our method replaces an iterative method with a one-shot method, we believe a natural extension to investigate would be to incorporate template deformations with recurrent or autoregressive neural networks for more flexible and potentially larger deformations, mitigating the effects of the chosen resolution for the control point grid. Though in this work we chose to use B-Splines, our method is agnostic to the choice of parameterization of the deformation field. The exploration of other and potentially more flexible parameterizations is also of great interest. Additionally, we would also like to explore the use of deformation fields which are not on fixed grids so as to allow for finer deformation fields as and when is needed without the burden of excess computation.

APPENDIX
We provide full examples of vessel segmentations in Fig. 11 to give the reader larger context into the accuracy of the model, where we have trimmed the aorta for clearer visualizations. Without prior information it is clear that the sensitivity of the networks drop substantially.