Convolutional VAE for spatial transcriptomic data

Table of Contents

Introduction

There is great interest in analyzing transcriptomic profiles in complex tissues measured at single (or few) cell resolution jointly with associated spatial information (Ståhl et al. 2016, Vikovic et al. 2019, Rodriques et al. 2019). Such analysis could elucidate the molecular basis for the spatial organization of cell types/states (Moor and Itzkovitz 2017, Lein et al. 2017), and could be useful for genomics-based histology. Existing statistical methods for analyzing spatial transcriptomic data are largely based on Gaussian processes (Svensson et al. 2018, Edsgärd et al. 2018, Sun et al. 2020). These methods have proven to be successful in identifying genes that exhibit statistically significant spatial patterns of gene expression (compared to a null model of no structured spatial variation). However, these methods analyze one gene at a time, and therefore lose information on gene co-expression within and between samples. Further, they rely on a statistical test to prioritize genes that could be relevant to observed spatial structures (e.g., in associated imaging data).

For spatial transcriptomic data that are collected on a regular grid, we can instead use convolutional neural networks (CNNs; Denker et al. 1989, LeCun et al. 1989) to learn spatial structure. There are several advantages to the convolutional approach.

  1. A single model can be learned on all genes, and therefore can exploit covariance of gene expression induced by the gene regulatory network.
  2. The model can learn a lower-dimensional representation of spatial expression patterns, which can be used to identify genes with highly similar spatial expression patterns
  3. With associated imaging data, the convolutional filters can be jointly learned between image features and gene expression features, which can more directly associate cell type/state differences at the expression level with e.g., differences in morphology.

The key idea of our approach is to combine a Poisson measurement model (Sarkar and Stephens 2020),

\begin{equation} \newcommand\Pois{\operatorname{Pois}} \newcommand\N{\mathcal{N}} \newcommand\vx{\mathbf{x}} \newcommand\vz{\mathbf{z}} x_{ij} \mid s_i, \lambda_{ij} \sim \Pois(s_i \lambda_{ij}) \end{equation}

with a spatial expression model,

\begin{equation} \lambda_{ij} \mid \vz_j = (h(\vz_j))_i, \end{equation}

where \(h\) denotes a CNN. To perform approximate Bayesian inference (on \(\vz_j\)), we treat \(h\) as a decoder network in a variational autoencoder (Kingma and Welling 2014) and introduce an encoder network

\begin{equation} \vz_j \mid \vx_{\cdot j} \sim \N(\mu(\vx_{\cdot j}), \sigma^2(\vx_{\cdot j})), \end{equation}

where \(\mu, \sigma^2\) are also CNNs (Salimans et al. 2014). Unlike existing application of VAEs to scRNA-seq data (e.g., Lopez et al. 2018), in the spatial transcriptomics setting, inference is done over minibatches of genes rather than samples.

The learned VAE can be used to answer several questions of interest:

  1. What are the predominant patterns of spatial gene expression? These are learned as the low dimensional spatial representations of each gene \(\vz_j\).
  2. Which spatial structures define those patterns? These can be recovered using e.g., saliency maps to associate learned convolutional filters with the specific latent variable \(\vz_j\) of interest.
  3. Which genes co-vary spatially? These can be recovered by clustering the learned \(\vz_j\).

Setup

import anndata
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.io as si
import scipy.sparse as ss
import sklearn.datasets as skd
import scmodes
import torch
import torch.utils.data as td
import torch.utils.tensorboard as tb
import torchvision as tv
import umap
%matplotlib inline
%config InlineBackend.figure_formats = set(['retina'])
import colorcet
import matplotlib.colors as mc
import matplotlib.pyplot as plt
plt.rcParams['figure.facecolor'] = 'w'
plt.rcParams['font.family'] = 'Nimbus Sans'

Methods

Data

Implementation

class Encoder(torch.nn.Module):
  def __init__(self, input_h, input_w, output_dim):
    super().__init__()
    self.net = torch.nn.Sequential(
      torch.nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1, stride=2),
      torch.nn.ReLU(),
      torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1, stride=2),
      torch.nn.ReLU(),
      torch.nn.Flatten(),
      torch.nn.Linear(32 * input_h * input_w // 16, 1024),
      torch.nn.ReLU(),
      torch.nn.Linear(1024, 512),
      torch.nn.ReLU(),
    )
    self.mean = torch.nn.Linear(512, output_dim)
    self.scale = torch.nn.Sequential(
      torch.nn.Linear(512, output_dim),
      torch.nn.Softplus())

  def forward(self, x):
    h = self.net(x)
    return self.mean(h), self.scale(h)

class Reshape(torch.nn.Module):
  def __init__(self, shape):
    super().__init__()
    self.shape = shape

  def forward(self, x):
    return torch.reshape(x, self.shape)

class Decoder(torch.nn.Module):
  def __init__(self, input_dim, output_h, output_w):
    super().__init__()
    self.net = torch.nn.Sequential(
      torch.nn.Linear(input_dim, 512),
      torch.nn.ReLU(),
      torch.nn.Linear(512, 1024),
      torch.nn.ReLU(),
      torch.nn.Linear(1024, 32 * output_h * output_w // 16),
      torch.nn.ReLU(),
      Reshape([-1, 32, output_h // 4, output_w // 4]),
      torch.nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, stride=2, padding=1, output_padding=1),
      torch.nn.ReLU(),
      torch.nn.ConvTranspose2d(in_channels=16, out_channels=1, kernel_size=3, stride=2, padding=1, output_padding=1),
      torch.nn.Sigmoid(),
    )
    self.output_h = output_h
    self.output_w = output_w

  def forward(self, x):
    return self.net(x)

class ConvVAE(torch.nn.Module):
  def __init__(self, input_h, input_w, latent_dim):
    super().__init__()
    self.encoder = Encoder(input_h, input_w, latent_dim)
    self.decoder = Decoder(latent_dim, input_h, input_w)

  def forward(self, x, n_samples, writer=None, global_step=None):
    mean, scale = self.encoder(x)
    kl = torch.sum(.5 * (1 - 2 * torch.log(scale) + (mean * mean + scale * scale)), dim=1)
    qz = torch.distributions.Normal(mean, scale).rsample(n_samples)
    probs = self.decoder.forward(qz)
    error = torch.sum(torch.mean(x * torch.log(probs) + (1 - x) * torch.log(1 - probs), dim=1), dim=[1, 2])
    if writer is not None:
      writer.add_scalar('loss/error', torch.sum(error), global_step)
      writer.add_scalar('loss/kl', torch.sum(kl), global_step)
    loss = -torch.sum(error - kl)
    if torch.isnan(loss):
      raise RuntimeError
    return loss

  def fit(self, data, n_epochs=5, n_samples=None, log_dir=None, **kwargs):
    if log_dir is None:
      writer = None
    else:
      writer = tb.SummaryWriter(log_dir)
    if n_samples is None:
      n_samples = torch.Size([1])
    if torch.cuda.is_available():
      self.cuda()
    opt = torch.optim.Adam(self.parameters(), **kwargs)
    global_step = 0
    for epoch in range(n_epochs):
      for x, y in data:
        if torch.cuda.is_available():
          x = x.cuda()
        opt.zero_grad()
        loss = self.forward(x, n_samples, writer, global_step)
        loss.backward()
        opt.step()
        global_step += 1
    return self
class Pois(torch.nn.Module):
  def __init__(self, input_dim, output_h, output_w):
    super().__init__()
    self.net = torch.nn.Sequential(
      torch.nn.Linear(input_dim, 512),
      torch.nn.ReLU(),
      torch.nn.Linear(512, 1024),
      torch.nn.ReLU(),
      torch.nn.Linear(1024, 32 * output_h * output_w // 16),
      torch.nn.ReLU(),
      Reshape([-1, 32, output_h // 4, output_w // 4]),
      torch.nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, stride=2, padding=1, output_padding=1),
      torch.nn.ReLU(),
      torch.nn.ConvTranspose2d(in_channels=16, out_channels=1, kernel_size=3, stride=2, padding=1, output_padding=1),
      torch.nn.Softplus(),
    )
    self.output_h = output_h
    self.output_w = output_w

  def forward(self, x):
    return self.net(x)

class PoisConvVAE(torch.nn.Module):
  def __init__(self, input_h, input_w, latent_dim):
    super().__init__()
    self.encoder = Encoder(input_h, input_w, latent_dim)
    self.decoder = Pois(latent_dim, input_h, input_w)

  def forward(self, x, size, writer=None, global_step=None):
    mean, scale = self.encoder(x)
    kl = torch.sum(.5 * (1 - 2 * torch.log(scale) + (mean * mean + scale * scale)), dim=1)
    # TODO: support n_samples
    qz = torch.distributions.Normal(mean, scale).rsample([1])
    lam = size.unsqueeze(0) * self.decoder.forward(qz)
    w = size > 0
    error = torch.sum(w * torch.mean(x * torch.log(lam + 1e-16) - lam, dim=1), dim=[1, 2])
    loss = -torch.sum(error - kl)
    if writer is not None:
      writer.add_scalar('loss/error', torch.sum(error), global_step)
      writer.add_scalar('loss/kl', torch.sum(kl), global_step)
      writer.add_scalar('loss/neg_elbo', loss, global_step)
    if torch.isnan(loss):
      raise RuntimeError
    return loss

  def fit(self, data, size, n_epochs=5, log_dir=None, **kwargs):
    assert torch.cuda.is_available()
    if log_dir is None:
      writer = None
    else:
      writer = tb.SummaryWriter(log_dir)
    if torch.cuda.is_available():
      self.cuda()
    opt = torch.optim.Adam(self.parameters(), **kwargs)
    global_step = 0
    for epoch in range(n_epochs):
      for x in data:
        x = x.cuda()
        opt.zero_grad()
        loss = self.forward(x, size, writer, global_step)
        loss.backward()
        opt.step()
        global_step += 1
    return self

Results

MNIST

Implement a VAE using convolutional layers instead of fully connected layers. Refer to “Common architectures in convolutional neural networks”.

batch_size = 32
mnist_train = td.DataLoader(
  tv.datasets.MNIST(
    root='/scratch/midway2/aksarkar/singlecell/',
    transform=lambda x: (np.frombuffer(x.tobytes(), dtype='uint8') > 0).astype(np.float32).reshape((1, x.size[0], x.size[1]))),
  batch_size=batch_size,
  shuffle=True,
  pin_memory=True)
latent_dim = 4
n_epochs = 2
run = 3
torch.manual_seed(run)
model = ConvVAE(input_h=28, input_w=28, latent_dim=latent_dim)
model.fit(mnist_train,
          n_epochs=n_epochs,
          lr=1e-4,
          log_dir=f'runs/convvae/mnist/stride_latent_{latent_dim}_epochs_{n_epochs}_run_{run}')
ConvVAE(
(encoder): Encoder(
(net): Sequential(
(0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(1): ReLU()
(2): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(3): ReLU()
(4): Flatten()
(5): Linear(in_features=1568, out_features=1024, bias=True)
(6): ReLU()
(7): Linear(in_features=1024, out_features=512, bias=True)
(8): ReLU()
)
(mean): Linear(in_features=512, out_features=4, bias=True)
(scale): Sequential(
(0): Linear(in_features=512, out_features=4, bias=True)
(1): Softplus(beta=1, threshold=20)
)
)
(decoder): Decoder(
(net): Sequential(
(0): Linear(in_features=4, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=1024, bias=True)
(3): ReLU()
(4): Linear(in_features=1024, out_features=1568, bias=True)
(5): ReLU()
(6): Reshape()
(7): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(8): ReLU()
(9): ConvTranspose2d(16, 1, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(10): Sigmoid()
)
)
)

Serialize the fitted model parameters.

torch.save(model.state_dict(), f'/scratch/midway2/aksarkar/singlecell/mnist-stride_latent_{latent_dim}_epochs_{n_epochs}_run_{run}.pt')

Draw some samples.

rng = np.random.default_rng(1)
with torch.no_grad():
  z = rng.normal(size=(10, latent_dim))
  xh = model.decoder.forward(torch.tensor(z, dtype=torch.float, device='cuda')).squeeze().cpu().numpy()
plt.clf()
fig, ax = plt.subplots(2, 5)
for i, a in enumerate(ax.ravel()):
  a.imshow(xh[i], cmap=colorcet.cm['gray'])
  a.set_xticks([])
  a.set_yticks([])
fig.tight_layout()

mnist-ex.png

Plot the latent space.

with torch.no_grad():
  temp = [(y.numpy(), model.encoder(x.cuda())[0].cpu().numpy()) for x, y in mnist_train]
labels = np.hstack([y for y, _ in temp])
embed = np.vstack([z for _, z in temp])
u, d, vt = np.linalg.svd(embed, full_matrices=False)
cm = plt.get_cmap('Set3')
plt.clf()
fig, ax = plt.subplots(1, 3, sharey=True)
fig.set_size_inches(7.5, 2.5)
for k, a in enumerate(ax):
  for y in range(10):
    idx = labels == y
    a.scatter(embed[idx,k + 1], embed[idx,0], s=1, alpha=0.25, c=np.array(cm(y)).reshape(1, -1), label=y)
  a.set_xlabel(f'PC {k + 2}')
ax[0].set_ylabel('PC 1')
leg = ax[-1].legend(frameon=False, markerscale=4, handletextpad=0, loc='center left', bbox_to_anchor=(1, .5))
for h in leg.legendHandles:
  h.set_alpha(1)
fig.tight_layout()

mnist-cvae-latent-pca.png

umaps = dict()
for n_neighbors in (5, 10, 25, 50):
  print(f'fitting {n_neighbors}')
  umaps[n_neighbors] = umap.UMAP(n_neighbors=n_neighbors, min_dist=0, metric='euclidean', random_state=1).fit_transform(embed)
cm = plt.get_cmap('Set3')
plt.clf()
fig, ax = plt.subplots(1, 4, sharex=True, sharey=True)
fig.set_size_inches(7.5, 2.5)
for k, a in zip(umaps, ax):
  for y in range(10):
    idx = labels == y
    a.scatter(umaps[k][idx,0], umaps[k][idx,1], s=1, alpha=0.25, c=np.array(cm(y)).reshape(1, -1), label=y)
  a.set_title(f'k = {k}')
  a.set_xlabel('UMAP 1')
  a.set_xticks([])
  a.set_yticks([])
ax[0].set_ylabel('UMAP 2')
leg = ax[-1].legend(frameon=False, markerscale=4, handletextpad=0, loc='center left', bbox_to_anchor=(1, .5))
for h in leg.legendHandles:
  h.set_alpha(1)
fig.tight_layout()

mnist-cvae-latent-umap.png

Planted spatial signal

Generate an interesting spatial pattern.

X, y = skd.make_moons(n_samples=5000, noise=0.15, random_state=1)
cm = plt.get_cmap('Dark2')
plt.clf()
plt.gcf().set_size_inches(3, 3)
plt.scatter(X[:,0], X[:,1], s=2, c=cm(y))
plt.tight_layout()

moons.png

Bin the samples to get spots.

n_bins = 64
t = [[X[:,0].min(), X[:,0].max()], [X[:,1].min(), X[:,1].max()]]
H0, *_ = np.histogram2d(X[y == 0,0], X[y == 0,1], bins=n_bins, range=t, density=True)
H1, *_ = np.histogram2d(X[y == 1,0], X[y == 1,1], bins=n_bins, range=t, density=True)
cm = colorcet.cm['fire']
plt.clf()
fig, ax = plt.subplots(1, 2)
fig.set_size_inches(6, 3)
ax[0].pcolormesh(H0.T, cmap=cm)
ax[1].pcolormesh(H1.T, cmap=cm)
plt.tight_layout()

moons-binned.png

Treat the spots as the true \(\lambda\), and generate example noisy observations.

np.random.seed(2)
x0 = st.nbinom(n=np.exp(-1), p = 1 / (1 + H0 / np.exp(-1))).rvs()
x1 = st.nbinom(n=np.exp(-2), p = 1 / (1 + H0 / np.exp(-2))).rvs()
cm = colorcet.cm['fire']
plt.clf()
fig, ax = plt.subplots(1, 2)
fig.set_size_inches(6, 3)
ax[0].pcolormesh(x0.T, cmap=cm)
ax[0].set_title('$\phi$ = $\exp(1)$')
ax[1].pcolormesh(x1.T, cmap=cm)
ax[1].set_title('$\phi$ = $\exp(2)$')
plt.tight_layout()

moons-gen.png

Generate a training data set.

n = 8192
z = np.random.uniform(size=n) < 0.5
log_inv_disp = np.random.uniform(-3, 0, size=n)
x = np.stack([
  st.nbinom(n=np.exp(phi_j), p=1 / (1 + H0 / np.exp(phi_j))).rvs() if z_j else
  st.nbinom(n=np.exp(phi_j), p=1 / (1 + H1 / np.exp(phi_j))).rvs()
  for z_j, phi_j in zip(z, log_inv_disp)])
train = td.DataLoader(
  torch.tensor(x, dtype=torch.float).unsqueeze(1),
  batch_size=64,
  shuffle=True,
  pin_memory=True,
  num_workers=1)
latent_dim = 4
n_epochs = 40
lr = 5e-3
run = 0
torch.manual_seed(run)
model = PoisConvVAE(input_h=64, input_w=64, latent_dim=latent_dim)
model.fit(train,
          size=torch.tensor(x.sum(axis=0), dtype=torch.float, device='cuda'),
          n_epochs=n_epochs,
          lr=lr,
          log_dir=f'runs/convvae/planted/adam_latent_{latent_dim}_epochs_{n_epochs}_lr_{lr:.1g}_run_{run}')
PoisConvVAE(
(encoder): Encoder(
(net): Sequential(
(0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(1): ReLU()
(2): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(3): ReLU()
(4): Flatten(start_dim=1, end_dim=-1)
(5): Linear(in_features=8192, out_features=1024, bias=True)
(6): ReLU()
(7): Linear(in_features=1024, out_features=512, bias=True)
(8): ReLU()
)
(mean): Linear(in_features=512, out_features=4, bias=True)
(scale): Sequential(
(0): Linear(in_features=512, out_features=4, bias=True)
(1): Softplus(beta=1, threshold=20)
)
)
(decoder): Pois(
(net): Sequential(
(0): Linear(in_features=4, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=1024, bias=True)
(3): ReLU()
(4): Linear(in_features=1024, out_features=8192, bias=True)
(5): ReLU()
(6): Reshape()
(7): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(8): ReLU()
(9): ConvTranspose2d(16, 1, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(10): Softplus(beta=1, threshold=20)
)
)
)
torch.save(model.state_dict(), f'/scratch/midway2/aksarkar/singlecell/planted_latent_{latent_dim}_epochs_{n_epochs}_lr_{lr:.1g}_run_{run}.pt')

De-noise one of the “genes” ignoring spatial information.

s = x.sum(axis=0)
fit0 = scmodes.ebpm.ebpm_gamma(x[0].ravel(), s.ravel(), tol=1e-7)
pm0 = np.where(s > 0, (x[0] + np.exp(fit0[1])) / (s + np.exp(fit0[1] - fit0[0])), 0)

De-noise the same “gene” using the spatial VAE.

with torch.no_grad():
  pm = model.decoder(model.encoder(torch.tensor(x[0], dtype=torch.float, device='cuda').unsqueeze(0).unsqueeze(0))[0]).squeeze().cpu().numpy()

Plot a fitted “gene”.

plt.clf()
fig, ax = plt.subplots(1, 3)
fig.set_size_inches(7.5, 3)
im = ax[0].pcolormesh(x[0].T, cmap=colorcet.cm['fire'])
ax[0].set_title('Observed counts')
fig.colorbar(im, ax=ax[0], orientation='horizontal')

im = ax[1].pcolormesh(pm0.T, cmap=colorcet.cm['bmy'])
ax[1].set_title('Naive denoising')
fig.colorbar(im, ax=ax[1], orientation='horizontal')

im = ax[2].pcolormesh(pm.T, cmap=colorcet.cm['bmy'])
ax[2].set_title('Spatial denoising')
fig.colorbar(im, ax=ax[2], orientation='horizontal')
plt.tight_layout()

planted-denoised.png

Visium

Preprocess the data.

x = si.mmread('/project2/mstephens/aksarkar/projects/singlecell-ideas/data/spatial/V1_Breast_Cancer_Block_A_Section_1/filtered_feature_bc_matrix/matrix.mtx.gz')
var = pd.read_csv('/project2/mstephens/aksarkar/projects/singlecell-ideas/data/spatial/V1_Breast_Cancer_Block_A_Section_1/filtered_feature_bc_matrix/features.tsv.gz', sep='\t', header=None)
var.columns = ['gene', 'name', 'featuretype']
barcodes = pd.read_csv('/project2/mstephens/aksarkar/projects/singlecell-ideas/data/spatial/V1_Breast_Cancer_Block_A_Section_1/filtered_feature_bc_matrix/barcodes.tsv.gz', sep='\t', header=None)
coord = pd.read_csv('/project2/mstephens/aksarkar/projects/singlecell-ideas/data/spatial/V1_Breast_Cancer_Block_A_Section_1/spatial/tissue_positions_list.csv', header=None)
obs = (barcodes
       .merge(coord, on=0)
       .rename(dict(enumerate(['barcode', 'in_tissue', 'row', 'col', 'pxl_row', 'pxl_col'])), axis=1))
idx = np.array(obs.sort_values(['row', 'col']).index)
dat = anndata.AnnData(x.T.tocsr()[idx], obs=obs.loc[idx], var=var)
dat.write('/project2/mstephens/aksarkar/projects/singlecell-ideas/data/spatial/V1_Breast_Cancer_Block_A_Section_1/dat.h5ad')

Read the pre-processed data.

dat = anndata.read_h5ad('/project2/mstephens/aksarkar/projects/singlecell-ideas/data/spatial/V1_Breast_Cancer_Block_A_Section_1/dat.h5ad')
sc.pp.filter_genes(dat, min_cells=10)

The data are a (sparse) tensor of positions \(x, y\) and counts \(z\). Convert to a dense tensor.

TODO:

temp = dat.X.A
r = dat[:,0].obs['row'].values
c = dat[:,0].obs['col'].values // 2
x = np.stack([
  ss.coo_matrix((temp[:,j], (r, c)),
                # Important: we need dimensions to be divisible by 4
                shape=(80, 64)).A
  for j in range(dat.shape[1])])
x.shape
(19690, 80, 64)

Make sure the wrapper worked, by extracting the spatial gene expression of CD44.

name = 'HLA-B'
z = x[dat.var['name'] == name].squeeze()
size = ss.coo_matrix((dat.X.sum(axis=1).A.ravel(), (dat[:,0].obs['row'], dat[:,0].obs['col'].values // 2)), shape=x.shape[1:]).A
fit = scmodes.ebpm.ebpm_gamma(z.ravel(), size.ravel())
pm = np.where(size.ravel(), (z.ravel() + np.exp(fit[1])) / (size.ravel() + np.exp(-fit[0] + fit[1])), 0)
plt.clf()
plt.gcf().set_size_inches(3, 3)
plt.imshow(pm.reshape(x.shape[1:]), cmap=colorcet.cm['fire'], vmin=pm.min(), vmax=pm.max())
plt.colorbar(label='Posterior mean gene expression', shrink=0.5)
plt.xticks([])
plt.yticks([])
plt.title(name)
plt.tight_layout()

visium-ex.png

train = td.DataLoader(
  torch.tensor(x, dtype=torch.float).unsqueeze(1),
  batch_size=64,
  shuffle=True,
  pin_memory=True,
  num_workers=4)
latent_dim = 4
n_epochs = 40
run = 4
torch.manual_seed(run)
model = PoisConvVAE(input_h=80, input_w=64, latent_dim=latent_dim)
model.fit(train,
          size=torch.tensor(x.sum(axis=0), dtype=torch.float, device='cuda'),
          n_epochs=n_epochs,
          lr=5e-7,
          log_dir=f'runs/convvae/visium/adam_latent_{latent_dim}_epochs_{n_epochs}_run_{run}')
PoisConvVAE(
(encoder): Encoder(
(net): Sequential(
(0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(1): ReLU()
(2): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(3): ReLU()
(4): Flatten()
(5): Linear(in_features=10240, out_features=1024, bias=True)
(6): ReLU()
(7): Linear(in_features=1024, out_features=512, bias=True)
(8): ReLU()
)
(mean): Linear(in_features=512, out_features=4, bias=True)
(scale): Sequential(
(0): Linear(in_features=512, out_features=4, bias=True)
(1): Softplus(beta=1, threshold=20)
)
)
(decoder): Pois(
(net): Sequential(
(0): Linear(in_features=4, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=1024, bias=True)
(3): ReLU()
(4): Linear(in_features=1024, out_features=10240, bias=True)
(5): ReLU()
(6): Reshape()
(7): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(8): ReLU()
(9): ConvTranspose2d(16, 1, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(10): Softplus(beta=1, threshold=20)
)
)
)
torch.save(model.state_dict(), '/scratch/midway2/aksarkar/singlecell/visium-stride_latent_4_epochs_2_run_9.pt')

Encode and decode each gene.

with torch.no_grad():
  lam = np.stack([
    model.decoder(
      model.encoder(
        torch.tensor(x[j].squeeze(), dtype=torch.float, device='cuda').unsqueeze(0).unsqueeze(0))[0]).cpu().numpy()
    for j in range(x.shape[0])])

0 - eacd7eb7-579e-4f23-8b03-17a9620022c9

Encode and decode HLA-B

with torch.no_grad():
  lam = model.decoder(
    model.encoder(
      torch.tensor(x[dat.var['name'] == 'HLA-B'].squeeze(), dtype=torch.float, device='cuda').unsqueeze(0).unsqueeze(0))
    [0]).cpu().numpy()
plt.clf()
fig, ax = plt.subplots(1, 2, sharex=True, sharey=True)
fig.set_size_inches(6, 3)
i = ax[0].imshow(pm.reshape(x.shape[1:]), cmap=colorcet.cm['fire'], vmin=pm.min(), vmax=pm.max())
fig.colorbar(i, label='Posterior mean expression', shrink=0.5, orientation='horizontal', ax=ax[0])
ax[0].set_title('Non-spatial model')

temp = lam.squeeze()
i = ax[1].imshow(temp, cmap=colorcet.cm['fire'], vmin=temp.min(), vmax=temp.max())
fig.colorbar(i, label='Posterior mean expression', shrink=0.5, orientation='horizontal', ax=ax[1])
ax[1].set_title('Spatial model')

for a in ax:
  a.set_xticks([])
  a.set_yticks([])
plt.tight_layout()

visium-ex2.png

Author: Abhishek Sarkar

Created: 2021-09-01 Wed 14:16

Validate