Speed up ash mode estimation

Table of Contents

Introduction

We want to deconvolve scRNA-seq data assuming \(g_j\) is some unimodal distribution over non-negative reals. In practice, we represent this family of distribution as

\[ g_j = \sum_{k=1}^K \pi_k \mathrm{Uniform}(\cdot; \lambda_0, a_{jk}) \]

where \(K\) is sufficiently large and \(\lambda_0\) is the mode (Stephens 2016).

To estimate the mode \(\lambda_{0j}\) for gene \(j\), we find:

\[ \lambda_{0j}^* = \arg\max_{\lambda_{0j}} \sum_i \int f(x_i \mid \lambda_i) g_j(\lambda_i \mid \pi, \lambda_{0j})\ d\lambda_i \]

using golden section search. Here, we investigate practical issues in this approach.

Setup

import functools as ft
import multiprocessing as mp
import numpy as np
import pandas as pd
import scipy.stats as st
import scipy.special as sp
import scmodes
import scqtl
import sklearn.model_selection as skms

import rpy2.robjects.packages
import rpy2.robjects.pandas2ri
import rpy2.robjects.numpy2ri

rpy2.robjects.pandas2ri.activate()
rpy2.robjects.numpy2ri.activate()

ashr = rpy2.robjects.packages.importr('ashr')
descend = rpy2.robjects.packages.importr('descend')
%matplotlib inline
%config InlineBackend.figure_formats = set(['retina'])
import colorcet
import matplotlib.pyplot as plt
plt.rcParams['figure.facecolor'] = 'w'
plt.rcParams['font.family'] = 'Nimbus Sans'

Results

Convexity

As an example, use the highest expressed genes in 10K sorted CD8+ cytotoxic T cells Zheng et al. 2017.

x = scmodes.dataset.read_10x('/project2/mstephens/aksarkar/projects/singlecell-ideas/data/10xgenomics/cytotoxic_t/filtered_matrices_mex/hg19')
xj = pd.Series(x[:,x.mean(axis=0).argmax()])
size_factor = pd.Series(x.sum(axis=1))
lam = xj / size_factor

deconv-example.png

Mengyin Liu claimed this problem is convex in \(\lambda_{0j}\), However, on the above example, the quality of the result depends on the bounds of the search. Is this problem actually convex?

By default, the bounds are \([\min(x_i), \max(x_i)]\), which can be extremely large. However, we need to remove the scaling factor, so should we instead search over \([\min(x_i / R_i), \max(x_i / R_i)]\)? The motivation for the proposed alternative is to only look over plausible values of \(\lambda_i\).

grid = np.geomspace(1e-3, xj.max(), 100)
llik = np.array([np.array(
  ashr.ash(
    pd.Series(np.zeros(xj.shape)),
    1,
    lik=ashr.lik_pois(y=xj, scale=size_factor, link='identity'),
    mode=lam0,
    outputlevel='loglik').rx2('loglik')) for lam0 in grid]).ravel()
res0 = ashr.ash(
  pd.Series(np.zeros(xj.shape)),
  1,
  lik=ashr.lik_pois(y=xj, scale=size_factor, link='identity'),
  mode='estimate')
res1 = ashr.ash(
  pd.Series(np.zeros(x.shape[0])),
  1,
  lik=ashr.lik_pois(y=xj, scale=size_factor, link='identity'),
  mode=pd.Series([lam.min(), lam.max()]))
plt.clf()
plt.gcf().set_size_inches(3, 3)
plt.xscale('log')
plt.plot(grid, np.array(llik).ravel(), lw=1, c='k')
plt.axvline(x=np.array(res0.rx2('fitted_g').rx2('a'))[0], c='k', lw=1, ls=':', label='Default')
plt.axvline(x=np.array(res1.rx2('fitted_g').rx2('a'))[0], c='r', lw=1, ls=':', label='Restricted')
plt.legend(frameon=False, loc='center left', bbox_to_anchor=(1, .5))
plt.xlabel('Mode $\lambda_0$')
_ = plt.ylabel('Marginal likelihood')

mode-est.png

It appears the problem is actually non-convex. Surprisingly, it appears non-convex even for a case where the data are not bimodal. For bimodal data, we might expect that choice of the mode would change the weight on/near zero and result in a non-convex objective.

According to the documentation, the search can fail for poor choice of initial query, which depends entirely on the initial interval. In this example, the initial interval does not contain the mode, and therefore the search finds the correct local optimum within the interval, but fails to find the global optimum.

This result does not necessarily mean that our proposed alternative, to search over \([\min(x_i / R_i), \max(x_i / R_i)]\) will work, because Poisson noise could mean the true \(\lambda_i > x_i / R_i\) for some sample \(i\).

Should we search further to be reasonably certain we haven't missed the mode? Intuitively, the largest \(\hat\lambda_i\) value we do observe should be "overestimated"; if it were not, then we should expect higher density of \(g\) around it, and values larger than it in the observed data.

Speed

We have to solve an ash subproblem for each query \(\lambda_0\), which becomes extremely expensive for large data sets. We can speed up the procedure by downsampling the data for mode estimation. How much worse is the fitted model?

def score_mode_estimation(data, seed=0, p=0.1):
  temp = data.sample(random_state=seed, frac=p)
  res0 = ashr.ash(
    pd.Series(np.zeros(temp.shape[0])),
    1,
    lik=ashr.lik_pois(y=temp['x'], scale=temp['scale'], link='identity'),
    mode=pd.Series([temp['lam'].min(), temp['lam'].max()]))
  lam0 = np.array(res0.rx2('fitted_g').rx2('a'))[0]
  res = ashr.ash(
    pd.Series(np.zeros(data.shape[0])),
    1,
    lik=ashr.lik_pois(y=data['x'], scale=data['scale'], link='identity'),
    mode=lam0)
  return lam0, np.array(res.rx2('loglik'))[0]

def evaluate_mode_estimation(data, num_trials):
  result = []
  for p in (0.1, 0.25, 0.5):
    for trial in range(num_trials):
      lam0, llik = score_mode_estimation(data, seed=trial, p=p)
      result.append([p, trial, lam0, llik])
  result = pd.DataFrame(result, columns=['p', 'trial', 'lam0', 'llik'])
  return result
mode_estimation_result = evaluate_mode_estimation(pd.DataFrame({'x': xj, 'scale': size_factor, 'lam': lam}), num_trials=10)
plt.clf()
plt.gcf().set_size_inches(3, 3)
plt.scatter(mode_estimation_result['p'], mode_estimation_result['llik'], s=4, c='k')
plt.axhline(y=np.array(res1.rx2('loglik'))[0], c='k', lw=1, ls=':')
plt.xlabel('Fraction of original data')
plt.ylabel('Training log likelihood')
Text(0, 0.5, 'Training log likelihood')

downsampling-mode-estimation.png

Downsampling is likely to result in a much worse model fit, so we should not pursue that strategy to speed up the model estimation.

Using a simpler model

We previously found an example of near-Poisson data where ash mode estimation fails.

data, _ = scqtl.simulation.simulate(num_samples=1000, logodds=-5, seed=2)
x = pd.Series(data[:,0])
s = pd.Series(data[:,1])
lam = x / s

Plot the data.

pois-point-unimodal-example.png

Fit \(g\) assuming a point mass \(\mu\).

fit0 = scqtl.simple.fit_pois(x, s)

Fit \(g\) assuming Gamma.

\[ \lambda_i \sim \operatorname{Gamma}(1/\phi, 1/(\mu\phi)) \]

# Important: this returns mu, 1/phi
fit1 = scqtl.simple.fit_nb(x, s)

Report the estimated modes and log likelihoods.

# Important: Gamma(a, b) mode is max((a - 1) / b, 0)
pd.DataFrame({'mode': [fit0[0], fit1[0] * (fit1[1] - 1) / fit1[1]], 'llik': [fit0[-1], fit1[-1]]}, index=['pointmass', 'gamma'])
mode         llik
pointmass  0.000034 -2008.663140
gamma      0.000033 -2008.503731

The unimodal distribution is parameterized:

\[ g = \sum_{k=1}^{K} w_k \operatorname{Uniform}(\lambda_0, \lambda_0 \pm a_k) \]

where we abuse notation for brevity, and the endpoints \(a_k\) follow a geometric series. Fix the mode to the Gamma mode, and fit unimodal \(g\) for different choices of ratio between successive endpoints.

ratio = np.linspace(1.1, 1.5, 50)
res = []
for r in ratio:
  low = 1 / s.mean()
  high = (x / s).max()
  mixsd = pd.Series(np.exp(np.arange(np.log(low), np.log(high), step=np.log(r))))
  fit = ashr.ash_workhorse(
    pd.Series(np.zeros(x.shape)), 1,
    lik=ashr.lik_pois(y=x, scale=s, link='identity'),
    mixsd=mixsd,
    mode=fit1[0] * (fit1[1] - 1) / fit1[1],
    output=pd.Series(['loglik', 'fitted_g']))
  res.append([r, np.array(fit.rx2('loglik'))[0], np.array(fit.rx2('fitted_g').rx2('pi'))[0]])
res = pd.DataFrame(res)
res.columns = ['ratio', 'llik', 'pi0']

Plot the log likelihood versus chosen ratio. Mark the location of ratio \(\sqrt{2}\).

plt.clf()
plt.gcf().set_size_inches(3, 3)
plt.plot(res['ratio'], res['llik'], lw=1, c='k')
plt.axvline(x=np.sqrt(2), ls=':', c='r', lw=1)
plt.xlabel('Ratio')
plt.ylabel('Log liklihood')
plt.tight_layout()

ratio-ex1.png

Report the ratio which achieves the best log likelihood.

res.loc[res['pi0'].idxmax()]
ratio       1.434694
llik    -2008.561414
pi0         0.966383
Name: 41, dtype: float64

Try applying this approach to the highest expressed genes in 10K sorted CD8+ cytotoxic T cells Zheng et al. 2017.

x = scmodes.dataset.read_10x('/project2/mstephens/aksarkar/projects/singlecell-ideas/data/10xgenomics/cytotoxic_t/filtered_matrices_mex/hg19', return_df=True)
xj = x[x.mean(axis=0).idxmax()]
s = x.sum(axis=1)
lam = xj / s

Fit Gamma, and test for goodness of fit.

init = scqtl.simple.fit_nb(xj, s)
scmodes.benchmark.gof._gof(
  xj.values.ravel(),
  cdf=scmodes.benchmark.gof._zig_cdf,
  pmf=scmodes.benchmark.gof._zig_pmf,
  size=s.values.ravel(),
  log_mu=np.log(init[0]),
  log_phi=-np.log(init[1]))
KstestResult(statistic=0.0186349904442156, pvalue=0.0016661060591321017)

Fit unimodal, and test for goodness of fit.

fit = ashr.ash_workhorse(
  pd.Series(np.zeros(xj.shape)), 1,
  lik=ashr.lik_pois(y=xj, scale=s, link='identity'),
  mixsd=pd.Series(np.exp(np.arange(np.log(1 / s.mean()), np.log((xj / s).max()), step=.5 * np.log(2)))),
  mode=init[0] * (init[1] - 1) / init[1],
  output=pd.Series(['loglik', 'fitted_g', 'data']))
scmodes.benchmark.gof._gof(xj, cdf=scmodes.benchmark.gof._ash_cdf, pmf=scmodes.benchmark.gof._ash_pmf, fit=fit, s=s)
KstestResult(statistic=0.059205028160045026, pvalue=1.6543190838957947e-31)

Previously, we found full mode search worked on this example.

fit_estmode = ashr.ash_workhorse(
  pd.Series(np.zeros(xj.shape)), 1,
  lik=ashr.lik_pois(y=xj, scale=s, link='identity'),
  mixsd=pd.Series(np.exp(np.arange(np.log(1 / s.mean()), np.log((xj / s).max()), step=.5 * np.log(2)))),
  mode=pd.Series([lam.min(), lam.max()]),
  output=pd.Series(['loglik', 'fitted_g', 'data']))
scmodes.benchmark.gof._gof(xj, cdf=scmodes.benchmark.gof._ash_cdf, pmf=scmodes.benchmark.gof._ash_pmf, fit=fit_estmode, s=s)
KstestResult(statistic=0.02630989797207417, pvalue=1.4551355049273953e-06)

Report the estimated modes and marginal log likelihood of the fits.

pd.DataFrame(
  {
    'mode': [init[0] * (init[1] - 1) / init[1], init[0] * (init[1] - 1) / init[1], np.array(fit_estmode.rx2('fitted_g').rx2('a'))[0]],
    'llik': [init[-1], np.array(fit.rx2('loglik'))[0], np.array(fit_estmode.rx2('loglik'))[0]]
  },
  index=['gamma', 'unimodal_gamma', 'unimodal'])
mode          llik
gamma           0.035399 -41349.739021
unimodal_gamma  0.035399 -41409.282482
unimodal        0.037238 -41274.359548

Plot the data, and the fitted distributions.

grid = pd.Series(np.linspace(lam.min(), lam.max(), 1000))
gamma_cdf = st.gamma(a=init[1], scale=init[0] / init[1]).cdf(grid)
unimodal_cdf = np.array(ashr.cdf_ash(fit, grid).rx2('y')).ravel()
estmode_cdf = np.array(ashr.cdf_ash(fit_estmode, grid).rx2('y')).ravel()
cm = plt.get_cmap('Dark2')
plt.clf()
fig, ax = plt.subplots(2, 1)
fig.set_size_inches(6, 4)
ax[0].hist(xj.values, bins=np.arange(xj.values.max() + 1), color='k')
ax[0].set_xlabel('Number of molecules')
ax[0].set_ylabel('Number of cells')

for i, (k, F) in enumerate(zip(['Gamma', 'Unimodal (Gamma mode)', 'Unimodal'], [gamma_cdf, unimodal_cdf, estmode_cdf])):
  ax[1].plot(grid, F, c=cm(i), label=k, lw=1)
ax[1].set_xlim(0, 0.1)
ax[1].legend(frameon=False)
ax[1].set_xlabel('Latent gene expression')
ax[1].set_ylabel('CDF')

fig.tight_layout()

malat1-gamma-mode.png

Try denser grids, and report the log likelihoods.

pd.Series({step: np.array(ashr.ash_workhorse(
  pd.Series(np.zeros(xj.shape)), 1,
  lik=ashr.lik_pois(y=xj, scale=s, link='identity'),
  mixsd=pd.Series(np.exp(np.arange(np.log(1 / s.mean()), np.log((xj / s).max()), step=np.log(step)))),
  mode=pd.Series([lam.min(), lam.max()]),
  output=pd.Series(['loglik'])).rx2('loglik'))[0]
              for step in (1.1, 1.15, 1.2, 1.25, 1.3)})
1.10   -41273.537196
1.15   -41273.525195
1.20   -41273.574435
1.25   -41273.954289
1.30   -41273.725401
dtype: float64

Author: Abhishek Sarkar

Created: 2019-10-01 Tue 00:25

Validate