Skip to content

Commit

Permalink
Fix library loading for SSD TBE tests in OSS (#3069)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3069

X-link: facebookresearch/FBGEMM#164

- Fix library loading for SSD TBE tests in OSS

Reviewed By: duduyi2013

Differential Revision: D62159277

fbshipit-source-id: 037e737ca93bac10a2dff4b6ccd1ecb1c2baeb5b
  • Loading branch information
q10 authored and facebook-github-bot committed Sep 4, 2024
1 parent 4260382 commit 669b6c7
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 25 deletions.
33 changes: 8 additions & 25 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,13 @@
# pyre-strict
# pyre-ignore-all-errors[56]

from typing import Optional

import torch


def _load_torch_module(
unified_path: str, cuda_path: Optional[str] = None, hip_path: Optional[str] = None
) -> None:
try:
torch.ops.load_library(unified_path)
except Exception:
if torch.version.hip:
if not hip_path:
hip_path = f"{unified_path}_hip"
torch.ops.load_library(hip_path)
else:
if not cuda_path:
cuda_path = f"{unified_path}_cuda"
torch.ops.load_library(cuda_path)


_load_torch_module(
"//deeplearning/fbgemm/fbgemm_gpu:ssd_split_table_batched_embeddings"
)

from fbgemm_gpu.utils.loader import load_torch_module

try:
load_torch_module(
"//deeplearning/fbgemm/fbgemm_gpu:ssd_split_table_batched_embeddings"
)
except Exception:
pass

ASSOC = 32
6 changes: 6 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
29 changes: 29 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/utils/loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict
# pyre-ignore-all-errors[56]

from typing import Optional

import torch


def load_torch_module(
unified_path: str, cuda_path: Optional[str] = None, hip_path: Optional[str] = None
) -> None:
try:
torch.ops.load_library(unified_path)
except Exception:
if torch.version.hip:
if not hip_path:
hip_path = f"{unified_path}_hip"
torch.ops.load_library(hip_path)
else:
if not cuda_path:
cuda_path = f"{unified_path}_cuda"
torch.ops.load_library(cuda_path)

0 comments on commit 669b6c7

Please sign in to comment.