Iterative refinement of NPMLE grid in ashr

Table of Contents

Introduction

The NPMLE (Kiefer and Wolfowitz 1956) is

\begin{equation} \arg\max_g \sum_{i=1}^n \int_0^{\infty} \operatorname{Poisson}(x_i; s_i \lambda_i)\,dg(\lambda_i) \end{equation}

where the maximum is taken over the space of all distributions \(g\). Theorem 2 of Koenker and Mizera states if the likelihood is the standard normal density, then the optimum \(g^*\) is an atomic probability measure of not more than \(n\) atoms. (We assume that this is still for true for the Poisson density, but have not proven it.)

In practice, we approximate this measure using a mixture of tiled uniform segments, and solve

\begin{equation} \arg\max_{\pi} \sum_{i=1}^n \int_0^{\infty} \operatorname{Poisson}(x_i; s_i \lambda_i) \prod_{k=1}^{K} \pi_k \operatorname{Uniform}(\lambda_i; (k - 1)a, ka)\,d\lambda_i \end{equation}

A fundamental challenge in applying this approach to all genes in a variety of data sets is that the size of the grid (i.e., the size of the convex optimization problem) depends on the range of the data. In practice, this means that a single fixed choice of step size can either prove to be inadequate to fit variation at some genes, or require prohibitive memory to represent. To address this issue, we developed an iterative algorithm to refine the grid. The intuition behind our iterative refinement scheme is that if some segment \(k\) overlaps atoms of \(g^*\), then the corresponding \(\pi_k^* > 0\); otherwise \(\pi_k^* = 0\). Then, we can discard segments for which \(\pi_k^* = 0\), and divide the remaining segments into smaller segments, and this property should still hold.

Formally, given an initial grid size \(K\):

  1. Let \(\lambda_i = x_i / s_i\), initialize the step size \(a = \max(\lambda_i) / K\)
  2. Initialize \(\mathcal{G}\) to be the set of mixtures distributions over uniform segments \(\{(ka, (k + 1)a) \mid 0 \leq k < K\}\)
  3. Initialize the current solution \(\hat{g} = \operatorname{EBPM}(x_1, \ldots, x_n, s_1, \ldots, s_n, \mathcal{G})\)
  4. Until the improvement in log likelihood falls below a threshold (or maximum number of updates are exceeded)
    1. Update \(\mathcal{G} = \{(a_k, (b_k - a_k) / 2), ((b_k - a_k) / 2, b_k) \mid (a_k, b_k) \in \mathcal{G}, \pi_k > \mathrm{threshold}\}\)
    2. Update the current solution \(\hat{g} = \operatorname{EBPM}(x_1, \ldots, x_n, s_1, \ldots, s_n, \mathcal{G})\)

Here, we investigate cases where this scheme fails.

Setup

import anndata
import collections
import numpy as np
import os
import pandas as pd
import scanpy as sc
import scmodes
import scipy.optimize as so
import scipy.stats as st
import rpy2.robjects.packages
import rpy2.robjects.pandas2ri
rpy2.robjects.pandas2ri.activate()
ashr = rpy2.robjects.packages.importr('ashr')
%matplotlib inline
%config InlineBackend.figure_formats = set(['retina'])
import colorcet
import matplotlib.pyplot as plt
plt.rcParams['figure.facecolor'] = 'w'
plt.rcParams['font.family'] = 'Nimbus Sans'

Results

Control data example

Look at the case where (the improvement over point mass expression model of) the NPMLE expression model appears worse than the unimodal expression model. Plot the log density of the observed data.

dat = data['chromium2']()
x = dat[:,dat.var['index'] == 'ERCC-00031'].X.A.ravel()
s = dat.X.sum(axis=1).A.ravel()
lam = x / s
plt.clf()
plt.gcf().set_size_inches(2.5, 2.5)
n, h = np.histogram(x, np.arange(x.max() + 2))
plt.plot(h[:-1], np.log(n) - np.log(x.shape[0]), c='k', lw=1)
plt.xlabel('Number of molecules')
plt.ylabel('Log density')
plt.tight_layout()

chromium2-ERCC-00031.png

Report the log likelihood achieved by different initial grids.

pd.Series({
  k: scmodes.ebpm.ebpm_npmle(x, s, K=k, max_grid_updates=30, tol=1e-7).rx2('loglik')[0]
  for k in (32, 64, 128, 256, 512)
})
32    -132.714344
64    -132.714344
128   -128.090748
256   -132.714344
512   -124.649173
dtype: float64

Trace the refinement of the grid starting from \(K = 256\) and \(K = 512\).

max_grid_updates = 30
thresh = 1e-7
tol = 1e-5

res = collections.defaultdict(list)
for init in (256, 512):
  K = init
  grid = np.linspace(0, lam.max(), K + 1)
  fit = ashr.ash_pois(
    pd.Series(x), pd.Series(s),
    g=ashr.unimix(pd.Series(np.ones(K) / K), pd.Series(grid[:-1]), pd.Series(grid[1:])))
  obj = fit.rx2('loglik')[0]
  for i in range(max_grid_updates):
    print(f'init={init} iteration={i} obj={obj}')
    g = np.array(fit.rx2('fitted_g'))
    g = g[:,g[0] > thresh]
    res[init].append(g)
    grid = np.linspace(g[1], g[2], 3)
    a = pd.Series(grid[:-1,:].ravel(order='F'))
    b = pd.Series(grid[1:,:].ravel(order='F'))
    pi = pd.Series((np.tile(g[0], (2, 1)) / 2).ravel(order='F'))
    if i == 0:
      print(np.vstack((pi, a, b)))
    fit1 = ashr.ash_pois(pd.Series(x), pd.Series(s), g=ashr.unimix(pi, a, b))
    update = fit1.rx2('loglik')[0]
    if update < obj:
      raise RuntimeError('loglik decreased')
    elif update - obj < tol:
      g = np.array(fit1.rx2('fitted_g'))
      g = g[:,g[0] > thresh]
      res[init].append(g)
      break
    else:
      obj = update
      fit = fit1
plt.clf()
fig, ax = plt.subplots(1, 2, sharey=True)
fig.set_size_inches(4, 2.5)
for k, a in zip(res, ax):
  for i, g in enumerate(res[k]):
    xx = g[1:].ravel(order='F')
    yy = np.zeros(xx.shape[0])
    for j in np.arange(xx.shape[0]):
      if j % 2:
        yy[j] = yy[j - 1] + g[0][j // 2]
      elif j > 0:
        yy[j] = yy[j - 1]
    a.plot(xx, yy, c=colorcet.cm['bmy'](i / len(res[k])), marker='.', ms=4, lw=1)
  a.set_title(f'Initial $K$ = {k}')
  a.set_xlim(0, 4e-5)
  a.set_ylim(0, 1)
ax[0].set_ylabel('CDF')
a = fig.add_subplot(111, frameon=False, xticks=[], yticks=[])
a.set_xlabel('Latent gene expression', labelpad=32)
fig.tight_layout()

chromium2-ERCC-00031-trace-cdf.png

Clearly, even the first dividing step leads to the wrong solution. Look more closely at the dividing step.

init = dict()
for K in (256, 512):
  grid = np.linspace(0, lam.max(), K + 1)
  init[K] = ashr.ash_pois(pd.Series(x), pd.Series(s), g=ashr.unimix(pd.Series(np.ones(K) / K), pd.Series(grid[:-1]), pd.Series(grid[1:])))
plt.clf()
fig, ax = plt.subplots(1, 2, sharey=True)
fig.set_size_inches(4, 2.5)
for k, a in zip(init, ax):
  g = np.array(init[k].rx2('fitted_g'))
  xx = g[1:].ravel(order='F')
  yy = np.zeros(xx.shape[0])
  for j in np.arange(xx.shape[0]):
    if j % 2:
      yy[j] = yy[j - 1] + g[0][j // 2]
    elif j > 0:
      yy[j] = yy[j - 1]
  a.plot(xx, yy, c=colorcet.cm['bmy'](i / len(res[k])), marker='.', ms=4, lw=1)
  a.set_title(f'$K$ = {k}')
  a.set_xlim(0, 4e-5)
  a.set_ylim(0, 1)
ax[0].set_ylabel('CDF')
a = fig.add_subplot(111, frameon=False, xticks=[], yticks=[])
a.set_xlabel('Latent gene expression', labelpad=32)
fig.tight_layout()

chromium2-ERCC-00031-init-cdf.png

First, look at the segments with non-trivial mixture weight in the \(K=512\) initialization.

g = np.array(init[512].rx2('fitted_g'))
g[:,g[0] > thresh]
array([[9.94031230e-01, 5.54471208e-03, 4.24058253e-04],
[0.00000000e+00, 7.60198382e-04, 4.79509748e-03],
[1.16953597e-05, 7.71893741e-04, 4.80679284e-03]])

Next, look at the segments with non-trivial mixture weight in the \(K=256\) initialization.

g = np.array(init[256].rx2('fitted_g'))
g[:,g[0] > thresh]
array([[1.00000000e+00],
[0.00000000e+00],
[2.33907194e-05]])

Now, look at the result of dividing the segments with non-trivial mixture weight.

g = np.array(init[256].rx2('fitted_g'))
g = g[:,g[0] > 1e-7]
temp = np.linspace(g[1], g[2], 3)
a = pd.Series(temp[:-1,:].ravel(order='F'))
b = pd.Series(temp[1:,:].ravel(order='F'))
pi = pd.Series((np.tile(g[0], (2, 1)) / 2).ravel(order='F'))
temp = ashr.ash_pois(pd.Series(x), pd.Series(s), g=ashr.unimix(pd.Series(pi), pd.Series(a), pd.Series(b)))
g = np.array(temp.rx2('fitted_g'))
g[:,g[0] > thresh]
array([[7.31373678e-01, 2.68626322e-01],
[0.00000000e+00, 1.16953597e-05],
[1.16953597e-05, 2.33907194e-05]])

This example appears to contradict our intuition that the solution should have density wherever there was some underlying point mass, because the \(K=512\) solution put non-trivial weight on some segment that did not get non-trivial weight in the \(K=256\) solution.

B cell data example

Find a gene where the refinement procedure fails to terminate in a reasonable number of iterations.

dat = data['b_cells'](chunk=0)
x = dat[:,dat.var['gene'] == 'ENSG00000142676'].X.A.ravel()
s = dat.obs['size']
gene = dat.var.loc[dat.var['gene'] == 'ENSG00000142676', 'name'][0]
y = np.arange(x.max() + 1)
pmf = dict()

gamma_res = scmodes.ebpm.ebpm_gamma(x, s)
pmf['Gamma'] = np.array([scmodes.benchmark.gof._zig_pmf(k, size=s, log_mu=gamma_res[0], log_phi=-gamma_res[1]).mean() for k in y])

npmle_res = scmodes.ebpm.wrappers.ebpm_npmle(x, s, max_grid_updates=10, tol=1e-4, verbose=True)
g = np.array(npmle_res.rx2('fitted_g'))
a = np.fmin(g[1], g[2])
b = np.fmax(g[1], g[2])
comp_dens_conv = np.array([((st.gamma(a=k + 1, scale=1 / s.values.reshape(-1, 1)).cdf(b.reshape(1, -1)) - st.gamma(a=k + 1, scale=1 / s.values.reshape(-1, 1)).cdf(a.reshape(1, -1))) / np.outer(s, b - a)).mean(axis=0) for k in y])
comp_dens_conv[:,0] = st.poisson(mu=s.values.reshape(-1, 1) * b[0]).pmf(y).mean(axis=0)
pmf['NPMLE'] = comp_dens_conv @ g[0]

opt = so.minimize_scalar(lambda a: -scmodes.ebpm.ebpm_gamma(x, np.exp((1 - a) * np.log(s)))[-1], bracket=[0, 1], method='brent')
assert opt.success
sa = np.exp((1 - opt.x) * np.log(s))
gamma_alpha_res = scmodes.ebpm.ebpm_gamma(x, sa)
pmf[rf'Gamma ($\alpha$ = {opt.x:.3g})'] = np.array([scmodes.benchmark.gof._zig_pmf(k, size=sa.values.reshape(-1, 1), log_mu=gamma_alpha_res[0], log_phi=-gamma_alpha_res[1]).mean() for k in y])

iter obj K 0 -27044.0962599 512 1 -27044.0948073 12 2 -27044.094779 10

cm = plt.get_cmap('Dark2')
plt.clf()
plt.gcf().set_size_inches(4, 2)
plt.hist(x, bins=y, color='0.7', density=True)
for i, k in enumerate(pmf):
  plt.plot(y + .5, pmf[k], lw=1, c=cm(i), label=k)
plt.legend(frameon=False)
plt.xlim(0, y.max())
plt.xlabel('Number of molecules')
plt.ylabel('Density')
plt.title(gene)
plt.tight_layout()

b_cells-ENSG00000142676.png

Author: Abhishek Sarkar

Created: 2020-08-22 Sat 17:55

Validate