scvi.core.trainers.TotalTrainer

class scvi.core.trainers.TotalTrainer(model, dataset, train_size=0.9, test_size=0.1, pro_recons_weight=1.0, n_epochs_kl_warmup=None, n_iter_kl_warmup='auto', discriminator=None, use_adversarial_loss=False, kappa=None, early_stopping_kwargs='auto', **kwargs)[source]

Unsupervised training for totalVI using variational inference.

Parameters
model : TOTALVAETOTALVAE

A model instance from class TOTALVAE

adata

A registered AnnData object

train_size : floatfloat (default: 0.9)

The train size, a float between 0 and 1 representing proportion of dataset to use for training to use Default: 0.90.

test_size : floatfloat (default: 0.1)

The test size, a float between 0 and 1 representing proportion of dataset to use for testing to use Default: 0.10. Note that if train and test do not add to 1 the remainder is placed in a validation set

pro_recons_weight : floatfloat (default: 1.0)

Scaling factor on the reconstruction loss for proteins. Default: 1.0.

n_epochs_kl_warmup : int, NoneOptional[int] (default: None)

Number of epochs for annealing the KL terms for z and mu of the ELBO (from 0 to 1). If None, no warmup performed, unless n_iter_kl_warmup is set.

n_iter_kl_warmup : str, intUnion[str, int] (default: 'auto')

Number of minibatches for annealing the KL terms for z and mu of the ELBO (from 0 to 1). If set to “auto”, the number of iterations is equal to 75% of the number of cells. n_epochs_kl_warmup takes precedence if it is not None. If both are None, then no warmup is performed.

discriminator : Classifier, NoneOptional[Classifier] (default: None)

Classifier used for adversarial training scheme

use_adversarial_loss : boolbool (default: False)

Whether to use adversarial classifier to improve mixing

kappa : float, NoneOptional[float] (default: None)

Scaling factor for adversarial loss. If None, follow inverse of kl warmup schedule.

early_stopping_kwargs : dict, str, NoneUnion[dict, str, None] (default: 'auto')

Keyword args for early stopping. If “auto”, use totalVI defaults. If None, disable early stopping.

Attributes

default_metrics_to_monitor

kl_weight

scvi_data_loaders_loop

Methods

check_training_status()

Checks if loss is admissible.

compute_metrics()

create_scvi_dl([model, adata, shuffle, …])

data_loaders_loop()

Returns an zipped iterable corresponding to loss signature.

loss(tensors)

loss_discriminator(z, batch_index[, …])

on_epoch_begin()

on_epoch_end()

on_iteration_begin()

on_iteration_end()

on_training_begin()

on_training_end()

on_training_loop(tensors_dict)

register_data_loader(name, value)

train([n_epochs, lr, eps, params])

train_test_validation([model, adata, …])

Creates data loaders train_set, test_set, validation_set.

training_extras_end()

Place to put extra models in eval mode, etc.

training_extras_init([lr_d, eps])

Other necessary models to simultaneously train.