Sparse Poisson Factor Analysis
Table of Contents
Introduction
flash
(Wang and Stephens 2018) is a low rank model which estimates a
flexible prior on loadings and factors \(
\DeclareMathOperator\E{E}
\DeclareMathOperator\KL{\mathcal{KL}}
\newcommand\kl[2]{\KL(#1\;\Vert\;#2)}
\DeclareMathOperator\N{\mathcal{N}}
\DeclareMathOperator\Poi{Poisson}
\DeclareMathOperator\V{V}
\newcommand\Gfam{\mathcal{G}}
\newcommand\mf{\mathbf{F}}
\newcommand\ml{\mathbf{L}}
\newcommand\mx{\mathbf{X}}
\newcommand\vf{\mathbf{f}}
\newcommand\vl{\mathbf{l}}
\)
where \(\Gfam\) denotes a family of distributions. The posterior \(p(\ml, \mf \mid \mx)\) can be approximately estimated using variational inference. In particular, coordinate ascent updates which maximize the ELBO correspond to solutions to the Empirical Bayes Normal Means problem. They study two main cases: (1) \(\Gfam\) is the family of unimodal, zero-mode scale mixtures of Normals over a fixed grid of scales, and (2) \(\Gfam\) is the family of point-Normal mixtures. They found that in some practical examples, the differences in inference between these two families was minimal.
We are now interested in developing similarly flexible methods to learn low-rank, (possibly) sparse representations of scRNA-seq data, which follow a different generative process (Sarkar and Stephens 2020)
\begin{align} x_{ij} \mid s_i, \lambda_{ij} &\sim \Poi(s_i \lambda_{ij})\\ \lambda_{ij} &= h^{-1}((\ml\mf')_{ij})\\ l_{1k}, \ldots, l_{nk} &\sim g_{\vl_k} \in \Gfam\\ f_{1k}, \ldots, f_{pk} &\sim g_{\vf_k} \in \Gfam, \end{align}where \(h\) is a link function. Although it is natural to choose \(h = \ln\), for practical purposes it can be preferable to choose the softplus link \(h(x) = \ln(1 + \exp(x))\) (Argelauget 2018). Variational inference in this model requires either Taylor approximation (Seeger 2012) or a Monte Carlo approach (Kingma and Welling 2014, Rezende et al. 2014). Here, we investigate the latter.
Setup
import numpy as np import torch import torch.utils.tensorboard as tb
%matplotlib inline %config InlineBackend.figure_formats = set(['retina'])
import matplotlib.pyplot as plt plt.rcParams['figure.facecolor'] = 'w' plt.rcParams['font.family'] = 'Nimbus Sans'
Method
The evidence lower bound (ELBO) can be written
\begin{equation} \ell = \sum_{i,j} \E_q[\ln p(x_{ij} \mid s_i, \lambda_{ij})] - \sum_k \left[\kl{q(l_{ik})}{p(l_{ik})} + \kl{q(f_{jk})}{p(f_{jk})}\right]. \end{equation}The KL divergence terms in the ELBO are analytic (e.g., Carbonetto and Stephens 2012). The remaining terms involve intractable expectations, so we take a Monte Carlo approach to estimate them, resulting in a stochastic objective function. The critical calculation in this approach is
\begin{equation} \V[\ml\mf'] = \V[\ml]\V[\mf'] + \E[\ml]^2 \V[\mf]' + \V[\ml] \E[\mf]^2, \end{equation}where expectations (variances) are taken with respect to \(q\). We optimize the stochastic objective function by gradient descent.
class SpikeSlab(torch.nn.Module): """Independent Point-Normal priors on columns of matrix M""" def __init__(self, p, k): super().__init__() self.logits = torch.nn.Parameter(torch.randn([p, k])) self.mean = torch.nn.Parameter(torch.randn([p, k])) self.log_prec = torch.nn.Parameter(torch.zeros([p, k])) self.prior_logodds = torch.nn.Parameter(torch.randn([1, k])) self.log_prior_prec = torch.nn.Parameter(torch.full([1, k], 1.)) def forward(self): """Return E[M], V[M], and KL(q(M) || p_prior(M))""" finfo = torch.finfo(self.logits.dtype) pip = torch.clamp(torch.sigmoid(self.logits), finfo.tiny, 1 - finfo.eps) prec = torch.exp(self.log_prec) kl = (pip * torch.log(pip / torch.sigmoid(self.prior_logodds)) + (1 - pip) * torch.log((1 - pip) / torch.sigmoid(-self.prior_logodds)) + .5 * pip * (1 + self.log_prec - self.log_prior_prec + torch.exp(self.log_prior_prec) * (self.mean ** 2 + 1 / prec))).sum() assert not torch.isnan(kl).any() pm = pip * self.mean pv = pip / prec + pip * (1 - pip) * self.mean ** 2 return pm, pv, kl class SFA(torch.nn.Module): def __init__(self, n, p, k): super().__init__() self.l = SpikeSlab(n, k) self.f = SpikeSlab(p, k) def forward(self, x, writer=None, global_step=None): """Return ELBO(x, q)""" pm_l, pv_l, kl_l = self.l.forward() pm_f, pv_f, kl_f = self.f.forward() eta_mean = pm_l @ pm_f.T eta_scale = torch.sqrt(pv_l @ pv_f.T + (pm_l ** 2) @ pv_f.T + pv_l @ (pm_f.T ** 2)) # TODO: more than one stochastic sample eta = torch.distributions.Normal(eta_mean, eta_scale).rsample() err = torch.distributions.Normal(eta, 1).log_prob(x).sum() elbo = err - kl_l - kl_f if writer is not None: writer.add_scalar('loss/kl_l', kl_l, global_step) writer.add_scalar('loss/kl_f', kl_f, global_step) writer.add_scalar('loss/err', err, global_step) writer.add_scalar('loss/elbo', elbo, global_step) return -elbo def fit(self, x, init=None, n_epochs=1000, log_dir=None, **kwargs): assert torch.is_tensor(x) if init is not None: self.l.mean.data = init[0] self.f.mean.data = init[1] opt = torch.optim.RMSprop(self.parameters(), **kwargs) global_step = 0 if log_dir is not None: writer = tb.SummaryWriter(log_dir) else: writer = None for i in range(n_epochs): opt.zero_grad() loss = self.forward(x, writer=writer, global_step=global_step) if torch.isnan(loss): raise RuntimeError('nan loss') loss.backward() opt.step() global_step += 1 return self class SPFA(torch.nn.Module): def __init__(self, n, p, k): super().__init__() self.l = SpikeSlab(n, k) self.f = SpikeSlab(p, k) def forward(self, x, s, writer=None, global_step=None): """Return ELBO(x, q)""" pm_l, pv_l, kl_l = self.l.forward() pm_f, pv_f, kl_f = self.f.forward() eta_mean = pm_l @ pm_f.T eta_scale = torch.sqrt(pv_l @ pv_f.T + (pm_l ** 2) @ pv_f.T + pv_l @ (pm_f.T ** 2)) # TODO: more than one stochastic sample eta = torch.distributions.Normal(eta_mean, eta_scale).rsample() err = torch.distributions.Poisson(s * torch.nn.functional.softplus(eta)).log_prob(x).sum() elbo = err - kl_l - kl_f if writer is not None: writer.add_scalar('loss/kl_l', kl_l, global_step) writer.add_scalar('loss/kl_f', kl_f, global_step) writer.add_scalar('loss/err', err, global_step) writer.add_scalar('loss/elbo', elbo, global_step) return -elbo def fit(self, x, s, init=None, n_epochs=1000, log_dir=None, **kwargs): assert torch.is_tensor(x) assert torch.is_tensor(s) assert s.shape == torch.Size([x.shape[0], 1]) if init is not None: self.l.mean.data = init[0] self.f.mean.data = init[1] opt = torch.optim.RMSprop(self.parameters(), **kwargs) global_step = 0 if log_dir is not None: writer = tb.SummaryWriter(log_dir) else: writer = None for i in range(n_epochs): opt.zero_grad() loss = self.forward(x, s, writer=writer, global_step=global_step) if torch.isnan(loss): raise RuntimeError('nan loss') loss.backward() opt.step() global_step += 1 return self
Results
Sanity check
Make sure the algorithm works by simulating from the Gaussian model, and then fitting that model.
rng = np.random.default_rng(1) n = 100 p = 500 k = 3 pi0 = [0.1, 0.9, 0.99] sa = np.array([[.5, 1, 2]]) l = np.zeros((n, k)) zl = rng.uniform(size=(n, k)) > pi0 l[zl] = rng.normal(size=(n, k), scale=sa)[zl] f = np.zeros((p, k)) zf = rng.uniform(size=(p, k)) > pi0 f[zf] = rng.normal(size=(p, k), scale=sa)[zf] x = rng.normal(loc=l @ f.T)
fit = (SFA(n=n, p=p, k=k) .fit(torch.tensor(x, dtype=torch.float), init=(torch.tensor(l), torch.tensor(f)), n_epochs=2000, log_dir='runs/spfa/norm-1/'))
Simulated example
Simulate from the model. This simulation itself reveals some tricky aspects of the problem: (1) it is typically estimated that \(\lambda_{ij}\) is close to zero for almost all entries, but this means that e.g., \(\ln\lambda_{ij}\) is typically large and negative, rather than zero; (2) the maximum value of \(\lambda_{ij}\) is typically not observed to be very large, putting strong constraints on the sparsity and scale of the prior.
rng = np.random.default_rng(1) n = 100 p = 500 k = 3 pi0 = [0.1, 0.9, 0.99] sa = np.array([[.5, 1, 2]]) l = np.zeros((n, k)) zl = rng.uniform(size=(n, k)) > pi0 l[zl] = rng.normal(size=(n, k), scale=sa)[zl] f = np.zeros((p, k)) zf = rng.uniform(size=(p, k)) > pi0 f[zf] = rng.normal(size=(p, k), scale=sa)[zf] lam = np.log1p(np.exp(l @ f.T)) x = rng.poisson(lam)
fit = (SPFA(n=n, p=p, k=k) .fit(torch.tensor(x, dtype=torch.float), s=torch.ones([x.shape[0], 1]), init=(torch.tensor(l), torch.tensor(f)), n_epochs=4000, log_dir='runs/spfa/sim-ex-7/'))