Gibbs Sampling

-- When in doubt sample more!

Gibbs sampling is a Monte Carlo Markov Chain (MCMC) method that allows us to sample from any distribution as long we know how to sample from its conditional distributions. It is important to note that Gibbs sampling is only applicable in situations where the random variable X has at least 2 dimensions. This means that X must be a vector i.e. X = [x1, ..., xn]T, n ≥ 2 .

The underlying premise of Gibbs sampling is quite simple. At its heart, it mounts to approximating X ∈ Rn, n ≥ 2, sampled from P(X), by evaluating P(xi|x1, ..., xj, ..., xn) j != i, ∀ i ∈ [1, n]. The entire algorithm can be described by the following steps:

    X(0) = [x1(0), ... xn(0)];
    for t = 1 to T do:
        for i = 1 to n do:
            xi(t+1) ~ P(xi|x1(t+1), ..., xi-1(t+1), xi+1(t), ..., xn(t))
        end for
    end for

In more verbose terms, start by seeding X(0) by randomly choosing values for xi ∀ i ∈ [1, n]. Then, using the initial values xi(0) use the conditional distributions for each xi ~ P(xi|x1, ..., xj, ..., xn) j != i to sample the next value of xi.

Let's work through a very simple example to understand the above algorithm in more detail. Let X = [x, y] represent discrete points in a cartesian grid such that x ∈ [0, 2], and y ∈ [0, 2]. Further let's assume we know the conditional distribution for P(x|y) and P(y|x) as follows:

x 0 1 2
P(x|y=0) 0.60 0.20 0.20
P(x|y=1) 0.00 1.00 0.00
P(x|y=2) 0.00 1.00 0.00
y 0 1 2
P(y|x=0) 1.00 0.00 0.00
P(y|x=1) 0.17 0.50 0.33
P(y|x=2) 1.00 0.00 0.00

A simple Gibbs sampler written in python would look as follows:

    import numpy as np
    from numpy.random import multinomial, randint
    from typing import List, Tuple

    p_x_y = {
        0: [0.60, 0.20, 0.20],
        1: [0.00, 1.00, 0.00],
        2: [0.00, 1.00, 0.00]
    }

    p_y_x = {
        0: [1.00, 0.00, 0.00],
        1: [0.17, 0.50, 0.33],
        2: [1.00, 0.00, 0.00]
    }

    def sample_p_x_given_y(y: int) -> int:
        return np.where(multinomial(1, p_x_y[y]) == 1)[0][0]

    def sample_p_y_given_x(x: int) -> int:
        return np.where(multinomial(1, p_y_x[x]) == 1)[0][0]

    def gibbs_samples(t: int, burn_in=0) -> np.ndarray:
        x = randint(0, 2 + 1)
        y = randint(0, 2 + 1)
        observations = np.zeros((3, 3))
        
        for i in range(t):
            x = sample_p_x_given_y(y)
            y = sample_p_y_given_x(x)
            if i < burn_in:
                continue
            observations[x, y] += 1
            
        return observations
            

You might be wondering what burn_in is since it wasn't mentioned previously. Although it is not directly part of the algorithm, it is common practice to throw away the first few iterations as they do not come from the posterior distribution. These discarded iterations are known as the "burn-in" period. To get a fairly accurate representation of P(X) we can try running the Gibbs sampler for ~ 10k iterations, throwing away the first 100 iterations by calling gibbs_samples(10000, 100). The result looks as follows:

y
x
0 1 2 sum
0 2939 0 0 2939
1 1037 2928 1997 5962
2 999 0 0 999
sum 4975 2928 1997 9900
y
x
0 1 2 P(x)
0 0.297 0.000 0.000 0.297
1 0.105 0.296 0.202 0.603
2 0.100 0.000 0.000 0.100
P(y) 0.502 0.296 0.202 1.000

From the generated samples (left), we can approximate the joint probability table (right). We can also verify how good the simulation was by re-computing the conditional distributions from P(x, y). Using baye's theorem we get P(x|y) = P(x, y) / P(y) and similarly P(y|x) = P(x, y) / P(x). Using these equations to reconstruct the conditional probabilities we get:

x 0 1 2
P(x|y=0) 0.591 0.208 0.201
P(x|y=1) 0.000 1.000 0.000
P(x|y=2) 0.000 1.000 0.000
y 0 1 2
P(y|x=0) 1.000 0.000 0.000
P(y|x=1) 0.174 0.491 0.335
P(y|x=2) 1.000 0.000 0.000

As you can see the reconstruction is quite similar to the true conditional distributions we started off with, showing how simple and efficient Gibbs sampling can be! For those who wish to learn more about Gibbs sampling / go into more complex examples I highly recommend reading the following.