class neuralgym.train.Trainer(primary=True, **context)

Bases: object

Trainer class for train iterative algorithm on single GPU.

There are two types of trainer in neuralgym: primary trainer and secondary trainer. For primary trainer, tensorflow related instances and configurations will be initialized, e.g. init all variables, summary writer, session, start_queue_runner and others. For the secondary trainer only train_ops and losses are iteratively updated/ran.


Add callbacks.

Parameters:callbacks – dict of callbacks

Initialize primary trainer context including:

  • log_dir
  • global_step
  • sess_config
  • allow_growth
  • summary writer
  • saver
  • global_variables_initializer
  • start_queue_runners
progress_logger(step, loss)

Progress bar for logging.

Note all statistics are averaged over epoch.


Start training with callbacks.

class neuralgym.train.MultiGPUTrainer(**context)

Bases: neuralgym.train.trainer.Trainer

Trainer class for train iterative algorithm on multi GPUs.

  • num_gpus (int) – Number of GPU(s) for training.
  • async_train (bool) – Asynchronous train or not.

Start training with callbacks.