Mixture of negative binomials

Table of Contents

Introduction

Here, we use the EM algorithm to fit a mixture of negative binomials to count data.

The motivating example is explaining bimodal variation in spike-in molecule counts using the fact that the distribution of size factors is bimodal:

chromium2-size.png

Setup

import numpy as np
import pandas as pd
import scipy.optimize as so
import scipy.stats as st
import scipy.special as sp
import scqtl.simple
%matplotlib inline
%config InlineBackend.figure_formats = set(['retina'])
import matplotlib.pyplot as plt
plt.rcParams['figure.facecolor'] = 'w'

Methods

EM algorithm

The complete data model is:

\[ x_i \mid \lambda_i \sim \mathrm{Poisson}(\lambda_i) \]

\[ \lambda_i \mid z_i \sim \mu_{z_i} \mathrm{Gamma}(\phi_{z_i}^{-1}, \phi_{z_i}^{-1}) \]

def nb_llik(x, mean, inv_disp):
  return (x * np.log(mean / inv_disp + 1e-8) -
          x * np.log1p(mean / inv_disp) -
          inv_disp * np.log1p(mean / inv_disp) +
          sp.gammaln(x + inv_disp) -
          sp.gammaln(inv_disp) -
          sp.gammaln(x + 1))

def objective(theta, x, z):
  # Important: theta gets flattened by minimize
  n_components = z.shape[1]
  mean = np.exp(theta[:n_components])
  inv_disp = np.exp(theta[n_components:])
  return -(np.log(z + 1e-8) + nb_llik(x, mean, inv_disp)).sum()

def check_result(opt):
  return (opt.success and
          np.isfinite(opt.fun) and
          np.isfinite(opt.jac).all())

def init(n_components, x, z):
  theta = np.random.normal(size=(2, n_components))
  for j in range(n_components):
    theta[0,j] = np.log(x[z[:,j] > 0.5].mean())
  return theta

def mstep(theta, x, z):
  n_components = theta.shape[1]
  theta1 = np.zeros(theta.shape)
  x0 = init(n_components, x, z)
  for j in range(n_components):
    mstep = so.minimize(objective, x0=x0[:,j], args=(x, z[:,j]))
    if not check_result(mstep):
      raise RuntimeError('M step failed to converge')
    theta1[:,j] = mstep.x
  return theta1

def mix_nb(x, n_components=2, max_iters=1000, max_retries=10):
  """Fit mixture of negative binomials via EM"""
  x = x.reshape(-1, 1)
  n = x.shape[0]
  z = np.random.uniform(size=(n, n_components))
  z /= z.sum(axis=1, keepdims=True)
  theta = init(n_components, x, z)

  obj = np.inf
  for i in range(max_iters):
    theta1 = mstep(theta, x, z)
    update = np.array([objective(t, x, z) for t in theta1.T]).sum()
    assert np.isfinite(update)
    if update > obj:
      raise RuntimeError('objective decreased')
    elif np.isclose(update, obj):
      print(f'{i}: {obj}')
      return z, theta, obj
    else:
      obj = update
      print(f'{i}: {obj}')
      theta = theta1
      # E step
      z = sp.softmax(nb_llik(x, np.exp(theta[0]), np.exp(theta[1])), axis=1)
  raise RuntimeError('failed to converge')

Results

Simulation

Degenerate model

First, make sure the easy case (one component) works. Generate some Poisson data.

np.random.seed(1)
x = np.random.poisson(lam=100, size=(1000, 1))
z = np.ones((x.shape[0], 1))
theta0 = init(1, x, z)
objective(theta0, x, z)
5430.563796635256

opt = so.minimize(objective, x0=theta0, args=(x, z), method='Nelder-Mead')
assert opt.success
{'fitted': np.exp(opt.x[0]), 'true': 100}
{'fitted': 100.66675997637512, 'true': 100}

Generate some negative binomial data.

np.random.seed(1)
x = st.nbinom(n=3, p=1e-3).rvs(size=(1000, 1))
z = np.ones((x.shape[0], 1))
opt = so.minimize(objective, x0=theta0, args=(x, z), method='Nelder-Mead')
assert opt.success
{'fitted_mean': np.exp(opt.x[0]),
 'true_mean': st.nbinom(n=3, p=1e-3).mean(),
 'fitted_var': np.exp(2 * opt.x[0] - opt.x[1]),
 'true_var': st.nbinom(n=3, p=1e-3).var()}
{'fitted_mean': 2978.7653010814693,
'true_mean': 2997.0,
'fitted_var': 2889291.139179591,
'true_var': 2997000.0}

Now make sure it works when we give the model more than the true number of components.

z = np.zeros((x.shape[0], 2))
z[:,0] = (np.random.uniform(size=z.shape[0]) < 0.5).astype(int)
z[:,1] = 1 - z[:,0]
theta0 = init(2, x, z)
objective(theta0, x, z)
36078.88820426077

opt = so.minimize(objective, x0=theta0, args=(x, z), method='Nelder-Mead')
assert opt.success
opt.x
array([7.99922953, 1.12198022, 1.70447064, 0.38548806])

x[z[:,0].astype(bool)].mean(), np.exp(opt.x[0])
(2994.14859437751, 2978.741830944349)

x[~(z[:,0].astype(bool))].mean(), np.exp(opt.x[2])
(2963.430278884462, 8.322971383121208)

Author: Abhishek Sarkar

Created: 2019-05-13 Mon 10:04

Validate