Skip to content

Commit

Permalink
Move embedding_rocksdb_wrapper to its own header. (#3622)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3622

X-link: facebookresearch/FBGEMM#700

because we can't instantiate the kvTensorWrapper without EmbeddingRDBWrapper in C++.

$title.

Reviewed By: jiayulu

Differential Revision: D68727692

fbshipit-source-id: ca141c45ac2520323a55ab93a63ba617a19c8c57
  • Loading branch information
pradeepfn authored and facebook-github-bot committed Jan 28, 2025
1 parent f7d5dae commit bab9b62
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 128 deletions.
143 changes: 143 additions & 0 deletions fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include "kv_tensor_wrapper.h"

namespace ssd {

class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
public:
EmbeddingRocksDBWrapper(
std::string path,
int64_t num_shards,
int64_t num_threads,
int64_t memtable_flush_period,
int64_t memtable_flush_offset,
int64_t l0_files_per_compact,
int64_t max_D,
int64_t rate_limit_mbps,
int64_t size_ratio,
int64_t compaction_ratio,
int64_t write_buffer_size,
int64_t max_write_buffer_num,
double uniform_init_lower,
double uniform_init_upper,
int64_t row_storage_bitwidth = 32,
int64_t cache_size = 0,
bool use_passed_in_path = false,
int64_t tbe_unique_id = 0,
int64_t l2_cache_size_gb = 0,
bool enable_async_update = false)
: impl_(std::make_shared<ssd::EmbeddingRocksDB>(
path,
num_shards,
num_threads,
memtable_flush_period,
memtable_flush_offset,
l0_files_per_compact,
max_D,
rate_limit_mbps,
size_ratio,
compaction_ratio,
write_buffer_size,
max_write_buffer_num,
uniform_init_lower,
uniform_init_upper,
row_storage_bitwidth,
cache_size,
use_passed_in_path,
tbe_unique_id,
l2_cache_size_gb,
enable_async_update)) {}

void set_cuda(
Tensor indices,
Tensor weights,
Tensor count,
int64_t timestep,
bool is_bwd) {
return impl_->set_cuda(indices, weights, count, timestep, is_bwd);
}

void get_cuda(Tensor indices, Tensor weights, Tensor count) {
return impl_->get_cuda(indices, weights, count);
}

void set(Tensor indices, Tensor weights, Tensor count) {
return impl_->set(indices, weights, count);
}

void set_range_to_storage(
const Tensor& weights,
const int64_t start,
const int64_t length) {
return impl_->set_range_to_storage(weights, start, length);
}

void get(Tensor indices, Tensor weights, Tensor count, int64_t sleep_ms) {
return impl_->get(indices, weights, count, sleep_ms);
}

std::vector<int64_t> get_mem_usage() {
return impl_->get_mem_usage();
}

std::vector<double> get_rocksdb_io_duration(
const int64_t step,
const int64_t interval) {
return impl_->get_rocksdb_io_duration(step, interval);
}

std::vector<double> get_l2cache_perf(
const int64_t step,
const int64_t interval) {
return impl_->get_l2cache_perf(step, interval);
}

void compact() {
return impl_->compact();
}

void flush() {
return impl_->flush();
}

void reset_l2_cache() {
return impl_->reset_l2_cache();
}

void wait_util_filling_work_done() {
return impl_->wait_util_filling_work_done();
}

c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper> create_snapshot() {
auto handle = impl_->create_snapshot();
return c10::make_intrusive<EmbeddingSnapshotHandleWrapper>(handle, impl_);
}

void release_snapshot(
c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper> snapshot_handle) {
auto handle = snapshot_handle->handle;
CHECK_NE(handle, nullptr);
impl_->release_snapshot(handle);
}

int64_t get_snapshot_count() const {
return impl_->get_snapshot_count();
}

private:
friend class KVTensorWrapper;

// shared pointer since we use shared_from_this() in callbacks.
std::shared_ptr<ssd::EmbeddingRocksDB> impl_;
};

} // namespace ssd
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <torch/custom_class.h>

#include "./ssd_table_batched_embeddings.h"
#include "embedding_rocksdb_wrapper.h"
#include "fbgemm_gpu/utils/ops_utils.h"

using namespace at;
Expand Down Expand Up @@ -258,134 +259,6 @@ void compact_indices_cuda(

namespace ssd {

class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
public:
EmbeddingRocksDBWrapper(
std::string path,
int64_t num_shards,
int64_t num_threads,
int64_t memtable_flush_period,
int64_t memtable_flush_offset,
int64_t l0_files_per_compact,
int64_t max_D,
int64_t rate_limit_mbps,
int64_t size_ratio,
int64_t compaction_ratio,
int64_t write_buffer_size,
int64_t max_write_buffer_num,
double uniform_init_lower,
double uniform_init_upper,
int64_t row_storage_bitwidth = 32,
int64_t cache_size = 0,
bool use_passed_in_path = false,
int64_t tbe_unique_id = 0,
int64_t l2_cache_size_gb = 0,
bool enable_async_update = false)
: impl_(std::make_shared<ssd::EmbeddingRocksDB>(
path,
num_shards,
num_threads,
memtable_flush_period,
memtable_flush_offset,
l0_files_per_compact,
max_D,
rate_limit_mbps,
size_ratio,
compaction_ratio,
write_buffer_size,
max_write_buffer_num,
uniform_init_lower,
uniform_init_upper,
row_storage_bitwidth,
cache_size,
use_passed_in_path,
tbe_unique_id,
l2_cache_size_gb,
enable_async_update)) {}

void set_cuda(
Tensor indices,
Tensor weights,
Tensor count,
int64_t timestep,
bool is_bwd) {
return impl_->set_cuda(indices, weights, count, timestep, is_bwd);
}

void get_cuda(Tensor indices, Tensor weights, Tensor count) {
return impl_->get_cuda(indices, weights, count);
}

void set(Tensor indices, Tensor weights, Tensor count) {
return impl_->set(indices, weights, count);
}

void set_range_to_storage(
const Tensor& weights,
const int64_t start,
const int64_t length) {
return impl_->set_range_to_storage(weights, start, length);
}

void get(Tensor indices, Tensor weights, Tensor count, int64_t sleep_ms) {
return impl_->get(indices, weights, count, sleep_ms);
}

std::vector<int64_t> get_mem_usage() {
return impl_->get_mem_usage();
}

std::vector<double> get_rocksdb_io_duration(
const int64_t step,
const int64_t interval) {
return impl_->get_rocksdb_io_duration(step, interval);
}

std::vector<double> get_l2cache_perf(
const int64_t step,
const int64_t interval) {
return impl_->get_l2cache_perf(step, interval);
}

void compact() {
return impl_->compact();
}

void flush() {
return impl_->flush();
}

void reset_l2_cache() {
return impl_->reset_l2_cache();
}

void wait_util_filling_work_done() {
return impl_->wait_util_filling_work_done();
}

c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper> create_snapshot() {
auto handle = impl_->create_snapshot();
return c10::make_intrusive<EmbeddingSnapshotHandleWrapper>(handle, impl_);
}

void release_snapshot(
c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper> snapshot_handle) {
auto handle = snapshot_handle->handle;
CHECK_NE(handle, nullptr);
impl_->release_snapshot(handle);
}

int64_t get_snapshot_count() const {
return impl_->get_snapshot_count();
}

private:
friend class KVTensorWrapper;

// shared pointer since we use shared_from_this() in callbacks.
std::shared_ptr<ssd::EmbeddingRocksDB> impl_;
};

SnapshotHandle::SnapshotHandle(EmbeddingRocksDB* db) : db_(db) {
auto num_shards = db->num_shards();
CHECK_GT(num_shards, 0);
Expand Down

0 comments on commit bab9b62

Please sign in to comment.