neuralgym.train¶
-
class
neuralgym.train.
Trainer
(primary=True, **context)¶ Bases:
object
Trainer class for train iterative algorithm on single GPU.
There are two types of trainer in neuralgym: primary trainer and secondary trainer. For primary trainer, tensorflow related instances and configurations will be initialized, e.g. init all variables, summary writer, session, start_queue_runner and others. For the secondary trainer only train_ops and losses are iteratively updated/ran.
-
add_callbacks
(callbacks)¶ Add callbacks.
Parameters: callbacks – dict of callbacks
-
init_primary_trainer
()¶ Initialize primary trainer context including:
- log_dir
- global_step
- sess_config
- allow_growth
- summary writer
- saver
- global_variables_initializer
- start_queue_runners
-
progress_logger
(step, loss)¶ Progress bar for logging.
Note all statistics are averaged over epoch.
-
train
()¶ Start training with callbacks.
-