Paired factor analysis for single cell lineage tracing

Table of Contents

Introduction

Lineage tracing is the problem of assigning cells to a branching phylogeny of cell types. A number of approaches have been proposed to solve this problem (for a review, see GrĂ¼n 2018).

Gao Wang and Kushal Dey proposed paired factor analysis (PFA) as a new solution to this problem. The intuition behind PFA is to cluster cells, under the constraint that each cell can belong to at most two clusters. Then, the cluster centroids are the nodes of the phylogenetic tree, and the cluster weights for each cell position that cell on a branch of the tree.

Here, we develop a fast, simple method to find a maximum likelihood solution to PFA.

Setup

import numpy as np
import scipy.stats as st
import scqtl
import tensorflow as tf
%matplotlib inline
%config InlineBackend.figure_formats = set(['retina'])
import colorcet
import matplotlib
import matplotlib.pyplot as plt

Methods

Paired factor analysis

For modeling scRNA-seq count data, we start from Poisson factorization (PF):

\[ x_{ij} \sim \mathrm{Poisson}(\lambda_{ij}) \]

\[ \lambda_{ij} = [\mathbf{LF}]_{ij} \]

The PF MLE can be converted to a grade of membership model by normalizing to satisfy the constraints:

\[ l_{ik} \geq 0 \]

\[ \sum_i l_{ik} = 1 \]

\[ f_{kj} \geq 0 \]

After normalizing, the interpretation of the model is:

  • \(f_{ik}\) is the centroid of cluster \(k\) (vector of mean expression)
  • \(l_{ik}\) is the cluster weight of sample \(i\) on cluster \(k\).

In this model, sample \(i\) could potentially belong to all clusters. For lineage tracing, we instead model sample \(i\) as belonging to at most two clusters. Then, the interpretation is that sample loadings interpolate between reference points, which are factors.

We introduce latent indicator variables \(z_{ik}\), which denote whether sample \(i\) belongs to cluster \(k\).

\[ l_{ik} = \tilde{l}_{ik} z_{ik} \]

Now, we need to efficiently maximize the likelihood of the observed data and latent variables.

\[ \max_{\mathbf{L},\mathbf{F},\mathbf{Z}} \ln p(\mathbf{X} \mid \cdot) \]

\[ \mathrm{s.t.} \sum_k z_{ik} \leq 2\ \forall i \]

\[ z_{ik} \in \{0, 1\} \]

This is a mixed integer non-convex optimization problem.

EM algorithm

Given \(\mathbf{Z}, \mathbf{F}\), we can update \(\mathbf{L}\) using coordinate descent (Cichoki et al. 2009, Hsieh et al. 2011, Lin 2018).

Given \(\mathbf{L}\), we can update \(\mathbf{F}\) using a multiplicative update (Lee and Seung 2001).

Given \(\mathbf{L}, \mathbf{F}\), we can update \(\mathbf{Z}\):

\[ p(z_{ik} \mid \cdot) \propto p(x_{ik} \mid z_{ik}, \cdot) \]

Non-convex relaxation

We will relax the integer constraint to make the problem easier.

\[ \max_{\mathbf{L},\mathbf{F},\mathbf{Z}} \ln p(\mathbf{X} \mid \cdot) \]

\[ \mathrm{s.t.} \sum_k z_{ik} \leq 2\ \forall i \]

\[ 0 \leq z_{ik} \leq 1 \]

We can re-parameterize to drop one constraint:

\[ z_{ik} = \mathrm{sigmoid}(\tilde{z}_{ik}) \]

\[ \max_{\mathbf{L},\mathbf{F},\mathbf{\tilde{Z}}} \ln p(\mathbf{X} \mid \cdot) \]

\[ \mathrm{s.t.} \sum_k z_{ik} \leq 2\ \forall i \]

We can solve this problem using the barrier method.

\[ \max_{\mathbf{L},\mathbf{F},\mathbf{\tilde{Z}}} \ln p(\mathbf{X} \mid \cdot) + \alpha \sum_i \ln\left(2 - \sum_k z_{ik}\right) \]

This approach can be more readily generalized to more complicated likelihoods (NB, ZINB) through the use of automatic differentiation.

import tensorflow as tf

def pois_llik(x, mean):
  return x * tf.log(mean) - mean - tf.lgamma(x + 1)

def pfa(x, k, size, alpha=None, F=None, learning_rate=1e-2, max_epochs=10000, atol=1e-1, verbose=True):
  """Return paired factor analysis estimate

  :param x: array-like data ([n, p], tf.float32)
  :param k: number of factors
  :param size: size factor per observation
  :param alpha: increasing sequence of penalty weights (default: np.logspace(-1, 3, 10))
  :param F: known factors
  :param learning_rate: base learning rate for RMSProp
  :param max_epochs: maximum number of iterations per penalty weight
  :param atol: absolute tolerance for convergence
  :param verbose: print objective function updates

  :returns loadings: Estimated sample loadings ([n, k])
  :returns factors: Estimated factors ([k, p])

  """
  n, p = x.shape
  if alpha is None:
    alpha = np.logspace(-1, 3, 10)
  else:
    alpha = np.sort(np.atleast_1d(alpha).ravel())
  graph = tf.Graph()
  with graph.as_default():
    x = tf.Variable(x.astype(np.float32), trainable=False)
    size = tf.Variable(size.reshape(-1, 1).astype(np.float32), trainable=False)
    if F is not None:
      assert F.shape == (k, p)
      factors = tf.Variable(F.astype(np.float32), trainable=False)
    else:
      factors = tf.exp(tf.Variable(tf.random_normal([k, p])))
    Z = tf.sigmoid(tf.Variable(tf.random_normal([n, k])))
    Z /= tf.reduce_sum(Z, axis=1, keepdims=True)
    loadings = tf.exp(tf.Variable(tf.random_normal([n, k])))
    weight = tf.placeholder(tf.float32, [])

    mean = size * tf.matmul(loadings, factors)
    llik = tf.reduce_sum(pois_llik(x, mean))
    penalty = weight * tf.reduce_sum(tf.log(2 - tf.reduce_sum(Z, axis=1, keepdims=True)))
    loss = -llik - penalty

    optim = tf.train.RMSPropOptimizer(learning_rate=learning_rate)
    train = optim.minimize(loss)
    reset = tf.variables_initializer(optim.variables())
    trace = [loss, weight, penalty]
    opt = [loadings, factors]
    obj = float('-inf')
    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      for a in alpha:
        sess.run(reset)  # Restart momentum
        for i in range(max_epochs):
          _, update = sess.run([train, trace], {weight: a})
          if not np.isfinite(update[0]):
            raise tf.train.NanLossDuringTrainingError
          elif np.isclose(update[0], obj, atol=atol):
            break
          if verbose and not i % 500:
            print(i, *update, end='\r')
        if verbose:
          print(i, *update)
      return sess.run(opt)

\(l_1\) penalized PF

The preliminary simulation results suggest simply penalizing \(Z\) to be sparse could be sufficient to satisfy the relaxed constraint.

def pfa_l1(x, k, size, weight=None, learning_rate=1e-2, max_epochs=10000, atol=1e-1, verbose=True):
  """Return paired factor analysis estimate

  :param x: array-like data ([n, p], tf.float32)
  :param k: number of factors
  :param size: size factor per observation
  :param weight: l1 penalty weight
  :param F: known factors
  :param learning_rate: base learning rate for RMSProp
  :param max_epochs: maximum number of iterations per penalty weight
  :param atol: absolute tolerance for convergence
  :param verbose: print objective function updates

  :returns loadings: Estimated sample loadings ([n, k])
  :returns factors: Estimated factors ([k, p])

  """
  n, p = x.shape
  graph = tf.Graph()
  with graph.as_default():
    x = tf.Variable(x.astype(np.float32), trainable=False)
    size = tf.Variable(size.reshape(-1, 1).astype(np.float32), trainable=False)
    weight = tf.Variable(np.array(weight).astype(np.float32), trainable=False)
    if F is not None:
      assert F.shape == (k, p)
      factors = tf.Variable(F.astype(np.float32), trainable=False)
    else:
      factors = tf.exp(tf.Variable(tf.random_normal([k, p])))
    Z = tf.sigmoid(tf.Variable(tf.random_normal([n, k])))
    Z /= tf.reduce_sum(Z, axis=1, keepdims=True)
    loadings = tf.exp(tf.Variable(tf.random_normal([n, k])))

    mean = size * tf.matmul(loadings, factors)
    llik = tf.reduce_sum(pois_llik(x, mean))
    penalty = weight * tf.reduce_sum(Z)
    loss = -llik - penalty

    train = tf.train.RMSPropOptimizer(learning_rate=learning_rate).minimize(loss)
    trace = [loss, weight, penalty]
    opt = [loadings, factors]
    obj = float('-inf')
    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      for i in range(max_epochs):
        _, update = sess.run([train, trace])
        if not np.isfinite(update[0]):
          raise tf.train.NanLossDuringTrainingError
        elif np.isclose(update[0], obj, atol=atol):
          break
        if verbose and not i % 500:
          print(i, *update, end='\r')
      if verbose:
        print(i, *update)
      return sess.run(opt)

Plotting functions

Produce a STRUCTURE plot.

def plot_structure(weights, ax=None, idx=None):
  if ax is None:
    ax = plt.gca()
  prop = np.cumsum(weights, axis=1)
  if idx is None:
    idx = np.lexsort(weights.T)
  for i in range(prop.shape[1]):
    if i > 0:
      bot = prop[idx,i - 1]
    else:
      bot = None
    ax.bar(np.arange(prop.shape[0]), weights[idx,i], bottom=bot, color=f'C{i}', width=1, label=f'Topic {i + 1}')
  ax.set_xlim(0, weights.shape[0])
  ax.set_xticks([])
  ax.set_ylim(0, 1)

Plot estimated against true topics.

def plot_topics(topics, ax=None):
  if ax is None:
    ax = plt.gca()
  z = np.cumsum(topics, axis=1)
  for row in z:
    ax.bar()

Results

Simulation

Unconstrained Poisson factorization

Simulate some Poisson data.

np.random.seed(1)
L = np.random.lognormal(size=(100, 3))
F = np.random.lognormal(size=(3, 10))
size = 100 * np.ones((L.shape[0], 1))
X = np.random.poisson(lam=size * L.dot(F)).astype(np.float32)
topics = F / F.sum(axis=1, keepdims=True)
weights = L * F.sum(axis=1) / size
weights /= weights.sum(axis=1, keepdims=True)

Plot the pairwise correlation of the topics to each other.

plt.clf()
plt.gcf().set_size_inches(3, 3)
plt.imshow(np.corrcoef(topics), cmap=colorcet.cm['coolwarm'], vmin=-1, vmax=1)
cb = plt.colorbar(shrink=0.5)
cb.set_label('Correlation')
plt.xlabel('True topic')
plt.ylabel('True topic')
plt.gcf().tight_layout()

pf-true-corr.png

Fix F and optimize L.

lhat, fhat = pfa(X, 3, alpha=0, F=F, size=size, learning_rate=1e-2)
est_topics = fhat / fhat.sum(axis=1, keepdims=True)
est_weights = lhat * fhat.sum(axis=1) / size
est_weights /= est_weights.sum(axis=1, keepdims=True)
plt.clf()
plt.set_cmap('Set2')
fig, ax = plt.subplots(2, 1, sharex=True)
fig.set_size_inches(6, 4)

plot_structure(weights, ax=ax[0])
plot_structure(est_weights, ax=ax[1], idx=np.lexsort(weights.T))

ax[0].set_ylabel('True topic weight')
ax[1].set_ylabel('Estimated topic weight')
ax[1].set_xlabel('Sample')

ax[1].legend(loc='center left', bbox_to_anchor=(1, .5), frameon=False)
fig.tight_layout()

pf-fixed-f.png

Fit PF. This is bi-convex, so we might not expect to recover \(\mathbf{L}, \mathbf{F}\).

lhat, fhat = pfa(X, 3, alpha=0, size=size, learning_rate=1e-3)
est_topics = fhat / fhat.sum(axis=1, keepdims=True)
est_weights = lhat * fhat.sum(axis=1) / size
est_weights /= est_weights.sum(axis=1, keepdims=True)

Plot the pairwise correlation between the estimated topics and the true topics.

plt.clf()
plt.gcf().set_size_inches(3, 3)
plt.imshow(np.triu(np.corrcoef(F, fhat)), cmap=colorcet.cm['coolwarm'], vmin=-1, vmax=1)
cb = plt.colorbar(shrink=0.5)
cb.set_label('Correlation')
plt.xticks(np.arange(6), ['True1', 'True2', 'True3', 'Est1', 'Est2', 'Est3'], rotation=90)
plt.yticks(np.arange(6), ['True1', 'True2', 'True3', 'Est1', 'Est2', 'Est3'])
plt.gcf().tight_layout()

pf-corr.png

Reorder the estimated topics to maximally correlate with the true topics.

est_weights = est_weights[:,np.argmax(np.corrcoef(F, fhat)[:3,3:], axis=1)]

Plot the corresponding topic model.

pf-full.png

PFA model

Simulate some data where each sample interpolates between a pair of factors.

np.random.seed(1)
L = np.random.lognormal(size=(100, 3))
F = np.random.lognormal(size=(3, 10))
Z = np.zeros(L.shape)
for i in range(Z.shape[0]):
  Z[i,np.random.choice(3, size=2).astype(int)] = 1
L *= Z
L /= L.sum(axis=1, keepdims=True)
size = 100 * np.ones((L.shape[0], 1))
X = np.random.poisson(lam=size * L.dot(F)).astype(np.float32)

topics = F / F.sum(axis=1, keepdims=True)
weights = L * F.sum(axis=1) / size
weights /= weights.sum(axis=1, keepdims=True)

Plot the pairwise correlation of the topics to each other.

pfa-true-corr.png

Fix \(F\) and fit \(L\) using PF.

lhat, fhat = pfa(X, 3, alpha=0, F=F, size=size, max_epochs=10000, learning_rate=1e-2)
est_topics = fhat / fhat.sum(axis=1, keepdims=True)
est_weights = lhat * fhat.sum(axis=1) / size
est_weights /= est_weights.sum(axis=1, keepdims=True)

Plot the corresponding topic models.

pfa-pf-fixed-f.png

Fix \(F\) to the truth and fit \(L\) using PFA.

lhat, fhat = pfa(X, 3, alpha=1, F=F, size=size, max_epochs=10000, learning_rate=1e-2)
est_topics = fhat / fhat.sum(axis=1, keepdims=True)
est_weights = lhat * fhat.sum(axis=1) / size
est_weights /= est_weights.sum(axis=1, keepdims=True)

Plot the corresponding topic models.

pfa-fixed-f.png

Project the solution onto the hard constraint.

est_weights[np.arange(est_weights.shape[0]),np.argmin(est_weights, axis=1)] = 0
est_weights /= est_weights.sum(axis=1, keepdims=True)

Plot the corresponding topic models.

pfa-fixed-f-hard.png

Fit full PFA.

lhat, fhat = pfa(X, 3, size=size, alpha=np.logspace(-1, 8, 10), max_epochs=10000, learning_rate=1e-2)
est_topics = fhat / fhat.sum(axis=1, keepdims=True)
est_weights = lhat * fhat.sum(axis=1) / size
est_weights /= est_weights.sum(axis=1, keepdims=True)

Plot the pairwise correlation between the estimated topics and the true topics.

plt.clf()
plt.gcf().set_size_inches(3, 3)
plt.imshow(np.triu(np.corrcoef(F, fhat)), cmap=colorcet.cm['coolwarm'], vmin=-1, vmax=1)
plt.colorbar(shrink=0.5)
plt.xticks(np.arange(6), ['True1', 'True2', 'True3', 'Est1', 'Est2', 'Est3'], rotation=90)
plt.yticks(np.arange(6), ['True1', 'True2', 'True3', 'Est1', 'Est2', 'Est3'])
plt.gcf().tight_layout()

pfa-corr.png

Reorder the estimated topics to maximally correlate with the true topics.

est_weights = est_weights[:,np.argmax(np.corrcoef(F, fhat)[:3,3:], axis=1)]

Plot the corresponding topic models.

pfa-full.png

Project the solution onto the hard constraint.

est_weights[np.arange(est_weights.shape[0]),np.argmin(est_weights, axis=1)] = 0
est_weights /= est_weights.sum(axis=1, keepdims=True)

pfa-full-hard.png

\(l_1\) penalized PF

Simulate data from PFA.

np.random.seed(1)
L = np.random.lognormal(size=(100, 3))
F = np.random.lognormal(size=(3, 10))
Z = np.zeros(L.shape)
for i in range(Z.shape[0]):
  Z[i,np.random.choice(3, size=2).astype(int)] = 1
L *= Z
L /= L.sum(axis=1, keepdims=True)
size = 100 * np.ones((L.shape[0], 1))
X = np.random.poisson(lam=size * L.dot(F)).astype(np.float32)

topics = F / F.sum(axis=1, keepdims=True)
weights = L * F.sum(axis=1) / size
weights /= weights.sum(axis=1, keepdims=True)

Fit \(L\) using sparse PF.

lhat, fhat = pfa_l1(X, 3, weight=0., size=size, max_epochs=10000, learning_rate=1e-2)
est_topics = fhat / fhat.sum(axis=1, keepdims=True)
est_weights = lhat * fhat.sum(axis=1) / size
est_weights /= est_weights.sum(axis=1, keepdims=True)

Reorder the estimated topics to maximally correlate with the true topics.

est_weights = est_weights[:,np.argmax(np.corrcoef(F, fhat)[:3,3:], axis=1)]

Plot the corresponding topic models.

pfa-l1-full.png

Project the solution onto the hard constraint.

est_weights[np.arange(est_weights.shape[0]),np.argmin(est_weights, axis=1)] = 0
est_weights /= est_weights.sum(axis=1, keepdims=True)

Plot the corresponding topic models.

pfa-full-hard.png

NB data

The preliminary results suggest that in the absence of extra-Poisson noise, ordinary PF gives a reasonable solution. Now, try simulating negative binomial data from PFA:

\[ x_{ij} \sim \mathrm{Poisson}(R_i \lambda_{ij}) \]

\[ \lambda_{ij} \sim [\mathbf{LF}]_{ij} u_{ij} \]

\[ u_{ij} \sim \mathrm{Gamma}(\phi_j^{-1}, \phi_j^{-1}) \]

Now, marginally \(x_{ij} \sim \mathrm{NB}(R_i \lambda_{ij}, \phi_j)\).

np.random.seed(1)
L = np.random.lognormal(size=(100, 3))
F = np.random.lognormal(size=(3, 10))
Z = np.zeros(L.shape)
for i in range(Z.shape[0]):
  Z[i,np.random.choice(3, size=2).astype(int)] = 1
L *= Z
L /= L.sum(axis=1, keepdims=True)
size = 100 * np.ones((L.shape[0], 1))
mu = size * L.dot(F)
phi = np.random.lognormal(sigma=.1, size=(100, 1))
u = np.random.gamma(shape=phi, size=(100, 10))
X = np.random.poisson(lam=mu * u).astype(np.float32)

topics = F / F.sum(axis=1, keepdims=True)
weights = L * F.sum(axis=1) / size
weights /= weights.sum(axis=1, keepdims=True)

Fix \(F\) and fit PF.

lhat, fhat = pfa(X, 3, alpha=0, F=F, size=size, learning_rate=1e-2)
est_topics = fhat / fhat.sum(axis=1, keepdims=True)
est_weights = lhat * fhat.sum(axis=1) / size
est_weights /= est_weights.sum(axis=1, keepdims=True)
est_weights = est_weights[:,np.argmax(np.corrcoef(F, fhat)[:3,3:], axis=1)]

Plot the corresponding topic models.

nb-pf-fixed-f.png

Fit PF.

lhat, fhat = pfa(X, 3, alpha=0, size=size, learning_rate=1e-2)
est_topics = fhat / fhat.sum(axis=1, keepdims=True)
est_weights = lhat * fhat.sum(axis=1) / size
est_weights /= est_weights.sum(axis=1, keepdims=True)
est_weights = est_weights[:,np.argmax(np.corrcoef(F, fhat)[:3,3:], axis=1)]

Plot the corresponding topic models.

nb-pf-full.png

Tabula Muris data

Tabula Muris (Tabula Muris consortium et al. 2018) has collected scRNA-seq from 20 murine tissues. In particular, they collected 53,760 FACS-sorted cells and 55,656 unsorted cells.

We can place the unsorted cells on lineages in two stages:

  1. Estimate 20 factors (mean vectors) from the FACS-sorted cells.
  2. Estimate PFA on the unsorted cells, using the fixed factors from (1).

Human embryoid body data

Han et al. 2018 study embryoid bodies (EBs) generated by allowing human pluripotent stem cells to spontaneously differentiate.

Author: Abhishek Sarkar

Created: 2020-05-14 Thu 23:52

Validate