Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Multi-gpu training notebook is giving error if we generate schema from core #651

Closed
rnyak opened this issue Mar 16, 2023 · 7 comments · Fixed by #654
Closed

[BUG] Multi-gpu training notebook is giving error if we generate schema from core #651

rnyak opened this issue Mar 16, 2023 · 7 comments · Fixed by #654
Assignees
Labels
bug Something isn't working P0 status/needs-triage
Milestone

Comments

@rnyak
Copy link
Contributor

rnyak commented Mar 16, 2023

Bug description

I am getting the following error when I run multi-gpu training notebook

/usr/local/lib/python3.8/dist-packages/torch/distributed/launch.py:180: FutureWarning: The module torch.distributed.launch is deprecated
and will be removed in future. Use torchrun.
Note that --use_env is set by default in torchrun.
If your script expects `--local_rank` argument to be set, please
change it to read from `os.environ['LOCAL_RANK']` instead. See 
https://pytorch.org/docs/stable/distributed.html#launch-utility for 
further instructions

  warnings.warn(
WARNING:torch.distributed.run:
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
*****************************************
/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/tf.py:52: UserWarning: Tensorflow dtype mappings did not load successfully due to an error: No module named 'tensorflow'
  warn(f"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}")
/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/tf.py:52: UserWarning: Tensorflow dtype mappings did not load successfully due to an error: No module named 'tensorflow'
  warn(f"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}")
Traceback (most recent call last):
  File "pyt_trainer.py", line 41, in <module>
    input_module = tr.TabularSequenceFeatures.from_schema(
  File "/usr/local/lib/python3.8/dist-packages/transformers4rec/torch/features/sequence.py", line 193, in from_schema
    output: TabularSequenceFeatures = super().from_schema(  # type: ignore
  File "/usr/local/lib/python3.8/dist-packages/transformers4rec/torch/features/tabular.py", line 176, in from_schema
    output = cls(
  File "/usr/local/lib/python3.8/dist-packages/transformers4rec/torch/features/sequence.py", line 127, in __init__
    super().__init__(
  File "/usr/local/lib/python3.8/dist-packages/transformers4rec/torch/features/tabular.py", line 84, in __init__
    assert to_merge != {}, "Please provide at least one input layer"
AssertionError: Please provide at least one input layer
Traceback (most recent call last):
  File "pyt_trainer.py", line 41, in <module>
    input_module = tr.TabularSequenceFeatures.from_schema(
  File "/usr/local/lib/python3.8/dist-packages/transformers4rec/torch/features/sequence.py", line 193, in from_schema
    output: TabularSequenceFeatures = super().from_schema(  # type: ignore
  File "/usr/local/lib/python3.8/dist-packages/transformers4rec/torch/features/tabular.py", line 176, in from_schema
    output = cls(
  File "/usr/local/lib/python3.8/dist-packages/transformers4rec/torch/features/sequence.py", line 127, in __init__
    super().__init__(
  File "/usr/local/lib/python3.8/dist-packages/transformers4rec/torch/features/tabular.py", line 84, in __init__
    assert to_merge != {}, "Please provide at least one input layer"
AssertionError: Please provide at least one input layer
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 23905) of binary: /usr/bin/python
Traceback (most recent call last):
  File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/launch.py", line 195, in <module>
    main()
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/launch.py", line 191, in main
    launch(args)
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/launch.py", line 176, in launch
    run(args)
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/run.py", line 753, in run
    elastic_launch(
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/launcher/api.py", line 132, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/launcher/api.py", line 246, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
pyt_trainer.py FAILED
------------------------------------------------------------
Failures:
[1]:
  time      : 2023-03-16_19:30:24
  host      : 1902e905751e
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 23906)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-03-16_19:30:24
  host      : 1902e905751e
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 23905)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html

Steps/Code to reproduce bug

You need to run 01 and 03 notebooks in this folder in order. For dataset generation you can use

import calendar
import datetime

import numpy as np
import pandas as pd


def generate_synthetic_data(
    start_date: datetime.date, end_date: datetime.date, rows_per_day: int = 1000
) -> pd.DataFrame:
    assert end_date > start_date, "end_date must be later than start_date"

    number_of_days = (end_date - start_date).days
    total_number_of_rows = number_of_days * rows_per_day

    # Generate a long-tail distribution of item interactions. This simulates that some items are
    # more popular than others.
    long_tailed_item_distribution = np.clip(
        np.random.lognormal(3.0, 1.0, total_number_of_rows).astype(np.int64), 1, 50000
    )

    # generate random item interaction features
    df = pd.DataFrame(
        {
            "session_id": np.random.randint(70000, 80000, total_number_of_rows),
            "item_id": long_tailed_item_distribution,
        },
    )

    # generate category mapping for each item-id
    df["category"] = pd.cut(df["item_id"], bins=334, labels=np.arange(1, 335)).astype(
        np.int64
    )

    max_session_length = 60 * 60  # 1 hour

    def add_timestamp_to_session(session: pd.DataFrame):
        random_start_date_and_time = calendar.timegm(
            (
                start_date
                # Add day offset from start_date
                + datetime.timedelta(days=np.random.randint(0, number_of_days))
                # Add time offset within the random day
                + datetime.timedelta(seconds=np.random.randint(0, 86_400))
            ).timetuple()
        )
        session["timestamp"] = random_start_date_and_time + np.clip(
            np.random.lognormal(3.0, 1.0, len(session)).astype(np.int64),
            0,
            max_session_length,
        )
        return session

    df = df.groupby("session_id").apply(add_timestamp_to_session).reset_index()

    return df

interactions_df = generate_synthetic_data(datetime.date(2014, 4, 1), datetime.date(2014, 6, 30))
interactions_df = cudf.from_pandas(interactions_df)

Environment details

  • Transformers4Rec version:
  • Platform:
  • Python version:
  • Huggingface Transformers version:
  • PyTorch version (GPU?):
  • Tensorflow version (GPU?):

Additional context

I am using merlin-pytorch:23.02 image with the latest main branches pulled from libs.

@rnyak rnyak added bug Something isn't working status/needs-triage labels Mar 16, 2023
@rnyak rnyak changed the title [BUG] Multi-gpu training notebook is given error [BUG] Multi-gpu training notebook is giving error Mar 16, 2023
@rnyak rnyak added the P0 label Mar 16, 2023
@rnyak rnyak added this to the Merlin 23.03 milestone Mar 16, 2023
@rnyak rnyak changed the title [BUG] Multi-gpu training notebook is giving error [BUG] Multi-gpu training notebook is giving error if we generate schema from core Mar 20, 2023
@edknv
Copy link
Collaborator

edknv commented Mar 21, 2023

This might be due to a change from dataloader. Specifically, NVIDIA-Merlin/dataloader@dbf8816 (and related NVIDIA-Merlin/dataloader@4301447).

These dataloader changes did not make it to the 23.02 release (even though they are in the release-23.02 branch). You can see that the above commit is missing from the v23.02 release note. Also, verifying from the container:

$ docker run --rm -it --gpus all --net host -v ~/data:/workspace/data nvcr.io/nvidia/merlin/merlin-pytorch:23.02 bash
root@0810733-lcedt:/dataloader# git log
commit 02aad2124e247e6a4f229d6638eaaec0931aca8c (grafted, HEAD, tag: v23.02.00)
Author: Karl Higley <[email protected]>
Date:   Mon Feb 13 15:23:40 2023 -0500

    Replace `nnzs` with `row_lengths` for clarity (#99)

If I install the problematic commit and run the multi-gpu notebook, the notebook fails:

$ docker run --rm -it --gpus all --net host -v ~/data:/workspace/data nvcr.io/nvidia/merlin/merlin-pytorch:23.02 bash
root@0810733-lcedt:/opt/tritonserver# cd /dataloader/                                                                                                                                                                      
root@0810733-lcedt:/dataloader# git fetch origin 226ad6903a7abfb5c1288f20eaf7d91eb952e374                                                                                                                                  
root@0810733-lcedt:/dataloader# git checkout 226ad6903a7abfb5c1288f20eaf7d91eb952e374    
root@0810733-lcedt:/dataloader# pip install . --no-deps                                                                                                                                  

It works again if I revert:

root@0810733-lcedt:/transformers4rec# cd /dataloader/
root@0810733-lcedt:/dataloader# git checkout 02aad2124e247e6a4f229d6638eaaec0931aca8c
root@0810733-lcedt:/dataloader# pip install . --no-deps

I didn't have time today to see which condition is failing and test out a solution. (My guess is the local_rank environment variable from T4R/pytorch is not being passed to device correctly and we need to set this properly.)

@rnyak
Copy link
Contributor Author

rnyak commented Apr 5, 2023

I reopened this since I am getting error when running the multi-gpu notebook.

@bbozkaya
Copy link
Contributor

bbozkaya commented Apr 5, 2023

I get the error "RuntimeError: CUDA error at: /usr/local/include/rmm/device_uvector.hpp:316: cudaErrorIllegalAddress an illegal memory access was encountered" when I run this notebook on a 2-gpu 32GB NGC instance with 23.02 Pytorch container + the main branch pulled and compiled for all 6 Merlin libraries (core, nvtabular, dataloader, models, systems, transformers4rec). The error is generated after executing the cell "! torchrun --nproc_per_node 2 pyt_trainer.py --path "/workspace/data/preproc_sessions_by_day" --learning-rate 0.0005". The full error message is below:

File "pyt_trainer.py", line 101, in
recsys_trainer.train()
File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 1290, in train
for step, inputs in enumerate(epoch_iterator):
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 630, in next
data = self._next_data()
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 673, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py", line 34, in fetch
data.append(next(self.dataset_iter))
File "/usr/local/lib/python3.8/dist-packages/merlin/dataloader/torch.py", line 62, in next
converted_batch = self.convert_batch(super().next())
File "/usr/local/lib/python3.8/dist-packages/merlin/dataloader/loader_base.py", line 259, in next
return self._get_next_batch()
File "/usr/local/lib/python3.8/dist-packages/merlin/dataloader/loader_base.py", line 330, in _get_next_batch
batch = next(self._batch_itr)
File "/usr/local/lib/python3.8/dist-packages/merlin/dataloader/loader_base.py", line 367, in make_tensors
tensors_by_name = self._convert_df_to_tensors(gdf)
File "/usr/local/lib/python3.8/dist-packages/nvtx/nvtx.py", line 101, in inner
result = func(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/merlin/dataloader/loader_base.py", line 546, in _convert_df_to_tensors
if isinstance(leaves[0], list):
File "/usr/local/lib/python3.8/dist-packages/nvtx/nvtx.py", line 101, in inner
result = func(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/cudf/core/series.py", line 1171, in getitem
return self.loc[arg]
File "/usr/local/lib/python3.8/dist-packages/nvtx/nvtx.py", line 101, in inner
result = func(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/cudf/core/series.py", line 259, in getitem
return self._frame.iloc[arg]
File "/usr/local/lib/python3.8/dist-packages/nvtx/nvtx.py", line 101, in inner
result = func(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/cudf/core/series.py", line 181, in getitem
data = self._frame._get_elements_from_column(arg)
File "/usr/local/lib/python3.8/dist-packages/cudf/core/single_column_frame.py", line 385, in _get_elements_from_column
return self._column.element_indexing(int(arg))
File "/usr/local/lib/python3.8/dist-packages/cudf/core/column/column.py", line 453, in element_indexing
return libcudf.copying.get_element(self, idx).value
File "scalar.pyx", line 162, in cudf._lib.scalar.DeviceScalar.value.get
File "scalar.pyx", line 138, in cudf._lib.scalar.DeviceScalar._to_host_scalar
File "scalar.pyx", line 426, in cudf._lib.scalar._get_np_scalar_from_numeric
RuntimeError: CUDA error at: /usr/local/include/rmm/device_uvector.hpp:316: cudaErrorIllegalAddress an illegal memory access was encountered
Error in sys.excepthook:
Traceback (most recent call last):
File "/usr/local/lib/python3.8/dist-packages/exceptiongroup/_formatting.py", line 71, in exceptiongroup_excepthook
TypeError: 'NoneType' object is not callable

Original exception was:
Traceback (most recent call last):
File "cupy_backends/cuda/api/driver.pyx", line 217, in cupy_backends.cuda.api.driver.moduleUnload
File "cupy_backends/cuda/api/driver.pyx", line 60, in cupy_backends.cuda.api.driver.check_status
cupy_backends.cuda.api.driver.CUDADriverError: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered
Exception ignored in: 'cupy.cuda.function.Module.dealloc'
Traceback (most recent call last):
File "cupy_backends/cuda/api/driver.pyx", line 217, in cupy_backends.cuda.api.driver.moduleUnload
File "cupy_backends/cuda/api/driver.pyx", line 60, in cupy_backends.cuda.api.driver.check_status
cupy_backends.cuda.api.driver.CUDADriverError: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered
Error in sys.excepthook:
Traceback (most recent call last):
File "/usr/local/lib/python3.8/dist-packages/exceptiongroup/_formatting.py", line 71, in exceptiongroup_excepthook
TypeError: 'NoneType' object is not callable

Original exception was:
Traceback (most recent call last):
File "cupy_backends/cuda/api/driver.pyx", line 217, in cupy_backends.cuda.api.driver.moduleUnload
File "cupy_backends/cuda/api/driver.pyx", line 60, in cupy_backends.cuda.api.driver.check_status
cupy_backends.cuda.api.driver.CUDADriverError: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered
Exception ignored in: 'cupy.cuda.function.Module.dealloc'
Traceback (most recent call last):
File "cupy_backends/cuda/api/driver.pyx", line 217, in cupy_backends.cuda.api.driver.moduleUnload
File "cupy_backends/cuda/api/driver.pyx", line 60, in cupy_backends.cuda.api.driver.check_status
cupy_backends.cuda.api.driver.CUDADriverError: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered
Error in sys.excepthook:
Traceback (most recent call last):
File "/usr/local/lib/python3.8/dist-packages/exceptiongroup/_formatting.py", line 71, in exceptiongroup_excepthook
TypeError: 'NoneType' object is not callable

Original exception was:
Traceback (most recent call last):
File "cupy_backends/cuda/api/driver.pyx", line 217, in cupy_backends.cuda.api.driver.moduleUnload
File "cupy_backends/cuda/api/driver.pyx", line 60, in cupy_backends.cuda.api.driver.check_status
cupy_backends.cuda.api.driver.CUDADriverError: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered
Exception ignored in: 'cupy.cuda.function.Module.dealloc'
Traceback (most recent call last):
File "cupy_backends/cuda/api/driver.pyx", line 217, in cupy_backends.cuda.api.driver.moduleUnload
File "cupy_backends/cuda/api/driver.pyx", line 60, in cupy_backends.cuda.api.driver.check_status
cupy_backends.cuda.api.driver.CUDADriverError: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered
Error in sys.excepthook:
Traceback (most recent call last):
File "/usr/local/lib/python3.8/dist-packages/exceptiongroup/_formatting.py", line 71, in exceptiongroup_excepthook
TypeError: 'NoneType' object is not callable

Original exception was:
Traceback (most recent call last):
File "cupy_backends/cuda/api/driver.pyx", line 217, in cupy_backends.cuda.api.driver.moduleUnload
File "cupy_backends/cuda/api/driver.pyx", line 60, in cupy_backends.cuda.api.driver.check_status
cupy_backends.cuda.api.driver.CUDADriverError: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered
Exception ignored in: 'cupy.cuda.function.Module.dealloc'
Traceback (most recent call last):
File "cupy_backends/cuda/api/driver.pyx", line 217, in cupy_backends.cuda.api.driver.moduleUnload
File "cupy_backends/cuda/api/driver.pyx", line 60, in cupy_backends.cuda.api.driver.check_status
cupy_backends.cuda.api.driver.CUDADriverError: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered
Error in sys.excepthook:

Original exception was:
Error in sys.excepthook:

Original exception was:
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4406 closing signal SIGTERM
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 1 (pid: 4407) of binary: /usr/bin/python3
Traceback (most recent call last):
File "/usr/local/bin/torchrun", line 8, in
sys.exit(main())
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/elastic/multiprocessing/errors/init.py", line 346, in wrapper
return f(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/run.py", line 762, in main
run(args)
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/run.py", line 753, in run
elastic_launch(
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/launcher/api.py", line 132, in call
return launch_agent(self._config, self._entrypoint, list(args))
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/launcher/api.py", line 246, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:

pyt_trainer.py FAILED

Failures:
<NO_OTHER_FAILURES>

Root Cause (first observed failure):
[0]:
time : 2023-04-05_15:06:10
host : 4466477
rank : 1 (local_rank: 1)
exitcode : 1 (pid: 4407)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html

@edknv
Copy link
Collaborator

edknv commented Apr 5, 2023

It looks like we are seeing an error again due to this change in the dataloader: NVIDIA-Merlin/dataloader@1452e82. If I check out the previous commit (a075ebfd2afc17b97bf8b271bebfbbe308f288e3) the notebook works. I'm not sure exactly which change in the problematic commit is causing the notebook to fail.

@rnyak rnyak modified the milestones: Merlin 23.03, Merlin 23.04 Apr 6, 2023
@rnyak
Copy link
Contributor Author

rnyak commented Apr 6, 2023

It looks like we are seeing an error again due to this change in the dataloader: NVIDIA-Merlin/dataloader@1452e82. If I check out the previous commit (a075ebfd2afc17b97bf8b271bebfbbe308f288e3) the notebook works. I'm not sure exactly which change in the problematic commit is causing the notebook to fail.

@jperez999 and @karlhigley fyi.

@rnyak
Copy link
Contributor Author

rnyak commented Apr 12, 2023

@edknv is working on this NVIDIA-Merlin/dataloader#132 that might solve multi-gpu error.

@edknv
Copy link
Collaborator

edknv commented Apr 13, 2023

@edknv is working on this NVIDIA-Merlin/dataloader#132 that might solve multi-gpu error.

Unfortunately, NVIDIA-Merlin/dataloader#132 doesn't solve the problem for the notebook because we still get an error with list columns. See this issue for more details: NVIDIA-Merlin/dataloader#131.

@rnyak rnyak closed this as completed Jun 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working P0 status/needs-triage
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants