-- When in doubt sample more!
Gibbs sampling is a Monte Carlo Markov Chain (MCMC) method that allows us to sample from any distribution as long we know how to sample from its conditional distributions. It is important to note that Gibbs sampling is only applicable in situations where the random variable X has at least 2 dimensions. This means that X must be a vector i.e. X = [x1, ..., xn]T, n ≥ 2 .
The underlying premise of Gibbs sampling is quite simple. At its heart, it mounts to approximating X ∈ Rn, n ≥ 2, sampled from P(X), by evaluating P(xi|x1, ..., xj, ..., xn) j != i, ∀ i ∈ [1, n]. The entire algorithm can be described by the following steps:
X(0) = [x1(0), ... xn(0)]; for t = 1 to T do: for i = 1 to n do: xi(t+1) ~ P(xi|x1(t+1), ..., xi-1(t+1), xi+1(t), ..., xn(t)) end for end for
In more verbose terms, start by seeding X(0) by randomly choosing values for xi ∀ i ∈ [1, n]. Then, using the initial values
xi(0) use the conditional distributions for each
xi ~ P(xi|x1, ..., xj, ..., xn) j != i
to sample the next value of
xi.
Let's work through a very simple example to understand the above algorithm in more detail. Let X = [x, y] represent discrete points in a cartesian grid such that x ∈ [0, 2], and y ∈ [0, 2]. Further let's assume we know the conditional distribution for P(x|y) and P(y|x) as follows:
x | 0 | 1 | 2 |
---|---|---|---|
P(x|y=0) | 0.60 | 0.20 | 0.20 |
P(x|y=1) | 0.00 | 1.00 | 0.00 |
P(x|y=2) | 0.00 | 1.00 | 0.00 |
y | 0 | 1 | 2 |
---|---|---|---|
P(y|x=0) | 1.00 | 0.00 | 0.00 |
P(y|x=1) | 0.17 | 0.50 | 0.33 |
P(y|x=2) | 1.00 | 0.00 | 0.00 |
A simple Gibbs sampler written in python would look as follows:
import numpy as np from numpy.random import multinomial, randint from typing import List, Tuple p_x_y = { 0: [0.60, 0.20, 0.20], 1: [0.00, 1.00, 0.00], 2: [0.00, 1.00, 0.00] } p_y_x = { 0: [1.00, 0.00, 0.00], 1: [0.17, 0.50, 0.33], 2: [1.00, 0.00, 0.00] } def sample_p_x_given_y(y: int) -> int: return np.where(multinomial(1, p_x_y[y]) == 1)[0][0] def sample_p_y_given_x(x: int) -> int: return np.where(multinomial(1, p_y_x[x]) == 1)[0][0] def gibbs_samples(t: int, burn_in=0) -> np.ndarray: x = randint(0, 2 + 1) y = randint(0, 2 + 1) observations = np.zeros((3, 3)) for i in range(t): x = sample_p_x_given_y(y) y = sample_p_y_given_x(x) if i < burn_in: continue observations[x, y] += 1 return observations
You might be wondering what burn_in
is since it wasn't mentioned previously. Although it is not directly part of the algorithm, it is common practice to throw away the first few iterations as they do not come from the posterior distribution. These discarded iterations are known as the "burn-in" period. To get a fairly accurate representation of P(X) we can try running the Gibbs sampler for ~ 10k iterations, throwing away the first 100 iterations by calling gibbs_samples(10000, 100)
. The result looks
as follows:
y x |
0 | 1 | 2 | sum |
---|---|---|---|---|
0 | 2939 | 0 | 0 | 2939 |
1 | 1037 | 2928 | 1997 | 5962 |
2 | 999 | 0 | 0 | 999 |
sum | 4975 | 2928 | 1997 | 9900 |
y x |
0 | 1 | 2 | P(x) |
---|---|---|---|---|
0 | 0.297 | 0.000 | 0.000 | 0.297 |
1 | 0.105 | 0.296 | 0.202 | 0.603 |
2 | 0.100 | 0.000 | 0.000 | 0.100 |
P(y) | 0.502 | 0.296 | 0.202 | 1.000 |
From the generated samples (left), we can approximate the joint probability table (right). We can also verify how good the simulation was by re-computing the conditional distributions from P(x, y). Using baye's theorem we get
P(x|y) = P(x, y) / P(y)
and similarly P(y|x) = P(x, y) / P(x)
. Using these equations to reconstruct the conditional probabilities we get:
x | 0 | 1 | 2 |
---|---|---|---|
P(x|y=0) | 0.591 | 0.208 | 0.201 |
P(x|y=1) | 0.000 | 1.000 | 0.000 |
P(x|y=2) | 0.000 | 1.000 | 0.000 |
y | 0 | 1 | 2 |
---|---|---|---|
P(y|x=0) | 1.000 | 0.000 | 0.000 |
P(y|x=1) | 0.174 | 0.491 | 0.335 |
P(y|x=2) | 1.000 | 0.000 | 0.000 |
As you can see the reconstruction is quite similar to the true conditional distributions we started off with, showing how simple and efficient Gibbs sampling can be! For those who wish to learn more about Gibbs sampling / go into more complex examples I highly recommend reading the following.