Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Data parallel training freezes due to different number of batches #75

Open
bschifferer opened this issue Sep 23, 2022 · 17 comments
Open
Assignees
Labels
bug Something isn't working P1
Milestone

Comments

@bschifferer
Copy link
Contributor

Bug description

In data parallel training, we start multiple workers with different initialization of the dataloader and train with horovod. After each batch update, the parameters are synced. Merlin dataloader has different number of batches depending on the selected rank. Therefore, some workers finishes the training loop and other workers are still training - this causes horovod to freeze.

import cudf
import os

import merlin.models.tf.dataset as tf_dataloader
import nvtabular as nvt

os.system('mkdir ./test/')

df = cudf.DataFrame({
    'col1': range(0,9000000)
})
df.to_parquet('./test/part_1.parquet')
df = cudf.DataFrame({
    'col1': range(0,10000000)
})
df.to_parquet('./test/part_2.parquet')
df = cudf.DataFrame({
    'col1': range(0,11000000)
})
df.to_parquet('./test/part_3.parquet')
df = cudf.DataFrame({
    'col1': range(0,12000000)
})
df.to_parquet('./test/part_4.parquet')

ds = nvt.Dataset('./test/*.parquet', part_size='100MB')
for i in range(4):
    train_dl = tf_dataloader.BatchedDataset(
        ds,
        batch_size = 1024*16,
        shuffle=True,
        drop_last=True,
        cat_names=['col1'],
        global_size=4,
        global_rank=i,
    )
    print(len(train_dl))

Output:

549
610
671
732
@bschifferer bschifferer added the bug Something isn't working label Sep 23, 2022
@rnyak
Copy link

rnyak commented Sep 26, 2022

@bschifferer will try to set the seed before the dataloader, and check it out.

@bschifferer
Copy link
Contributor Author

@jperez999 I provided an example with Merlin Models:
NVIDIA-Merlin/models#778

I add the seed_fn

train_dl = tf_dataloader.BatchedDataset(
    train,
    batch_size = batch_size,
    shuffle=True,
    drop_last=True,
    global_size=2,
    global_rank=hvd.rank(),
    seed_fn=seed_fn
)
print(len(train_dl))

When I check the print statement, I get following split:

[1,0]<stdout>:2
[1,0]<stdout>:0
[1,1]<stdout>:2
[1,1]<stdout>:1
[1,0]<stdout>:322
[1,1]<stdout>:104

@EvenOldridge
Copy link
Member

@jperez999 @benfred @rjzamora
It looks like batching isn't correctly splitting the datasets? Is this particular to multi-gpu or does the problem also occur on single GPU but it just doesn't show an error.

Trying to figure out if this has always been broken or if it's a recent change.

@rnyak rnyak assigned edknv and unassigned benfred Oct 19, 2022
@jperez999
Copy link
Collaborator

jperez999 commented Oct 19, 2022

So this is not the correct way to use the merlin dataloader with horovod. This requires a lot more background information. You should never be creating dataloaders in a for loop. When dealing with horovod you should follow the example in the tests in nvtabular https://github.com/NVIDIA-Merlin/NVTabular/blob/main/tests/unit/loader/test_tf_dataloader.py#L537. Notice that to use horovod you need to use the horovodrun subprocess. And you need to ensure you are also using the supplied wrapper as it adds the necessary variables for mpi to run under the hood, located here: https://github.com/NVIDIA-Merlin/NVTabular/blob/main/examples/multi-gpu-movielens/hvd_wrapper.sh.

@EvenOldridge
Copy link
Member

@bschifferer can you follow up and make sure we're using HV correctly. We'll probably need to find a way to make clear to our customers how to properly set this up, even if it's just giving the links that @jperez999 shared more highlights.

@edknv
Copy link
Contributor

edknv commented Oct 19, 2022

@jperez999 Is there a way to produce equal number of batches so that the workload is balanced across workers? Although nvtabular seems to produce equal-sized batches in tf_trainer.py, the number of batches are different (hence the need for hvd.join() https://github.com/NVIDIA-Merlin/NVTabular/blob/main/examples/multi-gpu-movielens/tf_trainer.py#L142).

for batch, (examples, labels) in enumerate(train_dataset_tf):
    loss_value = training_step(examples, labels, batch == 0)
print(f"There are {batch} batches in worker {hvd.local_rank()}.")
#hvd.join()
[1,1]<stdout>:There are 548 batches in worker 1.
[1,0]<stdout>:There are 670 batches in worker 0.

Without hvd.join(), worker 1 will terminate before worker 0 and horovod will crash. Although it does work with hvd.join(), it doesn't seem ideal to have one worker sit idle while the other one keeps processing remaining batches.

One workaround I have is to repartition the dataset with something like

train = Dataset(output_path / "train" / "*.parquet")
ddf = train.to_ddf().repartition(npartitions=hvd.size())
train = Dataset(ddf, schema=train.schema)

but I'm wondering if there is a better way to do this in the dataloader.

Edit: I ran horovodrun -np 2 sh hvd_wrapper.sh python tf_trainer.py --dir_in $BASE_DIR --batch_size 16384 using tf_trainer.py and hvd_wrapper.sh.

@jperez999
Copy link
Collaborator

So I just ran this unit test: pytest tests/unit/loader/test_tf_dataloader.py::test_horovod_multigpu And it runs as expected. There are five partitions spread across two workers, so naturally one worker will get more partitions than the other. The dataloader is designed like this. Now what can happens is that, depending on the batch size, you can end up slicing your partition into much smaller pieces. This could mean that one partition could give you 100+ batches and if that worker has one more partition than the other workers... then you will end up with 100 extra batches in that worker.

Remember that the split does not happen based on batches... it happens based on partitions. Those partitions are subsequently broken down into chunks of batch_size. You can try to repartition the dataset, so that it will put out the same number of partitions in each worker. But even then... you would need to guarantee that the partitions are all the same size. This is a gotcha we have known since the creation of the dataloader. because partitions are not merged across files. Lets say you have a dataset that has 2 files and in file one there 225 rows, and your partition size if 50 rows, then that first file we have 5 partitions (even though the last partition is half full). Then file 2 has only 150 rows, then this file will have 3 partitions and you will find yourself in a situation where one worker will get four full partitions and the other will get 3 full one partial. Now you can extrapolate that to a scenario where all files end with half partitions... you see how you can find yourself with non-full partitions littered across your dataset?
And the current time dask has no way to fix this issue. it cant continue one partition between files. Another thing to remember is that if you have a dataset with 10 partitions and you repartition to 1000 you may want to also change the parts_per_chunk variable to allow you to create bigger chunks of data.

@rnyak
Copy link

rnyak commented May 19, 2023

@jperez999 @EvenOldridge
The very same issue has been reported by @bschifferer for multi-gpu training with PyT DDP in TF4Rec example. With a work around solution he could help a user, but if we could get his fixed that'd be great. @bschifferer do you have a repro for TF4Rec example?

thanks.

@viswa-nvidia viswa-nvidia added P1 and removed P0 labels Jun 13, 2023
@bschifferer
Copy link
Contributor Author

bschifferer commented Jun 29, 2023

Even if it is a single file, I can have different number of batches. The part_size parameter controls the number of partitions and changes the number of batches

import cudf
import os
import numpy as np

from merlin.loader.tensorflow import Loader
from merlin.io import Dataset

df = cudf.DataFrame({
    'col1': np.random.randint(0,1,size=100_000_000)
})
df.to_parquet('single2.parquet')
ds = Dataset('single2.parquet', part_size='100MB')
print('npartitions:' + str(ds.npartitions))
for i in range(4):
    train_dl = Loader(
        ds,
        batch_size = 1024,
        shuffle=True,
        drop_last=True,
        global_size=4,
        global_rank=i,
    )
    train_dl.global_rank = i
    print('Dataloader: ' + str(i) + ' Num Batches: ' + str(len(train_dl)))

Output:

npartitions:9
Dataloader: 0 Num Batches: 35156
Dataloader: 1 Num Batches: 35156
Dataloader: 2 Num Batches: 27343
Dataloader: 3 Num Batches: 0

Changing part_size to 200MB

npartitions:4
Dataloader: 0 Num Batches: 24414
Dataloader: 1 Num Batches: 24414
Dataloader: 2 Num Batches: 24414
Dataloader: 3 Num Batches: 24414

Not providing any part_size:

npartitions:1

File /usr/local/lib/python3.8/dist-packages/merlin/dataloader/loader_base.py:224, in LoaderBase._indices_for_process(self)
    219 if len(self.indices) < self.global_size:
    220     warnings.warn(
    221         f"""You have more processes({self.global_size}) than dataset
    222             partitions({len(self.indices)}), reduce the number of processes."""
    223     )
--> 224     raise IndexError
    225 per_worker = _num_steps(len(self.indices), self.global_size)
    226 # identify process rank out of all processes (not local rank)

IndexError: 

@bschifferer
Copy link
Contributor Author

bschifferer commented Jun 29, 2023

Let's repartition the dataset based on here

import cudf
import os
import numpy as np

from merlin.loader.tensorflow import Loader
import nvtabular as nvt

df = cudf.DataFrame({
    'col1': np.random.randint(0,1,size=100_000_000)
})
df.to_parquet('single2.parquet')
ds = Dataset('single2.parquet', part_size='100MB')
ddf = ds.to_ddf()
ddf2 = ddf.repartition(npartitions=4*12)
ds2 = Dataset(ddf2)
print('npartitions:' + str(ds2.npartitions))
for i in range(4):
    train_dl = Loader(
        ds2,
        batch_size = 1024,
        shuffle=True,
        drop_last=True,
        global_size=4,
        global_rank=i,
    )
    train_dl.global_rank = i
    print('Dataloader: ' + str(i) + ' Num Batches: ' + str(len(train_dl)))

Output:

npartitions:48
Dataloader: 0 Num Batches: 28125
Dataloader: 1 Num Batches: 28125
Dataloader: 2 Num Batches: 28125
Dataloader: 3 Num Batches: 13281

Output with part_size=200MB:

npartitions:48
Dataloader: 0 Num Batches: 24414
Dataloader: 1 Num Batches: 24414
Dataloader: 2 Num Batches: 24414
Dataloader: 3 Num Batches: 24414

Output with no partsize:

npartitions:48
Dataloader: 0 Num Batches: 24414
Dataloader: 1 Num Batches: 24414
Dataloader: 2 Num Batches: 24414
Dataloader: 3 Num Batches: 24414

@bschifferer
Copy link
Contributor Author

Having multiple files

import cudf
import os
import numpy as np

from merlin.loader.tensorflow import Loader
import nvtabular as nvt

df = cudf.DataFrame({
    'col1': np.random.randint(0,1,size=100_000_000)
})
df.to_parquet('single2_1.parquet')
df = cudf.DataFrame({
    'col1': np.random.randint(0,1,size=50_000_000)
})
df.to_parquet('single2_2.parquet')
ds = Dataset(['single2_1.parquet', 'single2_2.parquet'], part_size='200MB')
print('npartitions:' + str(ds.npartitions))
for i in range(4):
    train_dl = Loader(
        ds,
        batch_size = 1024,
        shuffle=True,
        drop_last=True,
        global_size=4,
        global_rank=i,
    )
    train_dl.global_rank = i
    print('Dataloader: ' + str(i) + ' Num Batches: ' + str(len(train_dl)))

Output:

npartitions:6
Dataloader: 0 Num Batches: 48828
Dataloader: 1 Num Batches: 48828
Dataloader: 2 Num Batches: 48828
Dataloader: 3 Num Batches: 0

@bschifferer
Copy link
Contributor Author

bschifferer commented Jun 29, 2023

Having multiple files with repartition:

import cudf
import os
import numpy as np

from merlin.loader.tensorflow import Loader
import nvtabular as nvt

df = cudf.DataFrame({
    'col1': np.random.randint(0,1,size=100_000_000)
})
df.to_parquet('single2_1.parquet')
df = cudf.DataFrame({
    'col1': np.random.randint(0,1,size=90_000_000)
})
df.to_parquet('single2_2.parquet')
ds = Dataset(['single2_1.parquet', 'single2_2.parquet'], part_size='1000MB')
ddf = ds.to_ddf()
ddf2 = ddf.repartition(npartitions=4*12)
ds2 = Dataset(ddf2)
print('npartitions:' + str(ds2.npartitions))
for i in range(4):
    train_dl = Loader(
        ds2,
        batch_size = 1024,
        shuffle=True,
        drop_last=True,
        global_size=4,
        global_rank=i,
    )
    train_dl.global_rank = i
    print('Dataloader: ' + str(i) + ' Num Batches: ' + str(len(train_dl)))

Output:

npartitions:48
Dataloader: 0 Num Batches: 48828
Dataloader: 1 Num Batches: 48828
Dataloader: 2 Num Batches: 43945
Dataloader: 3 Num Batches: 43945

@bschifferer
Copy link
Contributor Author

If we use NVTabular to process multiple input files, it will generate multiple output files with the same shapes:

df = cudf.DataFrame({
    'col1': np.random.randint(0,1,size=100_000_000)
})
df.to_parquet('single2_1.parquet')
df = cudf.DataFrame({
    'col1': np.random.randint(0,1,size=90_000_000)
})
df.to_parquet('single2_2.parquet')

features = ['col1'] >> nvt.ops.Categorify()

workflow = nvt.Workflow(features)
ds = Dataset(['single2_1.parquet', 'single2_2.parquet'], part_size='1000MB')
workflow.fit(ds)
workflow.transform(ds).to_parquet('/raid/test/')

df = cudf.read_parquet('/raid/test/part_0.parquet')
print(df.shape)
df = cudf.read_parquet('/raid/test/part_1.parquet')
print(df.shape)

Output:

(100000000, 1)
(90000000, 1)

@bschifferer
Copy link
Contributor Author

bschifferer commented Jun 29, 2023

Memory Foot Print

# Creating the dataset
import cudf
import os
import numpy as np

from merlin.loader.tensorflow import Loader
from merlin.io import Dataset

df = cudf.DataFrame({
    'col1': np.random.random(size=500_000_000)
})
df.to_parquet('/raid/single2_1.parquet')
# restart the environment to free GPU memory
import cudf
import os
import numpy as np

from merlin.loader.tensorflow import Loader
from merlin.io import Dataset

ds = Dataset('/raid/single2_1.parquet')
#check nvidia-smi
ddf = ds.to_ddf()
ddf2 = ddf.repartition(npartitions=4*100)
ds2 = Dataset(ddf2)
print('npartitions:' + str(ds2.npartitions))
#check nvidia-smi
for i in range(4):
    train_dl = Loader(
        ds2,
        batch_size = 1024,
        shuffle=True,
        drop_last=True,
        global_size=4,
        global_rank=i,
    )
    train_dl.global_rank = i
    print('Dataloader: ' + str(i) + ' Num Batches: ' + str(len(train_dl)))
    break
#check nvidia-smi

@bschifferer
Copy link
Contributor Author

Memory Foot Print II

Dataset Creation

import cudf
import os
import numpy as np

from merlin.loader.tensorflow import Loader
from merlin.io import Dataset

df = cudf.DataFrame({
    'col1': np.random.random(size=500_000_000)
})
df.to_parquet('/raid/single2_1.parquet')

Repartition

import os

os.environ["CUDA_VISIBLE_DEVICES"] = str("0")
os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"

import cudf
import os
import numpy as np

from merlin.loader.tensorflow import Loader
from merlin.io import Dataset

ds = Dataset('/raid/single2_1.parquet')
# Check NVIDIASMI -> 825MB
ddf = ds.to_ddf()
ddf2 = ddf.repartition(npartitions=4*1000)
ds2 = Dataset(ddf2)
print('npartitions:' + str(ds2.npartitions))
# Check NVIDIASMI -> 4.6GB
train_dl = Loader(
    ds2,
    batch_size = 1024,
    shuffle=True,
    drop_last=True,
    global_size=4,
    global_rank=0,
)
# Check NVIDIASMI -> 4.6GB
import gc

for ibatch, batch in enumerate(train_dl):
    if ibatch>100:
        break
    gc.collect()
    #Check NVIDIASMI in another process -> ~12GB

Without Repartition

import os

os.environ["CUDA_VISIBLE_DEVICES"] = str("0")
os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"

import cudf
import os
import numpy as np

from merlin.loader.tensorflow import Loader
from merlin.io import Dataset

ds = Dataset('/raid/single2_1.parquet')
# Check NVIDIASMI -> 825MB
ddf = ds.to_ddf()
ddf2 = ddf.repartition(npartitions=4*1000)
ddf2.to_parquet('/raid/single2_1_repart_2.parquet')
# Reset Environment

import os

os.environ["CUDA_VISIBLE_DEVICES"] = str("0")
os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"

import cudf
import os
import numpy as np

from merlin.loader.tensorflow import Loader
from merlin.io import Dataset

ds2 = Dataset('/raid/single2_1_repart_2.parquet', part_size='100MB')
print('npartitions:' + str(ds2.npartitions))
# Check NVIDIASMI -> 825MB
train_dl = Loader(
    ds2,
    batch_size = 1024,
    shuffle=True,
    drop_last=True,
    global_size=4,
    global_rank=0,
)
# Check NVIDIASMI -> 825MB
import gc

for ibatch, batch in enumerate(train_dl):
    if ibatch>100:
        break
    gc.collect()
    #Check NVIDIASMI in another process -> 825MB

@bschifferer
Copy link
Contributor Author

import cudf
import numpy as np

df = cudf.DataFrame({
    'user_id': np.random.randint(0,10,size=10_000_000),
    'item_id': np.random.randint(0,10,size=10_000_000),
    'target': np.random.randint(0,2,size=10_000_000),
})
df.to_parquet('/raid/single2_1.parquet')
df = cudf.DataFrame({
    'user_id': np.random.randint(0,10,size=9_000_000),
    'item_id': np.random.randint(0,10,size=9_000_000),
    'target': np.random.randint(0,2,size=9_000_000),
})
df.to_parquet('/raid/single2_2.parquet')
%%writefile './tf_trainer.py'

import os

MPI_SIZE = int(os.getenv("OMPI_COMM_WORLD_SIZE"))
MPI_RANK = int(os.getenv("OMPI_COMM_WORLD_RANK"))

os.environ["CUDA_VISIBLE_DEVICES"] = str(MPI_RANK)

import cudf
import os
import numpy as np

from merlin.loader.tensorflow import Loader
import nvtabular as nvt
from merlin.io import Dataset
import merlin.models.tf as mm
import tensorflow as tf

ds = Dataset(['/raid/single2_1.parquet', '/raid/single2_2.parquet'], part_size='1000MB')
ddf = ds.to_ddf()
ddf2 = ddf.repartition(npartitions=4*12)

from merlin.schema import ColumnSchema, Schema, Tags
columns_list = []
columns_list.append(
    ColumnSchema(
        'user_id',
        tags=[Tags.CATEGORICAL],
        is_list=False,
        is_ragged=False,
        dtype='int64',
        properties={
            "domain": {"min": 0, "max": 20},
            "embedding_sizes": {"cardinality": 20, "dimension": 8}
        }
))
columns_list.append(
    ColumnSchema(
        'item_id',
        tags=[Tags.CATEGORICAL],
        is_list=False,
        is_ragged=False,
        dtype='int64',
        properties={
            "domain": {"min": 0, "max": 20},
            "embedding_sizes": {"cardinality": 20, "dimension": 8}
        }
))
columns_list.append(
    ColumnSchema(
        'target',
        dtype='int32',
        tags=[Tags.TARGET, Tags.BINARY_CLASSIFICATION],
        is_list=False,
        is_ragged=False,
))
schema = Schema(columns_list)

ds2 = Dataset(ddf2,schema=schema)
train_dl = Loader(
    ds2,
    batch_size = 1024*20,
    shuffle=True,
    drop_last=True,
    global_size=MPI_SIZE,
    global_rank=MPI_RANK,
)

model = mm.Model.from_block(mm.MLPBlock([64, 32]),
    schema, prediction_tasks=mm.BinaryOutput('target')
)

opt = tf.keras.optimizers.legacy.Adagrad(learning_rate=0.01)
model.compile(optimizer=opt, run_eagerly=False, metrics=[tf.keras.metrics.AUC()])
losses = model.fit(
    train_dl
)
!horovodrun -np 4 python tf_trainer.py

@jperez999
Copy link
Collaborator

So I have done a preliminary investigation into this issue based on the reproducer and I have found that the following code, using plain tensorflow and horovod does work as expected. In the following example, I used only two processes because I have only two GPUs available on my dev machine. Using the code provided previously in this thread, I was still able to reproduce the error (hang during training). However, the following code runs successfully:

import os

MPI_SIZE = int(os.getenv("OMPI_COMM_WORLD_SIZE"))
MPI_RANK = int(os.getenv("OMPI_COMM_WORLD_RANK"))

os.environ["CUDA_VISIBLE_DEVICES"] = str(MPI_RANK)

import cudf
import os
import numpy as np

from merlin.loader.tensorflow import Loader
import nvtabular as nvt
from merlin.io import Dataset
from merlin.core.dispatch import concat
import merlin.models.tf as mm
import tensorflow as tf
import cudf
import numpy as np
import horovod.tensorflow as hvd

ds = Dataset(["single2_1.parquet", "single2_2.parquet"], part_size='1000MB')
ddf = ds.to_ddf()
ddf2 = ddf.repartition(npartitions=4*12)

from merlin.schema import ColumnSchema, Schema, Tags
columns_list = []
columns_list.append(
    ColumnSchema(
        'user_id',
        tags=[Tags.CATEGORICAL],
        is_list=False,
        is_ragged=False,
        dtype='int64',
        properties={
            "domain": {"min": 0, "max": 20},
            "embedding_sizes": {"cardinality": 20, "dimension": 8}
        }
))
columns_list.append(
    ColumnSchema(
        'item_id',
        tags=[Tags.CATEGORICAL],
        is_list=False,
        is_ragged=False,
        dtype='int64',
        properties={
            "domain": {"min": 0, "max": 20},
            "embedding_sizes": {"cardinality": 20, "dimension": 8}
        }
))
columns_list.append(
    ColumnSchema(
        'target',
        dtype='int32',
        tags=[Tags.TARGET, Tags.BINARY_CLASSIFICATION],
        is_list=False,
        is_ragged=False,
))
schema = Schema(columns_list)

ds2 = Dataset(ddf2,schema=schema)
train_dl = Loader(
    ds2,
    batch_size = 1024*20,
    shuffle=True,
    drop_last=True,
    global_size=MPI_SIZE,
    global_rank=MPI_RANK,
)
inputs = {} 
in_layers = []
for col in ['user_id',"item_id"]:
    inputs[col] = tf.keras.Input(name=col, dtype=tf.int64, shape=(1,))
    in_layers.append(tf.keras.layers.Dense(10, activation="relu")(inputs[col]))
# for col in ['target']:
    # inputs[col] = tf.keras.Input(name=col, dtype=tf.int32, shape=(1,))

x = tf.keras.layers.concatenate(in_layers)
x = tf.keras.layers.Dense(30, activation="relu")(x)
x = tf.keras.layers.Dense(10, activation="relu")(x)
x = tf.keras.layers.Dense(1, activation="sigmoid")(x)
model = tf.keras.Model(inputs=inputs, outputs=x)
loss = tf.losses.BinaryCrossentropy()
opt = tf.keras.optimizers.legacy.SGD(0.01 * hvd.size())
opt = hvd.DistributedOptimizer(opt)
checkpoint_dir = "./checkpoints"
checkpoint = tf.train.Checkpoint(model=model, optimizer=opt)


@tf.function(experimental_relax_shapes=True)
def training_step(examples, labels, first_batch):
    with tf.GradientTape() as tape:
        probs = model(examples, training=True)
        loss_value = loss(labels, probs)
    # Horovod: add Horovod Distributed GradientTape.
    tape = hvd.DistributedGradientTape(tape, sparse_as_dense=True)
    grads = tape.gradient(loss_value, model.trainable_variables)
    opt.apply_gradients(zip(grads, model.trainable_variables))
    # Horovod: broadcast initial variable states from rank 0 to all other processes.
    # This is necessary to ensure consistent initialization of all workers when
    # training is started with random weights or restored from a checkpoint.
    #
    # Note: broadcast should be done after the first gradient step to ensure optimizer
    # initialization.
    if first_batch:
        hvd.broadcast_variables(model.variables, root_rank=0)
        hvd.broadcast_variables(opt.variables(), root_rank=0)
    return loss_value


# Horovod: adjust number of steps based on number of GPUs.
for batch, (examples, labels) in enumerate(train_dl):
    loss_value = training_step(examples, labels, batch == 0)
    if batch % 10 == 0 and hvd.local_rank() == 0:
        print("Step #%d\tLoss: %.6f" % (batch, loss_value))
print(batch)
hvd.join()
# Horovod: save checkpoints only on worker 0 to prevent other workers from
# corrupting it.
if hvd.rank() == 0:
    checkpoint.save(checkpoint_dir)

In the code you get the following output:

[1,1]<stderr>:Instructions for updating:
[1,1]<stderr>:Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
[1,0]<stdout>:Step #0   Loss: 0.849249
[1,0]<stdout>:Step #10  Loss: 0.705233
[1,0]<stdout>:Step #20  Loss: 0.696026
[1,0]<stdout>:Step #30  Loss: 0.694616
[1,0]<stdout>:Step #40  Loss: 0.694334
[1,0]<stdout>:Step #50  Loss: 0.693732
[1,0]<stdout>:Step #60  Loss: 0.694103
[1,0]<stdout>:Step #70  Loss: 0.693531
[1,0]<stdout>:Step #80  Loss: 0.693285
[1,0]<stdout>:Step #90  Loss: 0.693885
[1,0]<stdout>:Step #100 Loss: 0.693521
[1,0]<stdout>:Step #110 Loss: 0.693950
[1,0]<stdout>:Step #120 Loss: 0.693708
[1,0]<stdout>:Step #130 Loss: 0.693711
[1,0]<stdout>:Step #140 Loss: 0.693409
[1,0]<stdout>:Step #150 Loss: 0.693746
[1,0]<stdout>:Step #160 Loss: 0.693621
[1,0]<stdout>:Step #170 Loss: 0.693788
[1,0]<stdout>:Step #180 Loss: 0.693215
[1,0]<stdout>:Step #190 Loss: 0.693627
[1,0]<stdout>:Step #200 Loss: 0.693268
[1,0]<stdout>:Step #210 Loss: 0.693444
[1,0]<stdout>:Step #220 Loss: 0.693600
[1,0]<stdout>:Step #230 Loss: 0.693448
[1,0]<stdout>:Step #240 Loss: 0.693495
[1,0]<stdout>:Step #250 Loss: 0.693269
[1,0]<stdout>:Step #260 Loss: 0.693298
[1,0]<stdout>:Step #270 Loss: 0.693406
[1,0]<stdout>:Step #280 Loss: 0.693249
[1,0]<stdout>:Step #290 Loss: 0.693375
[1,0]<stdout>:Step #300 Loss: 0.693253
[1,0]<stdout>:Step #310 Loss: 0.693140
[1,0]<stdout>:Step #320 Loss: 0.693589
[1,0]<stdout>:Step #330 Loss: 0.693192
[1,0]<stdout>:Step #340 Loss: 0.693360
[1,0]<stdout>:Step #350 Loss: 0.693374
[1,0]<stdout>:Step #360 Loss: 0.693204
[1,0]<stdout>:Step #370 Loss: 0.693087
[1,0]<stdout>:Step #380 Loss: 0.693173
[1,0]<stdout>:Step #390 Loss: 0.693287
[1,0]<stdout>:Step #400 Loss: 0.693279
[1,0]<stdout>:Step #410 Loss: 0.693377
[1,0]<stdout>:Step #420 Loss: 0.692989
[1,0]<stdout>:Step #430 Loss: 0.693398
[1,1]<stdout>:438
[1,0]<stdout>:Step #440 Loss: 0.693277
[1,0]<stdout>:Step #450 Loss: 0.693598
[1,0]<stdout>:Step #460 Loss: 0.693274
[1,0]<stdout>:Step #470 Loss: 0.693230
[1,0]<stdout>:Step #480 Loss: 0.693435
[1,0]<stdout>:487

Notice that one of the processes has ~50 less batches than the other. And we are still able to complete training. The process with less batches is able to complete and the other continues more iterations until it has completed training on its batches. This leads me to believe the problem lies in the merlin models code. I have started my investigation into this and I have found that there are a few differences in how the code is being run. First, in merlin models we are using callback and secondly I was unable to find anywhere in the code where the hvd.join() command was used, third, the merlin models code relies on the super().fit call to the keras model class. I tried inserting the hvd.join command just after the super().fit call is made but that did not fix the issue (code still hangs). I will continue investigating, this is however, IMO, not a dataloader issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working P1
Projects
None yet
Development

No branches or pull requests

9 participants