Noisy topic models
Table of Contents
Introduction
Matthew Stephens and Zihao Wang suggest a variation on hierarchical Poisson matrix factorization (Cemgil 2009)\( \DeclareMathOperator\Dir{Dirichlet} \DeclareMathOperator\Gam{Gamma} \DeclareMathOperator\Mult{Multinomial} \DeclareMathOperator\Pois{Poisson} \newcommand\E[1]{\left\langle #1 \right\rangle} \newcommand\const{\mathrm{const}} \newcommand\mf{\mathbf{F}} \newcommand\ml{\mathbf{L}} \newcommand\mphi{\boldsymbol{\Phi}} \newcommand\vmu{\boldsymbol{\mu}} \)
\begin{align} x_{ij} &= \sum_{k=1}^K z_{ijk}\\ z_{ijk} &\sim \Pois(l_{ik} \mu_j u_{jk})\\ u_{jk} &\sim \Gam(\theta_{jk}, \theta_{jk}), \end{align}where the Gamma distribution is parameterized by shape and rate, \(\E{u_{jk}} = 1\), and \(V[u_{jk}] = 1 / \theta_{jk}\). The intuition is to rewrite factors \(f_{jk} = \mu_j u_{jk}\). After a suitable scaling, \(\ml\) and \(\mf\) are then a valid topic model in which most topics reflect the average gene expression at most genes, and \(\theta_{jk}\) can be used to find genes which depart from the mean, which could be of biological interest.
Setup
import ctypes import numba import numpy as np import scipy.sparse as ss import scipy.special as sp import scipy.stats as st
%matplotlib inline %config InlineBackend.figure_formats = set(['retina'])
import matplotlib.pyplot as plt plt.rcParams['figure.facecolor'] = 'w' plt.rcParams['font.family'] = 'Nimbus Sans'
Methods
Inference
The log joint is
\begin{multline} \ln p = \sum_{i,j,k} \left[ z_{ijk} \ln (l_{ik} \mu_j u_{jk}) - l_{ik} \mu_j u_{jk} - \ln\Gamma(z_{ijk} + 1) \right]\\ + \sum_{j,k} \left[ \theta_{jk}\ln \theta_{jk} + (\theta_{jk} - 1) \ln u_{jk} - \theta_{jk} u_{jk} - \ln\Gamma(\theta_{jk})\right], \end{multline}if \(x_{ij} = \sum_k z_{ijk}\), and \(-\infty\) otherwise. By a variational argument
\begin{align} q^*(z_{ij1}, \ldots, z_{ijK}) &\propto \exp(z_{ijk}(\ln(l_{ik} \mu_j) + \E{\ln u_{jk}}))\\ &= \Mult(x_{ij}, \pi_{ij1}, \ldots, \pi_{ijK}), \qquad \pi_{ijk} \propto l_{ik}\mu_j\exp(\E{\ln u_{jk}})\\ q^*(u_{jk}) &\propto \exp(\textstyle\sum_i (\E{z_{ijk}} + \theta_{jk} - 1) \ln u_{jk} - (l_{ik} \mu_j + \theta_{jk}) u_{jk})\\ &= \Gam(\textstyle\sum_i \E{z_{ijk}} + \theta_{jk}, \textstyle\sum_i l_{ik}\mu_j + \theta_{jk})\\ &\triangleq \Gam(\alpha_{jk}, \beta_{jk}). \end{align}where
\begin{align} \E{z_{ijk}} &= x_{ij} \pi_{ijk}\\ \E{u_{jk}} &= \alpha_{jk} / \beta_{jk}\\ \E{\ln u_{jk}} &= \psi(\alpha_{jk}) - \ln \beta_{jk} \end{align}and \(\psi\) denotes the digamma function. The evidence lower bound (ELBO) is
\begin{multline} \ell = \sum_{i,j,k} \left[ \E{z_{ijk}} (\ln (l_{ik} \mu_j) + \E{\ln u_{jk}} - \ln\pi_{ijk}) - l_{ik} \mu_j \E{u_{jk}} \right] - \sum_{i,j} \ln\Gamma(x_{ij} + 1)\\ + \sum_{j,k} \left[ (\theta_{jk} - \alpha_{jk}) \E{\ln u_{jk}} - (\theta_{jk} - \beta_{jk}) \E{u_{jk}} - \theta_{jk}\ln \theta_{jk} + \beta_{jk} \ln\alpha_{jk} - \ln\Gamma(\theta_{jk}) + \ln\Gamma(\alpha_{jk})\right], \end{multline}To maximize the ELBO,
\begin{align} \frac{\partial\ell}{\partial l_{ik}} &= \sum_j \frac{\E{z_{ijk}}}{l_{ik}} - \mu_j \E{u_{jk}} = 0\\ l_{ik} &= \frac{\sum_j \E{z_{ijk}}}{\sum_j \mu_j \E{u_{jk}}}\\ \frac{\partial\ell}{\partial \mu_j} &= \sum_{i, k} \frac{\E{z_{ijk}}}{\mu_j} - l_{ik} \E{u_{jk}} = 0\\ \mu_j &= \frac{\sum_{i, k} \E{z_{ijk}}}{\sum_{i, k} l_{ik} \E{u_{jk}}}\\ \frac{\partial\ell}{\partial \theta_{jk}} &= 1 + \ln \theta_{jk} + \E{\ln u_{jk}} - \psi(\theta_{jk}) \end{align}where \(\theta_{jk}\) can be updated via gradient ascent with line search.
Implementation
lgamma = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double)( numba.extending.get_cython_function_address('scipy.special.cython_special', 'gammaln')) digamma = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double)( numba.extending.get_cython_function_address('scipy.special.cython_special', '__pyx_fuse_1psi')) def hpmf(x, rank, step=1e-2, atol=1e-4, max_epochs=1000, verbose=False): """Fit hierarchical PMF rank - number of latent factors step - initial step size for VBEM update to log(theta) atol - convergence criterion (change in ELBO) max_epochs - max number of VBEM updates verbose - report ELBO after each epoch """ if not ss.isspmatrix_coo(x): x = ss.coo_matrix(x) n, p = x.shape l = np.random.uniform(size=(n, rank)) mu = np.ones((p, 1)) log_pi = np.full((x.nnz, rank), -rank) alpha = np.ones((p, rank)) beta = np.ones((p, rank)) theta = np.ones((p, rank)) # numba requires arguments be array, not coo_matrix obj = elbo(x.data, x.row, x.col, log_pi, l, mu, alpha, beta) for t in range(max_epochs): # Expectations wrt variational distribution u = alpha / beta # Important: this needs to be vectorized log_u = sp.digamma(alpha) - np.log(beta) z = x.data.reshape(-1, 1) * np.exp(log_pi) # Coordinate updates (in-place) update_l(l, z, x.row, u, mu) update_mu(mu, z, x.row, u, l) update_u(alpha, beta, z, x.col, l, mu, theta) update_z(log_pi, x.row, x.col, l, mu, log_u) # Hyperparameter update (in-place) theta = update_theta(theta, u, log_u, alpha, beta, step=step) update = elbo(x.data, x.row, x.col, log_pi, l, mu, alpha, beta) if update < obj: raise RuntimeError('objective increased') elif abs(update - obj) < atol: return l, mu, alpha, beta, theta else: obj = update print(f'[{t}] elbo={elbo:.2g}') raise RuntimeError('max_epochs exceeded') # @numba.njit(parallel=True) def update_l(l, z, row, u, mu): d = mu.reshape(-1, 1) * u for i in numba.prange(l.shape[0]): zi = z[row == i] if zi.shape[0] == 0: continue for k in numba.prange(z.shape[1]): l[i,k] = zi[:,k].sum() / d[:,k].sum() # @numba.njit(parallel=True) def update_mu(mu, z, row, u, l): d = l @ u.T for i in numba.prange(l.shape[0]): zi = z[row == i] if zi.shape[0] == 0: continue di = d[i].sum() for j in numba.prange(mu.shape[0]): mu[j] = zi[j].sum() / di # @numba.njit(parallel=True) def update_u(alpha, beta, z, col, l, mu, theta): for j in numba.prange(alpha.shape[0]): zj = z[col == j] if zj.shape[0] == 0: continue lj = (l * mu).sum(axis=1) for k in numba.prange(alpha.shape[1]): alpha[j,k] = zj.sum() + theta[j,k] beta[j,k] = lj.sum() + theta[j,k] # @numba.njit(parallel=True) def update_z(log_pi, row, col, l, mu, log_u): for t in numba.prange(log_pi.shape[0]): i = row[t] j = col[t] log_pi[t] = np.log(l[i,j]) + np.log(mu[j]) + log_u[j] w = np.exp(log_pi[t] - log_pi[t].max()) log_pi[t] -= w def theta_loss(theta, u, log_u, alpha, beta): return (theta - alpha) * log_u - (theta - beta) * u - theta * np.log(theta) - lgamma(theta) # @numba.njit(parallel=True) def update_theta(theta, u, log_u, alpha, beta, step=1, c=0.5, tau=0.5, max_iters=32, eps=1e-15): for j in numba.prange(theta.shape[0]): for k in numba.prange(theta.shape[1]): # Important: take steps wrt log_theta to avoid non-negativity constraint log_theta = np.log(theta[j,k]) d = (1 + np.log(theta[j,k]) + log_u[j,k] - digamma(theta[j,k])) * theta loss = theta_loss(theta[j,k], u[j,k], log_u[j,k], alpha[j,k], beta[j,k]) update = theta_loss(np.exp(theta[j,k] + step * d), u[j,k], log_u[j,k], alpha[j,k], beta[j,k]) while (not np.isfinite(update) or update > loss + c * step * d) and max_iters > 0: step *= tau update = theta_loss(np.exp(theta[j,k] + step * d), u[j,k], log_u[j,k], alpha[j,k], beta[j,k]) max_iters -= 1 if max_iters == 0: pass else: theta[j,k] = np.exp(log_theta + step * d) + eps def elbo(data, row, col, log_pi, l, mu, u, log_u): # TODO: this is E_q[ln p] only temp = np.zeros_like(data) for t in numba.prange(data.shape[0]): i = row[t] j = col[t] # Important: this has shape (k,) temp[t] = (data[t] * np.exp(log_pi[t]) * (np.log(l[i]) + np.log(mu[j]) + log_u[j]) - l[i] * mu[j] * u[j]).sum() + lgamma(data[t] + 1) return temp.sum()
Simulation
Simulate from a noisy topic model
\begin{align} x_{ij} \mid s_i, \lambda_{ij} &\sim \Pois(s_i \lambda_{ij})\\ \lambda_{ij} &= (\ml\mf')_{ij}\\ l_{i\cdot} &\sim \Dir(\boldsymbol{1}_k)\\ f_{jk} &= \mu_j u_{jk}\\ \mu_j &\sim \Dir(\boldsymbol{1}_p)\\ u_{jk} &\sim \Gam(\phi_{jk}, \phi_{jk})\\ \phi_{jk} &\sim \operatorname{Discrete}(\cdot) \end{align}def simulate(n, p, k, s=1e4, seed=0): np.random.seed(0) mu = np.random.dirichlet(np.ones(p)) phi = np.ones((p, k)) idx = np.random.uniform(size=(p, k)) <= 0.01 phi[idx] = 2 u = st.gamma(a=1 / phi, scale=phi).rvs(size=(p, k)) f = mu.reshape(-1, 1) * u l = np.random.dirichlet(np.ones(k), size=n) lam = l @ f.T x = np.random.poisson(s * lam) return x, l, mu, phi
Results
Simulated example
x, l, mu, phi = simulate(n=100, p=10000, k=5, s=1e3) x = ss.coo_matrix(x) x
<100x10000 sparse matrix of type '<class 'numpy.int64'>' with 89681 stored elements in COOrdinate format>
Report the largest observed count, and the sparsity of the data.
x.max(), (x > 0).mean()
(8, 0.08968100000000152)
Analyze a subset of the data.
y = x.tocsc()[:,:6].tocoo() res = hpmf(y, rank=5, verbose=True)
0 - 6f2e5199-b6aa-4b3d-886f-8e3374688295