Balancing reconstruction error and Kullback-Leibler divergence in Variational Autoencoders

In the loss function of Variational Autoencoders there is a well known tension between two components: the reconstruction loss, improving the quality of the resulting images, and the Kullback-Leibler divergence, acting as a regularizer of the latent space. Correctly balancing these two components is a delicate issue, easily resulting in poor generative behaviours. In a recent work, Dai and Wipf obtained a sensible improvement by allowing the network to learn the balancing factor during training, according to a suitable loss function. In this article, we show that learning can be replaced by a simple deterministic computation, helping to understand the underlying mechanism, and resulting in a faster and more accurate behaviour. On typical datasets such as Cifar and Celeba, our technique sensibly outperforms all previous VAE architectures.

The main frameworks of generative models that have been investigated so far are Generative Adversarial Networks (GAN) [11] and Variational Autoencoders (VAE) ( [17,19]), both of which generated an enormous amount of works, addressing variants, theoretical investigations, or practical applications.
The main feature of Variational Autoencoders is that they offer a strongly principled probabilistic approach to generative modeling. The key insight is the idea of addressing the problem of learning representations as a variational inference problem, coupling the generative model P (X|z) for X given the latent variable z, with an inference model Q(z|X) synthesizing the latent representation of the given data.
The loss function of VAEs is composed of two parts: one is just the log-likelihood of the reconstruction, while the second one is a term aimed to enforce a known prior distribution P (z) of the latent space -typically a spherical normal distribution. Technically, this is achieved by minimizing the Kullbach-Leibler distance between Q(z|X) and the prior distribution P (z); as a side effect, this will also improve the similarity of the aggregate inference distribution Q(z) = E X Q(z|Z) with the desired prior, that is our final objective.
E z∼Q(z|X) log(P (X|z)) log−likelihood −λ · KL(Q(z|X)||P (z)) KL−divergence Loglikelihood and KL-divergence are typically balanced by a suitable λ-parameter (called β in the terminology of β-VAE [12,7]), since they have somewhat contrasting effects: the former will try to improve the quality of the reconstruction, neglecting the shape of the latent space; on the other side, KL-divergence is normalizing and smoothing the latent space, possibly at the cost of some additional "overlapping" between latent variables, eventually resulting in a more noisy encoding [1].
If not properly tuned, KL-divergence can also easily induce a sub-optimal use of network capacity, where only a limited number of latent variables are exploited for generation: this is the so called overpruning/variablecollapse/sparsity phenomenon [6,23].
Tuning down λ typically reduces the number of collapsed variables and improves the quality of reconstructed images. However, this may not result in a better quality of generated samples, since we loose control on the shape of the latent space, that becomes harder to be exploited by a random generator.
Several techniques have been considered for the correct calibration of γ, comprising an annealed optimization schedule [5] or a policy enforcing minimum KL contribution from subsets of latent units [15]. Most of these schemes require hand-tuning and, quoting [23], they easily risk to "take away the principled regularization scheme that is built into VAE." An interesting alternative that has been recently introduced in [8] consists in learning the correct value for the balancing parameter during training, that also allows its automatic calibration along the training process. The parameter is called γ, in this context, and it is considered as a normalizing factor for the reconstruction loss.
Measuring the trend of the loss function and of the learned lambda parameter during training, it becomes evident that the parameter is proportional to the reconstruction error, with the result that the relevance of the KL-component inside the whole loss function becomes independent from the current error.
Considering the shape of the loss function, it is easy to give a theoretical justification for this behavior. As a consequence, there is no need for learning, that can be replaced by a simple deterministic computation, eventually resulting in a faster and more accurate behaviour.
The structure of the article is the following. In Section 2, we give a quick introduction to Variational Autoencoders, with particular emphasis on generative issues (Section 2.1). In Section 3, we discuss our approach to the problem of balancing reconstruction error and Kullback-Leibler divergence in the VAE loss function; this is obtained from a simple theoretical investigation of the loss function in [8], and essentially amounts to keeping a constant balance between the two components along training. Experimental results are provided in Section 4, relative to standard datasets such as CIFAR-10 (Section 4.1) and CelebA (Section 4.2): up to our knowledge, we get the best generative scores in terms of Frechet Inception Distance ever obtained by means of Variational Autoencoders. In Section 5, we try to investigate the reasons why our technique seems to be more effective than previous approaches, by considering the evolution of latent variables along training. Concluding remarks and ideas for future investigations are offered in Section 5.1.

Variational Autoencoders
In a generative setting, we are interested to express the probability of a data point X through marginalization over a vector of latent variables: For most values of z, P (X|z) is likely to be close to zero, contributing in a negligible way in the estimation of P (X), and hence making this kind of sampling in the latent space practically unfeasible. The variational approach exploits sampling from an auxiliary "inference" distribution Q(z|X), hopefully producing values for z more likely to effectively contribute to the (re)generation of X. The relation between P (X) and E z∼Q(z|X) P (X|z) is given by the following equation, where KL denotes the Kulback-Leibler divergence: KL-divergence is always positive, so the term on the right provides a lower bound to the loglikelihood P (X), known as Evidence Lower Bound (ELBO). If Q(z|X) is a reasonable approximation of P (z|X), the quantity KL(Q(z)||P (z|X)) is small; in this case the loglikelihood P (X) is close to the Evidence Lower Bound: the learning objective of VAEs is the maximization of the ELBO.
In traditional implementations, we additionally assume that Q(z|X) is normally distributed around an encoding function µ θ (X), with variance σ 2 θ (X); similarly P (X|z) is normally distributed around a decoder function d θ (z). The functions µ θ , σ 2 θ and d θ are approximated by deep neural networks. Knowing the variance of latent variables allows sampling during training.
Provided the model for the decoder function d θ (z) is sufficiently expressive, the shape of the prior distribution P (z) for latent variables can be arbitrary, and for simplicity we may assumed it is a normal distribution P (z) = G(0, 1). The term KL(Q(z|X)||P (z) is hence the KL-divergence between two Gaussian distributions G(µ θ (X), σ 2 θ (X)) and G(1, 0) which can be computed in closed form: As for the term E z∼Q(z|X) log(P (X|z), under the Gaussian assumption the logarithm of P (X|z) is just the quadratic distance between X and its reconstruction d θ (z); the λ parameter balancing reconstruction error and KL-divergence can understood in terms of the variance of this Gaussian ( [9]). The problem of integrating sampling with backpropagation, is solved by the well known reparametrization trick ( [16,19]).

Generation of new samples
The whole point of VAEs is to force the generator to produce a marginal distribution 1 Q(z) = E X Q(z|X) close to the prior P (z). If we average the Kullback-Leibler regularizer KL(Q(z|X)||P (z)) on all input data, and expand KL-divergence in terms of entropy, we get: The cross-entropy between two distributions is minimal when they coincide, so we are pushing Q(z) towards P (z). At the same time, we try to augment the entropy of each Q(z|X); under the assumption that Q(z|X) is Gaussian, this amounts to enlarge the variance, further improving the coverage of the latent space, essential for generative sampling (at the cost of more overlapping, and hence more confusion between the encoding of different datapoints). Since our prior distribution is a Gaussian, we expect Q(z) to be normally distributed too, so the mean µ should be 0 and the variance σ 2 should be 1. If Q(z|X) = N (µ(X), σ 2 (X)), we may look at Q(z) = E X Q(z|X) as a Gaussian Mixture Model (GMM). Then, we expect and especially, assuming the previous equation (see [2] for details), This rule, that we call variance law, provides a simple sanity check to test if the regularization effect of the KL-divergence is properly working.
The fact that the two first moments of the marginal inference distribution are 0 and 1, does not imply that it should look like a Normal. The possible mismatching between Q(z) and the expected prior P (z) is indeed a problematic aspect of VAEs that, as observed in several works [13,20,2] could compromise the whole generative framework. To fix this, some works extend the VAE objective by encouraging the aggregated posterior to match P (z) [21] or by exploiting more complex priors [14,22,4].
In [8] (that is the current state of the art), a second VAE is trained to learn an accurate approximation of Q(z); samples from a Normal distribution are first used to generate samples of Q(z), that are then fed to the actual generator of data points. Similarly, in [10], the authors try to give an ex-post estimation of Q(z), e.g. imposing a distribution with a sufficient complexity (they consider a combination of 10 Gaussians, reflecting the ten categories of MNIST and Cifar10).

The balancing problem
As we already observed, the problem of correctly balancing reconstruction error and KL-divergence in the loss function has been the object of several investigations. Most of the approaches were based on empirical evaluation, and often required manual hand-tuning of the relevant parameters. A more theoretical approach has been recently pursued in [8] The generative loss (GL), to be summed with the KL-divergence, is defined by the following expression (directly borrowed from the public code 2 ): where mse is the mean square error on the minibatch under consideration and γ is a parameter of the model, learned during training. The previous loss is derived in [8] by a complex analysis of the VAE objective function behavior, assuming the decoder has a gaussian error with variance γ 2 , and investigating the case of arbitrarily small but explicitly nonzero values of γ 2 . Since γ has no additional constraints, we can explicitly minimize it in equation 7. The derivative GL of GL is having a zero for γ 2 = mse, corresponding to a minimum for equation 7.
This suggests a very simple deterministic policy for computing γ instead of learning it: just use the current estimation of the mean square error. This can be easily computed as a discounted combination of the mse relative to the current minibatch with the previous approximation: in our implementation, we just take the minimum between these two values, in order to have a monotically decreasing value for γ (we work with minibatches of size 100, that is sufficiently large to provide a reasonable approximation of the real mse). Updating is done at every minibatch of samples.
Compared with the original approach in [8], the resulting technique is both faster and more accurate.
An additional contribution of our approach is to bring some light on the effect of the balancing technique in [8]. Neglecting constant addends, that have no role in the loss function, the total loss function for the VAE is simply: So, computing gamma according to the previous estimation of mse has essentially the effect of keeping a constant balance between reconstruction error and KLdivergence during the whole training: as mse is decreasing, we normalize it in order to prevent a prevalence of the KL-component, that would forbid further improvements of the quality of reconstructions.

Empirical evaluation
We compared our proposed Two Stage VAE with computed γ against the original model with learned γ using the same network architectures. In particular, we worked with many different variants of the so called ResNet version, schematically described in Figure 1 (pictures are borrowed from [8]). In all our experiments, we used a batch size of 100, and adopted Adam with default TensorFlow's hyperparameters as optimizer. Other hyperparameters, as well as additional architectural details will be described below, where we discuss the cases of Cifar and CelebA separately.
In general, in all our experiments, we observed a high sensibility of Fid scores to the learning rate, and to the deployment of auxiliary regularization techniques. As we shall discuss in Section 5, modifying these training configurations may easily result in a different number of inactive 3 latent variables at the end of training. Having both too few or too many active variables may eventually compromise generative sampling, for opposite reasons: few active variables usually compromise (A) Scale block (B) Encoder (C) decoder We mostly worked with a single residual block; two or more blocks makes the architecture sensibly heavier and slower to train, with no remarkable improvement (B) Encoder: the input is first transformed by a convolutional layer into and then passed to a chain of Scale blocks; after each Scale block, input is downsampled with a a convolutional layer with stride 2 channels are doubled. After N Scale blocks, the feature map is flattened to a vector. and then fed to another Scale Block composed by fully connected layers of dimension 512. The output of this Scale Block is used to produce mean and variances of the k latent variables. Following [8], N = 3 and k = 64 for CIFAR-10. For CelebA, we tested many different configurations. (C) Decoder: the latent representation z is first passed through a fully connected layer, reshaped to 2D, and then passed through a sequence of deconvolutions halving the number of channels at the same.
reconstruction quality, but an excessive number of active variables makes controlling the shape of the latent space sensbibly harder. The code is available on GitHub 4 . Checkpoints for Cifar10 and CelebA are available at the project's page 5 .

Cifar10
For Cifar10, we got relatively good results with the basic ResNet architecture with 3 Scale Blocks, a single Resblock for every Scaleblock, and 64 latent variables. We trained our model for 700 epochs on the first VAE and 1400 epochs on the second VAE; the initial learning rate was 0.0001, halving it every 200 epochs on the first VAE and every 100 epochs on the second VAE. Details about the evolution of reconstruction and generative error during training are provided in Figure 2 and Table 1. The data refer to ten different but "uniform" trainings ending with the same number of active latent variables, (17 in this case). Few pathological trainings resulting in less or higher sparsity (and worse FID scores) have been removed from the statistic.  In Table 2), we compare our approach with the original version with learned γ [8]. Since some people had problems in replicating the results in [8] (see the dis-cussion on OpenReview 6 ), we repeated the experiment (also in order to compute the reconstruction FID). Using the learning configuration suggested by the authors, namely 1000 epochs for the first VAE, 2000 epochs for the second one, initial learning rate equal to 0.0001, halved every 300 and 600 epochs for the two stages, respectively, we obtained results essentially in line with those declared in [8].
For the sake of completeness, we also compare with the FID scores for the recent RAE-l2 model [10] (variance was not provided by authors). In this case, the comparison is purely indicative, since in [10] they work, in the CIFAR-10 case, with a latent space of dimension 128. This also explains their particularly good reconstruction error, and the few training epochs.

CelebA
In the case of CelebA, we had more trouble in replicating the results of [8], although we were working with their own code. As we shall see, this was partly due to a mistake on our side, that pushed us to an extensive investigation of different architectures. In Table 3 we summarize some of the results we obtained, over a large variety of different network configurations. The metrics given in the table refer to the following models: All models have been trained with Adam, with an initial learning rate of 0.0001, halved every 48 epochs in the first stage and every 120 epochs in the second stage. According to the results in Table 3, we can do a few noteworthy observations: 1. for a given model, the technique computing γ systematically outperforms the version learning it, both in reconstruction and generation on both stages; 2. after the first 40 epochs, FID scores (comprising reconstruction FID) do not seem to improve any further, and can even get worse, in spite of the fact that the mean square error keep decreasing; this is in contrast with the intuitive idea that FID REC score should be proportional to mse; 3. the variance law is far from one, that seems to suggest Kl is too weak, in this case; this justifies the mediocre generative scores of the first stage, and the sensible improvement obtained with the second stage; 4. l2-regularization, as advocated in [10], seems indeed to have some beneficial effect.
We spent quite a lot of time trying to figure out the reasons of the discrepancy between our observations, and the results claimed in [8]. Inspecting the elements of the dataset with worse reconstruction errors, we remarked a particularly bad quality of some of the images, resulting from the resizing of the face crop of dimension 128x128 to the canonical dimension 64x64 expected from the neural network. The resizing function used in the source code of [8] available at was the deprecated imresize function of the scipy library 7 . Following the suggestion in the documentation, we replaced the call to imresize with a call to PILLOW: numpy.array(Image.fromarray(arr).resize()) Unfortunately, and surprisingly, the default resizing mode of PILLOW is Nearest Neighbours that, as described in Figure 3, introduces annoying jaggies that sensibly deteriorate the quality of images. This probably also explains the anomalous behaviour of FID REC with respect to mean squared error. The Variational Autoencoder fails to reconstruct images with high frequency jaggies, while keep improving on smoother images. This can be experimentally confirmed by the fact that while the minimum mse keeps decreasing during training, the maximum, after a while, stabilizes. So, in spite of the fact that the average mse decreases, the overall distribution of reconstructed images may remain far from the Fig. 3 Effect of resizing mode on a few CelebA samples. Nearest Neighbours produces bad staircase effects; bilinear, that is the common choice, is particularly smooth, suiting well to VAEs; bicubic is sligtly sharper.
distribution of real images, and possibly get even more more distant.
Resizing images with the traditional bilinear interpolation produces a substantial improvement, but not sufficient to obtain the expected generative scores.
Another essential component is again the balance between reconstruction error and KL-divergence. As observed above, in the case of CelebA the KL-divergence seems too weak, as clearly testified by the moments of latent variables expressed by the variance law. As a matter of fact, in the loss function of [8], both mse and KL-divergence are computed as reduced sums, respectively over pixels and latent variables. Now, passing from CIFAR-10 to Celeba, we multiplied the number of pixels by four, passing from 32x32 to 64x64, but kept a constant number of latent variables. So, in order to keep the same balance we used for CIFAR-10, we should multiply the KL-divergence by a factor 4.
Finally, learning seems to proceed quite fast in the case of CelebA, that suggests to work with a lower initial learning rate: 0.00005. We also kept l2 regularization on downsampling and upsampling layers.
With these simple expedients, we were already able to improve on generative scores in [8], (see Table 4), but not with respect to [10].
Analyzing the moments of the distribution of latent variables generated during the second stage, we observed that the actual variance was sensibly below the expected unitary variance (around .85). The simplest solution consists in normalizing the generated latent variables, to meet the expected variance (this point is a bit outside the scope of this contribution, and will be better investigated in a forthcoming article). This final precaution caused a sudden burst in the FID score for generated images, permitting to obtain, to the best of our knowledge, the best generative scores ever produced for CelebA with a variational approach.
In Figure 4 we provide examples of randomly generated faces. Note the particularly sharp quality of the images, so unusual for variational approaches.

Discussion
The reason why the balancing policy between reconstruction error and KL-regularization addressed in [8] and revisited in this article is so effective seems to rely on its laziness in the choice of the latent representation.
A Variational Autoencoder computes, for each latent variable z and each sample X, an expected value µ z (X) and a variance σ 2 z (X) around it. During training, the variance σ 2 z (X) usually drops very fast to values close to 0, reflecting the fact that the network is highly confident in its choice of µ z (X). The KL-component in the loss function can be understood as a mechanism aimed to reduce this confidence, by forcing a not negligible variance. By effect of the KL-regularization, some latent variables may be even neglected by the VAE, inducing sparsity in the resulting encoding [3]. The "collapsed" variables have, for any X, a value of µ z (X) close to 0 and a mean variance σ 2 z (X) close 1. So, typically, at a relatively early stage of training, the mean 38.6 ± 1.0 variance E X σ 2 z (X) of each latent variable z gets either close to 0, if the variable is exploited, of close to 1 if the variable is neglected (see Figure 5). Relevant variables have a variance close to 0, while inactive variables have a variance going to 1. The picture was borrowed from [3] and is relative to the first epoch of training for a dense VAE over the MNIST data set.
Traditional balancing policies addressed in the literature start with a low value for the KL-regularization, increasing it during training. The general idea is to start privileging the quality of reconstruction, and then try to induce a better coverage of the latent space. Unfortunately, this reshaping ex post of the latent space looks hard to achieve, in practice.
The balancing property discussed in this article does the opposite: it starts attributing a relatively high importance to KL-divergence, to balance the high initial reconstruction error, progressively reducing its relevance in a way proportional to the improvement of the reconstruction. In this way, the relative importance between the two components of the loss function remains constant during training.
The practical effect is that latent variables are kept for a long time in a sort of limbo from which, one at a time, they are retrieved and put to work by the autoencoder, as soon as it realizes how they can contribute to the reconstruction.
The previous behaviour is evident by looking at the evolution of the mean variance E X σ 2 z (X) of latent variables during training (not to be confused with the variance of the mean values µ z (X), that according to the variance law should approximately be the complement to 1 of the former).
In Figure 6 we see the evolution of the variance of the 64 latent variables during the first epoch of training on the Cifar10 data set: even after a full epoch, the "status" of most latent variables is still uncertain. During the next 50 epochs, in a very slow process, some of the "dormient" latent variables are woken up by the autoencoder, causing their mean variance to move towards 0: see Figure 7.
With the progress of training, less and less variables change their status, until the process finally stabilizes.
It would be nice to think, as hinted to in [8], that the number of active latent variables at the end of training corresponds to the actual dimensionality of the data manifold. Unfortunately, this number still depends on too many external factors to justify such a claim. For instance, a mere modification of the learning rate is sensibly affecting the sparsity of the resulting latent space, Fig. 7 Evolution of the mean variance of the 64 latent variables First 50 epochs of training on Cifar10. One by one, latent variables are retrieved from the limbo (variance around 0.8) , and put to work by the autoencoder. as shown in Table 5 where we compare, for different initial learning rates (l.r.), the final number of inactive variables, FID scores, and mean square error. Specifically, a high learning rate appears to be in conflict with the lazy way we would like latent variables to be chosen for activation; this typically results in less sparsity, that is not always beneficial for generative purposes. The annoying point is that with respect to the dimensionality of the latent space with the best generative FID, activating more variables can result in a lower reconstruction error, that should not be the case if we correctly identified the datafold dimensionality.
So, while the balancing strategy discussed in this article (similarly to the one in [8]) is eventually beneficial, still could take advantage of some tuning.

Conclusions
In this article, we stressed the importance of keeping a constant balance between reconstruction error and Kullback-Leibler divergence during training of Variational Autoencoders. We did so by normalizing the reconstruction error by an estimation of its current value, derived from minibatches. We developed the technique by an investigation of the loss function used in [8], where the balancing parameter was instead learned during training. Our technique seems to outperform all previous Variational Approaches, permitting us to obtain unprecedented FID scores for traditional datasets such as CIFAR-10 and CelebA.
In spite of its relevance, the politics of keeping a constant balance does not seem to entirely solve the balancing issue, that still seems to depend from many additional factors, such as the network architecture, the complexity and resolution of the dataset, or from training parameters, such as the learning rate.
Also, the regularization effect of the KL-component must be better understood, since it frequently fails to induce the expected distribution of latent variables, possibly requiring and justifying ex-post adjustments.
Credits: All innovative ideas and results contained in this article are to be credited to the first author. The second author mostly contributed on the experimental side.