Noisy topic models

Table of Contents

Introduction

Matthew Stephens and Zihao Wang suggest a variation on hierarchical Poisson matrix factorization (Cemgil 2009)\( \DeclareMathOperator\Dir{Dirichlet} \DeclareMathOperator\Gam{Gamma} \DeclareMathOperator\Mult{Multinomial} \DeclareMathOperator\Pois{Poisson} \newcommand\E[1]{\left\langle #1 \right\rangle} \newcommand\const{\mathrm{const}} \newcommand\mf{\mathbf{F}} \newcommand\ml{\mathbf{L}} \newcommand\mphi{\boldsymbol{\Phi}} \newcommand\vmu{\boldsymbol{\mu}} \)

\begin{align} x_{ij} &= \sum_{k=1}^K z_{ijk}\\ z_{ijk} &\sim \Pois(l_{ik} \mu_j u_{jk})\\ u_{jk} &\sim \Gam(\theta_{jk}, \theta_{jk}), \end{align}

where the Gamma distribution is parameterized by shape and rate, \(\E{u_{jk}} = 1\), and \(V[u_{jk}] = 1 / \theta_{jk}\). The intuition is to rewrite factors \(f_{jk} = \mu_j u_{jk}\). After a suitable scaling, \(\ml\) and \(\mf\) are then a valid topic model in which most topics reflect the average gene expression at most genes, and \(\theta_{jk}\) can be used to find genes which depart from the mean, which could be of biological interest.

Setup

import ctypes
import numba
import numpy as np
import scipy.sparse as ss
import scipy.special as sp
import scipy.stats as st
%matplotlib inline
%config InlineBackend.figure_formats = set(['retina'])
import matplotlib.pyplot as plt
plt.rcParams['figure.facecolor'] = 'w'
plt.rcParams['font.family'] = 'Nimbus Sans'

Methods

Inference

The log joint is

\begin{multline} \ln p = \sum_{i,j,k} \left[ z_{ijk} \ln (l_{ik} \mu_j u_{jk}) - l_{ik} \mu_j u_{jk} - \ln\Gamma(z_{ijk} + 1) \right]\\ + \sum_{j,k} \left[ \theta_{jk}\ln \theta_{jk} + (\theta_{jk} - 1) \ln u_{jk} - \theta_{jk} u_{jk} - \ln\Gamma(\theta_{jk})\right], \end{multline}

if \(x_{ij} = \sum_k z_{ijk}\), and \(-\infty\) otherwise. By a variational argument

\begin{align} q^*(z_{ij1}, \ldots, z_{ijK}) &\propto \exp(z_{ijk}(\ln(l_{ik} \mu_j) + \E{\ln u_{jk}}))\\ &= \Mult(x_{ij}, \pi_{ij1}, \ldots, \pi_{ijK}), \qquad \pi_{ijk} \propto l_{ik}\mu_j\exp(\E{\ln u_{jk}})\\ q^*(u_{jk}) &\propto \exp(\textstyle\sum_i (\E{z_{ijk}} + \theta_{jk} - 1) \ln u_{jk} - (l_{ik} \mu_j + \theta_{jk}) u_{jk})\\ &= \Gam(\textstyle\sum_i \E{z_{ijk}} + \theta_{jk}, \textstyle\sum_i l_{ik}\mu_j + \theta_{jk})\\ &\triangleq \Gam(\alpha_{jk}, \beta_{jk}). \end{align}

where

\begin{align} \E{z_{ijk}} &= x_{ij} \pi_{ijk}\\ \E{u_{jk}} &= \alpha_{jk} / \beta_{jk}\\ \E{\ln u_{jk}} &= \psi(\alpha_{jk}) - \ln \beta_{jk} \end{align}

and \(\psi\) denotes the digamma function. The evidence lower bound (ELBO) is

\begin{multline} \ell = \sum_{i,j,k} \left[ \E{z_{ijk}} (\ln (l_{ik} \mu_j) + \E{\ln u_{jk}} - \ln\pi_{ijk}) - l_{ik} \mu_j \E{u_{jk}} \right] - \sum_{i,j} \ln\Gamma(x_{ij} + 1)\\ + \sum_{j,k} \left[ (\theta_{jk} - \alpha_{jk}) \E{\ln u_{jk}} - (\theta_{jk} - \beta_{jk}) \E{u_{jk}} - \theta_{jk}\ln \theta_{jk} + \beta_{jk} \ln\alpha_{jk} - \ln\Gamma(\theta_{jk}) + \ln\Gamma(\alpha_{jk})\right], \end{multline}

To maximize the ELBO,

\begin{align} \frac{\partial\ell}{\partial l_{ik}} &= \sum_j \frac{\E{z_{ijk}}}{l_{ik}} - \mu_j \E{u_{jk}} = 0\\ l_{ik} &= \frac{\sum_j \E{z_{ijk}}}{\sum_j \mu_j \E{u_{jk}}}\\ \frac{\partial\ell}{\partial \mu_j} &= \sum_{i, k} \frac{\E{z_{ijk}}}{\mu_j} - l_{ik} \E{u_{jk}} = 0\\ \mu_j &= \frac{\sum_{i, k} \E{z_{ijk}}}{\sum_{i, k} l_{ik} \E{u_{jk}}}\\ \frac{\partial\ell}{\partial \theta_{jk}} &= 1 + \ln \theta_{jk} + \E{\ln u_{jk}} - \psi(\theta_{jk}) \end{align}

where \(\theta_{jk}\) can be updated via gradient ascent with line search.

Implementation

lgamma = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double)(
  numba.extending.get_cython_function_address('scipy.special.cython_special', 'gammaln'))
digamma = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double)(
  numba.extending.get_cython_function_address('scipy.special.cython_special', '__pyx_fuse_1psi'))

def hpmf(x, rank, step=1e-2, atol=1e-4, max_epochs=1000, verbose=False):
  """Fit hierarchical PMF

  rank - number of latent factors
  step - initial step size for VBEM update to log(theta)
  atol - convergence criterion (change in ELBO)
  max_epochs - max number of VBEM updates
  verbose - report ELBO after each epoch

  """
  if not ss.isspmatrix_coo(x):
    x = ss.coo_matrix(x)
  n, p = x.shape
  l = np.random.uniform(size=(n, rank))
  mu = np.ones((p, 1))
  log_pi = np.full((x.nnz, rank), -rank)
  alpha = np.ones((p, rank))
  beta = np.ones((p, rank))
  theta = np.ones((p, rank))

  # numba requires arguments be array, not coo_matrix
  obj = elbo(x.data, x.row, x.col, log_pi, l, mu, alpha, beta)
  for t in range(max_epochs):
    # Expectations wrt variational distribution
    u = alpha / beta
    # Important: this needs to be vectorized
    log_u = sp.digamma(alpha) - np.log(beta)
    z = x.data.reshape(-1, 1) * np.exp(log_pi)
    # Coordinate updates (in-place)
    update_l(l, z, x.row, u, mu)
    update_mu(mu, z, x.row, u, l)
    update_u(alpha, beta, z, x.col, l, mu, theta)
    update_z(log_pi, x.row, x.col, l, mu, log_u)
    # Hyperparameter update (in-place)
    theta = update_theta(theta, u, log_u, alpha, beta, step=step)

    update = elbo(x.data, x.row, x.col, log_pi, l, mu, alpha, beta)
    if update < obj:
      raise RuntimeError('objective increased')
    elif abs(update - obj) < atol:
      return l, mu, alpha, beta, theta
    else:
      obj = update
      print(f'[{t}] elbo={elbo:.2g}')
  raise RuntimeError('max_epochs exceeded')

# @numba.njit(parallel=True)
def update_l(l, z, row, u, mu):
  d = mu.reshape(-1, 1) * u
  for i in numba.prange(l.shape[0]):
    zi = z[row == i]
    if zi.shape[0] == 0:
      continue
    for k in numba.prange(z.shape[1]):
      l[i,k] = zi[:,k].sum() / d[:,k].sum()

# @numba.njit(parallel=True)
def update_mu(mu, z, row, u, l):
  d = l @ u.T
  for i in numba.prange(l.shape[0]):
    zi = z[row == i]
    if zi.shape[0] == 0:
      continue
    di = d[i].sum()
    for j in numba.prange(mu.shape[0]):
      mu[j] = zi[j].sum() / di

# @numba.njit(parallel=True)
def update_u(alpha, beta, z, col, l, mu, theta):
  for j in numba.prange(alpha.shape[0]):
    zj = z[col == j]
    if zj.shape[0] == 0:
      continue
    lj = (l * mu).sum(axis=1)
    for k in numba.prange(alpha.shape[1]):
      alpha[j,k] = zj.sum() + theta[j,k]
      beta[j,k] = lj.sum() + theta[j,k]

# @numba.njit(parallel=True)
def update_z(log_pi, row, col, l, mu, log_u):
  for t in numba.prange(log_pi.shape[0]):
    i = row[t]
    j = col[t]
    log_pi[t] = np.log(l[i,j]) + np.log(mu[j]) + log_u[j]
    w = np.exp(log_pi[t] - log_pi[t].max())
    log_pi[t] -= w

def theta_loss(theta, u, log_u, alpha, beta):
  return (theta - alpha) * log_u - (theta - beta) * u - theta * np.log(theta) - lgamma(theta)

# @numba.njit(parallel=True)
def update_theta(theta, u, log_u, alpha, beta, step=1, c=0.5, tau=0.5, max_iters=32, eps=1e-15):
  for j in numba.prange(theta.shape[0]):
    for k in numba.prange(theta.shape[1]):
      # Important: take steps wrt log_theta to avoid non-negativity constraint
      log_theta = np.log(theta[j,k])
      d = (1 + np.log(theta[j,k]) + log_u[j,k] - digamma(theta[j,k])) * theta
      loss = theta_loss(theta[j,k], u[j,k], log_u[j,k], alpha[j,k], beta[j,k])
      update = theta_loss(np.exp(theta[j,k] + step * d), u[j,k], log_u[j,k], alpha[j,k], beta[j,k])
      while (not np.isfinite(update) or update > loss + c * step * d) and max_iters > 0:
        step *= tau
        update = theta_loss(np.exp(theta[j,k] + step * d), u[j,k], log_u[j,k], alpha[j,k], beta[j,k])
        max_iters -= 1
      if max_iters == 0:
        pass
      else:
        theta[j,k] = np.exp(log_theta + step * d) + eps

def elbo(data, row, col, log_pi, l, mu, u, log_u):
  # TODO: this is E_q[ln p] only
  temp = np.zeros_like(data)
  for t in numba.prange(data.shape[0]):
    i = row[t]
    j = col[t]
    # Important: this has shape (k,)
    temp[t] = (data[t] * np.exp(log_pi[t]) * (np.log(l[i]) + np.log(mu[j]) + log_u[j])
               - l[i] * mu[j] * u[j]).sum() + lgamma(data[t] + 1)
  return temp.sum()

Simulation

Simulate from a noisy topic model

\begin{align} x_{ij} \mid s_i, \lambda_{ij} &\sim \Pois(s_i \lambda_{ij})\\ \lambda_{ij} &= (\ml\mf')_{ij}\\ l_{i\cdot} &\sim \Dir(\boldsymbol{1}_k)\\ f_{jk} &= \mu_j u_{jk}\\ \mu_j &\sim \Dir(\boldsymbol{1}_p)\\ u_{jk} &\sim \Gam(\phi_{jk}, \phi_{jk})\\ \phi_{jk} &\sim \operatorname{Discrete}(\cdot) \end{align}
def simulate(n, p, k, s=1e4, seed=0):
  np.random.seed(0)
  mu = np.random.dirichlet(np.ones(p))
  phi = np.ones((p, k))
  idx = np.random.uniform(size=(p, k)) <= 0.01
  phi[idx] = 2
  u = st.gamma(a=1 / phi, scale=phi).rvs(size=(p, k))
  f = mu.reshape(-1, 1) * u
  l = np.random.dirichlet(np.ones(k), size=n)
  lam = l @ f.T
  x = np.random.poisson(s * lam)
  return x, l, mu, phi

Results

Simulated example

x, l, mu, phi = simulate(n=100, p=10000, k=5, s=1e3)
x = ss.coo_matrix(x)
x
<100x10000 sparse matrix of type '<class 'numpy.int64'>'
with 89681 stored elements in COOrdinate format>

Report the largest observed count, and the sparsity of the data.

x.max(), (x > 0).mean()
(8, 0.08968100000000152)

Analyze a subset of the data.

y = x.tocsc()[:,:6].tocoo()
res = hpmf(y, rank=5, verbose=True)

0 - 6f2e5199-b6aa-4b3d-886f-8e3374688295

Author: Abhishek Sarkar

Created: 2020-05-14 Thu 23:52

Validate