Skip to content

Commit

Permalink
Fix bugs in global kmeans (#13809)
Browse files Browse the repository at this point in the history
  • Loading branch information
MBkkt authored Feb 1, 2025
1 parent 07c4b6b commit 5d71ae4
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 40 deletions.
4 changes: 2 additions & 2 deletions ydb/core/kqp/opt/logical/kqp_opt_log_indexes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,10 +488,10 @@ TExprBase DoRewriteTopSortOverKMeansTree(
const auto clusters = std::max<ui32>(2, settings.clusters());
const auto levels = std::max<ui32>(1, settings.levels());
Y_ENSURE(level <= levels);
const auto levelTop = std::min<ui32>(kqpCtx.Config->KMeansTreeSearchTopSize.Get().GetOrElse(1), clusters);

// TODO(mbkkt) count should be customizable via query options
auto count = ctx.Builder(pos)
.Callable("Uint64").Atom(0, std::to_string(std::min<ui32>(4, clusters)), TNodeFlags::Default).Seal()
.Callable("Uint64").Atom(0, std::to_string(levelTop), TNodeFlags::Default).Seal()
.Build();

// TODO(mbkkt) Is it best way to do `SELECT FROM levelTable WHERE first_pk_column = 0`?
Expand Down
2 changes: 2 additions & 0 deletions ydb/core/kqp/provider/yql_kikimr_settings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ TKikimrConfiguration::TKikimrConfiguration() {
REGISTER_SETTING(*this, MaxTasksPerStage);
REGISTER_SETTING(*this, MaxSequentialReadsInFlight);

REGISTER_SETTING(*this, KMeansTreeSearchTopSize);

/* Runtime */
REGISTER_SETTING(*this, ScanQuery);

Expand Down
2 changes: 2 additions & 0 deletions ydb/core/kqp/provider/yql_kikimr_settings.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ struct TKikimrSettings {
NCommon::TConfSetting<ui32, false> MaxTasksPerStage;
NCommon::TConfSetting<ui32, false> MaxSequentialReadsInFlight;

NCommon::TConfSetting<ui32, false> KMeansTreeSearchTopSize;

/* Runtime */
NCommon::TConfSetting<bool, true> ScanQuery;

Expand Down
22 changes: 17 additions & 5 deletions ydb/core/kqp/ut/indexes/kqp_indexes_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2199,7 +2199,9 @@ Y_UNIT_TEST_SUITE(KqpIndexes) {
ORDER BY {} {}
LIMIT 3;
)", target, metric, direction)));
const TString indexQuery(Q1_(std::format(R"({}
const TString indexQuery(Q1_(std::format(R"(
pragma ydb.KMeansTreeSearchTopSize = "3";
{}
SELECT * FROM `/Root/TestTable` VIEW index
ORDER BY {} {}
LIMIT 3;
Expand All @@ -2214,6 +2216,7 @@ Y_UNIT_TEST_SUITE(KqpIndexes) {
LIMIT 3;
)", target, metric, metric, direction)));
const TString indexQuery(Q1_(std::format(R"({}
pragma ydb.KMeansTreeSearchTopSize = "2";
SELECT {}, `/Root/TestTable`.* FROM `/Root/TestTable` VIEW index
ORDER BY {} {}
LIMIT 3;
Expand All @@ -2227,7 +2230,9 @@ Y_UNIT_TEST_SUITE(KqpIndexes) {
ORDER BY m {}
LIMIT 3;
)", target, metric, direction)));
const TString indexQuery(Q1_(std::format(R"({}
const TString indexQuery(Q1_(std::format(R"(
pragma ydb.KMeansTreeSearchTopSize = "1";
{}
SELECT {} AS m, `/Root/TestTable`.* FROM `/Root/TestTable` VIEW index
ORDER BY m {}
LIMIT 3;
Expand All @@ -2241,18 +2246,18 @@ Y_UNIT_TEST_SUITE(KqpIndexes) {
std::string_view function,
std::string_view direction) {
// target is left, member is right
// DoPositiveQueriesVectorIndexOrderBy(session, function, direction, "$target", "emb");
DoPositiveQueriesVectorIndexOrderBy(session, function, direction, "$target", "emb");
// target is right, member is left
DoPositiveQueriesVectorIndexOrderBy(session, function, direction, "emb", "$target");
}

void DoPositiveQueriesVectorIndexOrderByCosine(TSession& session) {
// distance, default direction
// DoPositiveQueriesVectorIndexOrderBy(session, "CosineDistance", "");
DoPositiveQueriesVectorIndexOrderBy(session, "CosineDistance", "");
// distance, asc direction
DoPositiveQueriesVectorIndexOrderBy(session, "CosineDistance", "ASC");
// similarity, desc direction
// DoPositiveQueriesVectorIndexOrderBy(session, "CosineSimilarity", "DESC");
DoPositiveQueriesVectorIndexOrderBy(session, "CosineSimilarity", "DESC");
}

TSession DoCreateTableForVectorIndex(TTableClient& db, bool nullable) {
Expand All @@ -2272,6 +2277,13 @@ Y_UNIT_TEST_SUITE(KqpIndexes) {
.AddNonNullableColumn("data", EPrimitiveType::String);
}
tableBuilder.SetPrimaryKeyColumns({"pk"});
tableBuilder.BeginPartitioningSettings()
.SetMinPartitionsCount(3)
.EndPartitioningSettings();
auto partitions = TExplicitPartitions{}
.AppendSplitPoints(TValueBuilder{}.BeginTuple().AddElement().OptionalInt64(4).EndTuple().Build())
.AppendSplitPoints(TValueBuilder{}.BeginTuple().AddElement().OptionalInt64(6).EndTuple().Build());
tableBuilder.SetPartitionAtKeys(partitions);
auto result = session.CreateTable("/Root/TestTable", tableBuilder.Build()).ExtractValueSync();
UNIT_ASSERT_VALUES_EQUAL(result.IsTransportError(), false);
UNIT_ASSERT_VALUES_EQUAL_C(result.GetStatus(), EStatus::SUCCESS, result.GetIssues().ToString());
Expand Down
19 changes: 18 additions & 1 deletion ydb/core/scheme/scheme_tablecell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,24 @@ size_t TSerializedCellVec::SerializedSize(TConstArrayRef<TCell> cells) {
return size;
}

TCell TSerializedCellVec::ExtractCell(std::string_view data, size_t pos) {
TSerializedCellReader reader{data};

ui16 cellCount = 0;
if (!reader.Read(&cellCount) || cellCount <= pos) {
return {};
}

TCell cell;
for (ui16 i = 0; i <= pos; ++i) {
cell = {};
if (!reader.ReadNewCell(&cell)) {
return {};
}
}
return cell;
}

bool TSerializedCellVec::DoTryParse() {
if (!TryDeserializeCellVec(Buf, Cells)) {
Buf.clear();
Expand Down Expand Up @@ -714,4 +732,3 @@ size_t GetCellHeaderSize() {
}

} // namespace NKikimr

2 changes: 2 additions & 0 deletions ydb/core/scheme/scheme_tablecell.h
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,8 @@ class TSerializedCellVec {

static size_t SerializedSize(TConstArrayRef<TCell> cells);

static TCell ExtractCell(std::string_view data, size_t pos);

const TString& GetBuffer() const { return Buf; }

TString ReleaseBuffer() {
Expand Down
8 changes: 5 additions & 3 deletions ydb/core/tx/datashard/kmeans_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,14 @@ struct TMaxInnerProductSimilarity: TMetric<T> {

template <typename TMetric>
struct TCalculation: TMetric {
ui32 FindClosest(std::span<const TString> clusters, const char* embedding) const
ui32 FindClosest(std::span<const TString> clusters, TArrayRef<const char> embedding) const
{
Y_DEBUG_ABORT_UNLESS(this->IsExpectedSize(embedding));
auto min = this->Init();
ui32 closest = std::numeric_limits<ui32>::max();
for (size_t i = 0; const auto& cluster : clusters) {
auto distance = this->Distance(cluster.data(), embedding);
Y_DEBUG_ABORT_UNLESS(this->IsExpectedSize(cluster));
auto distance = this->Distance(cluster.data(), embedding.data());
if (distance < min) {
min = distance;
closest = i;
Expand All @@ -195,7 +197,7 @@ ui32 FeedEmbedding(const TCalculation<TMetric>& calculation, std::span<const TSt
if (!calculation.IsExpectedSize(embedding)) {
return std::numeric_limits<ui32>::max();
}
return calculation.FindClosest(clusters, embedding.data());
return calculation.FindClosest(clusters, embedding);
}

void AddRowMain2Build(TBufferData& buffer, ui32 parent, TArrayRef<const TCell> key, const NTable::TRowState& row);
Expand Down
40 changes: 17 additions & 23 deletions ydb/core/tx/schemeshard/schemeshard_build_index__progress.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,16 @@ static constexpr const char* Name(TIndexBuildInfo::EState state) noexcept {
static std::tuple<ui32, ui32, ui32> ComputeKMeansBoundaries(const NSchemeShard::TTableInfo& tableInfo, const TIndexBuildInfo& buildInfo) {
const auto& kmeans = buildInfo.KMeans;
Y_ASSERT(kmeans.K != 0);
Y_ASSERT((kmeans.K & (kmeans.K - 1)) == 0);
const auto count = TIndexBuildInfo::TKMeans::BinPow(kmeans.K, kmeans.Level);
ui32 step = 1;
auto parts = count;
auto shards = tableInfo.GetShard2PartitionIdx().size();
if (!buildInfo.KMeans.NeedsAnotherLevel() || shards <= 1) {
shards = 1;
parts = 1;
if (!buildInfo.KMeans.NeedsAnotherLevel() || count <= 1 || shards <= 1) {
return {1, 1, 1};
}
for (; shards < parts; parts /= 2) {
for (; 2 * shards <= parts; parts = (parts + 1) / 2) {
step *= 2;
}
for (; parts < shards / 2; parts *= 2) {
Y_ASSERT(step == 1);
}
return {count, parts, step};
}

Expand Down Expand Up @@ -341,7 +336,7 @@ THolder<TEvSchemeShard::TEvModifySchemeTransaction> CreateBuildPropose(
modifyScheme.SetWorkingDir(path.Dive(buildInfo.IndexName).PathString());
modifyScheme.SetOperationType(NKikimrSchemeOp::ESchemeOpInitiateBuildIndexImplTable);
auto& op = *modifyScheme.MutableCreateTable();
const char* suffix = buildInfo.KMeans.Level % 2 != 0 ? BuildSuffix0 : BuildSuffix1;
std::string_view suffix = buildInfo.KMeans.Level % 2 != 0 ? BuildSuffix0 : BuildSuffix1;
op = CalcVectorKmeansTreePostingImplTableDesc(tableInfo, tableInfo->PartitionConfig(), implTableColumns, {}, suffix);

const auto [count, parts, step] = ComputeKMeansBoundaries(*tableInfo, buildInfo);
Expand All @@ -351,25 +346,24 @@ THolder<TEvSchemeShard::TEvModifySchemeTransaction> CreateBuildPropose(

auto& policy = *config.MutablePartitioningPolicy();
policy.SetSizeToSplit(0); // disable auto split/merge
policy.SetMinPartitionsCount(parts);
policy.SetMaxPartitionsCount(parts);
policy.ClearFastSplitSettings();
policy.ClearSplitByLoadSettings();

op.ClearSplitBoundary();
if (parts <= 1) {
return propose;
}
auto i = buildInfo.KMeans.Parent;
for (const auto end = i + count;;) {
i += step;
if (i >= end) {
Y_ASSERT(op.SplitBoundarySize() == std::min(count, parts) - 1);
break;
static constexpr std::string_view LogPrefix = "Create build table boundaries for ";
LOG_D(buildInfo.Id << " table " << suffix
<< ", count: " << count << ", parts: " << parts << ", step: " << step
<< ", kmeans: " << buildInfo.KMeansTreeToDebugStr());
if (parts > 1) {
const auto parentFrom = buildInfo.KMeans.ParentEnd + 1;
for (auto i = parentFrom + step, e = parentFrom + count; i < e; i += step) {
LOG_D(buildInfo.Id << " table " << suffix << " value: " << i);
auto cell = TCell::Make(i);
op.AddSplitBoundary()->SetSerializedKeyPrefix(TSerializedCellVec::Serialize({&cell, 1}));
}
auto cell = TCell::Make(i);
op.AddSplitBoundary()->SetSerializedKeyPrefix(TSerializedCellVec::Serialize({&cell, 1}));
}
policy.SetMinPartitionsCount(op.SplitBoundarySize() + 1);
policy.SetMaxPartitionsCount(op.SplitBoundarySize() + 1);
return propose;
}

Expand Down Expand Up @@ -574,7 +568,7 @@ struct TSchemeShard::TIndexBuilder::TTxProgress: public TSchemeShard::TIndexBuil
auto& clusters = *ev->Record.MutableClusters();
clusters.Reserve(buildInfo.Sample.Rows.size());
for (const auto& [_, row] : buildInfo.Sample.Rows) {
*clusters.Add() = row;
*clusters.Add() = TSerializedCellVec::ExtractCell(row, 0).AsBuf();
}

ev->Record.SetPostingName(path.Dive(buildInfo.KMeans.WriteTo()).PathString());
Expand Down
27 changes: 21 additions & 6 deletions ydb/core/tx/schemeshard/schemeshard_info_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ struct TSplitSettings {
TForceShardSplitSettings GetForceShardSplitSettings() const {
return TForceShardSplitSettings{
.ForceShardSplitDataSize = ui64(ForceShardSplitDataSize),
.DisableForceShardSplit = ui64(DisableForceShardSplit) != 0,
.DisableForceShardSplit = true,
};
}
};
Expand Down Expand Up @@ -3680,17 +3680,32 @@ struct TIndexBuildInfo: public TSimpleRefCount<TIndexBuildInfo> {
}

float CalcProgressPercent() const {
const auto total = Shards.size();
const auto done = DoneShards.size();
if (IsBuildVectorIndex()) {
const auto inProgress = InProgressShards.size();
const auto toUpload = ToUploadShards.size();
Y_ASSERT(KMeans.Level != 0);
// TODO(mbkkt) better calculation for vector index
return KMeans.Level * 100.0 / KMeans.Levels;
if (!KMeans.NeedsAnotherLevel() && !KMeans.NeedsAnotherParent()
&& toUpload == 0 && inProgress == 0) {
return 100.f;
}
auto percent = static_cast<float>(KMeans.Level - 1) / KMeans.Levels;
auto multiply = 1.f / KMeans.Levels;
if (KMeans.State == TKMeans::MultiLocal) {
percent += (multiply * (total - inProgress - toUpload)) / total;
} else {
const auto parentSize = KMeans.BinPow(KMeans.K, KMeans.Level - 1);
const auto parentFrom = KMeans.ParentEnd - parentSize + 1;
percent += (multiply * (KMeans.Parent - parentFrom)) / parentSize;
}
return 100.f * percent;
}
if (Shards) {
float totalShards = Shards.size();
return 100.0 * DoneShards.size() / totalShards;
return (100.f * done) / total;
}
// No shards - no progress
return 0.0;
return 0.f;
}

void SerializeToProto(TSchemeShard* ss, NKikimrIndexBuilder::TColumnBuildSettings* to) const;
Expand Down

0 comments on commit 5d71ae4

Please sign in to comment.