Challenges for Training Generative Model from Different Tabular Data Sources
Generative Adversarial Networks (GANs) are an emerging methodology to synthesize data, ranging from images, to text, and to tables. The key components of GANs are training two competing neural networks, i.e., generator and discriminator, where the former iteratively generates synthetic data and the latter judges its quality. During the training process, the discriminator needs to access the original data and provide feedback to the generator by comparing it with the generated data. However, such a privilege of direct data access may no longer be taken for granted due to the ever increasing concern for data privacy. For instance, training a medical image generator from multiple hospitals refrains from centralized data processing and calls for decentralized and privacy-preserving learning solutions. In response to such a demand, the federated learning (FL) paradigm emerges. FL features decentralize local processing, under which machine learning (ML) models can first be trained on clients’ local data in parallel and subsequently be securely aggregated by the federator. As such, the local data is not directly accessed, except by the owner, and only intermediate model data is shared. The key design choices of constructing a FL framework for GANs depends on how to effectively distribute the training of generator and discriminator networks across data sources. On the one hand, discriminators are typically located on clients’ premises due to the need of processing the client’s data. On the other hand, the prior art explores a disparate trend of training image generators: centrally at the server or locally at the clients. While tabular data is the most dominant data type in industries, there is no prior study on training GANs for tabular data under the FL paradigm.
Training of state-of-the-art tabular GANs, e.g., CTGAN, from decentralized data sources in a privacy preserving manner present multiple additional challenges as compared to image GANs. They are closely related to how current tabular GANs explicitly model each column, be it continuous or categorical variables, via data-dependent coding schemes and statistical distributions. For example in CTGAN, it uses Variational Gaussian Mixture (VGM) to estimate the continuous column and encode the values according to the estimated gaussian mixture.
Hence, the first challenge is to unify the encoding schemes across data sources that are non-identically independently distributed (Non-IID), and in a privacy preserving manner (e.g., If a server can access data from all clients regardless the privacy issue, unifying encoding schemes can be implemented directly in server side and distribute to all clients). Unify the encoding scheme for categorical column is easier since CTGAN uses one-hot encoding, server can simply collect all the unique classes from each client, then creates a global label encoder for all the categorical columns. But for continuous columns, since server can not access to clients’ data, and each client can estimate a different gaussian mixture only using local data, then we can not initialize a global encoding scheme for continuous column.
Secondly, the convergence speed of GANs critically depends on how to merge local models. For image GANs, the merging weights are determined jointly by the data quantity and the (dis)similarity of class distribution across clients. Beyond that, tabular GANs need to consider a more fine-grained (dis)similarity mechanism for deciding merging weights, i.e., differences in every column across clients.
Now that we have clarified the challenges above, let’s see an example of using a biased global encoding scheme for continuous columns
Above figure shows four subplots, (1) is the original distribution of the column age in Adult dataset. And in (2), we assume that in the server, it possesses 1% of the original data. Then in that case, we do not seek to build a global encoder for age column for all the clients. We just use this 1% data in server to initialize one with VGM and distribute it to all the clients. For (3) and (4), we design an experiment which is a regular federated learning (FL) with 5 clients. Each of client contains a CTGAN model locally. If we use the biased encoder using 1% data in server. In the end of FL training, the generation of column age is like in (3). And if we use 100% data to initialize the encoder for age column. The same FL training can generate the age column like in (4). It’s clear that a bad encoder can ruin the training of CTGAN in distributed framework.
Our team has came up with some solutions to make CTGAN distributed-training compatible. I will introduce them in next blog.