Poisson-Gamma log likelihood surface

Table of Contents

Introduction

We previously implemented maximum likelihood estimation of the ZINB model

\begin{align*} x_{ij} \mid x_i^+, \lambda_{ij} &\sim \operatorname{Poisson}(x_i^+ \lambda_{ij})\\ \lambda_{ij} &\sim g_j(\cdot) = \pi_j \delta_0(\cdot) + (1 - \pi_j) \operatorname{Gamma}(1/\phi_j, 1/(\mu_j\phi_j)) \end{align*}

via batch gradient descent. We chose this algorithm because in the motivating application (variance QTL mapping; Sarkar et al. 2019):

  1. The parameters \(\mu_j, \phi_j, \pi_j\) had a structure (one set of parameters per donor, per gene), resulting in >1.5 million parameters in total to be estimated from the data
  2. Observed techncial covariates had effects which were shared across donors (due to randomization in the study design; this makes them estimable from the data, but complicates the model)
  3. The algorithm is amenable to automatic differentiation and GPU computation

However, in our benchmarking for this study, this implementation has a major drawback: it requires memory on the order of the dimensions of the data. This is especially problematic for large, sparse data (e.g. 10X Genomics data). To address this problem, we have re-implemented this approach using stochastic gradient descent (in package scmodes). Here, we investigate convergence problems in this implementation.

Setup

import numpy as np
import scipy.stats as st
import scmodes.sgd
import scqtl.simple
import torch
import torch.utils.data
%matplotlib inline
%config InlineBackend.figure_formats = set(['retina'])
import colorcet
import matplotlib.pyplot as plt
plt.rcParams['figure.facecolor'] = 'w'
plt.rcParams['font.family'] = 'Nimbus Sans'

Results

Likelihood surface for a simulated problem

Simulate data for a single gene, fixing \(\pi_j = 0\) for simplicity.

np.random.seed(1)
# Typical values (Sarkar et al. 2019)
log_mu = np.random.uniform(-12, -6)
log_phi = np.random.uniform(-4, 0)

n = 5000
s = 1e5 * np.ones((n, 1))
lam = np.random.gamma(shape=np.exp(-log_phi), scale=np.exp(log_mu + log_phi), size=(n, 1))
x = np.random.poisson(lam=s * lam)

Plot the simulated data.

plt.clf()
plt.gcf().set_size_inches(3, 3)
plt.hist(x, bins=np.arange(x.max() + 1), color='k')
plt.xlim(0, x.max())
plt.xlabel('Number of molecules')
plt.ylabel('Number of cells')
plt.tight_layout()

example.png

Compute the log likelihood surface in a close window about the true value.

n_points = 100
window = 0.25
log_mu_grid = np.linspace(log_mu - window, log_mu + window, n_points)
log_phi_grid = np.linspace(log_phi - window, log_phi + window, n_points)
z = np.zeros((n_points, n_points))
for x_ in range(n_points):
  for y_ in range(n_points):
    z[y_, x_] = st.nbinom(n=np.exp(-log_phi_grid[y_]), p=1 / (1 + s * np.exp(log_mu_grid[x_] + log_phi_grid[y_]))).logpmf(x).sum()

Compute the MLE using the Nelder-Mead algorithm.

mu_hat, inv_phi_hat, _ = scqtl.simple.fit_nb(x.ravel(), s.ravel())

Trace the parameter values through batch gradient descent, starting from \(\ln\mu = 0, \ln\phi = 0\).

batch_gd_result = scmodes.sgd.ebpm_gamma(x, s, batch_size=n, max_epochs=2000, trace=True)
batch_trace_x = np.array(batch_gd_result[-1])[:,0]
batch_trace_y = -np.array(batch_gd_result[-1])[:,1]
batch_trace_keep = np.logical_and.reduce(
  (log_mu - window < batch_trace_x,
   batch_trace_x < log_mu + window,
   log_phi - window < batch_trace_y,
   batch_trace_y < log_phi + window))

Trace the parameter values through SGD, starting from \(\ln\mu = 0, \ln\phi = 0\).

# Important: learning rate needs to be reduced compared to batch gradient
# descent. (The default is 1e-2.)
sgd_result = scmodes.sgd.ebpm_gamma(x, s, batch_size=1, max_epochs=5, lr=1e-3, trace=True)
sgd_trace_x = np.array(sgd_result[-1])[:,0]
sgd_trace_y = -np.array(sgd_result[-1])[:,1]
sgd_trace_keep = np.logical_and.reduce(
  (log_mu - window < sgd_trace_x,
   sgd_trace_x < log_mu + window,
   log_phi - window < sgd_trace_y,
   sgd_trace_y < log_phi + window))

Trace the parameter values through minibatch SGD, starting from \(\ln\mu = 0, \ln\phi = 0\).

mb_sgd_result = scmodes.sgd.ebpm_gamma(x, s, batch_size=100, max_epochs=100, trace=True)
mb_sgd_trace_x = np.array(mb_sgd_result[-1])[:,0]
mb_sgd_trace_y = -np.array(mb_sgd_result[-1])[:,1]
mb_sgd_trace_keep = np.logical_and.reduce(
  (log_mu - window < mb_sgd_trace_x,
   mb_sgd_trace_x < log_mu + window,
   log_phi - window < mb_sgd_trace_y,
   mb_sgd_trace_y < log_phi + window))

Plot the log likelihood surface and optimization traces.

cm = plt.get_cmap('Dark2')
plt.clf()
plt.gcf().set_size_inches(6, 4)
cf = plt.contour(log_mu_grid, log_phi_grid, z, levels=20, cmap=colorcet.cm['fire'], linewidths=1)
plt.scatter(log_mu, log_phi, marker='x', c=cm(0), s=16, zorder=10, label='Ground truth')
plt.scatter(np.log(mu_hat), -np.log(inv_phi_hat), marker='x', c=cm(1), s=16, zorder=10, label='Nelder-Mead')
plt.scatter(batch_gd_result[0], -batch_gd_result[1], marker='x', c=cm(2), s=16, zorder=10, label='Batch GD')
plt.scatter(sgd_result[0], -sgd_result[1], marker='x', c=cm(3), s=16, zorder=10, label='SGD')
plt.scatter(mb_sgd_result[0], -mb_sgd_result[1], marker='x', c=cm(4), s=16, zorder=10, label='Minibatch SGD')

plt.plot(batch_trace_x[batch_trace_keep], batch_trace_y[batch_trace_keep], c=cm(2), marker=None, ls=':', lw=1, alpha=0.5, zorder=4, label=None)
plt.plot(sgd_trace_x[sgd_trace_keep][::100], sgd_trace_y[sgd_trace_keep][::100], c=cm(3), marker=None, lw=1, ls=':', alpha=0.5, zorder=5, label=None)
plt.plot(mb_sgd_trace_x[mb_sgd_trace_keep][::100], mb_sgd_trace_y[mb_sgd_trace_keep][::100], c=cm(4), marker=None, lw=1, ls=':', alpha=0.5, zorder=5, label=None)

cb = plt.colorbar(cf)
cb.set_label('Log lik')
plt.legend(frameon=True, loc='best', handletextpad=0)
plt.xlabel('$\ln(\mu)$')
plt.ylabel('$\ln(\phi)$')
plt.tight_layout()

example-llik.png

Author: Abhishek Sarkar

Created: 2019-11-25 Mon 13:47

Validate