Representation learning with CGAN for casual inference

Conditional Generative Adversarial Nets (CGAN) is often used to improve conditional image generation performance. However, there is little research on Representation learning with CGAN for causal inference. This paper proposes a new method for finding representation learning functions by adopting the adversarial idea. We apply the pattern of CGAN and theoretically emonstrate the feasibility of finding a suitable representation function in the context of two distributions being balanced. The theoretical result shows that when two distributions are balanced, the ideal representation function can be found and thus can be used to further research.


Introduction
The causal inference has many significant applications in real life.Can a drug work for a specific patient population?Does the promulgation of a policy have an impact on leadership approval ratings?Will online classes have an impact on student's academic performance?These are all questions of causal inference [1].It is difficult to answer the above questions only by relying on data, and we can only answer the questions of relevance by relying on data.Causal inference refers to a task considering the assumptions, estimation strategies, and study designs that are useful to draw conclusions based on data.Causal inference depends on the consideration of counterfactual states.Mainly, it considers the outcomes that could manifest given exposure to each set of treatment conditions [2].The process of causal inference always contains two steps: the representation step and the prediction step.The representation step is to make the distribution between the factual distribution and counterfactual distribution similar, and we will explain the reason in the next chapter.The prediction step is to predict the potential outcome based on the input x and condition t.This paper further introduces a conditional GAN [3] model to the traditional causal inference representation learning model to obtain a better representation function to balance the data distribution, improving the prediction step following.The paper is organized as follows: Chapter 2 introduces our work, Chapter 3 and Chapter 4 provide information about Generative Adversarial Network and Conditional Generative Adversarial Network, respectively, Chapter 5 illustrates the advantage of our work and Chapter 6 gives the conclusions and the future work.

Our work
2.1.

Purpose
The main problem faced in causal inference is the imbalance of the data distribution.We make the following definition: is the set of contexts, T is the set of possible actions, is the set of possible outcomes, and for t, which can be either 0 or 1, () belonging to is the potential outcome for x belonging to .The quantity of interest is Individual Treatment Effect defined as: The essential problem of causal inference is we can only observe for one specific value of , so we have to predict the potential outcome () for the other .Moreover, the main problem is that we train on a dataset with a factual, empirical distribution but predict on a counterfactual empirical distribution; the distributions of these two datasets are often dissimilar [4].Therefore, it is necessary to introduce a representation learning function to characterize the input data to make their data distribution similar.In this post, we hope to optimize representation learning methods to improve the performance of causal inference models.

Adversarial Nets
Previous work used neural networks to select representation functions; we added the idea of CGAN [5] to the model, GAN is a new way to train generative models.The new model consists of two adversarial models, which are a generative model and a discriminative model ; both and can be nonlinear mapping functions.
Generative adversarial nets are composed of two adversarial models one of which is a generative model, and the other one is a discriminative model.The generative model named aims to capture the data distribution and the discriminative model known as aims to estimate the probability that a sample came from the training data instead of .The generator constructs a mapping function from a noise distribution () to data space as (; ).Also, a single scalar representing the probability that came from data was given by the discriminator named (; ).Finally, the G and D would be trained following the formula shown below [6]:

Conditional Adversarial Nets
The generative adversarial nets could be expanded to conditional adversarial nets with some extra information known as conditions which could be labels or any other formats were given to both the generator and discriminator.
The noise () and extra information were combined in the joint hidden representation, the adversarial training framework was highly flexible in how the representation was composed [6].
The and would be represented as inputs to function .and would play the two-player min-max game with the value function as following: As far as we know, we are the first to use CGAN to optimize the representation function to get better results of causal inference.

Adversarial Nets For Representations
In this part, we first introduce the value function of GAN, then, in order to find our wanted representation learning function, we apply the CGAN [7] idea and perform conditional adversarial nets for representations.Concretely, we use the noise and the control group data to find the representation function for treatment group data.Finally, the procedure of the training is displayed in Algorithm 1.

Generative Adversarial Nets
Generative Adversarial Network (GAN) includes Generator () and Discriminator () two parts.The former maps random noise to samples while the latter discriminates generated and actual samples.On the one hand, the generator strives to capture the data distribution and generate samples that are hard to differentiate by discriminator .Its purpose is to maximize the probability that will make a mistake.On the other hand, the discriminator is optimized to identify real data from generated ones.The training procedure resembles a two-player versus game with the following value function, In which is a ground truth sample from the true distribution , is a noise variable sampled over the distribution .

Adversarial Nets for Representations
As we mentioned before, our objective is to find a representation function to characterize the input data to balance the data distribution of treatment and control groups.Therefore, we borrow the idea of CGAN and try to use adversarial nets to find our wanted representation function Φ(t) .In the context of the classic causal inference question, we have a treatment group and a control group of two existing anti-diabetic medications, A or B is better for a given patient.Our work aims to use the CGAN idea to find a representation learning function that makes the data distribution of the treatment group similar to the control group.
First, we define Φ(t) as the representation learning function, is the input treatment group data.To get the distribution of the generator over Φ(t) , we set a prior on input control group data () , and (; ) represents the mapping to data space.The perceptron with multiple layers represents the different function G. Also, we set another multi-layer perceptron (Φ(t); ).Then, we denote as the noise variable, which will be fed into both the discriminator and generator units as an additional input layer to enhance the robustness.In the generator, control group data (), and are incorporated in the shared hidden representation.In the discriminator, Φ(t) and are presented as inputs and to a discriminative function.The objective function of the two-player versus game would be represented as Eq2: During the training, we turn the last term of (2) into − ~ () [ ( ((|)))] to enhance the stability of CGAN.
We use an iterative and numerical approach to implement the game.In particular, we perform N steps to optimize D and one step to optimize G in a training iteration.While the G changing slowly enough, the result in D will be maintained near its best solution.The procedure of the training formally displays in Algorithm 1.
Algorithm 1 Minibatch stochastic gradient descent training of conditional adversarial nets for representations.The number of steps applied to the discriminator, N, is a hyperparameter.
• Descend the stochastic gradient to update the generator : end for

Theoretical Results
This section will present the theoretical result of our work, i.e., applying CGAN to the finding of a representation learning function that treatment group data distribution is similar to control group data distribution.As the samples G(Con) acquired when ~ , the generator G sets a distribution .We will show that this min-max game has the global optimum in the context of = , and Algorithm 1 can obtain the desired result by optimizing Eq.2.It should be aware that this section's setting is non-parametric.
Firstly, we intend to determine the optimum discriminator D given any generator G. Proposition 1.For any given G, the optimal discriminator D is Proof.The discriminator D intends to maximize the value function F(G,D).
For any(, ) ∈ ℜ 2 {0,0}function → + 1 − reaches its maximum in [0,1] at + .The discriminator only needs to be set inside of ∪ , and the proof is concluded.Thus, the Eq.2 can be formulated as: Theorem 1.The global minimum of H(G) is reached if and only if = , i.e., the data distribution of control group equals to that of treatment group.At that point, H(G)= -log4.Proof.From Eq.5 we observe that In which KL is the Kullback-Leibler divergence [8].We obtain the Jensen-Shannon divergence [9] between the distribution of the treatment group representation and the control group data distribution: =− 4 + 2 • || (7) For the reason that the Jensen-Shannon divergence between two distributions is always zero and non-negative only if they are equal, we can arrive at the conclusion that * =− ( 4) is the global minimum of H(G), and only under the circumstance where = can achieve that, i.e., the only solution for the situation that the control group data distribution and distribution of treatment group representation function are balanced.
Proposition 2. If G and D have enough capacity, the discriminator D can achieve its optimum given G at every step of Algorithm 1, and is updated so that the improvement can be made to the criterion Φ ~ * Φ + Φ ~ 1 − * Φ then converges to As a result, in the context of = , i.e., the control group data distribution and distribution of treatment group representation function are balanced, we will find the most suitable representation function 桅() for the treatment group data, which can make control group data distribution and treatment group data distribution balanced.

Advantages & Disadvantages
Markov chains are not needed anymore which is one of our advantages and also inference is not necessary during the process of learning, the gradients is obtained only by backpropagation.[7] Comparing to using a value such as disc to measuring the imbalance between treatment group and control group, our method is a good way to balance distributions to improve the performance of causal inference models, it can reflect the effect of representation more perfectly.
The disadvantages are that the representation model with GAN is hard to optimize.And the neural network for representation learning is not good enough and needed improving.The nerual network for prediction is also far from satisfying which can not fully reflect the actual information gain in real world, and this will do no good to the whole model.

Conclusions & Future work
Future work 1. Adjust the structure of representation neural network and prediction neural network to improve the performance of causal inference.2. Distribution balance is not equal to the individual-level balance, we expect more viable solutions of tackling the balance on individual-level.3. Direction left for future work is making CGAN more robust to noise.[10] In this paper, we propose the value function V(G,D) in terms of representation to tackle the problem that data distributions in causal inference are not balanced.We also show the feasibility of using CGAN pattern to find our wanted representation learning function, i.e., a representation function in the situation where the representation distribution equals the other data distribution, and the value is -log4 in this situation.In the further research, we will explore the actual representation function by conducting experiments.