Massively Parallel Empirical Bayes Poisson Means

Table of Contents

Introduction

The Empirical Bayes Poisson Means (EBPM) problem is \( \DeclareMathOperator\Gam{Gamma} \DeclareMathOperator\Poi{Poisson} \DeclareMathOperator\argmin{arg min} \newcommand\mf{\mathbf{F}} \newcommand\ml{\mathbf{L}} \newcommand\mx{\mathbf{X}} \newcommand\vl{\mathbf{l}} \newcommand\vx{\mathbf{x}} \)

\begin{align*} x_i \mid s_i, \lambda_i &\sim \Poi(s_i \lambda_i)\\ \lambda_i &\sim g(\cdot) \in \mathcal{G}, \end{align*}

where the (primary) inference goal is to estimate \(g\) by maximizing the likelihood. In our prior work (Sarkar et al. 2019), we used this approach to estimate the mean and variance of gene expression from scRNA-seq data collected on a homogeneous sample of cells from each of a number of donor individuals, where we assumed \(\mathcal{G}\) was the family of point-Gamma distributions. This procedure removes the effect of variation introduced by the measurement process, leaving the variation in true gene expression levels which are of interest (Sarkar and Stephens 2020). In total, we solved 537,678 EBPM problems in parallel by formulating them as a single factor model

\begin{align*} x_{ij} \mid x_{i+}, \lambda_{ij} &\sim \Poi(x_{i+} \lambda_{ij})\\ \lambda_{ij} \mid \mu_{ij}, \phi_{ij}, \pi_{ij} &\sim \pi_{ij} \delta_0(\cdot) + (1 - \pi_{ij}) \Gam(\phi_{ij}^{-1}, \mu_{ij}^{-1} \phi_{ij}^{-1})\\ \ln \mu_{ij} &= (\ml \mf_\mu')_{ij}\\ \ln \phi_{ij} &= (\ml \mf_\phi')_{ij}\\ \operatorname{logit} \pi_{ij} &= (\ml \mf_\pi')_{ij}, \end{align*}

where

  • \(x_{ij}\) is the number of molecules of gene \(j = 1, \ldots, p\) observed in cell \(i = 1, \ldots, n\)
  • \(x_{i+} \triangleq \sum_j x_{ij}\) is the total number of molecules observed in sample \(i\)
  • cells are taken from \(m\) donor individuals, \(\ml\) is \(n \times m\), and each \(\mf_{(\cdot)}\) is \(p \times m\)
  • assignments of cells to donors (loadings) \(l_{ik} \in \{0, 1\}, k = 1, \ldots, m\) are known and fixed.

We previously implemented maximum likelihood inference of this model via batch gradient descent in the Python package scqtl. We have now developed a new Python package mpebpm, which scales to much larger data sets. The key improvements are optimization using minibatch gradient descent and support for sparse matrices. Here, we evaluate the method on simulations and large biological data sets.

Setup

Submitted batch job 1331143

import anndata
import loompy
import mpebpm
import numpy as np
import os
import pandas as pd
import scipy.sparse as ss
import scipy.special as sp
import scipy.stats as st
import scmodes
import scqtl
import time
%matplotlib inline
%config InlineBackend.figure_formats = set(['retina'])
import matplotlib.pyplot as plt
plt.rcParams['figure.facecolor'] = 'w'
plt.rcParams['font.family'] = 'Nimbus Sans'
srun --pty --partition=mstephens bash
source activate singlecell
tensorboard --host=$(hostname -i) --logdir=runs &

Results

Accuracy of parameter estimation

We previously evaluated scqtl by simulating data from the model.

def evaluate(num_samples, num_mols, num_trials=10, **kwargs):
  # Important: generate all of the samples for each trial in one shot, and use
  # one-hot encoding to get separate estimates
  args = [(num_samples * num_trials, num_mols, log_mu, log_phi, logodds, None, None, None)
          for log_mu in np.linspace(-12, -6, 7)
          for log_phi in np.linspace(-4, 0, 5)
          for logodds in np.linspace(-3, 3, 7)]
  x = np.concatenate([scqtl.simulation.simulate(*a)[0][:,:1] for a in args], axis=1)
  x = ss.csr_matrix(x)
  s = num_mols * np.ones((x.shape[0], 1))
  onehot = np.zeros((num_samples * num_trials, num_trials))
  onehot[np.arange(onehot.shape[0]), np.arange(onehot.shape[0]) // num_samples] = 1
  onehot = ss.csr_matrix(onehot)

  log_mu, neg_log_phi, logodds = mpebpm.sgd.ebpm_point_gamma(x, s=s, onehot=onehot, **kwargs)
  result = pd.DataFrame(
    [(a[0] // num_trials, int(a[1]), int(a[2]), int(a[3]), int(a[4]), a[-1], trial)
     for a in args
     for trial in range(num_trials)],
    columns=['num_samples', 'num_mols', 'log_mu', 'log_phi', 'logodds', 'fold', 'trial'])
  result['mean'] = np.exp(result['log_mu'])
  result['var'] = (1 - sp.expit(result['logodds'])) * np.exp(2 * result['log_mu'] + result['log_phi']) + sp.expit(-result['logodds']) * (1 - sp.expit(result['logodds'])) * np.exp(2 * result['log_mu'])

  result['log_mu_hat'] = log_mu.ravel(order='F')
  result['log_phi_hat'] = -neg_log_phi.ravel(order='F')
  result['logodds_hat'] = logodds.ravel(order='F')
  result['mean_hat'] = np.exp(result['log_mu_hat'])
  result['var_hat'] = (1 - sp.expit(result['logodds_hat'])) * np.exp(2 * result['log_mu_hat'] + result['log_phi_hat']) + sp.expit(-result['logodds_hat']) * (1 - sp.expit(result['logodds_hat'])) * np.exp(2 * result['log_mu_hat'])

  diagnostic = []
  for i in range(x.shape[1]):
    for j in range(onehot.shape[1]):
      idx = onehot.A[:,j].astype(bool)
      diagnostic.append(scqtl.diagnostic.diagnostic_test(
        x.A[idx,i].reshape(-1, 1),
        log_mu[j,i],
        -neg_log_phi[j,i],
        -logodds[j,i],
        num_mols,
        np.ones((num_samples, 1))))
  diagnostic = np.array(diagnostic)
  result['ks_d'] = diagnostic[:,0]
  result['ks_p'] = diagnostic[:,1]
  return result

Run the simulation.

result = [evaluate(num_samples=num_samples,
                   num_mols=int(1e5),
                   batch_size=32,
                   num_epochs=num_epochs,
                   log_dir=f'runs/mpebpm/sim-{num_samples}/')
          # Important: for fixed batch size, having more samples means more
          # updates to each parameter per epoch
          for num_samples, num_epochs in zip((100, 1000), (200, 20))]
pd.concat(result).to_csv('/scratch/midway2/aksarkar/ideas/mpebpm-sim.txt.gz', sep='\t')

Read the results.

result = pd.read_csv('/scratch/midway2/aksarkar/ideas/mpebpm-sim.txt.gz', sep='\t', index_col=0)

Plot the estimated values against the ground truth values.

samples_pass = result['num_samples'] == 100
mu_pass = result['log_mu'] > -10
pi_pass = result['logodds'] < 0

plt.clf()
fig, ax = plt.subplots(2, 3)
fig.set_size_inches(8, 5)

subset = result.loc[np.logical_and(pi_pass, samples_pass)]
ax[0, 0].scatter(subset['log_mu'], subset['log_mu_hat'], s=2, c='k')
ax[0, 0].set_xlim(-14, -5)
ax[0, 0].set_ylim(ax[0, 0].get_xlim())
ax[0, 0].plot(ax[0, 0].get_xlim(), ax[0, 0].get_xlim(), c='r', ls=':', lw=1)
ax[0, 0].set_xlabel('True $\ln(\mu)$')
ax[0, 0].set_ylabel('Estimated $\ln(\mu)$')

ax[1, 0].set_xscale('log')
ax[1, 0].set_yscale('log')
ax[1, 0].scatter(subset['mean'], subset['mean_hat'], s=2, c='k')
ax[1, 0].set_xlim(1e-6, 1e-2)
ax[1, 0].set_ylim(ax[1, 0].get_xlim())
ax[1, 0].plot(ax[1, 0].get_xlim(), ax[1, 0].get_xlim(), c='r', ls=':', lw=1)
ax[1, 0].set_xlabel('True latent mean')
ax[1, 0].set_ylabel('Estimated latent mean')

subset = result.loc[np.logical_and.reduce(np.vstack([samples_pass, mu_pass, pi_pass]))]
ax[0, 1].scatter(subset['log_phi'], subset['log_phi_hat'], s=2, c='k')
ax[0, 1].set_xlim(-5, 2)
ax[0, 1].set_ylim(ax[0, 1].get_xlim())
ax[0, 1].plot(ax[0, 1].get_xlim(), ax[0, 1].get_xlim(), c='r', ls=':', lw=1)
ax[0, 1].set_xlabel('True $\ln(\phi)$')
ax[0, 1].set_ylabel('Estimated $\ln(\phi)$')

ax[1, 1].set_xscale('log')
ax[1, 1].set_yscale('log')
ax[1, 1].scatter(subset['var'], subset['var_hat'], s=2, c='k')
ax[1, 1].set_xlim(1e-9, 5e-5)
ax[1, 1].set_ylim(ax[1, 1].get_xlim())
ax[1, 1].plot(ax[1, 1].get_xlim(), ax[1, 1].get_xlim(), c='r', ls=':', lw=1)
ax[1, 1].set_xlabel('True latent variance')
ax[1, 1].set_ylabel('Estimated latent variance')

subset = result.loc[np.logical_and(pi_pass, samples_pass)]
ax[0, 2].scatter(subset['logodds'], subset['logodds_hat'], s=2, c='k')
ax[0, 2].plot(ax[0, 2].get_xlim(), ax[0, 2].get_xlim(), c='r', ls=':', lw=1)
ax[0, 2].set_xlabel('True $\mathrm{logit}(\pi)$')
ax[0, 2].set_ylabel('Estimated $\mathrm{logit}(\pi)$')

ax[1, 2].set_axis_off()
fig.tight_layout()

sim-params.png

samples_pass = result['num_samples'] == 1000
mu_pass = result['log_mu'] > -10
pi_pass = result['logodds'] < 0

plt.clf()
fig, ax = plt.subplots(2, 3)
fig.set_size_inches(8, 5)

subset = result.loc[np.logical_and(pi_pass, samples_pass)]
ax[0, 0].scatter(subset['log_mu'], subset['log_mu_hat'], s=2, c='k')
ax[0, 0].set_xlim(-14, -5)
ax[0, 0].set_ylim(ax[0, 0].get_xlim())
ax[0, 0].plot(ax[0, 0].get_xlim(), ax[0, 0].get_xlim(), c='r', ls=':', lw=1)
ax[0, 0].set_xlabel('True $\ln(\mu)$')
ax[0, 0].set_ylabel('Estimated $\ln(\mu)$')

ax[1, 0].set_xscale('log')
ax[1, 0].set_yscale('log')
ax[1, 0].scatter(subset['mean'], subset['mean_hat'], s=2, c='k')
ax[1, 0].set_xlim(1e-6, 1e-2)
ax[1, 0].set_ylim(ax[1, 0].get_xlim())
ax[1, 0].plot(ax[1, 0].get_xlim(), ax[1, 0].get_xlim(), c='r', ls=':', lw=1)
ax[1, 0].set_xlabel('True latent mean')
ax[1, 0].set_ylabel('Estimated latent mean')

subset = result.loc[np.logical_and.reduce(np.vstack([samples_pass, mu_pass, pi_pass]))]
ax[0, 1].scatter(subset['log_phi'], subset['log_phi_hat'], s=2, c='k')
ax[0, 1].set_xlim(-5, 2)
ax[0, 1].set_ylim(ax[0, 1].get_xlim())
ax[0, 1].plot(ax[0, 1].get_xlim(), ax[0, 1].get_xlim(), c='r', ls=':', lw=1)
ax[0, 1].set_xlabel('True $\ln(\phi)$')
ax[0, 1].set_ylabel('Estimated $\ln(\phi)$')

ax[1, 1].set_xscale('log')
ax[1, 1].set_yscale('log')
ax[1, 1].scatter(subset['var'], subset['var_hat'], s=2, c='k')
ax[1, 1].set_xlim(1e-9, 5e-5)
ax[1, 1].set_ylim(ax[1, 1].get_xlim())
ax[1, 1].plot(ax[1, 1].get_xlim(), ax[1, 1].get_xlim(), c='r', ls=':', lw=1)
ax[1, 1].set_xlabel('True latent variance')
ax[1, 1].set_ylabel('Estimated latent variance')

subset = result.loc[np.logical_and(pi_pass, samples_pass)]
ax[0, 2].scatter(subset['logodds'], subset['logodds_hat'], s=2, c='k')
ax[0, 2].plot(ax[0, 2].get_xlim(), ax[0, 2].get_xlim(), c='r', ls=':', lw=1)
ax[0, 2].set_xlabel('True $\mathrm{logit}(\pi)$')
ax[0, 2].set_ylabel('Estimated $\mathrm{logit}(\pi)$')

ax[1, 2].set_axis_off()
fig.tight_layout()

sim-1000-params.png

Goodness of fit

We previously developed a test for goodness of fit, based on the fact that if \(x_{ij} \sim F_{ij}\), then \(F_{ij}(x_{ij}) \sim \operatorname{Uniform}(0, 1)\). We applied this test to the distributions estimated from the simulated data sets. Plot the histogram of goodness-of-fit \(p\)-values.

plt.clf()
plt.gcf().set_size_inches(2, 2)
plt.hist(result.loc[result['num_samples'] == 100, 'ks_p'], bins=np.linspace(0, 1, 11), density=True, color='0.7')
plt.axhline(y=1, lw=1, ls=':', c='k')
plt.xlim(0, 1)
plt.xlabel('$p$-value')
plt.ylabel('Density')
plt.tight_layout()

mpebpm-sim-gof.png

plt.clf()
plt.gcf().set_size_inches(2, 2)
plt.hist(result.loc[result['num_samples'] == 1000, 'ks_p'], bins=np.linspace(0, 1, 11), density=True, color='0.7')
plt.axhline(y=1, lw=1, ls=':', c='k')
plt.xlim(0, 1)
plt.xlabel('$p$-value')
plt.ylabel('Density')
plt.tight_layout()

mpebpm-sim-1000-gof.png

Report the number (proportion) of simulation trials where the observed data significantly depart from the estimated distribution (\(p < 0.05\) after Bonferroni correction).

sig = result.groupby('num_samples').apply(lambda x: x.loc[x['ks_p'] < 0.05 / x.shape[0]]).reset_index(drop=True)
sig.groupby('num_samples')['ks_p'].agg(len)
num_samples
1000    21.0
Name: ks_p, dtype: float64

Plot the estimated and ground truth parameters for trials where the data departed from the estimated distribution.

samples_pass = sig['num_samples'] == 1000
mu_pass = sig['log_mu'] > -10
pi_pass = sig['logodds'] < 0

plt.clf()
fig, ax = plt.subplots(2, 3)
fig.set_size_inches(8, 5)

subset = sig.loc[np.logical_and(pi_pass, samples_pass)]
ax[0, 0].scatter(subset['log_mu'], subset['log_mu_hat'], s=2, c='k')
ax[0, 0].set_xlim(-14, -5)
ax[0, 0].set_ylim(ax[0, 0].get_xlim())
ax[0, 0].plot(ax[0, 0].get_xlim(), ax[0, 0].get_xlim(), c='r', ls=':', lw=1)
ax[0, 0].set_xlabel('True $\ln(\mu)$')
ax[0, 0].set_ylabel('Estimated $\ln(\mu)$')

ax[1, 0].set_xscale('log')
ax[1, 0].set_yscale('log')
ax[1, 0].scatter(subset['mean'], subset['mean_hat'], s=2, c='k')
ax[1, 0].set_xlim(1e-6, 1e-2)
ax[1, 0].set_ylim(ax[1, 0].get_xlim())
ax[1, 0].plot(ax[1, 0].get_xlim(), ax[1, 0].get_xlim(), c='r', ls=':', lw=1)
ax[1, 0].set_xlabel('True latent mean')
ax[1, 0].set_ylabel('Estimated latent mean')

subset = sig.loc[np.logical_and.reduce(np.vstack([samples_pass, mu_pass, pi_pass]))]
ax[0, 1].scatter(subset['log_phi'], subset['log_phi_hat'], s=2, c='k')
ax[0, 1].set_xlim(-5, 2)
ax[0, 1].set_ylim(ax[0, 1].get_xlim())
ax[0, 1].plot(ax[0, 1].get_xlim(), ax[0, 1].get_xlim(), c='r', ls=':', lw=1)
ax[0, 1].set_xlabel('True $\ln(\phi)$')
ax[0, 1].set_ylabel('Estimated $\ln(\phi)$')

ax[1, 1].set_xscale('log')
ax[1, 1].set_yscale('log')
ax[1, 1].scatter(subset['var'], subset['var_hat'], s=2, c='k')
ax[1, 1].set_xlim(1e-9, 5e-5)
ax[1, 1].set_ylim(ax[1, 1].get_xlim())
ax[1, 1].plot(ax[1, 1].get_xlim(), ax[1, 1].get_xlim(), c='r', ls=':', lw=1)
ax[1, 1].set_xlabel('True latent variance')
ax[1, 1].set_ylabel('Estimated latent variance')

subset = sig.loc[np.logical_and(pi_pass, samples_pass)]
ax[0, 2].scatter(subset['logodds'], subset['logodds_hat'], s=2, c='k')
ax[0, 2].plot(ax[0, 2].get_xlim(), ax[0, 2].get_xlim(), c='r', ls=':', lw=1)
ax[0, 2].set_xlabel('True $\mathrm{logit}(\pi)$')
ax[0, 2].set_ylabel('Estimated $\mathrm{logit}(\pi)$')

ax[1, 2].set_axis_off()
fig.tight_layout()

sim-1000-params-sig.png

Application to iPSCs

We previously generated scRNA-seq of 5,597 cells from 54 donors (Sarkar et al. 2019). Read the data, and remove the donor with evidence of contamination.

x = anndata.read_h5ad('/project2/mstephens/aksarkar/projects/singlecell-ideas/data/ipsc/ipsc.h5ad')
x = x[x.obs.chip_id != 'NA18498']
x.X
<5578x9957 sparse matrix of type '<class 'numpy.float32'>'
with 39529537 stored elements in Compressed Sparse Row format>

Prepare the data.

# Important: the dense data will fit on the GPU
y = x.X.A
s = x.obs['mol_hs'].values.reshape(-1, 1)
# Important: constructing this as a dense matrix will blow up memory for larger
# data sets
onehot = ss.coo_matrix((np.ones(x.shape[0]), (np.arange(x.shape[0]), pd.Categorical(x.obs['chip_id']).codes))).tocsr()
# Important: center the matrix of dummy variables (batch), because there is no
# baseline
design = ss.coo_matrix(pd.get_dummies(x.obs['experiment'])).astype(float).A
design -= design.mean(axis=0)

Fit mpebpm (41 s to initialize; 39 s to fit).

trial = 5
num_epochs = 40
batch_size = 64
lr = 1e-2
res1 = mpebpm.sgd.ebpm_gamma(
  y,
  s=s,
  onehot=onehot,
  batch_size=batch_size,
  shuffle=True,
  lr=lr,
  num_epochs=num_epochs,
  log_dir=f'runs/mpebpm/ipsc{trial}/init/')
log_mu, neg_log_phi, logodds = mpebpm.sgd.ebpm_point_gamma(
  y,
  s=s,
  onehot=onehot,
  init=res1,
  batch_size=batch_size,
  shuffle=True,
  lr=lr,
  num_epochs=num_epochs,
  log_dir=f'runs/mpebpm/ipsc{trial}/fit/')

Estimate log likelihood for each observation under Gamma and point-Gamma expression models.

mean = s.ravel() * onehot @ np.exp(res1[0])
inv_disp = onehot @ np.exp(res1[1])
nb_llik = (y * np.log(mean / inv_disp)
           - y * np.log(1 + mean / inv_disp)
           - inv_disp * np.log(1 + mean / inv_disp)
           + sp.gammaln(y + inv_disp)
           - sp.gammaln(inv_disp)
           - sp.gammaln(y + 1))

mean = s.ravel() * onehot @ np.exp(log_mu)
inv_disp = onehot @ np.exp(neg_log_phi)
temp = (y * np.log(mean / inv_disp)
           - y * np.log(1 + mean / inv_disp)
           - inv_disp * np.log(1 + mean / inv_disp)
           + sp.gammaln(y + inv_disp)
           - sp.gammaln(inv_disp)
           - sp.gammaln(y + 1))
case_zero = -np.log1p(np.exp(onehot @ -logodds)) + np.log1p(np.exp(temp - (onehot @ logodds)))
case_non_zero = -np.log1p(np.exp(onehot @ logodds)) + temp
zinb_llik = np.where(y < 1, case_zero, case_non_zero)

Take the best model for each donor/gene combination, then evaluate the full data log likelihood.

L = np.tensordot(onehot.T.A, np.stack([nb_llik, zinb_llik], axis=-1), 1)
L.max(axis=-1).sum() / np.prod(y.shape)
-51.14861191512254

Evaluate the proportion of times each model was the best fit for donor/gene combinations.

pd.Series({k: v for k, v in zip(
  ['gamma', 'point_gamma'],
  np.histogram(L.argmax(axis=-1), bins=np.arange(3))[0])}) / np.prod(L.shape[:2])
gamma          0.311343
point_gamma    0.688657
dtype: float64

Test each individual-gene combination for goodness-of-fit to the mpebpm-estimated distribution.

result = dict()
for j in range(x.shape[1]):
  for k, donor in enumerate(pd.Categorical(x.obs['chip_id']).categories):
    idx = onehot[:,k].A.ravel().astype(bool)
    size = s[idx].ravel()
    if L[k,j,0] > L[k,j,1]:
      d, p = scmodes.benchmark.gof._gof(
        y[idx,j],
        cdf=scmodes.benchmark.gof._zig_cdf,
        pmf=scmodes.benchmark.gof._zig_pmf,
        size=size,
        log_mu=res1[0][k,j],
        log_phi=-res1[1][k,j])
    else:
      d, p = scmodes.benchmark.gof._gof(
        y[idx,j],
        cdf=scmodes.benchmark.gof._zig_cdf,
        pmf=scmodes.benchmark.gof._zig_pmf,
        size=size,
        log_mu=log_mu[k,j],
        log_phi=-neg_log_phi[k,j],
        logodds=logodds[k,j])
    result[(donor, x.var.iloc[j].name)] = pd.Series({'stat': d, 'p': p})
result = pd.DataFrame.from_dict(result, orient='index')
result.index.names = ['donor', 'gene']
result = result.reset_index()

Write out the GOF tests.

result.to_csv('/scratch/midway2/aksarkar/ideas/mpebpm-ipsc-gof.txt.gz', sep='\t')

Plot the histogram of GOF \(p\)-values.

plt.clf()
plt.gcf().set_size_inches(2, 2)
plt.hist(result['p'], bins=np.linspace(0, 1, 11), color='0.7', density=True)
plt.axhline(y=1, lw=1, ls=':', c='k')
plt.xlim(0, 1)
plt.xlabel('$p$-value')
plt.ylabel('Density')
plt.tight_layout()

mpebpm-ipsc-gof.png

Report how many individual-gene combinations (proportion) depart significantly from the estimated distribution.

sig = result.loc[result['p'] < 0.05 / result.shape[0]]
sig.shape[0], sig.shape[0] / result.shape[0]
(65, 0.000123171145359006)

Look at one of the examples where the data depart from the estimated distribution.

plt.clf()
fig, ax = plt.subplots(2, 1)
fig.set_size_inches(6, 4)
query = x[x.obs['chip_id'] == sig.iloc[0]['donor'], x.var.index == sig.iloc[0]['gene']].X.A.ravel()
ax[0].hist(query, bins=np.arange(query.max() + 1), color='k')
ax[0].set_xlabel('Number of molecules')
ax[0].set_ylabel('Number of cells')
ax[0].set_title(x.var.loc[sig.iloc[0]['gene'], 'name'])

grid = np.linspace(0, 1e-3, 1000)
j = list(x.var.index).index(sig.iloc[0]['gene'])
k = list(pd.Categorical(x.obs['chip_id']).categories).index(sig.iloc[0]['donor'])
ax[1].plot(grid, st.gamma(a=np.exp(neg_log_phi[k,j]), scale=np.exp(log_mu[k,j] - neg_log_phi[k,j])).cdf(grid), c='k', lw=1)
ax[1].set_xlabel('Latent gene expression')
ax[1].set_ylabel('CDF')

fig.tight_layout()

mpebpm-ipsc-ex.png

Report all genes at which the data depart from the estimated distribution for at least one individual.

x.var.merge(sig, left_index=True, right_on='gene', how='inner')['name'].unique()
array(['NUP98', 'B4GALT5', 'ANXA5', 'RHOG', 'MT-CO2', 'MT-CYB', 'MT-ND2',
'MT-ND4', 'MT-ATP6', 'MT-CO3', 'MT-ND4L'], dtype=object)

Confounder correction in iPSC data

Repeat the analysis, including C1 chip as a covariate (1.4 minutes init, 1.35 minutes fit).

num_epochs = 80
init = mpebpm.sgd.ebpm_gamma(
  y,
  s=s,
  onehot=onehot,
  design=design,
  batch_size=64,
  lr=1e-2,
  num_epochs=num_epochs,
  log_dir='runs/mpebpm/ipsc/design2/init')
log_mu1, neg_log_phi1, logodds1, bhat1 = mpebpm.sgd.ebpm_point_gamma(
  y,
  s=s,
  init=init[:-1],
  onehot=onehot,
  design=design,
  batch_size=64,
  lr=1e-2,
  num_epochs=num_epochs,
  log_dir='runs/mpebpm/ipsc/design2/fit')
np.save('/scratch/midway2/aksarkar/ideas/mpebpm-ipsc-design-log-mu', log_mu1)
np.save('/scratch/midway2/aksarkar/ideas/mpebpm-ipsc-design-neg-log-phi', neg_log_phi1)
np.save('/scratch/midway2/aksarkar/ideas/mpebpm-ipsc-design-logodds', logodds1)
np.save('/scratch/midway2/aksarkar/ideas/mpebpm-ipsc-design-bhat', bhat1)

Read the mpebpm and scqtl estimates.

log_mu1 = np.load('/scratch/midway2/aksarkar/ideas/mpebpm-ipsc-design-log-mu.npy')
neg_log_phi1 = np.load('/scratch/midway2/aksarkar/ideas/mpebpm-ipsc-design-neg-log-phi.npy')
logodds1 = np.load('/scratch/midway2/aksarkar/ideas/mpebpm-ipsc-design-logodds.npy')
bhat1 = np.load('/scratch/midway2/aksarkar/ideas/mpebpm-ipsc-design-bhat.npy')

log_mu2 = pd.read_table('/project2/mstephens/aksarkar/projects/singlecell-qtl/data/density-estimation/design1/zi2-log-mu.txt.gz', index_col=0, sep=' ')
log_phi2 = pd.read_table('/project2/mstephens/aksarkar/projects/singlecell-qtl/data/density-estimation/design1/zi2-log-phi.txt.gz', index_col=0, sep=' ')
logodds2 = pd.read_table('/project2/mstephens/aksarkar/projects/singlecell-qtl/data/density-estimation/design1/zi2-logodds.txt.gz', index_col=0, sep=' ')
bhat2 = pd.read_table('/project2/mstephens/aksarkar/projects/singlecell-qtl/data/density-estimation/design1/beta.txt.gz', index_col=0, sep=' ')

Compare the mpebpm estimates with and without correcting for batch.

plt.clf()
fig, ax = plt.subplots(1, 3)
fig.set_size_inches(7, 2.5)

ax[0].scatter(log_mu.ravel(), log_mu1.ravel(), s=1, c='k', alpha=0.1)
ax[0].set_xlim(ax[0].get_ylim())
ax[0].plot(ax[0].get_xlim(), ax[0].get_ylim(), lw=1, ls=':', c='r')
ax[0].set_xlabel('Est $\log(\mu)$')
ax[0].set_ylabel('Corrected est $\log(\mu)$')

ax[1].scatter(-neg_log_phi.ravel(), -neg_log_phi1.ravel(), s=1, c='k', alpha=0.1)
ax[1].set_xlim(ax[1].get_ylim())
ax[1].plot(ax[1].get_xlim(), ax[1].get_ylim(), lw=1, ls=':', c='r')
ax[1].set_xlabel('Est $\log(\phi)$')
ax[1].set_ylabel('Corrected est $\log(\phi)$')

ax[2].scatter(logodds.ravel(), logodds1.ravel(), s=1, c='k', alpha=0.1)
ax[2].set_xlim(ax[2].get_ylim())
ax[2].plot(ax[2].get_xlim(), ax[2].get_ylim(), lw=1, ls=':', c='r')
ax[2].set_xlabel('Est $\mathrm{logit}(\pi)$')
ax[2].set_ylabel('Corrected est $\mathrm{logit}(\pi)$')

fig.tight_layout()

mpebpm-ipsc-design.png

Compare the mpebpm estimated mean latent gene expression against the scqtl estimate.

mean1 = -np.log1p(np.exp(logodds1)) + log_mu1
mean2 = -np.log1p(np.exp(logodds2)) + log_mu2
del mean2['NA18498']
plt.clf()
plt.gcf().set_size_inches(2.5, 2.5)
plt.scatter(mean1.ravel(), mean2.values.ravel(order='F'), c='k', s=1, alpha=0.1)
plt.xlim(plt.ylim())
plt.plot(plt.xlim(), plt.ylim(), lw=1, ls=':', c='r')
plt.xlabel('mpebpm log latent mean')
plt.ylabel('scqtl log latent mean')
plt.tight_layout()

mpebpm-scqtl-ipsc-latent-mean.png

Variance within versus between individuals in iPSC data

Read the estimated parameters.

log_mu1 = np.load('/scratch/midway2/aksarkar/ideas/mpebpm-ipsc-design-log-mu.npy')
neg_log_phi1 = np.load('/scratch/midway2/aksarkar/ideas/mpebpm-ipsc-design-neg-log-phi.npy')
logodds1 = np.load('/scratch/midway2/aksarkar/ideas/mpebpm-ipsc-design-logodds.npy')
bhat1 = np.load('/scratch/midway2/aksarkar/ideas/mpebpm-ipsc-design-bhat.npy')

Estimate the latent mean and variance of gene expression for each individual, for each gene.

log_mean = -np.log1p(np.exp(-logodds1)) + log_mu1
# V[lam] = (1 - pi) mu^2 phi + pi (1 - pi) mu^2
# log(x + y) = log(x) + softplus(log y - log x)
# log(sigmoid(x)) = -softplus(-x)
a = -np.log1p(np.exp(-logodds1)) + 2 * log_mu1 - neg_log_phi1
b = -np.log1p(np.exp(-logodds1)) - np.log1p(np.exp(logodds1)) + 2 * log_mu1
log_var = a + np.log1p(np.exp(b - a))

Plot the relationship between variance between individuals (sample variance of latent mean gene expression values), and: (1) average gene expression, and (2) average variance of latent gene expression within an individual.

plt.clf()
fig, ax = plt.subplots(1, 2, sharey=True)
fig.set_size_inches(4.5, 2.5)
for a in ax.ravel():
  a.set_xscale('log')
  a.set_yscale('log')
  a.set_aspect('equal', adjustable='datalim')
ax[0].scatter(np.exp(log_mean).mean(axis=0), np.exp(log_mean).var(axis=0), c='k', s=1,alpha=0.2)
ax[0].set_xlabel('Mean latent gene expression')
ax[0].set_ylabel('Between individual variance\nlatent gene expression')
ax[1].scatter(np.exp(log_var).mean(axis=0), np.exp(log_mean).var(axis=0), c='k', s=1,alpha=0.2)
ax[1].set_xlabel('Within individual variance\nlatent gene expression')
fig.tight_layout()

ipsc-var-between-var-within.png

The distribution of variance between individuals appears to be bimodal. Look at the genes in the lower mode.

query = np.exp(log_mean).var(axis=0) < 1e-20
temp = x.var.loc[query].copy()
temp['mean'] = np.exp(log_mean).mean(axis=0)[query]
print(temp.sort_values('mean', ascending=False).head(n=20).to_html(classes='table'))
chr start end name strand source mean
index
ENSG00000198886 hsMT 10760 12137 MT-ND4 + H. sapiens 2.657278e-10
ENSG00000198727 hsMT 14747 15887 MT-CYB + H. sapiens 2.137406e-10
ENSG00000198899 hsMT 8527 9207 MT-ATP6 + H. sapiens 1.966070e-10
ENSG00000198938 hsMT 9207 9990 MT-CO3 + H. sapiens 1.633298e-10
ENSG00000198712 hsMT 7586 8269 MT-CO2 + H. sapiens 1.565673e-10
ENSG00000087086 hs19 49468558 49470135 FTL + H. sapiens 1.232170e-10
ENSG00000111640 hs12 6643093 6647537 GAPDH + H. sapiens 1.212971e-10
ENSG00000198763 hsMT 4470 5511 MT-ND2 + H. sapiens 1.182991e-10
ENSG00000075624 hs7 5566782 5603415 ACTB - H. sapiens 1.138831e-10
ENSG00000164587 hs5 149822753 149829319 RPS14 - H. sapiens 1.133044e-10
ENSG00000084207 hs11 67351066 67354131 GSTP1 + H. sapiens 1.117254e-10
ENSG00000181163 hs5 170814120 170838141 NPM1 + H. sapiens 1.109008e-10
ENSG00000228253 hsMT 8366 8572 MT-ATP8 + H. sapiens 1.093323e-10
ENSG00000137818 hs15 69745123 69748255 RPLP1 + H. sapiens 1.091205e-10
ENSG00000231500 hs6 33239787 33244287 RPS18 + H. sapiens 1.037726e-10
ENSG00000198804 hsMT 5904 7445 MT-CO1 + H. sapiens 1.030106e-10
ENSG00000105372 hs19 42363988 42376994 RPS19 + H. sapiens 1.018074e-10
ENSG00000149273 hs11 75110530 75133324 RPS3 + H. sapiens 9.832752e-11
ENSG00000197061 hs6 26104104 26104518 HIST1H4C + H. sapiens 9.703387e-11
ENSG00000177954 hs1 153963235 153964626 RPS27 + H. sapiens 9.048103e-11

Application to Census of Immune Cells

The Census of Immune Cells is part of the Human Cell Atlas. Currently, it comprises scRNA-seq data of 593,844 cells from 16 donors. The cells vary in donor sex, donor ancestry, tissue of origin, as well as (latent) cell types/subtypes/states within each donor.

with loompy.connect('/project2/mstephens/aksarkar/projects/singlecell-ideas/data/human-cell-atlas/immune-cell-census.loom') as con:
  metadata = pd.DataFrame(np.hstack([
    con.ca['derived_organ_parts_label', 'donor_organism.human_specific.ethnicity.ontology_label'],
    con.ca['donor_organism.provenance.document_id', 'donor_organism.sex']]))
  print(metadata.groupby([0,3,1,2]).agg(len).to_frame().to_html(classes='table'))
0
0 3 1 2
bone marrow female European 085e737d-adb5-4597-bd54-5ebeda170038 49716
af7fe7a6-7d7e-4cdf-9799-909680fa9a3f 44584
cf514c66-88b2-45e4-a397-7fb362ae9950 48630
fb30bb83-0278-4117-bd42-e2e8dddfedfe 49839
male African American d23515a7-e182-4bc6-89e2-b1635885c0ec 51239
eb8fb36b-6e02-41c4-8760-3eabbde6bacb 51024
European 0a6c46dd-0905-4581-95eb-d89eef8a7213 47167
9aaf8a07-924f-456c-86dc-82f5da718246 48784
umbilical cord blood female Asian e4b5115d-3a0d-4c50-aba4-04b5f76810da 57893
European 4a404c91-0dbf-4246-bc23-d13aff961ba7 45078
509c507c-4759-452f-994e-d134d90329fd 39584
male African American 0b91cb1f-e2a8-413a-836c-1d38e7af3f2d 54544
European 53af872d-b838-44d6-ae1b-25b56405483c 62609
6072d1f5-aa0c-4ab1-a8a6-a00ab479a1ba 52142
nan 31f89559-2682-4bbc-84c6-826dfe4a4e39 29455
4e98f612-15ec-44ab-b5f9-39787f92b01a 50571

To demonstrate the scalability of mpebpm, fit a point-Gamma distribution to each gene in each donor, resulting in 254,432 EBPM problems. We previously pre-processed the data to npz format, which is much faster to read than h5ad or loom.

y_csr = ss.load_npz('/scratch/midway2/aksarkar/modes/immune-cell-census.npz')
genes = pd.read_csv('/scratch/midway2/aksarkar/modes/immune-cell-census-genes.txt.gz', sep='\t', index_col=0)
donor = pd.read_csv('/scratch/midway2/aksarkar/modes/immune-cell-census-samples.txt.gz', sep='\t', index_col=0)['0']
onehot = ss.csr_matrix(pd.get_dummies(donor).values)

Remove genes which only have zero observations in some donor.

# Important: CSC needed to subset on columns (genes)
y_csc = y_csr.tocsc()
keep = (((y_csc.T @ onehot) > 0).sum(axis=1) == onehot.shape[1]).A.ravel()
genes = genes.loc[keep]
y_csc = y_csc[:,keep]
y_csr = y_csc.tocsr()
s = y_csr.sum(axis=1).A.ravel()
y_csr
<593844x15902 sparse matrix of type '<class 'numpy.int32'>'
with 550918891 stored elements in Compressed Sparse Row format>

Fit mpebpm (6 minutes/epoch; 33 minutes fit).

# This converges quickly
init = mpebpm.sgd.ebpm_gamma(
  y_csr, onehot=onehot, batch_size=128, lr=1e-2,
  max_epochs=1, shuffle=True)
log_mu, neg_log_phi, logodds = mpebpm.sgd.ebpm_point_gamma(
  y_csr, onehot=onehot, init=init, batch_size=128, lr=1e-2,
  max_epochs=5, shuffle=True, logdir='runs/mpebpm6')
pd.DataFrame(log_mu, index=donor.unique(), columns=genes['featurekey']).to_csv('/scratch/midway2/aksarkar/ideas/mpebpm-immune-census-log-mu.txt.gz', sep='\t')
pd.DataFrame(neg_log_phi, index=donor.unique(), columns=genes['featurekey']).to_csv('/scratch/midway2/aksarkar/ideas/mpebpm-immune-census-neg-log-phi.txt.gz', sep='\t')
pd.DataFrame(logodds, index=donor.unique(), columns=genes['featurekey']).to_csv('/scratch/midway2/aksarkar/ideas/mpebpm-immune-census-logodds.txt.gz', sep='\t')

Read the estimated parameters.

log_mu = pd.read_csv('/scratch/midway2/aksarkar/ideas/mpebpm-immune-census-log-mu.txt.gz', sep='\t', index_col=0)
neg_log_phi = pd.read_csv('/scratch/midway2/aksarkar/ideas/mpebpm-immune-census-neg-log-phi.txt.gz', sep='\t', index_col=0)
logodds = pd.read_csv('/scratch/midway2/aksarkar/ideas/mpebpm-immune-census-logodds.txt.gz', sep='\t', index_col=0)

Test each donor-gene combination for goodness-of-fit to the mpebpm-estimated distribution.

result = dict()
for j in range(y_csr.shape[1]):
  query = y_csc[:,j].tocsr()
  for k, name in enumerate(donor.unique()):
    # Important: scqtl.diagnostic blows up memory for some reason
    idx = onehot[:,k].tocsc().indices
    d, p = scmodes.benchmark.gof._gof(
       query[idx].A.ravel(),
       cdf=scmodes.benchmark.gof._zig_cdf,
       pmf=scmodes.benchmark.gof._zig_pmf,
       size=s[idx].ravel(),
       log_mu=log_mu.iloc[k,j],
       log_phi=-neg_log_phi.iloc[k,j],
       logodds=logodds.iloc[k,j])
    result[(name, genes.iloc[j]['Gene'])] = pd.Series({'stat': d, 'p': p})
result = pd.DataFrame.from_dict(result, orient='index')
result.index.names = ['donor', 'gene']
result = result.reset_index()
result.to_csv('/scratch/midway2/aksarkar/ideas/mpebpm-immune-census-gof.txt.gz', sep='\t')

0 - ee054e05-7b75-4fed-96f3-09c950cabb8c

Plot the histogram of GOF \(p\)-values.

plt.clf()
plt.gcf().set_size_inches(2, 2)
plt.hist(result['p'], bins=np.linspace(0, 1, 11), color='0.7', density=True)
plt.axhline(y=1, lw=1, ls=':', c='k')
plt.xlabel('$p$-value')
plt.ylabel('Density')
plt.tight_layout()

mpebpm-immune-census-gof.png

Look at an example.

j = 1
k = 0
query = y_csc[:,k].tocsr()
idx = onehot[:,j].tocsc().indices
x = query[idx].A.ravel()
scmodes.benchmark.gof._gof(
  x, 
  cdf=scmodes.benchmark.gof._zig_cdf,
  pmf=scmodes.benchmark.gof._zig_pmf,
  size=s[idx],
  log_mu=init[0][k,j],
  log_phi=-init[1][k,j],)
KstestResult(statistic=0.14346460385738558, pvalue=0.0)

scmodes.benchmark.gof._gof(
  x, 
  cdf=scmodes.benchmark.gof._zig_cdf,
  pmf=scmodes.benchmark.gof._zig_pmf,
  size=s[idx],
  log_mu=log_mu.iloc[k,j],
  log_phi=-neg_log_phi.iloc[k,j],
  logodds=logodds.iloc[k,j])
KstestResult(statistic=0.13278115317033756, pvalue=0.0)

res1 = scmodes.ebpm.ebpm_point_gamma(x, s[idx].ravel())
scmodes.benchmark.gof._gof(
  x, 
  cdf=scmodes.benchmark.gof._zig_cdf,
  pmf=scmodes.benchmark.gof._zig_pmf,
  size=s[idx],
  log_mu=res1[0],
  log_phi=-res1[1],
  logodds=res1[2])
KstestResult(statistic=0.006399062002559908, pvalue=0.1017356159621166)

Compare the negative log likelihood for this gene in this individual.

pd.Series({'scmodes': scmodes.ebpm.wrappers._zinb_obj(res1[:-1], x, s[idx]),
           'mpebpm': scmodes.ebpm.wrappers._zinb_obj([log_mu.iloc[k,j], neg_log_phi.iloc[k,j], logodds.iloc[k,j]], x, s[idx])})
scmodes     858.372281
mpebpm     7139.640261
dtype: float64
cm = plt.get_cmap('Paired')
plt.clf()
fig, ax = plt.subplots(2, 1)
fig.set_size_inches(5, 3.5)

grid = np.arange(x.max() + 2)
ax[0].hist(x, bins=grid, color='k')
ax[0].set_xticks(grid)
ax[0].set_xlabel('Number of molecules')
ax[0].set_ylabel('Number of cells')
ax[0].set_title(genes.iloc[1]['Gene'])

grid = np.linspace(0, 1e-5, 1000)
pi0 = sp.expit(logodds.iloc[k,j])
F = pi0 + (1 - pi0) * st.gamma(a=np.exp(-neg_log_phi.iloc[k,j]), scale=np.exp(log_mu.iloc[k,j] - neg_log_phi.iloc[k,j])).cdf(grid)
ax[1].plot(grid, F, c=cm(0), lw=1, label='mpebpm')

pi0 = sp.expit(res1[2])
F = pi0 + (1 - pi0) * st.gamma(a=np.exp(-res1[1]), scale=np.exp(res1[0] - res1[1])).cdf(grid)
ax[1].plot(grid, F, c=cm(1), lw=1, label='scmodes')
ax[1].set_xlabel('Latent gene expression')
ax[1].set_ylabel('CDF')
ax[1].legend(frameon=False)
fig.tight_layout()

ic-dpm1.png

Author: Abhishek Sarkar

Created: 2020-05-14 Thu 23:53

Validate