Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement distributed training using horovod #3533

Merged
merged 4 commits into from
Mar 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions doc/TRAINING.rst
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,27 @@ python3 DeepSpeech.py --train_files ./train.csv --dev_files ./dev.csv --test_fil

On a Volta generation V100 GPU, automatic mixed precision speeds up DeepSpeech training and evaluation by ~30%-40%.

Distributed training using Horovod
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

If you have a capable compute architecture, it is possible to distribute the training using `Horovod <https://github.com/horovod/horovod>`_. A fast network is recommended.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I understand that people using Horovod should know the underlying requirements, I think a fast network might be troublesome to some users: we have had requests of people to use distributed training to leverage several GPUs on only Gigabit Ethernet networks.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are still doing experiments, but ethernet might be sufficient for small setups, e.g. using two or three systems.

Horovod is capable of using MPI and NVIDIA's NCCL for highly optimized inter-process communication.
It also offers `Gloo <https://github.com/facebookincubator/gloo>`_ as an easy-to-setup communication backend.

For more information about setup or tuning of Horovod please visit `Horovod's documentation <https://horovod.readthedocs.io/en/stable/summary_include.html>`_.

Horovod is expected to run on heterogeneous systems (e.g. different number and model type of GPUs per machine).
However, this can cause unpredictable problems and user interaction in training code is needed.
Therefore, we do only support homogenous systems, which means same hardware and also same software configuration (OS, drivers, MPI, NCCL, TensorFlow, ...) on each machine.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💯

The only exception is different number of GPUs per machine, since this can be controlled by ``horovodrun -H``.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No risk of improper interactions with batch size for example?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not get your question at all.

Specified batch size via CLI will treated as batch size for each worker not for the machine or complete system, Therefore we do learning rate rescaling.
If all gpus are equal I do not see any problems, since all get the same batch size.

If you change code it would be possible to set different batch sizes on each gpu (e.g. for different memory or load balancing). This would open doors for load balance problem, you do not what to support.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If all gpus are equal I do not see any problems, since all get the same batch size.

batch size applies equally to all GPUs of one machine?

Sorry but the few hovorovrun examples makes it hard to grasp the meaning of the commands.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Horovod by itself does nothing with the batch size and therefore also horovodrun
The batch size used on each horovod process and so every GPU on every machine is the batch size you specified like normally with DeepSpeech.py --train_batch_size and is only used on creating dataset

train_set = create_dataset(FLAGS.train_files.split(','),
                           batch_size=FLAGS.train_batch_size,
                           epochs=FLAGS.epochs,
                           augmentations=Config.augmentations,
                           cache_path=FLAGS.feature_cache,
                           train_phase=True,
                           exception_box=exception_box,
                           process_ahead=Config.num_devices * FLAGS.train_batch_size * 2,
                           reverse=FLAGS.reverse_train,
                           limit=FLAGS.limit_train,
                           buffering=FLAGS.read_buffer,
                           split_dataset=split_dataset)

https://github.com/tud-zih-tools/DeepSpeech/blob/329bf876069720cf05b4e4700e6d0dde104b6bac/training/deepspeech_training/train.py#L423 (Is it possible to link the code here directly?)

So, your effective batch size for training on which the Optimizer is applyed is FLAGS.train_batch_size * NUM_PROC where NUM_PROC is the number of used horovod processes which is equally to the number gpus on your whole setup. In Horovod terminus, equally to HPC MPI terminus (horovod.size())

To prevent network convergence problems because of this bigger effective batch size we scale the learning rate as recommented by the horovod devs
with
FLAGS.learning_rate * NUM_PROC
https://github.com/tud-zih-tools/DeepSpeech/blob/329bf876069720cf05b4e4700e6d0dde104b6bac/training/deepspeech_training/train.py#L487

In theory horovod has no problem if you apply different batch sizes to each gpus. In practice you want to make sure every process finishes with its batch at about the same time (load balance). If one process is much late horovod error handling will take action.


Detailed documentation how to run Horovod is provided `here <https://horovod.readthedocs.io/en/stable/running.html>`_.
The short command to train on 4 machines using 4 GPUs each:

.. code-block:: bash

horovodrun -np 16 -H server1:4,server2:4,server3:4,server4:4 python3 DeepSpeech.py --train_files [...] --horovod
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth linking some official stable crash-course "how to use horovod" here ?


Checkpointing
^^^^^^^^^^^^^

Expand Down
10 changes: 10 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def main():
'tensorflow == 1.15.4'
]

horovod_pypi_dep = [
'horovod[tensorflow] == 0.21.3'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I have not checked the pypi repo) how does that works when tensorflow-gpu is needed, is this horovod[tensorflow] still working as expected ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you want to take a look here https://horovod.readthedocs.io/en/stable/install.html
There is only horovod[tensorflow] for TensorFlow bindings and should be the same as HOROVOD_WITH_TENSORFLOW=1 pip install horovod.

]

# Due to pip craziness environment variables are the only consistent way to
# get options into this script when doing `pip install`.
tc_decoder_artifacts_root = os.environ.get('DECODER_ARTIFACTS_ROOT', '')
Expand All @@ -94,6 +98,12 @@ def main():
else:
install_requires = install_requires + tensorflow_pypi_dep

if os.environ.get('DS_WITH_HOROVOD', ''):
install_requires = install_requires + horovod_pypi_dep
else:
install_requires = install_requires


setup(
name='deepspeech_training',
version=version,
Expand Down
193 changes: 120 additions & 73 deletions training/deepspeech_training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,18 +413,24 @@ def log_grads_and_vars(grads_and_vars):
def train():
exception_box = ExceptionBox()

if FLAGS.horovod:
import horovod.tensorflow as hvd

# Create training and validation datasets
split_dataset = FLAGS.horovod

train_set = create_dataset(FLAGS.train_files.split(','),
batch_size=FLAGS.train_batch_size,
epochs=FLAGS.epochs,
augmentations=Config.augmentations,
cache_path=FLAGS.feature_cache,
train_phase=True,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
process_ahead=Config.num_devices * FLAGS.train_batch_size * 2,
reverse=FLAGS.reverse_train,
limit=FLAGS.limit_train,
buffering=FLAGS.read_buffer)
buffering=FLAGS.read_buffer,
split_dataset=split_dataset)

iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
tfv1.data.get_output_shapes(train_set),
Expand All @@ -439,10 +445,11 @@ def train():
batch_size=FLAGS.dev_batch_size,
train_phase=False,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
process_ahead=Config.num_devices * FLAGS.dev_batch_size * 2,
reverse=FLAGS.reverse_dev,
limit=FLAGS.limit_dev,
buffering=FLAGS.read_buffer) for source in dev_sources]
buffering=FLAGS.read_buffer,
split_dataset=split_dataset) for source in dev_sources]
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]

if FLAGS.metrics_files:
Expand All @@ -451,10 +458,11 @@ def train():
batch_size=FLAGS.dev_batch_size,
train_phase=False,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
process_ahead=Config.num_devices * FLAGS.dev_batch_size * 2,
reverse=FLAGS.reverse_dev,
limit=FLAGS.limit_dev,
buffering=FLAGS.read_buffer) for source in metrics_sources]
buffering=FLAGS.read_buffer,
split_dataset=split_dataset) for source in metrics_sources]
metrics_init_ops = [iterator.make_initializer(metrics_set) for metrics_set in metrics_sets]

# Dropout
Expand All @@ -474,22 +482,38 @@ def train():
# Building the graph
learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False)
reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction))
optimizer = create_optimizer(learning_rate_var)
if FLAGS.horovod:
# Effective batch size in synchronous distributed training is scaled by the number of workers. An increase in learning rate compensates for the increased batch size.
optimizer = create_optimizer(learning_rate_var * hvd.size())
optimizer = hvd.DistributedOptimizer(optimizer)
else:
optimizer = create_optimizer(learning_rate_var)

# Enable mixed precision training
if FLAGS.automatic_mixed_precision:
log_info('Enabling automatic mixed precision training.')
optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)

gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates)
if FLAGS.horovod:
loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=False)
gradients = optimizer.compute_gradients(loss)

tfv1.summary.scalar(name='step_loss', tensor=loss, collections=['step_summaries'])
log_grads_and_vars(gradients)

# global_step is automagically incremented by the optimizer
global_step = tfv1.train.get_or_create_global_step()
apply_gradient_op = optimizer.apply_gradients(gradients, global_step=global_step)
else:
gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates)

# Average tower gradients across GPUs
avg_tower_gradients = average_gradients(gradients)
log_grads_and_vars(avg_tower_gradients)
# Average tower gradients across GPUs
avg_tower_gradients = average_gradients(gradients)
log_grads_and_vars(avg_tower_gradients)

# global_step is automagically incremented by the optimizer
global_step = tfv1.train.get_or_create_global_step()
apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)
# global_step is automagically incremented by the optimizer
global_step = tfv1.train.get_or_create_global_step()
apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)

# Summaries
step_summaries_op = tfv1.summary.merge_all('step_summaries')
Expand All @@ -506,18 +530,22 @@ def train():
}

# Checkpointing
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train')
if Config.is_master_process:
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train')

best_dev_saver = tfv1.train.Saver(max_to_keep=1)
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')

# Save flags next to checkpoints
if not is_remote_path(FLAGS.save_checkpoint_dir):
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
with open_remote(flags_file, 'w') as fout:
fout.write(FLAGS.flags_into_string())
# Save flags next to checkpoints
if not is_remote_path(FLAGS.save_checkpoint_dir):
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
with open_remote(flags_file, 'w') as fout:
fout.write(FLAGS.flags_into_string())

if FLAGS.horovod:
bcast = hvd.broadcast_global_variables(0)

with tfv1.Session(config=Config.session_config) as session:
log_debug('Session opened.')
Expand All @@ -527,6 +555,8 @@ def train():

# Load checkpoint or initialize variables
load_or_init_graph_for_training(session)
if FLAGS.horovod:
bcast.run()

def run_set(set_name, epoch, init_op, dataset=None):
is_train = set_name == 'train'
Expand Down Expand Up @@ -554,12 +584,13 @@ def __call__(self, progress, data, **kwargs):
data['mean_loss'] = total_loss / step_count if step_count else 0.0
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)

prefix = 'Epoch {} | {:>10}'.format(epoch, human_readable_set_names[set_name])
widgets = [' | ', progressbar.widgets.Timer(),
' | Steps: ', progressbar.widgets.Counter(),
' | ', LossWidget()]
suffix = ' | Dataset: {}'.format(dataset) if dataset else None
pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()
if Config.is_master_process:
prefix = 'Epoch {} | {:>10}'.format(epoch, human_readable_set_names[set_name])
widgets = [' | ', progressbar.widgets.Timer(),
' | Steps: ', progressbar.widgets.Counter(),
' | ', LossWidget()]
suffix = ' | Dataset: {}'.format(dataset) if dataset else None
pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()

# Initialize iterator to the appropriate dataset
session.run(init_op)
Expand All @@ -583,15 +614,17 @@ def __call__(self, progress, data, **kwargs):
total_loss += batch_loss
step_count += 1

pbar.update(step_count)
if Config.is_master_process:
pbar.update(step_count)

step_summary_writer.add_summary(step_summary, current_step)
step_summary_writer.add_summary(step_summary, current_step)

if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
checkpoint_time = time.time()
if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
checkpoint_time = time.time()

pbar.finish()
if Config.is_master_process:
pbar.finish()
mean_loss = total_loss / step_count if step_count > 0 else 0.0
return mean_loss, step_count

Expand All @@ -603,21 +636,25 @@ def __call__(self, progress, data, **kwargs):
try:
for epoch in range(FLAGS.epochs):
# Training
log_progress('Training epoch %d...' % epoch)
if Config.is_master_process:
log_progress('Training epoch %d...' % epoch)
train_loss, _ = run_set('train', epoch, train_init_op)
log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
if Config.is_master_process:
log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)

if FLAGS.dev_files:
# Validation
dev_loss = 0.0
total_steps = 0
for source, init_op in zip(dev_sources, dev_init_ops):
log_progress('Validating epoch %d on %s...' % (epoch, source))
if Config.is_master_process:
log_progress('Validating epoch %d on %s...' % (epoch, source))
set_loss, steps = run_set('dev', epoch, init_op, dataset=source)
dev_loss += set_loss * steps
total_steps += steps
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss))
if Config.is_master_process:
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss))

dev_loss = dev_loss / total_steps
dev_losses.append(dev_loss)
Expand All @@ -629,16 +666,19 @@ def __call__(self, progress, data, **kwargs):
else:
epochs_without_improvement = 0

# Save new best model
if dev_loss < best_dev_loss:
best_dev_loss = dev_loss
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
if Config.is_master_process:
# Save new best model
if dev_loss < best_dev_loss:
best_dev_loss = dev_loss
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step,
latest_filename='best_dev_checkpoint')
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))

# Early stopping
if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs:
log_info('Early stop triggered as the loss did not improve the last {} epochs'.format(
epochs_without_improvement))
if Config.is_master_process:
log_info('Early stop triggered as the loss did not improve the last {} epochs'.format(
epochs_without_improvement))
break

# Reduce learning rate on plateau
Expand All @@ -655,26 +695,31 @@ def __call__(self, progress, data, **kwargs):
# Reduce learning rate
session.run(reduce_learning_rate_op)
current_learning_rate = learning_rate_var.eval()
log_info('Encountered a plateau, reducing learning rate to {}'.format(
current_learning_rate))
if Config.is_master_process:
log_info('Encountered a plateau, reducing learning rate to {}'.format(
current_learning_rate))

# Overwrite best checkpoint with new learning rate value
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
log_info("Saved best validating model with reduced learning rate to: %s" % (save_path))
# Overwrite best checkpoint with new learning rate value
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step,
latest_filename='best_dev_checkpoint')
log_info("Saved best validating model with reduced learning rate to: %s" % (save_path))

if FLAGS.metrics_files:
# Read only metrics, not affecting best validation loss tracking
for source, init_op in zip(metrics_sources, metrics_init_ops):
log_progress('Metrics for epoch %d on %s...' % (epoch, source))
if Config.is_master_process:
log_progress('Metrics for epoch %d on %s...' % (epoch, source))
set_loss, _ = run_set('metrics', epoch, init_op, dataset=source)
log_progress('Metrics for epoch %d on %s - loss: %f' % (epoch, source, set_loss))
if Config.is_master_process:
log_progress('Metrics for epoch %d on %s - loss: %f' % (epoch, source, set_loss))

print('-' * 80)


except KeyboardInterrupt:
pass
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
if Config.is_master_process:
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
log_debug('Session closed.')


Expand Down Expand Up @@ -951,30 +996,32 @@ def main(_):
if FLAGS.train_files:
tfv1.reset_default_graph()
tfv1.set_random_seed(FLAGS.random_seed)

train()

if FLAGS.test_files:
tfv1.reset_default_graph()
test()
if Config.is_master_process:
if FLAGS.test_files:
tfv1.reset_default_graph()
test()

if FLAGS.export_dir and not FLAGS.export_zip:
tfv1.reset_default_graph()
export()
if FLAGS.export_dir and not FLAGS.export_zip:
tfv1.reset_default_graph()
export()

if FLAGS.export_zip:
tfv1.reset_default_graph()
FLAGS.export_tflite = True
if FLAGS.export_zip:
tfv1.reset_default_graph()
FLAGS.export_tflite = True

if listdir_remote(FLAGS.export_dir):
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
sys.exit(1)
if listdir_remote(FLAGS.export_dir):
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
sys.exit(1)

export()
package_zip()
export()
package_zip()

if FLAGS.one_shot_infer:
tfv1.reset_default_graph()
do_single_file_inference(FLAGS.one_shot_infer)
if FLAGS.one_shot_infer:
tfv1.reset_default_graph()
do_single_file_inference(FLAGS.one_shot_infer)


def run_script():
Expand Down
Loading