Skip to content

Commit

Permalink
Remove indices and offsets copying from prefetch (#2186)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2186

Since `linearize_cache_indices` did not support the case where
`indices` and `offsets` have different types, we casted both `indices`
and `offsets` to the same type, that is `int64_t`.  In some cases,
this caused the memory requirment to surge causing the peak memory
requirement to increase.  This diff modifies the
`linearize_cache_indices` op to support when `indices` and `offsets`
have different types.

Reviewed By: ehsanardestani

Differential Revision: D51723551

fbshipit-source-id: f40c9cc6e8e4435a8cc01702edd85cf843835c07
  • Loading branch information
sryap authored and facebook-github-bot committed Dec 1, 2023
1 parent c58679a commit 0fc0d4e
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -505,10 +505,6 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> None:
if not self.lxu_cache_weights.numel():
return

# FIXME: check the int32_t range failure in https://fburl.com/gdoc/kcdnrnvg .
# The real failure should be in cache handling in https://fburl.com/ox3f26r0 .
indices, offsets = indices.long(), offsets.long()

linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
self.cache_hash_size_cumsum,
indices,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1158,7 +1158,6 @@ def _prefetch(self, indices: Tensor, offsets: Tensor) -> None:
if not self.lxu_cache_weights.numel():
return

(indices, offsets) = indices.long(), offsets.long()
linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
self.cache_hash_size_cumsum,
indices,
Expand Down
43 changes: 24 additions & 19 deletions fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ using namespace fbgemm_gpu;

namespace {

template <typename index_t>
template <typename index_t, typename offset_t>
__global__ __launch_bounds__(kMaxThreads) void linearize_cache_indices_kernel(
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
cache_hash_size_cumsum,
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
indices,
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<offset_t, 1, at::RestrictPtrTraits>
table_offsets,
pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
linear_cache_indices) {
const index_t index = blockIdx.x * blockDim.x + threadIdx.x;
if (index >= indices.size(0)) {
Expand Down Expand Up @@ -72,31 +72,36 @@ DLL_PUBLIC Tensor linearize_cache_indices_cuda(
const auto B = (offsets.size(0) - 1) / T;
TORCH_CHECK(B >= 0);

auto linear_cache_indices = at::empty_like(indices);
auto linear_cache_indices =
at::empty(indices.sizes(), indices.options().dtype(at::kLong));
const auto num_indices = indices.numel();
if (B == 0 || num_indices == 0) {
return linear_cache_indices;
}

auto table_offsets = offsets.slice(0, B, B * T, B);
const auto table_offsets = offsets.slice(0, B, B * T, B);

AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "linearize_cache_indices_kernel", [&] {
table_offsets.scalar_type(), "linearize_cache_indices_kernel_1", [&] {
using offset_t = index_t;
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "linearize_cache_indices_kernel_2", [&] {
#ifdef FBGEMM_GPU_MEMCHECK
const char* func_name = "linearize_cache_indices_kernel";
const char* func_name = "linearize_cache_indices_kernel";
#endif
linearize_cache_indices_kernel<<<
div_round_up(num_indices, kMaxThreads),
kMaxThreads,
0,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(
func_name, cache_hash_size_cumsum, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, table_offsets, index_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, linear_cache_indices, index_t, 1, 32));
C10_CUDA_KERNEL_LAUNCH_CHECK();
linearize_cache_indices_kernel<<<
div_round_up(num_indices, kMaxThreads),
kMaxThreads,
0,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(
func_name, cache_hash_size_cumsum, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, table_offsets, offset_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, linear_cache_indices, int64_t, 1, 32));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
return linear_cache_indices;
}
Expand Down

0 comments on commit 0fc0d4e

Please sign in to comment.