Skip to content

Commit

Permalink
add a method to tbe module to recompute buffers
Browse files Browse the repository at this point in the history
Summary: If we initialize tbe module on meta device, its buffers will be on meta and we need to recompute them before running inference. There're a couple place need to do this so this diff add a method to tbe module for this logic so it can be called from different places.

Differential Revision: D53838053
  • Loading branch information
842974287 authored and facebook-github-bot committed Feb 16, 2024
1 parent 66a8779 commit e4f9c29
Showing 1 changed file with 57 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,63 @@ def reset_embedding_spec_location(
for spec in self.embedding_specs
]

def recompute_module_buffers(self) -> None:
"""
Compute module buffers that're on meta device and are not materialized in reset_weights_placements_and_offsets().
Currently those buffers are `weights_tys`, `rows_per_table`, `D_offsets` and `bounds_check_warning`.
Pruning related or uvm related buffers are not computed right now.
"""
if (
self.weights_tys.device == self.current_device
or self.current_device.type == "meta"
):
return

weights_tys_int = [e[3].as_int() for e in self.embedding_specs]
self.weights_tys = torch.tensor(
[weights_tys_int[t] for t in self.feature_table_map],
device=self.current_device,
dtype=torch.uint8,
)
rows = [e[1] for e in self.embedding_specs]
self.rows_per_table = torch.tensor(
[rows[t] for t in self.feature_table_map],
device=self.current_device,
dtype=torch.int64,
)
dims = [e[2] for e in self.embedding_specs]
D_offsets_list = [dims[t] for t in self.feature_table_map]
D_offsets_list = [0] + list(accumulate(D_offsets_list))
self.D_offsets = torch.tensor(
D_offsets_list, device=self.current_device, dtype=torch.int32
)
self.bounds_check_warning = torch.tensor(
[0], device=self.current_device, dtype=torch.int64
)

# For pruning related or uvm related buffers, we just set them as empty tensors.
self.index_remapping_hash_table = torch.empty_like(
self.index_remapping_hash_table, device=self.current_device
)
self.index_remapping_hash_table_offsets = torch.empty_like(
self.index_remapping_hash_table_offsets, device=self.current_device
)
self.index_remappings_array = torch.empty_like(
self.index_remappings_array, device=self.current_device
)
self.index_remappings_array_offsets = torch.empty_like(
self.index_remappings_array_offsets, device=self.current_device
)
self.lxu_cache_weights = torch.empty_like(
self.lxu_cache_weights, device=self.current_device
)
self.original_rows_per_table = torch.empty_like(
self.original_rows_per_table, device=self.current_device
)
self.table_wise_cache_miss = torch.empty_like(
self.table_wise_cache_miss, device=self.current_device
)

def _apply_split(
self,
dev_size: int,
Expand Down

0 comments on commit e4f9c29

Please sign in to comment.