scvi.core.trainers.trainer.Trainer

class scvi.core.trainers.trainer.Trainer(model, adata, use_cuda=True, metrics_to_monitor=None, benchmark=False, frequency=None, weight_decay=1e-06, early_stopping_kwargs=None, data_loader_kwargs=None, silent=False, batch_size=128, seed=0, max_nans=10)[source]

The abstract Trainer class for training a PyTorch model and monitoring its statistics.

It should be inherited at least with a .loss() function to be optimized in the training loop.

Parameters
model

A model instance from class VAE, VAEC, SCANVI

adata : AnnDataAnnData

A registered anndata object

use_cuda : boolbool (default: True)

Default: True.

metrics_to_monitor : List, NoneOptional[List] (default: None)

A list of the metrics to monitor. If not specified, will use the default_metrics_to_monitor as specified in each . Default: None.

benchmark : boolbool (default: False)

if True, prevents statistics computation in the training. Default: False.

frequency : int, NoneOptional[int] (default: None)

The frequency at which to keep track of statistics. Default: None.

early_stopping_metric

The statistics on which to perform early stopping. Default: None.

save_best_state_metric

The statistics on which we keep the network weights achieving the best store, and restore them at the end of training. Default: None.

on

The data_loader name reference for the early_stopping_metric and save_best_state_metric, that should be specified if any of them is. Default: None.

silent : boolbool (default: False)

If True, disables progress bar.

seed : intint (default: 0)

Random seed for train/test/validate split

Attributes

default_metrics_to_monitor

scvi_data_loaders_loop

Methods

check_training_status()

Checks if loss is admissible.

compute_metrics()

create_scvi_dl([model, adata, shuffle, …])

data_loaders_loop()

Returns an zipped iterable corresponding to loss signature.

on_epoch_begin()

on_epoch_end()

on_iteration_begin()

on_iteration_end()

on_training_begin()

on_training_end()

on_training_loop(tensors_dict)

register_data_loader(name, value)

train([n_epochs, lr, eps, params])

train_test_validation([model, adata, …])

Creates data loaders train_set, test_set, validation_set.

training_extras_end()

Place to put extra models in eval mode, etc.

training_extras_init(**extras_kwargs)

Other necessary models to simultaneously train.