Introduction to gimVI

Imputing missing genes in spatial data from sequencing data with gimVI

[2]:
import sys

#if stable==True, will install via pypi, else will install from source
stable = True
IN_COLAB = "google.colab" in sys.modules

if IN_COLAB and stable:
    !pip install --quiet scvi-tools[tutorials]
elif IN_COLAB and not stable:
    !pip install --quiet --upgrade jsonschema
    !pip install --quiet git+https://github.com/yoseflab/scvi-tools@master#egg=scvi-tools[tutorials]
     |████████████████████████████████| 153kB 2.7MB/s
     |████████████████████████████████| 122kB 8.3MB/s
     |████████████████████████████████| 153kB 8.5MB/s
     |████████████████████████████████| 112kB 11.5MB/s
     |████████████████████████████████| 7.7MB 10.4MB/s
     |████████████████████████████████| 8.7MB 42.2MB/s
     |████████████████████████████████| 2.4MB 44.9MB/s
     |████████████████████████████████| 51kB 7.0MB/s
     |████████████████████████████████| 3.2MB 48.6MB/s
     |████████████████████████████████| 51kB 7.7MB/s
     |████████████████████████████████| 112kB 35.1MB/s
     |████████████████████████████████| 51kB 7.0MB/s
     |████████████████████████████████| 61kB 9.2MB/s
  Building wheel for loompy (setup.py) ... done
  Building wheel for sinfo (setup.py) ... done
  Building wheel for numpy-groupies (setup.py) ... done

Open In Colab

[3]:
import scanpy
import anndata
import numpy as np
import copy
import matplotlib.pyplot as plt

from scipy.stats import spearmanr
from scvi.data import (
    smfish,
    cortex,
    setup_anndata
)
from scvi.model import GIMVI

train_size = 0.8
[4]:
spatial_data = smfish(run_setup_anndata=False)
seq_data = cortex(run_setup_anndata=False)
INFO      Downloading file at /content/data/osmFISH_SScortex_mouse_all_cell.loom
INFO      Loading smFISH dataset
INFO      Downloading file at /content/data/expression.bin
INFO      Loading Cortex data from /content/data/expression.bin
INFO      Finished loading Cortex data
/usr/local/lib/python3.6/dist-packages/anndata/_core/anndata.py:119: ImplicitModificationWarning: Transforming to str index.
  warnings.warn("Transforming to str index.", ImplicitModificationWarning)

Preparing the data

In this section, we hold out some of the genes in the spatial dataset in order to test the imputation results

[5]:
#only use genes in both datasets
seq_data = seq_data[:, spatial_data.var_names].copy()

seq_gene_names = seq_data.var_names
n_genes = seq_data.n_vars
n_train_genes = int(n_genes*train_size)

#randomly select training_genes
rand_train_gene_idx = np.random.choice(range(n_genes), n_train_genes, replace = False)
rand_test_gene_idx = sorted(set(range(n_genes)) - set(rand_train_gene_idx))
rand_train_genes = seq_gene_names[rand_train_gene_idx]
rand_test_genes = seq_gene_names[rand_test_gene_idx]

#spatial_data_partial has a subset of the genes to train on
spatial_data_partial = spatial_data[:,rand_train_genes].copy()

#remove cells with no counts
scanpy.pp.filter_cells(spatial_data_partial, min_counts= 1)
scanpy.pp.filter_cells(seq_data, min_counts = 1)

#setup_anndata for spatial and sequencing data
setup_anndata(spatial_data_partial, labels_key='labels', batch_key='batch')
setup_anndata(seq_data, labels_key='labels')

#spatial_data should use the same cells as our training data
#cells may have been removed by scanpy.pp.filter_cells()
spatial_data = spatial_data[spatial_data_partial.obs_names]
INFO      Using batches from adata.obs["batch"]
INFO      Using labels from adata.obs["labels"]
INFO      Using data from adata.X
INFO      Computing library size prior per batch
INFO      Successfully registered anndata object containing 4530 cells, 26
          genes, 1 batches, 6 labels, and 0 proteins. Also registered 0 extra
          categorical covariates and 0 extra continuous covariates.
INFO      Please do not further modify adata until model is trained.
INFO      No batch_key inputted, assuming all cells are same batch
INFO      Using labels from adata.obs["labels"]
INFO      Using data from adata.X
INFO      Computing library size prior per batch
INFO      Successfully registered anndata object containing 2996 cells, 33
          genes, 1 batches, 7 labels, and 0 proteins. Also registered 0 extra
          categorical covariates and 0 extra continuous covariates.
INFO      Please do not further modify adata until model is trained.

Creating the model and training

[6]:
#create our model
model = GIMVI(seq_data, spatial_data_partial)

#train for 200 epochs
model.train(200)

Analyzing the results

Getting the latent representations and plotting UMAPs

[15]:
#get the latent representations for the sequencing and spatial data
latent_seq, latent_spatial = model.get_latent_representation()

#concatenate to one latent representation
latent_representation = np.concatenate([latent_seq, latent_spatial])
latent_adata = anndata.AnnData(latent_representation)

#labels which cells were from the sequencing dataset and which were from the spatial dataset
latent_labels = (['seq'] * latent_seq.shape[0]) + (['spatial'] * latent_spatial.shape[0])
latent_adata.obs['labels'] = latent_labels

#compute umap
scanpy.pp.neighbors(latent_adata, use_rep = 'X')
scanpy.tl.umap(latent_adata)

#save umap representations to original seq and spatial_datasets
seq_data.obsm['X_umap'] = latent_adata.obsm['X_umap'][:seq_data.shape[0]]
spatial_data.obsm['X_umap'] = latent_adata.obsm['X_umap'][seq_data.shape[0]:]
[16]:
#umap of the combined latent space
scanpy.pl.umap(latent_adata, color = 'labels', show = True)
... storing 'labels' as categorical
../../_images/user_guide_notebooks_gimvi_tutorial_11_1.png
[17]:
#umap of sequencing dataset
scanpy.pl.umap(seq_data, color = 'cell_type')
../../_images/user_guide_notebooks_gimvi_tutorial_12_0.png
[18]:
#umap of spatial dataset
scanpy.pl.umap(spatial_data, color = 'str_labels')
../../_images/user_guide_notebooks_gimvi_tutorial_13_0.png

Getting Imputation Score

imputation_score() returns the median spearman r correlation over all the cells

[19]:
# utility function for scoring the imputation
def imputation_score(model, data_spatial, gene_ids_test, normalized=True):
    _, fish_imputation = model.get_imputed_values(normalized=normalized)
    original, imputed = (
        data_spatial.X[:, gene_ids_test],
        fish_imputation[:, gene_ids_test],
    )

    if normalized:
        original /= data_spatial.X.sum(axis=1).reshape(-1, 1)

    spearman_gene = []
    for g in range(imputed.shape[1]):
        if np.all(imputed[:, g] == 0):
            correlation = 0
        else:
            correlation = spearmanr(original[:, g], imputed[:, g])[0]
        spearman_gene.append(correlation)
    return np.median(np.array(spearman_gene))

imputation_score(model, spatial_data, rand_test_gene_idx, True)
[19]:
0.1893881544444226

Plot imputation for Lamp5, which should have been hidden in the training

[20]:
#utility function for plotting spatial genes
def plot_gene_spatial(model, data_spatial, gene):
    data_seq = model.adatas[0]
    data_fish = data_spatial

    fig, (ax_gt, ax) = plt.subplots(1, 2)

    if type(gene) == str:
        gene_id = list(data_seq.gene_names).index(gene)
    else:
        gene_id = gene

    x_coord = data_fish.obs["x_coord"]
    y_coord = data_fish.obs["y_coord"]

    def order_by_strenght(x, y, z):
        ind = np.argsort(z)
        return x[ind], y[ind], z[ind]

    s = 20

    def transform(data):
        return np.log(1 + 100 * data)

    # Plot groundtruth
    x, y, z = order_by_strenght(
        x_coord, y_coord, data_fish.X[:, gene_id] / (data_fish.X.sum(axis=1) + 1)
    )
    ax_gt.scatter(x, y, c=transform(z), s=s, edgecolors="none", marker="s", cmap="Reds")
    ax_gt.set_title("Groundtruth")
    ax_gt.axis("off")

    _, imputed = model.get_imputed_values(normalized=True)
    x, y, z = order_by_strenght(x_coord, y_coord, imputed[:, gene_id])
    ax.scatter(x, y, c=transform(z), s=s, edgecolors="none", marker="s", cmap="Reds")
    ax.set_title("Imputed")
    ax.axis("off")
    plt.tight_layout()
    plt.show()

assert 'Lamp5' in rand_test_genes
plot_gene_spatial(model, spatial_data, 9)
../../_images/user_guide_notebooks_gimvi_tutorial_17_0.png

Inspect classification accuracy (we expect a uniform matrix)

If the matrix is diagonal, the kappa needs to be scaled up to ensure mixing.

[21]:
discriminator_classification = model.trainer.get_discriminator_confusion()
print(discriminator_classification)
[[0.50057477 0.49942523]
 [0.49828842 0.5017117 ]]
[22]:
import pandas as pd

results = pd.DataFrame(
    model.trainer.get_loss_magnitude(),
    index=["reconstruction", "kl_divergence", "discriminator"],
    columns=["Sequencing", "Spatial"],
)
results.columns.name = "Dataset"
results.index.name = "Loss"
results
[22]:
Dataset Sequencing Spatial
Loss
reconstruction 789.895563 1670.660999
kl_divergence 202.259926 197.577706
discriminator 22.259008 22.046192
[ ]: