Scalable and Practical Natural Gradient for Large-Scale Deep Learning

Large-scale distributed training of deep neural networks results in models with worse generalization performance as a result of the increase in the effective mini-batch size. Previous approaches attempt to address this problem by varying the learning rate and batch size over epochs and layers, or ad hoc modifications of batch normalization. We propose Scalable and Practical Natural Gradient Descent (SP-NGD), a principled approach for training models that allows them to attain similar generalization performance to models trained with first-order optimization methods, but with accelerated convergence. Furthermore, SP-NGD scales to large mini-batch sizes with a negligible computational overhead as compared to first-order methods. We evaluated SP-NGD on a benchmark task where highly optimized first-order methods are available as references: training a ResNet-50 model for image classification on ImageNet. We demonstrate convergence to a top-1 validation accuracy of 75.4% in 5.5 minutes using a mini-batch size of 32,768 with 1,024 GPUs, as well as an accuracy of 74.9% with an extremely large mini-batch size of 131,072 in 873 steps of SP-NGD.


INTRODUCTION
A S the size of deep neural network models and the data which they are trained on continues to increase rapidly, the demand for distributed parallel computing is increasing.A common approach for achieving distributed parallelism in deep learning is to use the data-parallel approach, where the data is distributed across different processes while the model is replicated across them.When the mini-batch size per process is kept constant to increase the ratio of computation over communication, the effective mini-batch size over the entire system grows proportional to the number of processes.
When the mini-batch size is increased beyond a certain point, the generalization performance starts to degrade.This generalization gap caused by large mini-batch sizes have been studied extensively for various models and datasets [1].Hoffer et al. [2] attribute this generalization gap to the limited number of updates, and suggest to train longer.This has lead to strategies such as scaling the learning rate proportional to the mini-batch size, while using the first few epochs to gradually warmup the learning rate [3].Such methods have enabled the training for mini-batch sizes of 8K, where ImageNet [4] with ResNet-50 [5] could be trained for 90 epochs to achieve 76.3% top-1 validation accuracy in 60 minutes [6].Combining this learning rate scaling with other techniques such as RMSprop warm-up, batch normalization without moving averages, and a slow-start Manuscript received MONTH DD, 20XX; revised MONTH DD, 20XX.learning rate schedule, Akiba et al. [7] were able to train the same dataset and model with a mini-batch size of 32K to achieve 74.9% accuracy in 15 minutes.
More complex approaches for manipulating the learning rate were proposed, such as LARS [8], where a different learning rate is used for each layer by normalizing them with the ratio between the layer-wise norms of the weights and gradients.This enabled the training with a mini-batch size of 32K without the use of ad hoc modifications, which achieved 74.9% accuracy in 14 minutes (64 epochs) [8].It has been reported that combining LARS with counter intuitive modifications to the batch normalization, can yield 75.8% accuracy even for a mini-batch size of 65K [9].
The use of small batch sizes to encourage rapid convergence in early epochs, and then progressively increasing the batch size is yet another successful approach.Using such an adaptive batch size method, Mikami et al. [10] were able to train in 122 seconds with an accuracy of 75.3%, and Yamazaki et al. [11] were able to train in 75 seconds with a accuracy of 75.1%.The hierarchical synchronization of minibatches have also been proposed [12], but such methods have not been tested at scale to the extent of the authors' knowledge.
In this work, we take a more mathematically rigorous approach to close the generalization gap when large minibatches are used.Our approach builds on Natural Gradient Descent (NGD) [14], a second-order optimization method that leverages curvature information to accelerate optimization.This approach is made feasible by the use of large mini-batches, which enables stable estimation of curvature even in models with a large number of parameters.To this end, we propose an efficient distributed NGD design that scales to massively parallel settings and large minibatch sizes.In particular, we demonstrate scalability to batch sizes of 32,768 over 1024 GPUs across 256 nodes.Another unique aspect of our approach is the accuracy at which we  can approximate the Fisher information matrix (FIM) when compared to other second-order optimization methods.Unlike methods that use very crude approximations of the FIM, such as the TONGA [15], Hessian free methods [16], we adopt the Kronecker-Factored Approximate Curvature (K-FAC) method [17].The two main characteristics of K-FAC are that it converges faster than first-order stochastic gradient descent (SGD) methods, and that it can tolerate relatively large mini-batch sizes without any ad hoc modifications.K-FAC has been successfully applied to convolutional neural networks [18], distributed memory training of ImageNet [19], recurrent neural networks [20], Bayesian deep learning [21], reinforcement learning [22] , and Transformer models [23].
Our contributions are: • Extremely large mini-batch training.We were able to show for the first time that approximated NGD can achieve similar generalization capability compared to highly optimized SGD, by training ResNet-50 on ImageNet classification as a benchmark.We converged to over 75% top-1 validation accuracy for large mini-batch sizes of 4,096, 8,192, 16,384, 32,768 and 65,536.We also achieved 74.9% with an extremely large mini-batch size of 131,072, which took only 873 steps.
• Scalable natural gradient.We propose a distributed NGD design using data and model hybrid parallelism that shows superlinear scaling up to 64 GPUs.
• Practical natural gradient.We propose practical NGD techniques based on analysis of the FIM estimation in large mini-batch settings.Our practical techniques make the overhead of NGD compare to SGD almost negligible.Combining these techniques with our distributed NGD, we see an ideal scaling up to 1024 GPUs as shown in Figure 5.
Using 1024 NVIDIA Tesla V100, we achieve 75.4 % top-1 accuracy with ResNet-50 for ImageNet in 5.5 minutes (1760 steps = 45 epochs, including a validation after each epoch).The comparison is shown in Figure 1 and Table 1.
A preliminary version of this manuscript was published previously [24].Since then, the performance optimization of the distributed second-order optimization has been studied [25], and our distributed NGD framework has been applied to accelerate Bayesian deep learning with the natural gradient at ImageNet scale [26].We extend the previous work and propose Scalable and Practical Natural Gradient Descent (SP-NGD) framework, which includes more detailed analysis on the FIM estimation and significant improvements on the performance of the distributed NGD.

RELATED WORK
With respect to large-scale distributed training of deep neural networks, there have been very few studies that use second-order methods.At a smaller scale, there have been previous studies that used K-FAC to train ResNet-50 on ImageNet [19].However, the SGD they used as reference was not showing state-of-the-art Top-1 validation accuracy (only around 70%), so the advantage of K-FAC over SGD that they claim was not obvious from the results.In the present work, we compare the Top-1 validation accuracy with state-of-the-art SGD methods for large mini-batches mentioned in the introduction (Table 1).
The previous studies that used K-FAC to train ResNet-50 on ImageNet [19] also were not considering large minibatches and were only training with mini-batch size of 512 on 8 GPUs.In contrast, the present work uses minibatch sizes up to 131,072, which is equivalent to 32 per GPU on 4096 GPUs, and we are able to achieve a much higher Top-1 validation accuracy of 74.9%.Note that such large mini-batch sizes can also be achieved by accumulating the gradient over multiple iterations before updating the parameters, which can mimic the behavior of the execution on many GPUs without actually running them on many GPUs.
The previous studies using K-FAC also suffered from large overhead of the communication since they used a parameter-server approach for their TensorFlow [27] implementation of K-FAC with a single parameter-server.Since the parameter server requires all workers to send the gradients and receive the latest model's parameters from the parameter server, the parameter server becomes a huge communication bottleneck especially at large scale.Our implementation uses a decentralized approach using MPI/NCCL 1 collective communications among the processes.The decentralized approach has been used in high performance computing for a long time, and is known to scale to thousands of GPUs without modification.Although, software like Horovod 2 can alleviate the problems with parameter servers by working as a TensorFlow wrapper for NCCL, a workable realization of K-FAC requires solving many engineering and modeling challenges, and our solution is the first one that succeeds on a large scale task.

Mini-batch Stochastic Learning
Throughout this paper, we use E[•] to denote the empirical expectation among the samples in the mini-batch {(x, t)}, and compute the cross-entropy loss for a supervised learning as where x, t are the training input and label (one-hot vector), p θ (t|x) is the likelihood of each sample (x, t) calculated by the probabilistic model using a feed-forward deep neural network (DNN) with the parameters θ ∈ R N .For the standard mini-batch stochastic gradient descent (SGD), the parameters θ are updated based on the gradient of the loss function at the current point: where η > 0 is the learning rate.

Natural Gradient Descent in Deep Learning
Natural Gradient Descent (NGD) [14] is an optimizer which updates the parameters using the first-order gradient of the loss function preconditioned by the Fisher information matrix (FIM) of the probabilistic model: The FIM F ∈ R N ×N of a DNN with the learnable parameter θ ∈ R N is defined as: E v [•] is an expectation w.r.t. the random variable v, and q is the training data distribution.To limit the step size, a damping value λ > 0 is added to the diagonal of F before inverting it.In the training of DNNs, the FIM may be thought of as the curvature matrix in parameter space [14], [17], [28].
To realize an efficient NGD training procedure for deep neural networks, we make the following approximations to the FIM: • Layer-wise block-diagonal approximation.We assume that the correlation between parameters in different layers (Figure 2) is negligible and can be ignored.This assumption significantly reduces the computational cost of inverting F especially when N is large.
• Stochastic natural gradient.We approximate the expectation over the input data distribution E x∼q [•] using the empirical expectation over a mini-batch E[•].This enables estimation of F during mini-batch stochastic learning.
• Monte Carlo estimation.We approximate the expectation over the model predictive distribution E y∼p θ [•] using a single Monte Carlo sample (a single backward-pass).We note that for a K-class classification model, K backward-passes are required to approximate F .Using these approximations, we estimate the FIM F ∈ R N ×N for the -th layer using a Monte Carlo sample y ∼ p θ (y|x) for each input x in a mini-batch as With this F , the parameters w ∈ R N for the -th layer are then updated using the FIM preconditioned gradients: Here ∇ w L (t) ∈ R N denotes the gradient of the loss function w.r.t.w for w = w (t) .
1 < l a t e x i t s h a 1 _ b a s e 6 4 = " 8 j w y u 8 e a n 3 4 r 1 7 H 4 v W g p f P H M M f e J 8 / g + C P p g = = < / l a t e x i t > ⇡ < l a t e x i t s h a 1 _ b a s e 6 4 = " Y F C H K 6 W s I t i B x 7 h Y z t q r 3 r Y T J R k = " > A A A B 7 n i c b V B N S w M x E J 2 t X 7 V + V T 1 6 C R b B U 9 n V g n o r e P F Y w X 5 A u 5 R s m m 1 D s 0 l I s m J Z + i O 8 e F D E q 7 / H m / / G t N 2 D t j 4 Y e L w 3 w 8 y 8 S H F m r O 9 / e 4 W 1 9 Y 3 N r e J 2 a W d 3 b / + g f H j U M j L V h D a J 5 F J 3 I m w o Z 4 I 2 L b O c d p S m O I k 4 b U f j 2 5 n f f q T a M C k e 7 E T R M M F D w W J G s H V S u 4 e V 0 v K p X 6 7 4 V X 8 O t E q C n F Q g R 6 N f / u o N J E k T K i z h 2 J h u 4 C s b Z l h b R j i d l n q p o Q q T M R 7 S r q M C J 9 S E 2 f z c K T p z y g D F U r s S F s 3 V 3 x M Z T o y Z J J H r T L A d m W V v J v 7 n d V M b X 4 c Z E y q 1 V J D F o j j l y E o 0 + x 0 N m K b E 8 o k j m G j m b k V k h D U m 1 i V U c i E E y y + v k t Z F N b i s + v e 1 S v 0 m j 6 M I J 3 A K 5 x D A F d T h D h r Q B A J j e I Z X e P O U 9 + K 9 e x + L 1 o K X z x z D H 3 i f P 5 F X j 6 8 = < / l a t e x i t > ⇡ < l a t e x i t s h a 1 _ b a s e 6 4 = " Y F C H K 6 W s I t i B x 7 h Y z t q r 3 r Y T J R k = " > A A A B 7 n i c b V B N S w M x E J 2 t X 7 V + V T 1 6 C R b B U 9 n V g n o r e P F Y w X 5 A u 5 R s m m 1 D s 0 l I s m J Z + i O 8 e F D E q 7 / H m / / G t N 2 D t j 4 Y e L w 3 w 8 y 8 S H F m r O 9 / e 4 W 1 9 Y 3 N r e J 2 a W d 3 b / + g f H j U M j L V h D a J 5 F J 3 I m w o Z 4 I 2 L b O c d p S m O I k 4 b U f j 2 5 n f f q T a M C k e 7 E T R M M F D w W J G s H V S u 4 e V 0 v K p X 6 7 4 V X 8 O t E q C n F Q g R 6 N f / u o N J E k T K i z h 2 J h u 4 C s b Z l h b R j i d l n q p o Q q T M R 7 S r q M C J 9 S E 2 f z c K T p z y g D F U r s S F s 3 V 3 x M Z T o y Z J J H r T L A d m W V v J v 7 n d V M b X 4 c Z E y q 1 V J D F o j j l y E o 0 + x 0 N m K b E 8 o k j m G j m b k V k h D U m 1 i V U c i E E y y + v k t Z F N b i s + v e 1 S v 0 m j 6 M I J 3 A K 5 x D A F d T h D h r Q B A J j e I Z X e P O U 9 + K 9 e x + L 1 o K X z x z D H 3 i f P 5 F X j 6 8 = < / l a t e x i t > 1 < l a t e x i t s h a 1 _ b a s e 6 4 = " 8 j Fig. 2. Illustration of Fisher information matrix approximations for feed-forward deep neural networks used in this work.

K-FAC
Kronecker-Factored Approximate Curvature (K-FAC) [17] is a second-order optimization method for deep neural networks, that is based on an accurate and mathematically rigorous approximation of the FIM.Using K-FAC, we further approximate the FIM F for -th layer as a Kronecker product of two matrices (Figure 2): This is called Kronecker factorization and G , A −1 are called Kronecker factors.G is computed from the gradient of the loss w.r.t. the output of the -th layer, and A −1 is computed from the activation of the ( − 1)-th layer (the input of -th layer).
The definition and the sizes of the Kronecker factors G /A −1 depend on the dimension of the output/input and the type of layer [17], [18], [20], [23].

K-FAC for fully-connected layers.
In a fully-connected (FC) layer of a feed-forward DNN, the output s ∈ R d is calculated as where a −1 ∈ R d −1 is the input to this layer (the activation from the previous layer), and W ∈ R d ×d −1 is the weight matrix (the bias is ignored for simplicity).The Kronecker factors for this FC layer are defined as and [17].From this definition, we can consider that K-FAC is based on an assumption that the input to the layer and the gradient w.r.t. the layer output are statistically independent.

K-FAC for convolutional layers
In a convolutional (Conv) layer of a feed-forward DNN, the output S ∈ R c ×h ×w is calculated as where is the input to this layer, and W ∈ R c ×c −1 k 2 is the weight matrix (the bias is ignored for simplicity).c , c −1 are the number of output, input channels, respectively, and k is the kernel size (assuming square kernels for simplicity).The Kronecker factors for this Conv layer are defined as and

Inverting Kronecker-factored FIM
By the property of the Kronecker product and the Tikhonov damping technique used in [17], the inverse of F + λI is approximated by the Kronecker product of the inverse of each Kronecker factor where π 2 is the average eigenvalue of A −1 divided by the average eigenvalue of G .π > 0 because both G and A −1 are positive-semidefinite matrices as defined above.

PRACTICAL NATURAL GRADIENT
K-FAC for FC layer (9) and Conv layer (11) enables us to realize NGD in training deep ConvNets [18], [19].For deep and wide neural architectures with a huge number of learnable parameters, however, due to the extra computation for the FIM, even NGD with K-FAC has considerable overhead compared to SGD.In this section, we introduce practical techniques to accelerate NGD for such huge neural architectures.Using our techniques, we are able to reduce the overhead of NGD to almost a negligible amount as shown in Section 7.
Algorithm 1: Natural gradient with the stale statistics.
input : set of the statistics S (damped inverses) input : initial parameters θ output:

Fast Estimation with Empirical Fisher
Instead of using an estimation by a single Monte Carlo sampling defined in Eq. ( 5) ( F ,1mc ), we adopt the empirical Fisher [17] to estimate the FIM F : We implemented an efficient F ,emp computation in the Chainer framework [29] that allows us to compute F ,emp during the forward-pass and the backward-pass for the loss L(θ) 3 .Therefore, we do not need an extra backward-pass to compute F ,emp , while an extra backward-pass is necessary for computing F ,1mc .This difference is critical especially for a deeper network which takes longer time for a backwardpass.
3. The same empirical Fisher computation can be implemented on PyTorch.
< l a t e x i t s h a 1 _ b a s e 6 4 = " u 6 j O j r r P r r H 6

k p z W P 5 g N O E e R E Z S R 5 y S t B I j + h n P T + 7 c P P c r 9 a c u j O H
e X H e n Y 9 l a 8 H J Z 0 7 h D 5 z P H z 5 c j b s = < / l a t e x i t > t < l a t e x i t s h a 1 _ b a s e 6 4 = " Z X < l a t e x i t s h a 1 _ b a s e 6 4 = " s 0 I f m + D Q T J 9 g G H p N 5 u 0 P c d i S 5 u 0 = " > A A A B + 3 i c b V B N S 8 N A E N 3 4 W e t X r E c v i 0 X w V B I V 7 L H g x W M F + w F t C J v t t l 2 6 u w m 7 E 2 k J + S t e P C j i 1 T / i z X / j t s 1 B W x 8 M P N 6 b Y W Z e l A h u w P O + n Y 3 N r e 2 d 3 d J e e f / g 8 O j Y P a m 0 T Z x q y l o 0 b N 1 w i p l T 9 A f O 5 w / J 9 J T j < / l a t e x i t > Fig. 3. Estimate the interval (steps) ∆ until the next step t X next to refresh the statistics X.
Although it is insisted that F ,emp is not a proper ap- proximation of the FIM, and F ,1mc is better estimation in the literature [30], [31], we do not see any difference in the convergence behavior nor the final model accuracy between NGD with F ,emp and that with F ,1mc in training deep ConvNets for ImageNet classification as shown in Section 7.

Practical FIM Estimation for BatchNorm Layers
It is often the case that a deep ConvNet has Conv layers that are followed by BatchNorm layers [32].The -th BatchNorm layer after the ( −1)-th Conv layer has scale γ ∈ R c −1 and bias β ∈ R c −1 to be applied to the normalized features.When we see these parameters as learnable parameters, we can define where γ ,i and β ,i are the i-th element of γ and β , respectively (i = 1, . . ., c −1 ).For this BatchNorm layer, the FIM F (5) is a 2c −1 × 2c −1 matrix, and the computation cost of inverting this matrix can not be ignored when c −1 (the number of output channels of the previous Conv layer) is large (e.g.c out = 1024 for a Conv layer in ResNet-50 [5]).

Unit-wise Natural Gradient
We approximate this FIM by applying unit-wise natural gradient [33] to the learnable parameters of BatchNorm layers (Figure 2).A "unit" in a neural network is a collection of input/output nodes connected to each other.The "unitwise" natural gradient only takes into account the correlation of the parameters in the same node.Hence, for a BatchNorm layer, we only consider the correlation between γ c and β ,c of the same channel c: where are the i-th element of ∇ γ log p θ (y|x), ∇ β log p θ (y|x), respectively.The number of the elements to be computed and communicated is significantly reduced from 4c 2 −1 to 4c −1 , and we can get the inverse (F ,unit BN + λI) −1 with little computation cost using the inverse matrix formula: We observed that the unit-wise approximation on F of BatchNorm does not affect the accuracy of a deep ConvNet on ImageNet classification as shown in Section 7.

Natural Gradient with Stale Statistics
To achieve even faster training with NGD, it is critical to utilize stale statistics in order to avoid re-computing the matrices A −1 , G and F ,unitBN (Figure 2) at every step.Previous work on K-FAC [17] used a simple strategy where they refresh the Kronecker factors only once in 20 steps.However, as observed in [24], the statistics rapidly fluctuates at the beginning of training, and this simple strategy causes serious defects to the convergence.It was also observed that the degree of fluctuation of the statistics depends on the mini-batch size, the layer, and the type of the statistics (e.g, statistics with a larger mini-batch fluctuates less than that with a smaller mini-batch).Although the previous strategy [24] to reduce the frequency worked without any degradation of the accuracy, it requires the prior observation on the fluctuation of the statistics, and its effectiveness on training time has not been well studied.

Adaptive Frequency To Refresh Statistics
We propose an improved strategy which adaptively determines the frequency to refresh each statistics based on its staleness during the training.Our strategy is shown in Algorithm 1.We calculate the timing (step) to refresh each statistics based on the acceptable interval (steps) estimated in Algorithm 2 (Figure 3).In Algorithm 2, matrix A is considered to be similar to matrix B when A − B F / B F < α, where • F is Frobenius norm, and α > 0 is the threshold 4 .We tuned α by running training for a few epochs to check if the threshold is not too large (preserves the same convergence).And it aims to find the acceptable interval ∆ where the statistics X calculated at step t = t X +∆ is similar to that calculated at step t = t X .With this strategy, we can estimate the acceptable interval for each statistics and skip the computation efficiently -it keeps almost the same training time as the original, while reducing the cost for constructing and inverting A −1 , G and F ,unitBN as much as possible.This significantly reduces the overhead of NGD.We observe the effectiveness of our approach in Section 7.

SCALABLE NATURAL GRADIENT
Based on the practical estimation of the FIM proposed in the previous section, we designed a distributed parallelization scheme among multiple GPUs so that the overhead of NGD compare to SGD decreases as the number of GPUs (processes) is increased.In Stage 2, two procedures are performed in parallel -communication among all the processes and a backwardpass in each process.Since Stage 1 is done in a dataparallel fashion, each process computes the statistics only for the different parts of the mini-batch.In order to compute these statistics for the entire mini-batch, we need to average these statistics over all the processes.This is performed using a ReduceScatterV collective communication, which transitions our approach from data-parallelism to modelparallelism by reducing (taking the sum of) A −1 for different to different processes.This collective is much more efficient than an AllReduce , where A −1 for all are reduced to all the processes (Figure 4).While A −1 is communicated, each process also performs a backward-pass to get the gradient ∇ w L, the Kronecker factor G for Conv, FC layers, and F ,unitBN for BatchNorm layer, for each .

Distributed Natural Gradient
In Stage 3, G ,F ,unitBN and ∇ w L are communicated in the same way as A by ReduceScatterV collective.At this point, only a single process has the FIM estimation F and the gradient ∇ w L with the statistics for the entire minibatch for the -th layer.In Stage 4, only the process that has the FIM computes the matrix inverse and applies the NGD update (6) to the weights w of the -th layer.Hence, these computations are performed in a model-parallel fashion.When the number of layers is larger than the number of processes, multiple layers are handled by a process.
< l a t e x i t s h a 1 _ b a s e 6 4 = " C s X z t + m x 1 A v q Z d e 2 a b P 3 3 l e d F j E p l 2 y / G 0 v L K 6 t p 6 Y c P c 3 N r e 2 b X 2 9 p u S x w K T B u a M i 7 a H J G E 0 J A 1 F F S P t S B A U e I y 0 v N F l 5 r f u i J C U h 7 d q H J F u g A Y h 9 S l G S k s 9 q 1 N 2 A 4 8 / J K 7 H W T 9 A a g h L 9 6 W 0 l z i p 6 f a 5 k u Z i 2 y W M / V 9 x n Z 7 0 r K J d s S e A 8 8 T J S R H k q P e s Z z 0 R x w E J F W Z I y o 5 j R 6 q b I K < l a t e x i t s h a 1 _ b a s e 6 4 = " C s X z t + m x 1 A v q Z d e 2 a b P 3 3 l e d F j E p l 2 y / G 0 v L K 6 t p 6 Y c P c 3 N r e 2 b X 2 9 p u S x w K T B u a M i 7 a H J G E 0 J A 1 F F S P t S B A U e I y 0 v N F l 5 r f u i J C U h 7 d q H J F u g A Y h 9 S l G S k s 9 q 1 N 2 A 4 8 / J K 7 H W T 9 A a g h L 9 6 W 0 l z i p 6 f a 5 k u Z i 2 y W M / V 9 x n Z 7 0 r K J d s S e A 8 8 T J S R H k q P e s Z z 0 R x w E J F W Z I y o 5 j R 6 q b I K < l a t e x i t s h a 1 _ b a s e 6 4 = " C s X z t + m x 1 A v q Z d e 2 a b P 3 3 l e d F j E p l 2 y / G 0 v L K 6 t p 6 Y c P c 3 N r e 2 b X 2 9 p u S x w K T B u a M i 7 a H J G E 0 J A 1 F F S P t S B A U e I y 0 v N F l 5 r f u i J C U h 7 d q H J F u g A Y h 9 S l G S k s 9 q 1 N 2 A 4 8 / J K 7 H W T 9 A a g h L 9 6 W 0 l z i p 6 f a 5 k u Z i 2 y W M / V 9 x n Z 7 0 r K J d s S e A 8 8 T J S R H k q P e s Z z 0 R x w E J F W Z I y o 5 j R 6 W 8 e s Y c 8 q + g X r 8 w u a 8 6 z 0 < / l a t e x i t > F `, r w`L < l a t e x i t s h a 1 _ b a s e 6 4 = " f I H Once the weights w of each are updated, we synchronize the updated weights among all the processes by calling an AllGatherV (Figure 4) collective, and we switch back to data-parallelism.Combining the practical estimation of the FIM proposed in the previous section, we are able to reduce a significant amount of communication required for the Kronecker factors A −1 , G and F ,unitBN .Therefore, the amount of communication for our distirbuted NGD is similar to distributed SGD, where the AllReduce for the gradient ∇ w L is implemented as a ReduceScatter+AllGather.

Further acceleration
Our data-parallel and model-parallel hybrid approach allows us to minimize the overhead of NGD in a distributed setting.However, NGD still has a large overhead compared to SGD.There are two hotspots in our distributed NGD design.The first is the construction of the statistics A −1 , G , and F ,unitBN , that cannot be done in a model-parallel fashion.The second is the communication (ReduceScatterV) for distributing these statistics.In this section, we discuss how we accelerate these two hotspots to achieve even faster training time.
Mixed-precision computation.We use the Tensor Cores in the NVIDIA Volta Architecture 5 .This more than doubles the speed of the calculation for this part.One might think that this low-precision computation affects the overall accuracy of the training, but in our experiments we do not find any differences between training with half-precision floating point computation and that with full-precision floating point computation.

Symmetry-aware communication. The statistics matrices
A −1 , G , and F ,unitBN are symmetric matrices.We exploit this property to reduce the amount of communication without loss of information.To communicate a symmetric matrix of size N × N , we only need to send the upper triangular matrix with N (N + 1)/2 elements.
In addition to these two optimizations, we also adopted the performance optimizations done by [25]: • Explicitly use NHWC (mini-batch, height, width, and channels) format for the input/output data (tensor) 5. https://www.nvidia.com/en-us/data-center/tensorcore/ of Conv layers instead of NCHW format.This makes cuDNN API to fully benefit from the Tensor Cores.
• Data I/O pipeline using the NVIDIA Data Loading Library (DALI) 6 .
• Hierarchical AllReduce collective proposed by Ueno et al. [34], which alleviates the latency of the ring-AllReduce communication among a large number of GPUs.

TRAINING FOR IMAGENET CLASSIFICATION
The behavior of NGD on large models and datasets has not been studied in depth.Also, there are very few studies that use NGD (K-FAC) for large mini-batches (over 4K) using distributed parallelism at scale [19].Contrary to SGD, where the hyperparameters have been optimized by many practitioners even for large mini-batches, there is very little insight on how to tune hyperparameters for NGD.In this section, we have explored some methods, which we call training schemes, to achieve higher accuracy in our experiments.In this section, we show those training schemes in our large mini-batch training with NGD for ImageNet classification.

Data augmentation
To achieve good generalization performance while keeping the benefit of the fast convergence that comes from NGD, we adopt the data augmentation techniques commonly used for training networks with large mini-batch sizes.We resize all the images in ImageNet to 256×256 ignoring the aspect ratio of original images and compute the mean value of the upper left portion (224 × 224) of the resized images.When reading an image, we randomly crop a 224 × 224 image from it, randomly flip it horizontally, subtract the mean value, and scale every pixel to [0, 1].
Running mixup.We extend mixup [35] to increase its regularization effect.We synthesize virtual training samples from raw samples and virtual samples from the previous 6.https://developer.nvidia.com/DALIstep (while the original mixup method synthesizes new samples only from the raw samples): x (t) , t (t) is a raw input and label (one-hot vector), and x(t) , t(t) is a virtual input and label for t th step.λ is sampled from the Beta distribution with the beta function where we set α = β = α mixup .
Random erasing with zero value.We also implemented Random Erasing [36].We set elements within the erasing region of each input to zero instead of a random value as used in the original method.We set the erasing probability p = 0.5, the erasing area ratio S e ∈ [0.02, 0.25], and the erasing aspect ratio r e ∈ [0.3, 1].We randomly switch the size of the erasing area from (H e , W e ) to (W e , H e ).

Learning rate and momentum
The learning rate used for all of our experiments is scheduled by polynomial decay.The learning rate η (e) for e th epoch is determined as follows: η (0) is the initial learning rate and e start , e end is the epoch when the decay starts and ends.The decay rate p decay guides the speed of the learning rate decay.We use the momentum method for NGD updates.Because the learning rate decays rapidly in the final stage of the training with the polynomial decay, the current update can become smaller than the previous update.We adjust the momentum rate m (e) for e th epoch so that the ratio between m (e) and η (e) is fixed throughout the training: where m (0) is the initial momentum rate.The weights are updated as follows: where v (t) = w (t) − w (t−1) .

Weights rescaling
To prevent the scale of weights from becoming too large, we adopt the Normalizing Weights [37] technique to the w of FC and Conv layers.We rescale the w to have a norm √ 2 • d out after (23): where we use = 1 • 10 −9 to stabilize the computation.d out is the output dimension or channels of the layer.

EXPERIMENTS
We train ResNet-50 [5] for ImageNet [4] in all of our experiments.We use the same hyperparameters for the same mini-batch size.The hyperparameters for our results are shown in Table 2.We implement all of our methods using Chainer [29].Our Chainer extenstion is available at https://github.com/tyohei/chainerkfac.We initialize the weights by the HeNormal initializer of Chainer 7 with the default parameters.

Experiment Environment
We conduct all experiments on the ABCI (AI Bridging Cloud Infrastructure) 8

Scalability
We measure the scalability of our distributed NGD implementation for training ResNet-50 on ImageNet.Figure 5 shows the time for one step with different number of GPUs and different techniques proposed in Section 4. Note that since we fix the number of images to be processed per GPU (=32), doubling the number of GPUs means doubling the total number of images (mini-batch size) to be processed in a step (e.g, 32K images are processed with 1024 GPUs in a step).In a distributed training with multiple GPUs, it is considered ideal if this plot shows a flat line parallel to the x-axis, that is, the time per step is independent of the number of GPUs, and the number of images processed in a certain time increases linearly with the number of GPUs.From 1 GPU to 64 GPUs, however, we observed a superlinear scaling.For example, the time per step with 64 GPUs is 300% faster than that with 1 GPU for emp+fullBN.This is the consequence of our model-parallel design since ResNet-50 has 107 layers in total when all the Conv, FC, and Batch-Norm layers are accounted for.With more than 128 GPUs, we observe slight performance degradation due to the communication overhead comes from ReduceScatterV and AllGatherV collective.Yet for emp+unitBN+stale, we see almost the ideal scaling from 128 GPUs to 1024 GPUs.
Moreover, with 512 GPUs, which corresponds to BS=16K, we see a superlinear scaling, again.We discuss this in the next sub-section.

Effectiveness of Practical Natural Gradient
We examine the effectiveness of our practical NGD approaches proposed in Section 4 for training ResNet-50 on ImageNet with extreamely large mini-batch.We show that our practical techniques makes the overhead of NGD close to a negligible amount and improves training time significantly.The summary of the training time is shown in Table 1 and Figure 1.
Natural Gradient by Empirical Fisher.We compare the time and model accuracy in training by NGD with empirical Fisher and that with a Fisher estimation by a single Monte Carlo sampling ( F ,emp vs F ,1mc ) .In Figure 5, the time per a step by each training is labeled as emp and 1mc, respectively.Due to the extra backward-pass required for constructing F ,1mc , 1mc is slower than emp at any number of GPUs.We do not see any difference in the convergence behavior (accuracy vs steps) and the final accuracy for training ResNet-50 on ImageNet with BS={4K,8K,16K,32K,65K,131K}.Note that we used the same hyperparameters tuned for emp for each BS (shown in Table 2) for the limitation of computational resource to tune for 1mc.
Unit-Wise Natural Gradient.We also compare training with natural gradient and that with unit-wise natural gradient on BatchNorm parameters (F vs F ,unitBN ) .In Figure 5, the time per step for each method is labeled as fullBN and unitBN, respectively.From  is better) of the communication volume for the statistics, and speedup (emp+unitBN vs emp+unitBN+stale) is shown in Table 2.The communication amount (bytes) in the ReduceScatterV collective in each step during a training and the reduction rate are plotted in Figure 6.With BS=16K,32K, we can reduce the communication amount for the statistics (A −1 , G and F ) to 5.4%,7.8%,respectively.
We might be able to attribute this significant reduction rate to the fact that the statistics with larger BS (16K,32K) is more stable than that with smaller BS (4K,8K).Note that though we show the reduction rate of the amount of communication, this rate is also applicable to estimate the reduction rate of the amount of computation for the statistics, and the cost for inverting them is also removed.With these improvements on NGD, we see almost an ideal scaling from 128 GPUs to 1024 GPUs, which corresponds to BS=4K to 32K.

Training ResNet-50 on ImageNet with NGD in 5.5 minutes
Finally, we combine all the practical techniques -empirical Fisher, unit-wise NGD and NGD with stale statistics.Using 1024 NVIDIA Tesla V100, we achieve 75.4 % top-1 accuracy with ResNet-50 for ImageNet in 5.5 minutes (1760 steps = 45 epochs, including a validation after each epoch).We used the same hyperparameters shown in Table 2.The training time and the validation accuracy are competitive with the results reported by related work that use SGD for training (the comparison is shown in Table 1).We refer to our training method as Scarable and Practical NGD (SP-NGD).

DISCUSSION AND FUTURE WORK
In this work, we proposed a Scalable and Practical Natural Gradient Descent (SP-NGD), a framework which combines i) a large-scale distributed computational design with data and model hybrid parallelism for the Natural Gradient Descent (NGD) [14] and ii) practical Fisher information estimation techniques including Kronecker-Factored Approximate Curvature (K-FAC) [17], that alleviates the computational overhead of NGD over SGD.Using our SP-NGD framework, we showed the advantages of the NGD over firstorder stochastic gradient descent (SGD) for training ResNet-50 on ImageNet classification with extremely large minibatches.We introduced several schemes for the training using the NGD with mini-batch sizes up to 131,072 and achieved over 75% top-1 accuracy in much fewer number of steps compared to the existing work using the SGD with large mini-batch.Contrary to prior claims that models trained with second-order methods do not generalize as well as the SGD, we were able to show that this is not at all the case, even for extremely large mini-batches.Our SP-NGD framework allowed us to train on 1024 GPUs and achieved 75.4% in 5.5 minutes.This is the first work which observes the relationship between the FIM of ResNet-50 and its training on large mini-batches ranging from 4K to 131K.The advantage that we have in designing better optimizers by taking this approach is that we are starting from the most mathematically rigorous form, and every improvement that we make is a systematic design decision based on observation of the FIM.Even if we end up having similar performance to the best known first-order methods, at least we will have a better understanding of why it works by starting from second-order methods including NGD.
More accurate and efficient FIM estimation.We showed that NGD using the empirical Fisher matrix [17] (emp) is much faster than that with an estimation using a single Monte Carlo (1mc), which is widely used by related work on the approximate natural gradient.Although it is stated that emp is not a good approximation of the NGD in the literature [30], [31], we observed the same convergence behavior as 1mc for training ResNet-50 on ImageNet.We might be able to attribute this result to the fact that emp is a good enough approximation to keep the behavior of the true NGD or that even 1mc is not a good approximation.To know whether these hypotheses are correct or not and to examine the actual value of the true NGD, we need a more accurate and effcient estimation of the NGD with less computational cost.
Towards Bayesian deep learning.NGD has been applied to Bayesian deep learning for estimating the posterior distribution of the network parameters.For example, K-FAC [17] has been applied to Bayesian deep learning to realize Noisy Natural Gradient [21], and our distributed NGD has been applied to that at ImageNet scale [26].We similarly expect that our SP-NGD framework will accelerate Bayesian deep learning research using natural gradient methods.

Fig. 1 .
Fig. 1.Top-1 validation accuracy vs the number of steps to converge (left) and vs training time (right) of ResNet-50 on ImageNet (1000 class) classification by related work with SGD and this work with Scalable and Practical NGD (SP-NGD).
9 k X d u 6 y 7 D 1 e 1 x m 0 R R x l O 4 B T O w Y N r a M A 9 N K E F D M b w D K / w 5 k T O i / P u f C x a S 0 4 x c w x / 4 H z + A A 4 Q j j 0 = < / l a t e x i t > ⌦ < l a t e x i t s h a 1 _ b a s e 6 4 = " / A o p N i + U v X m b n e 7 B 0 2 / 2 U z b A T F 0 = " > A A A B 7 n i c b V D L S g N B E O y N r x h f U Y 9 e B o P g K e y q o N 4 C X j x G M A 9 I l j A 7 m U 2 G z M 4 s M 7 1 C C P k I L x 4 U 8 e r 3 e P N v n C R 7 0 M S C h q K q m + 6 u K J X C o u 9 / e 4 W 1 9 Y 3 N r e J 2 a W d 3 b / + g f H j U t D o z j D e Y l t q 0

1 <
2 H P L S d o 7 P 1 r D c S / / O a s e m W W g k V U W y I w J N F 3 Z h B I + E o H 9 i h i m D D h p Y g r K i 9 F e I + U g g b m 2 L O h u D P v j x P a i d F / 7 T o 3 Z 7 l y 6 V p H F m w B / b B I f D B O S i D a 1 A B V Y D B I 3 g G r + D N e X J e n H f n Y 9 K a c a Y z u + A P n M 8 f x M u d e w = = < / l a t e x i t > l a t e x i t s h a 1 _ b a s e 6 4 = " 8 j L s h a 1 _ b a s e 6 4 = " 7 w Y G O 1 I 3 K v 2 v S B n e V G 5 I P S o m C N U = " > A A A B 6 n i c b V B N S 8 N A E J 3 U r 1 q / q h 6 9 L B b B U 0 m 0 Y I 8 F L x 4 r 2 g 9 o Q 9 l s N + 3 S z S b s T o Q S + h O 8 e F D E q 7 / I m / / G b Z u D t j 4 Y e L w 3 w 8 y 8 I J H C o O t + O 4 W N z a 3 t n e J u a W / / 4 P

Figure 4
Figure 4 shows the overview of our design, which shows a single step of training with our distributed NGD.We use the term Stage to refer to each phase of computation and communication, which is indicated at the top of the figure.Algorithm 3 shows the pseudo code of our distributed NGD design.4. α = 0.1 for all the experiments in this work.

Algorithm 3 :
Distributed Natural Gradient while not converge do // Stage 1 foreach = 1, • • • , L do forward in -th layer if -th layer is Conv or FC then compute A −1 end // Stage 2 ReduceScatterV(A 0:L−1 ) foreach = L, • • • , 1 do backward in -th layer if -th layer is Conv or FC then compute G else if -th layer is BatchNorm then compute F ,unitBN end // Stage 3 ReduceScatterV(G 1:L /F 1:L,unitBN and ∇ w 1:L L) // Stage 4 for = 1, • • • , L do in parallel update w by natural gradient (6) end // Stage 5 AllGatherV(w 1:L ) end return θ = w 1 , . . ., w L In Stage 1, each process (GPU) receives a different part of the mini-batch and performs a forward-pass in which the Kronecker factor A −1 is computed for the received samples, if -th layer is a Conv layer or a FC layer.
x i t s h a 1 _ b a s e 6 4 = " / 3 f 6 n x D w m e R Z 4 q p 7 l R d I 9 a 9 j n D e v m o t Z q l n F U w B E 4 B q f A B p e g B a 5 B G 3 Q A B o / g G b y C N + P J e D H e j Y 9 Z 6 5 J R z h y A P z A + f w A v e Z a r < / l a t e x i t > w L < l a t e x i t s h a 1 _ b a s e 6 4 = " 1 K Z

F 1 ,
z 5 P m a c U 5 q 9 g 3 5 8 V a N Y + j A A 7 B E S g D B 1 y A G r g C d d A A G D y C V / A O P o w n 4 8 3 4 N L 6 m p U t G 3 n M A / s D 4 / g F d V L G h < / l a t e x i t > r w1 L < l a t e x i t s h a 1 _ b a s e 6 4 = " e a + n 2 g V A J 0 a L z + o w s C 8 b m 7 D m / B 0

y u D 7 FFig. 4 .
Fig. 4. (Left) Overview of our distributed natural gradient descent (a single step of training).(Right) Illustrations of AllReduce, ReduceScatterV, and AllGatherV collective.Different colors correspond to data (and its communication) from different data sources.

Fig. 6 .
Fig. 6.The communication amount (bytes) for the statistics (A −1 , G , F ,unitBN ) in each step in training ResNet-50 on ImageNet with BS={4K,8K,16K,32K} (stacked graph -the amount for G/F is stacked on the amount for A).A and G/F correspond to the communication amount for A −1 and G /F ,unitBN , respectively.The reduction rate (smaller is better) of the communication amount for all the statistics throughout the training is shown with the percentage (%).

Kazuki
Osawa received his BS and MS from Tokyo Institute of Technology in 2016 and 2018, respectively.He is currently a PhD candidate at Tokyo Institute of Technology and a Research Fellow of Japan Society for the Promotion of Science (JSPS).His research interests include optimization, approximate Bayesian inference, and distributed computing for deep learning.He is a student member of IEEE.Yohei Tsuji received his BS and MS from Tokyo Institute of Technology in 2017 and 2019, respectively.He is currently a PhD student at Tokyo Institute of Technology.His research interests include high performance computing for machine learning, probabilistic programming.Yuichiro Ueno received his BS from Tokyo Institute of Technology in 2019.He is currently a master course student at Tokyo Institute of Technology.His research interests include a range of high-performance computing, such as GPU computing and networking, and its application to deep learning.He is a student member of ACM.Akira Naruse is a senior developer technology engineer at NVIDIA.He holds the MS degree in computer science from Nagoya University.Prior to joining NVIDIA, he was a research engineer at Fujitsu Laboratory and was involved in various high performance computing projects.His main interest is the performance analysis and optimization of scientific computing and deep learning applications on very large systems.Chuan-Sheng Foo is a Scientist at the Institute for Infocomm Research, A*STAR.He received his BS, MS and PhD from Stanford University.His research focuses on developing deep learning algorithms that can learn from less labeled data, inspired by applications in healthcare and medicine.Rio Yokota received his BS, MS, and PhD from Keio University in 2003, 2005, and 2009, respectively.He is currently an Associate Professor at GSIC, Tokyo Institute of Technology.His research interests range from high performance computing, hierarchical low-rank approximation methods, and scalable deep learning.He was part of the team that won the Gordon Bell prize for price/performance in 2009.

TABLE 1
Training time and top-1 single-crop validation accuracy of ResNet-50 for ImageNet reported by related work and this work.

Algorithm 2 :
Estimate the acceptable interval until next refresh based on the staleness of statistics.
input : current statistics X input : last statistics X −1 input : statistics before the last X −2

TABLE 2
operated by the National Institute of Advanced Industrial Science and Technology (AIST) in Japan.ABCI has 1088 nodes with four NVIDIA Tesla V100 GPUs per node.Due to the additional memory required by NGD, all of our experiments use a mini-batch size of 32 per GPU.We were only given a 24 hour window to use the full machine so we had to tune the hyperparameters on a smaller number of nodes while mimicking the global minibatch size of the full node run.For large mini-batch size experiments which cannot be executed directly (BS=65K, 131K requires 2048, 4096 GPUs, respectively) ,we used an accumulation method to mimic the behavior by accumulating the statistics A −1 , G , F ,unitBN , and ∇ w L over multiple steps.The hyperparameters of the training with large mini-batch size (BS) used for our schemes in Section 6 and top-1 single-crop validation accuracy of ResNet-50 for ImageNet.reductionandspeedup correspond to the reduction rate of the communication amount and the speedup comes from that, respectively, for emp+unitBN+stale compare to emp+unitBN in Figure5.