Sparse Poisson Factor Analysis

Table of Contents

Introduction

flash (Wang and Stephens 2018) is a low rank model which estimates a flexible prior on loadings and factors \( \DeclareMathOperator\E{E} \DeclareMathOperator\KL{\mathcal{KL}} \newcommand\kl[2]{\KL(#1\;\Vert\;#2)} \DeclareMathOperator\N{\mathcal{N}} \DeclareMathOperator\Poi{Poisson} \DeclareMathOperator\V{V} \newcommand\Gfam{\mathcal{G}} \newcommand\mf{\mathbf{F}} \newcommand\ml{\mathbf{L}} \newcommand\mx{\mathbf{X}} \newcommand\vf{\mathbf{f}} \newcommand\vl{\mathbf{l}} \)

\begin{align} x_{ij} &\sim \N((\ml\mf')_{ij}, \sigma_{ij}^2)\\ l_{1k}, \ldots, l_{nk} &\sim g_{\vl_k} \in \Gfam\\ f_{1k}, \ldots, f_{pk} &\sim g_{\vf_k} \in \Gfam, \end{align}

where \(\Gfam\) denotes a family of distributions. The posterior \(p(\ml, \mf \mid \mx)\) can be approximately estimated using variational inference. In particular, coordinate ascent updates which maximize the ELBO correspond to solutions to the Empirical Bayes Normal Means problem. They study two main cases: (1) \(\Gfam\) is the family of unimodal, zero-mode scale mixtures of Normals over a fixed grid of scales, and (2) \(\Gfam\) is the family of point-Normal mixtures. They found that in some practical examples, the differences in inference between these two families was minimal.

We are now interested in developing similarly flexible methods to learn low-rank, (possibly) sparse representations of scRNA-seq data, which follow a different generative process (Sarkar and Stephens 2020)

\begin{align} x_{ij} \mid s_i, \lambda_{ij} &\sim \Poi(s_i \lambda_{ij})\\ \lambda_{ij} &= h^{-1}((\ml\mf')_{ij})\\ l_{1k}, \ldots, l_{nk} &\sim g_{\vl_k} \in \Gfam\\ f_{1k}, \ldots, f_{pk} &\sim g_{\vf_k} \in \Gfam, \end{align}

where \(h\) is a link function. Although it is natural to choose \(h = \ln\), for practical purposes it can be preferable to choose the softplus link \(h(x) = \ln(1 + \exp(x))\) (Argelauget 2018). Variational inference in this model requires either Taylor approximation (Seeger 2012) or a Monte Carlo approach (Kingma and Welling 2014, Rezende et al. 2014). Here, we investigate the latter.

Setup

import numpy as np
import torch
import torch.utils.tensorboard as tb
%matplotlib inline
%config InlineBackend.figure_formats = set(['retina'])
import matplotlib.pyplot as plt
plt.rcParams['figure.facecolor'] = 'w'
plt.rcParams['font.family'] = 'Nimbus Sans'

Method

The evidence lower bound (ELBO) can be written

\begin{equation} \ell = \sum_{i,j} \E_q[\ln p(x_{ij} \mid s_i, \lambda_{ij})] - \sum_k \left[\kl{q(l_{ik})}{p(l_{ik})} + \kl{q(f_{jk})}{p(f_{jk})}\right]. \end{equation}

The KL divergence terms in the ELBO are analytic (e.g., Carbonetto and Stephens 2012). The remaining terms involve intractable expectations, so we take a Monte Carlo approach to estimate them, resulting in a stochastic objective function. The critical calculation in this approach is

\begin{equation} \V[\ml\mf'] = \V[\ml]\V[\mf'] + \E[\ml]^2 \V[\mf]' + \V[\ml] \E[\mf]^2, \end{equation}

where expectations (variances) are taken with respect to \(q\). We optimize the stochastic objective function by gradient descent.

class SpikeSlab(torch.nn.Module):
  """Independent Point-Normal priors on columns of matrix M"""
  def __init__(self, p, k):
    super().__init__()
    self.logits = torch.nn.Parameter(torch.randn([p, k]))
    self.mean = torch.nn.Parameter(torch.randn([p, k]))
    self.log_prec = torch.nn.Parameter(torch.zeros([p, k]))
    self.prior_logodds = torch.nn.Parameter(torch.randn([1, k]))
    self.log_prior_prec = torch.nn.Parameter(torch.full([1, k], 1.))

  def forward(self):
    """Return E[M], V[M], and KL(q(M) || p_prior(M))"""
    finfo = torch.finfo(self.logits.dtype)
    pip = torch.clamp(torch.sigmoid(self.logits), finfo.tiny, 1 - finfo.eps)
    prec = torch.exp(self.log_prec)
    kl = (pip * torch.log(pip / torch.sigmoid(self.prior_logodds)) +
          (1 - pip) * torch.log((1 - pip) / torch.sigmoid(-self.prior_logodds)) +
          .5 * pip * (1 + self.log_prec - self.log_prior_prec + torch.exp(self.log_prior_prec) * (self.mean ** 2 + 1 / prec))).sum()
    assert not torch.isnan(kl).any()
    pm = pip * self.mean
    pv = pip / prec + pip * (1 - pip) * self.mean ** 2
    return pm, pv, kl

class SFA(torch.nn.Module):
  def __init__(self, n, p, k):
    super().__init__()
    self.l = SpikeSlab(n, k)
    self.f = SpikeSlab(p, k)

  def forward(self, x, writer=None, global_step=None):
    """Return ELBO(x, q)"""
    pm_l, pv_l, kl_l = self.l.forward()
    pm_f, pv_f, kl_f = self.f.forward()
    eta_mean = pm_l @ pm_f.T
    eta_scale = torch.sqrt(pv_l @ pv_f.T + (pm_l ** 2) @ pv_f.T + pv_l @ (pm_f.T ** 2))
    # TODO: more than one stochastic sample
    eta = torch.distributions.Normal(eta_mean, eta_scale).rsample()
    err = torch.distributions.Normal(eta, 1).log_prob(x).sum()
    elbo = err - kl_l - kl_f
    if writer is not None:
      writer.add_scalar('loss/kl_l', kl_l, global_step)
      writer.add_scalar('loss/kl_f', kl_f, global_step)
      writer.add_scalar('loss/err', err, global_step)
      writer.add_scalar('loss/elbo', elbo, global_step)
    return -elbo

  def fit(self, x, init=None, n_epochs=1000, log_dir=None, **kwargs):
    assert torch.is_tensor(x)
    if init is not None:
      self.l.mean.data = init[0]
      self.f.mean.data = init[1]
    opt = torch.optim.RMSprop(self.parameters(), **kwargs)
    global_step = 0
    if log_dir is not None:
      writer = tb.SummaryWriter(log_dir)
    else:
      writer = None
    for i in range(n_epochs):
      opt.zero_grad()
      loss = self.forward(x, writer=writer, global_step=global_step)
      if torch.isnan(loss):
        raise RuntimeError('nan loss')
      loss.backward()
      opt.step()
      global_step += 1
    return self

class SPFA(torch.nn.Module):
  def __init__(self, n, p, k):
    super().__init__()
    self.l = SpikeSlab(n, k)
    self.f = SpikeSlab(p, k)

  def forward(self, x, s, writer=None, global_step=None):
    """Return ELBO(x, q)"""
    pm_l, pv_l, kl_l = self.l.forward()
    pm_f, pv_f, kl_f = self.f.forward()
    eta_mean = pm_l @ pm_f.T
    eta_scale = torch.sqrt(pv_l @ pv_f.T + (pm_l ** 2) @ pv_f.T + pv_l @ (pm_f.T ** 2))
    # TODO: more than one stochastic sample
    eta = torch.distributions.Normal(eta_mean, eta_scale).rsample()
    err = torch.distributions.Poisson(s * torch.nn.functional.softplus(eta)).log_prob(x).sum()
    elbo = err - kl_l - kl_f
    if writer is not None:
      writer.add_scalar('loss/kl_l', kl_l, global_step)
      writer.add_scalar('loss/kl_f', kl_f, global_step)
      writer.add_scalar('loss/err', err, global_step)
      writer.add_scalar('loss/elbo', elbo, global_step)
    return -elbo

  def fit(self, x, s, init=None, n_epochs=1000, log_dir=None, **kwargs):
    assert torch.is_tensor(x)
    assert torch.is_tensor(s)
    assert s.shape == torch.Size([x.shape[0], 1])
    if init is not None:
      self.l.mean.data = init[0]
      self.f.mean.data = init[1]
    opt = torch.optim.RMSprop(self.parameters(), **kwargs)
    global_step = 0
    if log_dir is not None:
      writer = tb.SummaryWriter(log_dir)
    else:
      writer = None
    for i in range(n_epochs):
      opt.zero_grad()
      loss = self.forward(x, s, writer=writer, global_step=global_step)
      if torch.isnan(loss):
        raise RuntimeError('nan loss')
      loss.backward()
      opt.step()
      global_step += 1
    return self

Results

Sanity check

Make sure the algorithm works by simulating from the Gaussian model, and then fitting that model.

rng = np.random.default_rng(1)
n = 100
p = 500
k = 3
pi0 = [0.1, 0.9, 0.99]
sa = np.array([[.5, 1, 2]])

l = np.zeros((n, k))
zl = rng.uniform(size=(n, k)) > pi0
l[zl] = rng.normal(size=(n, k), scale=sa)[zl]
f = np.zeros((p, k))
zf = rng.uniform(size=(p, k)) > pi0
f[zf] = rng.normal(size=(p, k), scale=sa)[zf]

x = rng.normal(loc=l @ f.T)
fit = (SFA(n=n, p=p, k=k)
       .fit(torch.tensor(x, dtype=torch.float),
            init=(torch.tensor(l), torch.tensor(f)),
            n_epochs=2000,
            log_dir='runs/spfa/norm-1/'))

Simulated example

Simulate from the model. This simulation itself reveals some tricky aspects of the problem: (1) it is typically estimated that \(\lambda_{ij}\) is close to zero for almost all entries, but this means that e.g., \(\ln\lambda_{ij}\) is typically large and negative, rather than zero; (2) the maximum value of \(\lambda_{ij}\) is typically not observed to be very large, putting strong constraints on the sparsity and scale of the prior.

rng = np.random.default_rng(1)
n = 100
p = 500
k = 3
pi0 = [0.1, 0.9, 0.99]
sa = np.array([[.5, 1, 2]])

l = np.zeros((n, k))
zl = rng.uniform(size=(n, k)) > pi0
l[zl] = rng.normal(size=(n, k), scale=sa)[zl]
f = np.zeros((p, k))
zf = rng.uniform(size=(p, k)) > pi0
f[zf] = rng.normal(size=(p, k), scale=sa)[zf]

lam = np.log1p(np.exp(l @ f.T))
x = rng.poisson(lam)
fit = (SPFA(n=n, p=p, k=k)
       .fit(torch.tensor(x, dtype=torch.float),
            s=torch.ones([x.shape[0], 1]),
            init=(torch.tensor(l), torch.tensor(f)),
            n_epochs=4000,
            log_dir='runs/spfa/sim-ex-7/'))

Author: Abhishek Sarkar

Created: 2020-06-30 Tue 14:29

Validate