Noisy topic models

Matthew Stephens and Zihao Wang suggest a variation on hierarchical Poisson matrix factorization (Cemgil 2009)\( \DeclareMathOperator\Dir{Dirichlet} \DeclareMathOperator\Gam{Gamma} \DeclareMathOperator\Mult{Multinomial} \DeclareMathOperator\Pois{Poisson} \newcommand\E[1]{\left\langle #1 \right\rangle} \newcommand\const{\mathrm{const}} \newcommand\mf{\mathbf{F}} \newcommand\ml{\mathbf{L}} \newcommand\mphi{\boldsymbol{\Phi}} \newcommand\vmu{\boldsymbol{\mu}} \)

\begin{align} x_{ij} &= \sum_{k=1}^K z_{ijk}\\ z_{ijk} &\sim \Pois(l_{ik} \mu_j u_{jk})\\ u_{jk} &\sim \Gam(\theta_{jk}, \theta_{jk}), \end{align}

where the Gamma distribution is parameterized by shape and rate, \(\E{u_{jk}} = 1\), and \(V[u_{jk}] = 1 / \theta_{jk}\). The intuition is to rewrite factors \(f_{jk} = \mu_j u_{jk}\). After a suitable scaling, \(\ml\) and \(\mf\) are then a valid topic model in which most topics reflect the average gene expression at most genes, and \(\theta_{jk}\) can be used to find genes which depart from the mean, which could be of biological interest.


import ctypes
import numba
import numpy as np
import scipy.sparse as ss
import scipy.special as sp
import scipy.stats as st
%matplotlib inline
%config InlineBackend.figure_formats = set(['retina'])
import matplotlib.pyplot as plt
plt.rcParams['figure.facecolor'] = 'w'
plt.rcParams[''] = 'Nimbus Sans'



The log joint is

\begin{multline} \ln p = \sum_{i,j,k} \left[ z_{ijk} \ln (l_{ik} \mu_j u_{jk}) - l_{ik} \mu_j u_{jk} - \ln\Gamma(z_{ijk} + 1) \right]\\ + \sum_{j,k} \left[ \theta_{jk}\ln \theta_{jk} + (\theta_{jk} - 1) \ln u_{jk} - \theta_{jk} u_{jk} - \ln\Gamma(\theta_{jk})\right], \end{multline}

if \(x_{ij} = \sum_k z_{ijk}\), and \(-\infty\) otherwise. By a variational argument

\begin{align} q^*(z_{ij1}, \ldots, z_{ijK}) &\propto \exp(z_{ijk}(\ln(l_{ik} \mu_j) + \E{\ln u_{jk}}))\\ &= \Mult(x_{ij}, \pi_{ij1}, \ldots, \pi_{ijK}), \qquad \pi_{ijk} \propto l_{ik}\mu_j\exp(\E{\ln u_{jk}})\\ q^*(u_{jk}) &\propto \exp(\textstyle\sum_i (\E{z_{ijk}} + \theta_{jk} - 1) \ln u_{jk} - (l_{ik} \mu_j + \theta_{jk}) u_{jk})\\ &= \Gam(\textstyle\sum_i \E{z_{ijk}} + \theta_{jk}, \textstyle\sum_i l_{ik}\mu_j + \theta_{jk})\\ &\triangleq \Gam(\alpha_{jk}, \beta_{jk}). \end{align}


\begin{align} \E{z_{ijk}} &= x_{ij} \pi_{ijk}\\ \E{u_{jk}} &= \alpha_{jk} / \beta_{jk}\\ \E{\ln u_{jk}} &= \psi(\alpha_{jk}) - \ln \beta_{jk} \end{align}

and \(\psi\) denotes the digamma function. The evidence lower bound (ELBO) is

\begin{multline} \ell = \sum_{i,j,k} \left[ \E{z_{ijk}} (\ln (l_{ik} \mu_j) + \E{\ln u_{jk}} - \ln\pi_{ijk}) - l_{ik} \mu_j \E{u_{jk}} \right] - \sum_{i,j} \ln\Gamma(x_{ij} + 1)\\ + \sum_{j,k} \left[ (\theta_{jk} - \alpha_{jk}) \E{\ln u_{jk}} - (\theta_{jk} - \beta_{jk}) \E{u_{jk}} - \theta_{jk}\ln \theta_{jk} + \beta_{jk} \ln\alpha_{jk} - \ln\Gamma(\theta_{jk}) + \ln\Gamma(\alpha_{jk})\right], \end{multline}

To maximize the ELBO,

\begin{align} \frac{\partial\ell}{\partial l_{ik}} &= \sum_j \frac{\E{z_{ijk}}}{l_{ik}} - \mu_j \E{u_{jk}} = 0\\ l_{ik} &= \frac{\sum_j \E{z_{ijk}}}{\sum_j \mu_j \E{u_{jk}}}\\ \frac{\partial\ell}{\partial \mu_j} &= \sum_{i, k} \frac{\E{z_{ijk}}}{\mu_j} - l_{ik} \E{u_{jk}} = 0\\ \mu_j &= \frac{\sum_{i, k} \E{z_{ijk}}}{\sum_{i, k} l_{ik} \E{u_{jk}}}\\ \frac{\partial\ell}{\partial \theta_{jk}} &= 1 + \ln \theta_{jk} + \E{\ln u_{jk}} - \psi(\theta_{jk}) \end{align}

where \(\theta_{jk}\) can be updated via gradient ascent with line search.


lgamma = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double)(
  numba.extending.get_cython_function_address('scipy.special.cython_special', 'gammaln'))
digamma = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double)(
  numba.extending.get_cython_function_address('scipy.special.cython_special', '__pyx_fuse_1psi'))

def hpmf(x, rank, step=1e-2, atol=1e-4, max_epochs=1000, verbose=False):
  """Fit hierarchical PMF

  rank - number of latent factors
  step - initial step size for VBEM update to log(theta)
  atol - convergence criterion (change in ELBO)
  max_epochs - max number of VBEM updates
  verbose - report ELBO after each epoch

  if not ss.isspmatrix_coo(x):
    x = ss.coo_matrix(x)
  n, p = x.shape
  l = np.random.uniform(size=(n, rank))
  mu = np.ones((p, 1))
  log_pi = np.full((x.nnz, rank), -rank)
  alpha = np.ones((p, rank))
  beta = np.ones((p, rank))
  theta = np.ones((p, rank))

  # numba requires arguments be array, not coo_matrix
  obj = elbo(, x.row, x.col, log_pi, l, mu, alpha, beta)
  for t in range(max_epochs):
    # Expectations wrt variational distribution
    u = alpha / beta
    # Important: this needs to be vectorized
    log_u = sp.digamma(alpha) - np.log(beta)
    z =, 1) * np.exp(log_pi)
    # Coordinate updates (in-place)
    update_l(l, z, x.row, u, mu)
    update_mu(mu, z, x.row, u, l)
    update_u(alpha, beta, z, x.col, l, mu, theta)
    update_z(log_pi, x.row, x.col, l, mu, log_u)
    # Hyperparameter update (in-place)
    theta = update_theta(theta, u, log_u, alpha, beta, step=step)

    update = elbo(, x.row, x.col, log_pi, l, mu, alpha, beta)
    if update < obj:
      raise RuntimeError('objective increased')
    elif abs(update - obj) < atol:
      return l, mu, alpha, beta, theta
      obj = update
      print(f'[{t}] elbo={elbo:.2g}')
  raise RuntimeError('max_epochs exceeded')

# @numba.njit(parallel=True)
def update_l(l, z, row, u, mu):
  d = mu.reshape(-1, 1) * u
  for i in numba.prange(l.shape[0]):
    zi = z[row == i]
    if zi.shape[0] == 0:
    for k in numba.prange(z.shape[1]):
      l[i,k] = zi[:,k].sum() / d[:,k].sum()

# @numba.njit(parallel=True)
def update_mu(mu, z, row, u, l):
  d = l @ u.T
  for i in numba.prange(l.shape[0]):
    zi = z[row == i]
    if zi.shape[0] == 0:
    di = d[i].sum()
    for j in numba.prange(mu.shape[0]):
      mu[j] = zi[j].sum() / di

# @numba.njit(parallel=True)
def update_u(alpha, beta, z, col, l, mu, theta):
  for j in numba.prange(alpha.shape[0]):
    zj = z[col == j]
    if zj.shape[0] == 0:
    lj = (l * mu).sum(axis=1)
    for k in numba.prange(alpha.shape[1]):
      alpha[j,k] = zj.sum() + theta[j,k]
      beta[j,k] = lj.sum() + theta[j,k]

# @numba.njit(parallel=True)
def update_z(log_pi, row, col, l, mu, log_u):
  for t in numba.prange(log_pi.shape[0]):
    i = row[t]
    j = col[t]
    log_pi[t] = np.log(l[i,j]) + np.log(mu[j]) + log_u[j]
    w = np.exp(log_pi[t] - log_pi[t].max())
    log_pi[t] -= w

def theta_loss(theta, u, log_u, alpha, beta):
  return (theta - alpha) * log_u - (theta - beta) * u - theta * np.log(theta) - lgamma(theta)

# @numba.njit(parallel=True)
def update_theta(theta, u, log_u, alpha, beta, step=1, c=0.5, tau=0.5, max_iters=32, eps=1e-15):
  for j in numba.prange(theta.shape[0]):
    for k in numba.prange(theta.shape[1]):
      # Important: take steps wrt log_theta to avoid non-negativity constraint
      log_theta = np.log(theta[j,k])
      d = (1 + np.log(theta[j,k]) + log_u[j,k] - digamma(theta[j,k])) * theta
      loss = theta_loss(theta[j,k], u[j,k], log_u[j,k], alpha[j,k], beta[j,k])
      update = theta_loss(np.exp(theta[j,k] + step * d), u[j,k], log_u[j,k], alpha[j,k], beta[j,k])
      while (not np.isfinite(update) or update > loss + c * step * d) and max_iters > 0:
        step *= tau
        update = theta_loss(np.exp(theta[j,k] + step * d), u[j,k], log_u[j,k], alpha[j,k], beta[j,k])
        max_iters -= 1
      if max_iters == 0:
        theta[j,k] = np.exp(log_theta + step * d) + eps

def elbo(data, row, col, log_pi, l, mu, u, log_u):
  # TODO: this is E_q[ln p] only
  temp = np.zeros_like(data)
  for t in numba.prange(data.shape[0]):
    i = row[t]
    j = col[t]
    # Important: this has shape (k,)
    temp[t] = (data[t] * np.exp(log_pi[t]) * (np.log(l[i]) + np.log(mu[j]) + log_u[j])
               - l[i] * mu[j] * u[j]).sum() + lgamma(data[t] + 1)
  return temp.sum()


Simulate from a noisy topic model

\begin{align} x_{ij} \mid s_i, \lambda_{ij} &\sim \Pois(s_i \lambda_{ij})\\ \lambda_{ij} &= (\ml\mf')_{ij}\\ l_{i\cdot} &\sim \Dir(\boldsymbol{1}_k)\\ f_{jk} &= \mu_j u_{jk}\\ \mu_j &\sim \Dir(\boldsymbol{1}_p)\\ u_{jk} &\sim \Gam(\phi_{jk}, \phi_{jk})\\ \phi_{jk} &\sim \operatorname{Discrete}(\cdot) \end{align}
def simulate(n, p, k, s=1e4, seed=0):
  mu = np.random.dirichlet(np.ones(p))
  phi = np.ones((p, k))
  idx = np.random.uniform(size=(p, k)) <= 0.01
  phi[idx] = 2
  u = st.gamma(a=1 / phi, scale=phi).rvs(size=(p, k))
  f = mu.reshape(-1, 1) * u
  l = np.random.dirichlet(np.ones(k), size=n)
  lam = l @ f.T
  x = np.random.poisson(s * lam)
  return x, l, mu, phi


Simulated example

x, l, mu, phi = simulate(n=100, p=10000, k=5, s=1e3)
x = ss.coo_matrix(x)
<100x10000 sparse matrix of type '<class 'numpy.int64'>'
with 89681 stored elements in COOrdinate format>

Report the largest observed count, and the sparsity of the data.

x.max(), (x > 0).mean()
(8, 0.08968100000000152)

Analyze a subset of the data.

y = x.tocsc()[:,:6].tocoo()
res = hpmf(y, rank=5, verbose=True)

