WAFFLe: Weight Anonymized Factorization for Federated Learning

In domains where data are sensitive or private, there is great value in methods that can learn in a distributed manner without the data ever leaving the local devices. In light of this need, federated learning has emerged as a popular training paradigm. However, many federated learning approaches trade transmitting data for communicating updated weight parameters for each local device. Therefore, a successful breach that would have otherwise directly compromised the data instead grants whitebox access to the local model, which opens the door to a number of attacks, including exposing the very data federated learning seeks to protect. Additionally, in distributed scenarios, individual client devices commonly exhibit high statistical heterogeneity. Many common federated approaches learn a single global model; while this may do well on average, performance degrades when the i.i.d. assumption is violated, underfitting individuals further from the mean, and raising questions of fairness. To address these issues, we propose Weight Anonymized Factorization for Federated Learning (WAFFLe), an approach that combines the Indian Buffet Process with a shared dictionary of weight factors for neural networks. Experiments on MNIST, FashionMNIST, and CIFAR-10 demonstrate WAFFLe's significant improvement to local test performance and fairness while simultaneously providing an extra layer of security.


I. INTRODUCTION
With the rise of the Internet of Things (IoT), the proliferation of smart phones, and the digitization of records, modern systems generate increasingly large quantities of data. These data provide rich information about each individual, opening the door to highly personalized intelligent applications, but this knowledge can also be sensitive: images of faces, typing histories, medical records, and survey responses are all examples of data that should be kept private. Federated learning [1] has been proposed as a possible solution to this problem. By keeping user data on each local client device and only sharing model updates with the global server, federated learning represents a possible strategy for training machine learning models on heterogeneous, distributed networks in a privacy-preserving manner. While demonstrating promise in such a paradigm, a number of challenges remain for federated learning [2].
As with centralized distributed learning settings [3], many federated learning algorithms focus on learning a single global model. However, due to variation in user characteristics, personal data are likely to exhibit significant statistical heterogeneity. To simulate this, federated learning algorithms are commonly tested in non-i.i.d. settings [1,4,5,6], but data are often equally represented across clients and ultimately a single global model is typically learned. As is usually the case for one-size-fits-all solutions, while the model may perform acceptably on average for many users, some clients may see poor performance. Questions of fairness [7,8] may arise if performance is compromised for individuals in the minority in favor of the majority.
Another challenge for federated learning is security. Data privacy is the primary motivation for keeping user data local on each device, rather than gathering it in a centralized location for training. In traditional distributed learning systems, data are exposed to additional vulnerabilities while being transmitted to and while residing in the central data repository. In lieu of the data, many federated learning approaches require clients to send weight updates to train the aggre-gated model. However, the threat of membership inference attacks [9,10] or model inversion [11,12] mean that private data on each device can still be compromised if federated learning updates are intercepted or if the central server is breached.
We propose Weight Anonymized Factorization for Federated Learning (WAFFLe), leveraging Bayesian nonparametrics and neural network weight factorization to address these issues. We make the following contributions:i) Rather than learning a single global model, we learn a dictionary of rank-1 weight factor matrices. By selecting and weighting these factors, each local device can have a model customized to its unique data distribution, while sharing the learning burden of the weight factors across devices. ii)We employ the Indian Buffet Process [13] as a prior to encourage factor sparsity and reuse of factors, performing variational inference to infer the distribution of factors for each client. iii)While updates to the dictionary of factors are transmitted to the server, the distribution capturing which factors a client uses are kept local. This adds an extra insulating layer of security by obfuscating which factors a client is using, hindering an adversary's ability to perform membership inference attacks or dataset reconstruction.iv)Finally, individually customized models represent in more fairness.
We perform experiments on MNIST [14], FMNIST [15], and CIFAR-10 [16] in settings exhibiting strong statistical heterogeneity. We observe that the model customization central to WAFFLe's design leads to higher performance for each client's local distribution, while also being significantly fairer across all clients. Finally, we perform membership inference [9] and model inversion [11] attacks on WAFFLe, showing that it is much harder to expose user data than with FedAvg [1].

A. SHARED DICTIONARY OF WEIGHT FACTORS
Single Global Model Consider N client devices, with the i th device having data distribution D i , which may differ as a function of i. In many distributed learning settings, a single global model is learned and deployed to all N clients. Thus, assuming a multilayer perceptron (MLP) architecture 1 with layers ℓ = 1, ..., L, the set of weights θ = {W ℓ } L ℓ=1 is shared across all clients. To satisfy the global objective, θ is learned to minimize the loss on average across all clients. This is the approach of many federated learning approaches. For example, FedAvg [1] minimizes the following objective: where L i (θ) := E xi∼Di [l i (x i ; θ)] is the local objective function, N is the number of clients, and p i ≥ 0 is the weight of each device i. However, given statistical heterogeneity, such a one-size-fits-all approach may lead to the global model underfitting on certain clients; often this translates to how close a particular client's local distribution is to the population distribution. As a result, this model may be viewed as less fair to these clients with less common traits. Individual Local Models On the other extreme, we may alternatively consider learning N local models θ i = {W ℓ i } L ℓ=1 , each only trained on D i . In this case, each set of weights θ i is maximally specific to the data distribution of each client i. However, each client typically has limited data, which may be insufficient for training a full model without overfitting; the total number of parameters that must be learned across all clients scales with N . Additionally, learning N separate models does not leverage similarities between client data distributions or the shared learning task.
Shared Weight Factors To make more efficient use of data, we instead propose a compromise between a single global model and N individual local models. Specifically, we allow each client's model to be personalized to the client's local distribution, but with all models sharing a dictionary of jointly learned components. Using a layer-wise decomposition [17], we construct each weight matrix with the following factorization: where W ℓ a ∈ R J×F and W ℓ b ∈ R F ×M are global parameters shared across clients and λ ℓ i ∈ R F is a client-specific vector. Note that the construction is not a post-training processing such as singular value decomposition(SVD) on trained weights but a parameters format before training. This factorization can be equivalently expressed as where w ℓ a,k is the k th column of W ℓ a , w ℓ b,k is the k th row of W ℓ b , and ⊗ represents an outer product. Written in this way, the interpretation of the corresponding pairs of columns and rows w ℓ a,k and w ℓ b,k as weight factors is more apparent: W ℓ a and W ℓ b together comprise a global dictionary of the weight factors, and λ ℓ i can be viewed as the factor scores of client i used to select the corresponding rank-1 matrices formed using weight factors. Differences in λ ℓ i between clients allows for customization of the model to each client's data distribution (see Figure 1), while sharing of the underlying factors W ℓ a and W ℓ b enables learning from the data of all clients.
We constitute each of the client's factor scores λ ℓ i as the element-wise product: where r ℓ ∈ R F indicates the strength of each factor and b ℓ i ∈ {0, 1} F is a binary vector indicating the active factors. As explained below, b ℓ i is typically sparse, so in general each client only uses a small subset of the available weight factors. Throughout this work, we use the absence of the ℓ superscript Each client uses a sparse diagonal matrix Λ ℓ i , specifying the combination of weight factors that constitute its own personalized model. Neither the client data D i nor factor selections Λ ℓ i leave the local device.

Algorithm 1 Updating Scheme in Each Communication
Input: local training epochs E, learning rate η Server randomly selects subset St of clients Server sends {Wa, r, for e = 1, · · · , E do for minibatch b ∈ Di do Update {Wa, r, W b , πi, ci, di} by minimizing (11) end for end for return Wa, r, W b , πi, ci, di end function (e.g., λ i ) to refer to the entire collection across all layers for which this factorization is done. We learn a point-estimate for W a , W b and r.

B. THE INDIAN BUFFET PROCESS
Desiderata Within the context of federated learning with statistical heterogeneity, there are a number of desirable properties we wish the client factor scores to have collectively. Firstly, λ i should be sparse, which encourages consolidation of related knowledge while minimizing interference: client A should be able to update the global factors during training without destroying client B's ability to perform its own task. This encourages fairness, as in settings with multiple subpopulations, this interference is most likely to be at the smaller groups' expense. On the other hand, we would also like factors to be reused among clients. While data may be non-i.i.d. across clients, there are often some similarities; thus, shared factors distribute learning across all clients' data, avoiding the N independent model's scenario. Finally, in the distributed settings considered in federated learning, the total number of nodes is rarely pre-defined. Therefore, there needs to be a way to gracefully expand to accommodate new clients to the system without re-initializing the whole model. This includes both increasing server-side capacity if necessary and initializing new clients.
Prior Given these desiderata, the Indian Buffet Process (IBP) [13] is a natural choice. As a prior, the IBP regularizes client factors to be sparse, and new factors are introduced but at a harmonic rate, preferring reusing factors as much as possible over initializing new ones. This Bayesian nonparametric approach allows the data to dictate client factor assignment, factor reuse, and server-side model expansion. We use the stick-breaking construction of the IBP [18] as a prior for the factor selection: where k indexes the factor, π ℓ i,k denotes the probability of the k th factor being active, and α is a hyperparameter controlling the expected number of active factors and the rate of new factors being incorporated. Note that in the stick-breaking construction, π ℓ i,k is generated using a cumulative product of Beta random variables (v ℓ i,κ ). Inference We learn the posterior distribution for the random variables ϕ i = {b i , v i }. Exact inference of the posterior is intractable, so we employ variational inference with meanfield approximation to determine the active factors for each client device, using the variational distributions: learning the variational parameters {π i , c i , d i } for each queried client using Bayes by Backprop [19]. Needing a differentiable parameterization, we use the Kumaraswamy distribution [20] as a replacement for the Beta distribution of v i and utilize a soft relaxation of the Bernouilli distribution [21]. The objective for each client is to maximize the variational lower bound: (11) the first term provides label supervision and the second term (R) regularizes the posterior with the IBP prior. The KL divergence in R is approximated by sampling from the posterior distribution.

C. CLIENT-SERVER COMMUNICATION
Training Before the training begins, the global weight factors {W a , W b } and the factor strengths r are initialized by the server. Once initialized, each training round begins with {W a , W b , r} being sent to the selected subset of clients. Each sampled client then trains the model on their own private dataset D i for E epochs, updating not only the weight factor dictionary {W a , W b } and the factor strengths r, but also its also own variational parameters {π i , c i , d i }, which controls which factors it uses. Once local training is finished, each client sends {W a , W b , r} back to the server, but not {π i , c i , d i }, which remain with the client with data D i . After the server has received back updates from all clients, the various new values for {W a , W b , r} are aggregated with a simple averaging step. The process then repeats, with the server selecting a new subset of clients to query, sending the new updated set of global parameters, until the desired number of communication rounds have passed. This process is summarized in Algorithm 1 and 2.
Evaluation When a client enters the evaluation mode, it requests the current version of global parameters {W a , W b , r} from the server. If the client has been previously queried for federated training, the local model consists of the aggregated global parameters and the factor score vector generated by its own local variational parameters {π i }. Otherwise, the client uses only the aggregated {W a , W b , r}. Note that if a client has been previously queried, the most recently cached copy of the global parameters is an option if a network connection is unavailable or too expensive; in our experiments, we assume clients are able to request the most up-to-date parameters.
Security Data security is one of the central tenets of federated learning. Simpler, more standard methods of training a model could be utilized if all data were first aggregated at a central server. However, sensitive client data being intercepted during transmission or the server's data repository being breached by an attacker are major concerns, motivating federated learning's approach of keeping the data on the local device. On the other hand, keeping the data client-side may not be sufficient. Just as data can be compromised in transit or at the central database in non-federated settings, federated training updates are similarly vulnerable. In methods like FedAvg, this update is the entirety of the model's parameters. Effectively, this means that FedAvg trades yielding the data immediately for surrendering whitebox access to the model, which opens the model to a wide range of malicious activities [22,11,9,23,12], including, critically, exposing the very data that federated learning aims to protect. With WAFFLe, clients transmit back the entire dictionary of weight factors {W a , W b } and r, but not {π i , c i , d i }. As such, the knowledge of which specific factors that a particular client uses is kept local. Therefore, even if messages are intercepted, an adversary cannot completely reconstruct the model, hampering their ability to perform attacks to recover the data.

III. RELATED WORK A. STATISTICAL HETEROGENEITY
Statistical heterogeneity of the data distributions of client devices has long been recognized as a challenge for federated learning. Despite acknowledging statistical heterogeneity, many federated learning algorithms focus on learning a single global model [1]; such an approach often suffers from model divergence, as local models may vary significantly from each other. To address this, a number of works break away from the single-global-model formulation. Several [4,24] have cast federated learning as a multi-task learning problem, with each client treated as a separate task. FedProx [25] adds a proximal term to account for statistical heterogeneity by limiting the impact of local updates. Others study federated learning within a model-agnostic meta-learning framework [26,27]. [28] recognize performance degradation from non-i.i.d. data and propose global sharing of a small subset of data, which while effective, may compromise privacy. In settings of high statistical heterogeneity, fairness is also a natural question. AFL [7] and q-FFL [8] propose methods of focusing the optimization objective on the clients with the worst performance, though they do not change the network itself to model different data distributions.

B. PRESERVING DATA SAFETY
While much progress has been made in machine learning with public datasets [14,16,29], in real-world settings, data are often sensitive, potentially for propriety [30,31], security [32], or privacy [33] reasons. Protecting user data is one of the primary motivations for federated learning in the first place. Approaches include releasing artificial data [34,35], homomorphic encryption [36], or differential privacy [37,38,39]. However, artificial data can still strongly resemble the original data, and sharing the model architecture and its parameters presents risks associated with whitebox access, leaving the data vulnerable to attacks such as membership inference [9] or model inversion [11,23,12].

C. BAYESIAN NONPARAMETRIC FEDERATED LEARNING
Several previous works have applied Bayesian nonparameterics to federated learning, primarily as a means for parameter matching during aggregation. Instead of averaging the parameters weight-wise without considering the meaning of each parameter, past works have proposed using the Beta-Bernouilli Process [40] for matching parameters, first with fully connected layers [41], but later also extended by [42] to convolutions and LSTMs [43]. In contrast, our method utilizes Bayesian nonparametrics for modeling rank-1 factors for multitask learning, instead of the aggregation stage.

D. PERSONALIZED FEDERATED LEARNING
Personalized FL models has become a recent focus. One approach is to mix the global and local model parameters during optimization [44]. However, this requires meta-features from each client, which partially violates the goal of privacy in FL. Another commonly used strategy is splitting neural networks [45,46]: the model is divided into two parts, the feature extractor and the personalized layers. The feature extractor is aggregated and shared by the server, and both parts are trained for a personalized model. Recent work also explore meta-learning, particularly model-agnostic metalearning (MAML) [47]. For example, Per-FedAvg [48] builds a meta-model initialization that is then updated by a gradient step for a personalized model. However, meta-optimization often requires computing second-order derivatives, which can be computationally prohibitive for FL. pFedMe [49] proposes decoupling the process of optimizing personalized models from learning the global model. pFedMe keeps the learning process of FedAvg while optimizing the personalized model in parallel, showing performance superiority over Per-FedAvg [48].

1) Statistical Heterogeneity
Settings with higher statistical heterogeneity are more challenging for federated learning than when data are i.i.d. across clients, as well as more representative of the real-world, so we focus our experiments on the former. We consider two forms of statistical heterogeneity.
Unimodal non-i.i.d. We first consider the non-i.i.d. setting introduced by [1]. This is a widely used evaluation setting, commonly referred to as "non-i.i.d." or "heterogeneous" in other federated learning works, to distinguish it from completely i.i.d. data splits. We refer to this as unimodal non-i.i.d. to distinguish it from our second setting, which is also non-i.i.d. The primary purpose of such a partition is to investigate the behavior of federated average algorithms when each client has data from only a subset (Z) of classes.
This type of partition begins by sorting all data by class. Given N client devices, the samples from each class are evenly divided into shards of data, each consisting of a single class, resulting in N Z shards across all classes. These shards are then randomly distributed to the N clients such that each receives Z shards. The data in the Z shards for each client is then shuffled together and split into a local training and test set. This ensures that the local test set for each client is representative of its own private data distribution.While this setting can be challenging, it has the property that the classes present in every client's data is equally represented in the global data distribution.As a result, a single global model may perform reasonably uniformly across all clients.
Multimodal non-i.i.d. While the unimodal non-i.i.d. partition does explore the non-i.i.d. nature of class distribution among clients, it does not adequately characterize the tendency for subpopulations to exist, with some being more prevalent than others. We propose a new non-i.i.d. setting to capture this, which we call multimodal non-i.i.d., as each subpopulation group can be thought of as a mode of the overall distribution. In the real world, the mode can correspond to age, gender, ethnicity, wealth, or a number of other demographic factors. The number of subpopulation group is arbitrary, but we choose two for simplicity, creating "majority" and "minority" subpopulations. In our experiments, the two modes are odd digits (N 1 = 100) versus even digits (N 2 = 20) for MNIST [14], footwear and shirts (N 2 = 20) versus everything else (N 1 = 90) for FMNIST [15], and animals (N 1 = 90) versus vehicles (N 2 = 20) for CIFAR-10 [16], where N 1 and N 2 are the number of clients in the majority and minority subpopulations, respectively.
Once the classes have been separated by group, the process proceeds similarly to the unimodal i.i.d. partition process, with the data being divided into shards and then randomly allocated to clients within each subpopulation. We make the shards equal in size both within and across modes, so in instances where there are more data shards available than there are clients, we discard the unallocated data. Just as for unimodal non-i.i.d., local training and test sets are created for each client from its allocated data. An example multimodal non-i.i.d for MNIST is shown in Figure 2. Compared with unimodal non-i.i.d, the difference is that there is now a 5 : 1 ratio of odd to even digits in the total population, resulting in the clients with only even digits being in the minority of the global population. Note that multimodal non-i.i.d setting is not a distributed imbalanced classification setting. In multimodal non-i.i.d setting, the data amount of each class is even if aggregated while imbalanced classification setting is not. The hidden subpopulations lead to unfair model performance caused by the uniform sampling of clients.

2) Model Architecture and Training Setting
For MNIST [14] digit recognition, we use a multilayer perceptron with 1-hidden layer with 200 units using ReLU activations [50]. Based on this model, we constructed WAFFLe with F = 120 factors.  Notice that the primary difference is the grouping of the data into two subpopulations (here referred to as "Majority" and "Minority") before sharding and allocating Z shards to each client.
are partitioned into local training and test sets as described in Section IV-A. Stochastic gradient descent (SGD) with learning rate η = 0.04 is employed for all methods. For FMNIST [15] fashion recognition, we use a convolutional network consisting of two 5×5 convolution layers with 16 and 32 output channels respectively. Each convolution layer is followed by a 2 × 2 maxpooling operation with ReLU activations. A fully connected layer with a softmax is added for the output. Based on this model, we construct WAFFLe by only factorizing the convolution layers, with F = 25 factors. As with MNIST, the traditional 60K training examples are used to form the two local sets. SGD with learning rate η = 0.02 is used as the optimizer for all methods.
For CIFAR-10 [16], we use we use a convolutional network consisting of two 3 × 3 convolution layers with 16 and 16 output channels respectively. Each convolution layer is followed by a 2 × 2 maxpooling operation with ReLU activations. These two convolutions are followed by two fully-connected layers with hidden size 80 and 60, with a softmax applied for the final output probabilities. To construct WAFFLe, we set the number of factors F = 10 for the two convolution layers, F = 80 for the first fully connected layer, and F = 40 for the second fully connected layer. The 50K training examples are used for constructing the local train and test sets. SGD with learning rate η = 0.02 is utilized for all methods.

B. LOCAL TEST PERFORMANCE
We compare WAFFLe with FedAvg [1], the fairness-oriented q-FFL [8], and FedProx [25], which augments FedAvg with a proximal term designed for high statistical heterogeneity. We record local test performance averaged across all clients for both types of non-i.i.d. data allocation in Table 2, along with the total number of learnable parameters. WAFFLe performs well despite strong statistical heterogeneity, as each client can learn a personalized model by selecting different factors from {W a , W b }; having a model specific to each data distribution results in higher local test performance than the baselines. This advantage is especially apparent when the data are distributed multimodal non-i.i.d., mainly because WAFFLe more effectively models underrepresented clients. Interestingly, we find that WAFFLe outperforms the baselines particularly significantly for CIFAR-10, the most challenging of the tested datasets, with WAFFLe's local test performance outstripping the other methods by 18.8% and 20.9% for unimodal and multimodal settings, respectively. This demonstrates WAFFLe's ability to scale to complex tasks beyond MNIST, a common federated learning test bed. Notably, even though WAFFLe effectively learns a different model for each client, this does not lead to the computation or memory costs typically associated with independent models. WAFFLe's number of communication rounds is largely the same, and by sharing rank-1 factors, each weight factor can be represented compactly, resulting in a total number of parameters that is fewer than the single model used by the baselines, despite using the same architecture.

C. TRAINING EFFICIENCY COMPARISON
We plot local test accuracy against the global epoch for FedAvg, FedProx and WAFFLe on MNIST, FMNIST, and CIFAR-10 averaged over all clients for unimodal non-i.i.d. data in Figure 4. A similar comparison is made between Fe-dAvg and WAFFLe for multimodal non-i.i.d. data in Figure 5, with the majority and minority learning curves separately shown. For both cases, the clear gap between curves shows that WAFFLe achieves better performance throughout training. Notably, WAFFLe converges at a similar rate as FedAvg with respect to the global epoch number; this is important as the number of communication rounds is often considered one of the primary bottlenecks in federated learning.
In the multimodal non-i.i.d. case, the difference is especially stark for the minority subpopulation, which lags significantly behind the majority when modeled with Fe-dAvg's one-size-fits-all approach. Interesting, in addition to having lower value, the FedAvg minority's training curve is not as smooth, with large dips and spikes, especially when compared with the majority subpopulation's curve. We hypothesize that this may be due to the smaller subpopulation being more vulnerable to being unrepresented during client sampling, which may lead to catastrophic forgetting [51]. We find this to be an interesting future direction of research. In comparison, the WAFFLe minority, with its separate set of customized weight factors, has a much smoother training trajectory.

D. FAIRNESS
Average performance over all clients as in Table 2 is a commonly reported metric, but we argue that it does reveal the full story. We report subpopulation mean performance and overall population variance in Table 1. We observe that FedAvg, which learns a single global model, focuses on minimizing mean error across the population, resulting in stronger performance for the clients in the majority. However, as a result, clients in the minority are severely compromised, as evidenced by the large difference ("Gap") between majority (Major.) and minority (Minor.) values in Table 1; for example, FedAvg's performance for the "evens" group of clients is almost 30% lower than that of the "odds" group. This is gap is especially clear when visualizing the distributions of final local test performance for each client in the majority and minority groups (Figure 3). This underfitting can also be seen to exist throughout training from the "FedAvg_Minority" curve in Figure 5, which lags far below the "FedAvg_Majority" in all three datasets. On the other hand, because of WAFFLe's shared weight factor dictionary design (Equation 3), different knowledge can be encoded in separate weight factors, which can be used by different parts of the population. As a result, despite certain classes being underrepresented (both in terms of clients, and total samples) in the training set, WAFFLe is able to successfully model them, with performances on par with the overall population. Notably, we achieve this without explicitly enforcing fairness through client sampling during training [7,8], which can be incorporated to further encourage uniform performance across clients.

E. DATA SAFETY
A primary objective of federated learning is to keep data safe. However, as mentioned in Section II-C, the predominant federated learning strategy of each client sending their entire updated model's weights still leaves the client's data vulner-  Membership inference attacks (MIAs) [9,10] can be used to infer whether a given data query was used for model training, leveraging the tendency of machine learning to overfit or memorize training data. As such, a successful MIA can be used by an attacker to surmise the content of a client's private data from the model. We compare a LeNet [14] FedAvg [1] model with an analogous WAFFLe model, training both on 1000 CIFAR-10 samples per client. We attack both with a MIA inspired by [9], using a small ensemble of 3 "shadow" models. As shown in Table 3, this simple attack achieves a high success rate at identifying a FedAvg client's training data, as intercepting the training update gives the full model. On the other hand, WAFFLe's training update only send partial model information, as the identity of the active factors is kept private. As a result, MIA success rate on WAFFLe is only moderately higher than random chance (50%). This means it is significantly harder to identify the private training data for WAFFLe, relative to FedAvg.
We also perform a model inversion attack [11,23] on both FedAvg and WAFFLe. Unlike MIAs, which must start from a query data input, model inversion attacks seek to reconstruct the inputs used to train a model from the trained model itself; successful inversion attacks pose a significant risk from a data security perspective. We perform a model inversion attack on FedAvg and WAFFLe models trained on FMNIST, showing randomly selected results in Figure 6 recovered from an individual user. Importantly, reconstructions on FedAvg are  significantly sharper, with the class identity far clearer than for WAFFLe, meaning FedAvg is more vulnerable to model inversion attacks. We report two quantitative metrics [52] to evaluate model inversion attack in Table 4. i) Peak Signal-to-Noise Ratio (PSNR) is the ratio of an image's maximum squared pixel fluctuation over the mean squared error between the target image and the reconstructed image. The higher the PSNR, the better the quality of the reconstructed image. However, clear reconstruction images reveal the identity information of the client's data. For each class, for example T-shirt, we compute the PSNR between the reconstructed T-shirt and the average image of randomly selected T-shirt from the training data. The average PSNR of all classes of FedAvg and WAFFLe is reported. ii) Attack Accuracy (Attack Acc) is the accuracy of the input reconstructed image by an evaluation classifier that is trained separately. If the evaluation classifier achieves high accuracy, the reconstructed image is considered to expose identity information about the target label. We obtain an evaluation classifier with accuracy 96.67%. This evaluation classifier is used to classify the images reconstructed by Fe-dAvg and WAFFLe. The average attack accuracy is reported. For both PSNR and attack accuracy, lower values indicate more secure method. In Table 4, WAFFLe shows superior performance over FedAvg on both PSNR and attack accuracy, proofing more secure against model-inversion attack.

F. PERSONALIZATION
We further conduct experiments to compare against two personalized FL methods FedPer [46] and pFedMe [49] based on CIFAR-10 under both unimodal and multimodal settings for Z = 2. The local test performance is reported in Table 5. WAFFLe outperforms FedPer by offering personalization for multiple layers while FedPer only focuses on the last layer of the neural networks. Also, WAFFLe and FedPer outperforms pFedMe by 9.19% and 6.96% under multimodal setting respectively, highlighting that the methods based on WAFFLe is specifically designed for statistical heterogeneity, as each client can select different weight factors, effectively learning personalized models. WAFFLe was shown to excel when Z = 2, as this is a strongly non-i.i.d. setting: as each client only has samples from two classes.We also did experiments in unimodal settings with less statistical heterogeneity, for Z = {3, 4}. Although it takes longer to converge in these cases, WAFFLe still outperforms FedAvg by 7.20% and 2.74%, respectively. The learning curve comparison when Z = 3 is shown in the Figure 7. Note that FedProx is based on FedAvg combining a L-2 norm on the local weights during training. The difference between FedProx and FedAvg is highly dependent on the parameter for the proximal term. We follow the setting in the previous work [25] and set the parameter as 0.2 which is not significant enough to differentiate FedProx and FedAvg in extremely non-i.i.d setting. Training client devices for more local epochs allows each server to collect a bigger update from each device, increasing local computation in exchange for fewer total communication rounds. This is often a desirable trade-off, as communication costs are commonly viewed as the primary bottleneck for federated learning. However, too many local epochs can lead to divergence during the aggregation step. We study the influence of local epochs E for unimodal non-i.i.d. in Table 6 and for multimodal non-i.i.d. in Table 7 using the same settings as in Section IV-A except for reducing the global training epochs T to 50 and the learning rate η to 0.02 for all methods in multimodal non-i.i.d scenario. We observe VOLUME 4, 2016  that WAFFLe can handle increased number of local epochs, improving performance for all three datasets.

3) IBP Sparsity (α) and Number of Factors (F )
At the cost of more parameters, an increasing number factors F and higher IBP parameter α gives client more expressivity for modeling its local distribution.
We study the influence of α and F for an MLP architecture on MNIST partitioned in multimodal non-i.i.d. settings in Tables 9 As expected, the higher α and F are, the better performance we observe, though in practice we prefer lower α and F for efficiency. On the other hand, the overall difference in local test accuracy does not vary drastically, meaning that WAFFLe is fairly robust to both hyperparameters.
To empirically demonstrate the data-driven sparsity introduced by IBP prior, we also considered an alternative non-Bayesian version of our model. Specifically, we replace WAFFLe's inferred per-client weight factors with per-client weight factors optimized by standard gradient descent, and use an L1 sparsity constraint on factor usage as a replacement for the sparsity induced by the IBP. We list the test accuracy under Unimodal on MNIST in the Table 8.Note that our Bayesian formulation (WAFFLe) outperforms the non-Bayesian version while also avoiding the hyperparameter tuning of the weight of the sparsity term, which the non-Bayesian version is somewhat sensitive to.

V. CONCLUSION
We have introduced WAFFLe, a Bayesian nonparametric framework for federated learning, employing shared rank-1 weight factors. This approach allows for learning individual models for each client's specific data distributions while still sharing the underlying learning problem in a parameterefficient manner. Our experiments demonstrate that this model customizability makes WAFFLe successful at improving local test performance and, more importantly, signifi-  cantly improves fairness in model performance when the data distribution among clients is multimodal. Furthermore, we are able to scale our results to CIFAR-10 and convolutional networks, where we observe the biggest improvements. We also show that by keeping the active factors selected by each model private on each device along with the data, WAFFLe's communication rounds only send partial model information, making it significantly harder to perform membership inference or gradient-based model inversion attacks on the private data. .

APPENDIX A GENERALIZING WEIGHT FACTORIZATION TO CONVOLUTIONAL KERNELS
While introducing WAFFLe's formulation in Section II-A, we assumed a multilayer perceptron (MLP) model, as illustrating our proposed shared dictionary with the 2D weight matrices composing fully connected layers is made especially clearer. While MLPs are sufficient for simple datasets such as MNIST, more challenging datasets require more complex architectures to achieve the most competitive results. For computer vision, for example, this often means convolutional layers, whose kernels are 4D. While 4D tensors can be similarly decomposed into rank-1 factors with tensor rank decomposition, such an approach would result in a large increase in the number of parameters in the weight factor dictionary due to the low spatial dimensions of the convolutional kernels (e.g., 3 × 3) in most commonly used architectures. Instead, we reshape the 4D convolutional kernels into 2D matrices by combining the three input dimensions (number of input channels, kernel width, and kernel height) into a single input dimension. We then proceed with the formulation in (2). Similar approaches can be taken to generalize our formulation to other types of layers.