Gaussian mixture model (GMM)

A Gaussian mixture model is a probabilistic model that assumes all the data points are generated from a mixture of a finite number of Gaussian distributions with unknown parameters.

Interpretation from geometry

$p(x)$ is a weighted sum of multiple Gaussian distribution.

$$p(x)=\sum_{k=1}^{K} \alpha_{k} \cdot \mathcal{N}\left(x | \mu_{k}, \Sigma_{k}\right) $$

Interpretation from mixture model

setup:

  • The total number of Gaussian distribution $K$.

  • $x$, a sample (observed variable).

  • $z$, the distribution of the sample $x$ (a latent variable), where

    • $z \in {c_1, c_2, …, c_K}$.

    • $\sum_{k=1}^K p(z=c_k)= 1$. We denote $p(z=c_k)$ by $p_k$.

Mixture models are usually generative models, which means new data can be drawn from the distribution of models. Specifically, in the Gaussian Mixture Model (GMM), a new data is generated by first select a class $c_k$ based on the probability distribution of all classes $c$, and then draw a value from the Gaussian distribution of that class. Therefore, we could write $p(x)$ as the following

$$ \begin{aligned} p(x) &= \sum_z p(x,z) \\ &= \sum_{k=1}^{K} p(x, z=c_k) \\ &= \sum_{k=1}^{K} p(z=c_k) \cdot p(x|z=c_k) \\ &= \sum_{k=1}^{K} p_k \cdot \mathcal{N}(x | \mu_{k}, \Sigma_{k}) \end{aligned} $$

We see that two ways of interpretation reach to the same result.

GMM Derivation

set up

  • X: observed data, where $X = (x_1, x_2, …, x_N)$

  • $\theta$: parameter of the model, where $\theta={p_{1}, p_{2}, \cdots, p_{K}, \mu_{1}, \mu_{2}, \cdots, \mu_{K}, \Sigma_{1}, \Sigma_{2}, \cdots, \Sigma_{K}}$

  • $p(x) = \sum_{k=1}^{K} p_k \cdot \mathcal{N}(x | \mu_{k}, \Sigma_{k})$.

  • $p(x,z) = p(z) \cdot p(x|z) = p_z \cdot \mathcal{N}(x | \mu_{z}, \Sigma_{z})$

  • $p(z|x) = \frac{p(x,z)}{p(x)} = \frac{p_z \cdot \mathcal{N}(x | \mu_{z}, \Sigma_{z})}{\sum_{k=1}^K p_z \cdot \mathcal{N}(x | \mu_{z}, \Sigma_{z})}$

Solve by MLE

$$ \begin{aligned} \hat{\theta}_{MLE} &= \underset{\theta}{\operatorname{argmax}} p(X) \\ &=\underset{\theta}{\operatorname{argmax}} \log p(X) \\ &=\underset{\theta}{\operatorname{argmax}} \sum_{i=1}^{N} \log p\left(x_{i}\right) \\ &=\underset{\theta}{\operatorname{argmax}} \sum_{i=1}^{N} \log , [\sum_{i=1}^{K} p_{k} \cdot \mathcal{N}\left(x_{i} | \mu_{k}, \Sigma_{k}\right)] \end{aligned} $$

I mentioned in the previous posts multiple times that the log of summation is very hard to solve. Therefore, we need somehow use approximation methods to solve for optimal $\theta$. Since $Z$ is a hidden variable, it’s natural to use EM Algorithm.

Solve by EM Algorithm

Check my previous post to see how EM Algorithm is derived, Here is a briefly review of the Algorithm:

  1. Initialize peremeters $\theta_0$.

Iterate between steps 2 and 3 until convergence:

  1. Expectation (E) step:

$$ \begin{aligned} Q(\theta, \theta^{(t)}) &= \sum_Z P(Z|X,\theta^{(t)}) \cdot \log p(X,Z|\theta) \\ &= E_{Z \sim P(Z|X,\theta^{(t)})}[\log p(X,Z|\theta)] \end{aligned} $$

  1. Maximization (M) step:

Compute parameters maximizing $Q(\theta, \theta^{(t)})$ found on the $E$ step and then update parameters to $\theta^{(t+1)}$. That is

$$ \theta^{(t+1)} = \underset{\theta}{\operatorname{argmax}} Q(\theta, \theta^{(t)})$$

The derivation is the following:

E step:

$$ \begin{aligned} Q(\theta, \theta^{(t)}) &= E_{Z \sim P(Z|X,\theta^{(t)})}[\log p(X,Z|\theta)] \\ &= \sum_Z \log p(X,Z|\theta) \cdot P(Z|X,\theta^{(t)})\\ &= \sum_{Z}\left[\log \prod_{i=1}^{N} p\left(x_{i}, z_{i} | \theta\right)\right] \prod_{i=1}^{N} p\left(z_{i} | x_{i}, \theta^{(t)}\right) \\ &= \sum_{Z}\left[\sum_{i=1}^{N} \log p\left(x_{i}, z_{i} | \theta\right)\right] \prod_{i=1}^{N} p\left(z_{i} | x_{i}, \theta^{(t)}\right) &&(1)\ \end{aligned} $$

We can expand equation $(1)$ and try to simplify the first term first:

$$ \begin{aligned} & \quad \sum_{Z} \log p\left(x_{1}, z_{1} | \theta\right) \cdot \prod_{i=1}^{N} p\left(z_{i} | x_{i}, \theta^{(t)}\right) \\ &= \sum_{z_1}\sum_{z_2}…\sum_{z_N} \log p\left(x_{1}, z_{1} | \theta\right) \cdot \prod_{i=1}^{N} p\left(z_{i} | x_{i}, \theta^{(t)}\right) \\ &= \sum_{z_1}\sum_{z_2}…\sum_{z_N} \log p\left(x_{1}, z_{1} | \theta\right) \cdot p\left(z_{1} | x_{1}, \theta^{(t)}\right) \cdot \prod_{i=2}^{N} p\left(z_{i} | x_{i}, \theta^{(t)}\right) && (2)\\ &= \sum_{z_1} \log p\left(x_{1}, z_{1} | \theta\right) \cdot p\left(z_{1} | x_{1}, \theta^{(t)}\right) \sum_{z_2}…\sum_{z_N} , \prod_{i=2}^{N} p\left(z_{i} | x_{i}, \theta^{(t)}\right) && (3)\\ &= \sum_{z_1} \log p\left(x_{1}, z_{1} | \theta\right) \cdot p\left(z_{1} | x_{1}, \theta^{(t)}\right) \sum_{z_2}…\sum_{z_N} , p\left(z_{2} | x_{2}, \theta^{(t)}\right) … p\left(z_{N} | x_{N}, \theta^{(t)}\right) && (4)\\ &= \sum_{z_1} \log p\left(x_{1}, z_{1} | \theta\right) \cdot p\left(z_{1} | x_{1}, \theta^{(t)}\right) \sum_{z_2} p\left(z_{2} | x_{2}, \theta^{(t)}\right) …\sum_{z_N} p\left(z_{N} | x_{N}, \theta^{(t)}\right) && (5)\\ &= \sum_{z_1} \log p\left(x_{1}, z_{1} | \theta\right) \cdot p\left(z_{1} | x_{1}, \theta^{(t)}\right) && (6) \ \end{aligned} $$

Remark:

  • $(2) \to (3)$: since the term $\log p\left(x_{1}, z_{1} | \theta\right) \cdot p\left(z_{1} | x_{1}, \theta^{(t)}\right)$ is only related to $\sum_{z_1}$, we can pull it to the front.

  • $(4) \to (5)$: we use the same trick as in $(2) \to (3)$.

  • $(5) \to (6)$: Clearly, every summation of products is one except the first summation.

Therefore, the rest of terms can be simplified the same way. So we have

$$ \begin{aligned} Q(\theta, \theta^{(t)}) &= \sum_{Z}\left[\sum_{i=1}^{N} \log p\left(x_{i}, z_{i} | \theta\right)\right] \prod_{i=1}^{N} p\left(z_{i} | x_{i}, \theta^{(t)}\right)\\ &= \sum_{z_1} \log p\left(x_{1}, z_{1} | \theta\right) \cdot p\left(z_{1} | x_{1}, \theta^{(t)}\right) + … + \sum_{z_N} \log p\left(x_{N}, z_{N} | \theta\right) \cdot p\left(z_{N} | x_{N}, \theta^{(t)}\right) \\ &= \sum_{i=1}^N \sum_{z_i} \log p\left(x_{i}, z_{i} | \theta\right) \cdot p\left(z_{i} | x_{i}, \theta^{(t)}\right) \\ &= \sum_{k=1}^K \sum_{i=1}^N \log p\left(x_{i}, z_{k} | \theta\right) \cdot p\left(z_{i} = c_k | x_{i}, \theta^{(t)}\right) \\ &= \sum_{k=1}^K \sum_{i=1}^N [\log p_k + \log \mathcal{N}(x_i | \mu_{k}, \Sigma_{k})] \cdot p\left(z_{i} = c_k | x_{i}, \theta^{(t)}\right) \ \end{aligned} $$

where

$$ p\left(z_{i} = c_k | x_{i}, \theta^{(t)}\right) = \frac{p_{k}^{(t)} \mathcal{N}\left(x_{i} | \mu_{z_{i}}^{(t)}, \Sigma_{z_{i}}^{(t)}\right)}{\sum_{k} p_{k}^{(t)} \mathcal{N}\left(x_{i} | \mu_{k}^{(t)}, \Sigma_{k}^{(t)}\right)} $$

M step:

We can update $p^{(t+1)}, \mu^{(t+1)}, \Sigma^{(t+1)}$ separately. We first find $p^{(t+1)}$, where $p^{(t+1)} = (p_1^{(t+1)}, p_2^{(t+1)}, …, p_k^{(t+1)})$:

$$ p^{(t+1)} = \underset{p}{\operatorname{argmax}} \sum_{k=1}^K \sum_{i=1}^N \log p_k \cdot p (z_{i} = c_k | x_{i}, \theta^{(t)}) $$

with the constraint $ \sum_{k=1}^K p_k = 1$. Clearly, this is a constrained optimization problem. So we are going to solve it by introducing a Lagrange multiplier.

$$ \begin{aligned} L\left(p, \lambda \right) &=\sum_{k=1}^{K} \sum_{i=1}^{N} \log p_{k} p\left(z_{i}=c_k | x_{i}, \theta^{(t)}\right) + \lambda\left(\sum_{k=1}^{K} p_{k} - 1\right) && (7)\ &\frac{\partial L}{\partial p_{k}} &= \sum_{i=1}^{N} \frac{1}{p_{k}} p\left(z_{i}=k | x_{i}, \theta^{(t)}\right)+\lambda =0 \ \end{aligned} $$

Then multiply on both side by $p_k$ and combine all $p_k$, we have

$$\sum_{i=1}^{N} p\left(z_{i}=c_k | x_{i}, \theta^{(t)}\right) + p_k \cdot \lambda = 0 \ \sum_{i=1}^{N} \sum_{k=1}^{K} p\left(z_{i}=c_k | x_{i}, \theta^{(t)}\right) + \sum_{k=1}^{K} p_k \cdot \lambda = 0 $$

Since $\sum_{k=1}^{K} p\left(z_{i}=c_k | x_{i}, \theta^{(t)}\right) = 1$, and $\sum_{k=1}^{K} p_k = 1$, we have

$$\lambda = -N$$

Put $\lambda = -N$ back to (7), we have

$$ p_k^{(t+1)} = \frac{1}{N} \sum_{i=1}^{N} p\left(z_{i}=c_k | x_{i}, \theta^{(t)}\right)$$

Note that solving $\mu^{(t+1)}$ and $\Sigma^{(t+1)}$ is basically the same as solving $p^{(t+1)}$ except that they don’t have any constraint.

K-means

K-Means is one of the most popular “clustering” algorithms. K-means stores $k$ centroids that it uses to define clusters. A point is considered to be in a particular cluster if it is closer to that cluster’s centroid than any other centroid.

The Algorithm:

  1. Initialize $K$ centroids.

  2. Iterate until convergence:

    a. Hard assign each data-point to it’s closest centroid.

    b. Move each centroid to the center of data-points assigned to it.

Notice that this process is very similar to the way we update parameters of GMM. And we call it an EM-style method to approximate optimal parameters.

The objective function is to minimize

$$ L = \sum_{i=1}^{N} \sum_{k=1}^K \gamma_{ik} \cdot || x_i - \mu_k||^2_2$$

where $\gamma_{ik}$ = 1 if $x_i \in c_k$, $0$ otherwise. Note that $\gamma_{ik}$ here is a hard label, which can only be $0$ or $1$. So GMM is a more generalized model than K-means. In K-means, we can think $\gamma_{ik}$ as the latent variable and $\mu$ as the parameter we want to optimize.


Reference: