diff --git a/demo/dali_reader.py b/demo/dali_reader.py index 3efde21fc539b85d54ecda26d058c26b87072c01..00cea874aa79a45db5b40e6d897d974ea527c101 100644 --- a/demo/dali_reader.py +++ b/demo/dali_reader.py @@ -38,6 +38,7 @@ def main(): ins.set_dataset_dir(args.data_dir) ins.set_train_epochs(args.num_epochs) ins.set_loss_type(args.loss_type) + ins.set_mixed_precision(True) # 1. Build a dali reader gpu_id = ins.trainer_id % 8 # Assume 8 card per machine dali_iter = dali.train(ins.train_batch_size,