Skip to content

Commit

Permalink
feat(query): add l2 distance operator <-> (databendlabs#12382)
Browse files Browse the repository at this point in the history
* feat(query): add l2 distance operator <->

* feat(query): add l2 distance operator <->

* feat(query): add l2 distance operator <->

* fix
  • Loading branch information
sundy-li authored and andylokandy committed Nov 27, 2023
1 parent f5b5cdc commit d4ff344
Show file tree
Hide file tree
Showing 24 changed files with 69 additions and 98 deletions.
17 changes: 17 additions & 0 deletions src/common/vector/src/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,20 @@ pub fn cosine_distance(from: &[f32], to: &[f32]) -> Result<f32> {

Ok(1.0 - (&a * &b).sum() / ((aa_sum).sqrt() * (bb_sum).sqrt()))
}

pub fn l2_distance(from: &[f32], to: &[f32]) -> Result<f32> {
if from.len() != to.len() {
return Err(ErrorCode::InvalidArgument(format!(
"Vector length not equal: {:} != {:}",
from.len(),
to.len(),
)));
}

Ok(from
.iter()
.zip(to.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f32>()
.sqrt())
}
1 change: 1 addition & 0 deletions src/common/vector/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
mod distance;

pub use distance::cosine_distance;
pub use distance::l2_distance;
1 change: 0 additions & 1 deletion src/meta/process/src/examples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ pub fn print_table_meta(config: &Config) -> anyhow::Result<()> {
Ok(())
}

#[allow(dead_code)]
fn pretty<T>(v: &T) -> Result<String, serde_json::Error>
where T: Serialize {
serde_json::to_string_pretty(v)
Expand Down
6 changes: 6 additions & 0 deletions src/query/ast/src/ast/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,8 @@ pub enum BinaryOperator {
BitwiseXor,
BitwiseShiftLeft,
BitwiseShiftRight,

L2Distance,
}

impl BinaryOperator {
Expand All @@ -445,6 +447,7 @@ impl BinaryOperator {
BinaryOperator::BitwiseShiftLeft => "bit_shift_left".to_string(),
BinaryOperator::BitwiseShiftRight => "bit_shift_right".to_string(),
BinaryOperator::Caret => "pow".to_string(),
BinaryOperator::L2Distance => "l2_distance".to_string(),
_ => {
let name = format!("{:?}", self);
name.to_lowercase()
Expand Down Expand Up @@ -664,6 +667,9 @@ impl Display for BinaryOperator {
BinaryOperator::BitwiseShiftRight => {
write!(f, ">>")
}
BinaryOperator::L2Distance => {
write!(f, "<->")
}
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/query/ast/src/parser/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ impl<'a, I: Iterator<Item = WithSpan<'a, ExprElement>>> PrattParser<I> for ExprP
BinaryOperator::BitwiseOr => Affix::Infix(Precedence(22), Associativity::Left),
BinaryOperator::BitwiseAnd => Affix::Infix(Precedence(22), Associativity::Left),
BinaryOperator::BitwiseXor => Affix::Infix(Precedence(22), Associativity::Left),
BinaryOperator::L2Distance => Affix::Infix(Precedence(22), Associativity::Left),

BinaryOperator::BitwiseShiftLeft => {
Affix::Infix(Precedence(23), Associativity::Left)
Expand Down Expand Up @@ -1093,6 +1094,7 @@ pub fn binary_op(i: Input) -> IResult<BinaryOperator> {
value(BinaryOperator::Div, rule! { DIV }),
value(BinaryOperator::Modulo, rule! { "%" }),
value(BinaryOperator::StringConcat, rule! { "||" }),
value(BinaryOperator::L2Distance, rule! { "<->" }),
value(BinaryOperator::Gt, rule! { ">" }),
value(BinaryOperator::Lt, rule! { "<" }),
value(BinaryOperator::Gte, rule! { ">=" }),
Expand Down
4 changes: 4 additions & 0 deletions src/query/ast/src/parser/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,9 @@ pub enum TokenKind {
LOCATION_PREFIX,
#[token("ROLES", ignore(ascii_case))]
ROLES,
/// L2DISTANCE op, from https://github.com/pgvector/pgvector
#[token("<->")]
L2DISTANCE,
#[token("LEADING", ignore(ascii_case))]
LEADING,
#[token("LEFT", ignore(ascii_case))]
Expand Down Expand Up @@ -1048,6 +1051,7 @@ impl TokenKind {
| Abs
| SquareRoot
| CubeRoot
| L2DISTANCE
| Placeholder
| EOI
)
Expand Down
2 changes: 1 addition & 1 deletion src/query/ast/tests/it/testdata/expr-error.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ error:
--> SQL:1:10
|
1 | CAST(col1)
| ---- ^ expected `AS`, `,`, `(`, `.`, `IS`, `NOT`, or 69 more ...
| ---- ^ expected `AS`, `,`, `(`, `.`, `IS`, `NOT`, or 70 more ...
| |
| while parsing `CAST(... AS ...)`
| while parsing expression
Expand Down
2 changes: 1 addition & 1 deletion src/query/ast/tests/it/testdata/statement-error.txt
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ error:
--> SQL:1:41
|
1 | SELECT * FROM t GROUP BY GROUPING SETS ()
| ------ ^ expected `(`, `IS`, `IN`, `EXISTS`, `BETWEEN`, `+`, or 67 more ...
| ------ ^ expected `(`, `IS`, `IN`, `EXISTS`, `BETWEEN`, `+`, or 68 more ...
| |
| while parsing `SELECT ...`

Expand Down
27 changes: 27 additions & 0 deletions src/query/functions/src/scalars/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use common_expression::FunctionDomain;
use common_expression::FunctionRegistry;
use common_openai::OpenAI;
use common_vector::cosine_distance;
use common_vector::l2_distance;

pub fn register(registry: &mut FunctionRegistry) {
// cosine_distance
Expand Down Expand Up @@ -50,6 +51,32 @@ pub fn register(registry: &mut FunctionRegistry) {
),
);

// L2 distance
// cosine_distance
// This function takes two Float32 arrays as input and computes the l2 distance between them.
registry.register_passthrough_nullable_2_arg::<ArrayType<Float32Type>, ArrayType<Float32Type>, Float32Type, _, _>(
"l2_distance",
|_, _, _| FunctionDomain::MayThrow,
vectorize_with_builder_2_arg::<ArrayType<Float32Type>, ArrayType<Float32Type>, Float32Type>(
|lhs, rhs, output, ctx| {
let l_f32=
unsafe { std::mem::transmute::<Buffer<F32>, Buffer<f32>>(lhs) };
let r_f32=
unsafe { std::mem::transmute::<Buffer<F32>, Buffer<f32>>(rhs) };

match l2_distance(l_f32.as_slice(), r_f32.as_slice()) {
Ok(dist) => {
output.push(F32::from(dist));
}
Err(err) => {
ctx.set_error(output.len(), err.to_string());
output.push(F32::from(0.0));
}
}
}
),
);

// embedding_vector
// This function takes two strings as input, sends an API request to OpenAI, and returns the Float32 array of embeddings.
// The OpenAI API key is pre-configured during the binder phase, so we rewrite this function and set the API key.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1805,6 +1805,8 @@ Functions overloads:
1 json_path_query_first(Variant NULL, String NULL) :: Variant NULL
0 json_to_string(Variant) :: String
1 json_to_string(Variant NULL) :: String NULL
0 l2_distance(Array(Float32), Array(Float32)) :: Float32
1 l2_distance(Array(Float32) NULL, Array(Float32) NULL) :: Float32 NULL
0 left(String, UInt64) :: String
1 left(String NULL, UInt64 NULL) :: String NULL
0 length(Variant NULL) :: UInt32 NULL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ use crate::pipelines::PipelineBuildResult;
use crate::sessions::QueryContext;
use crate::sessions::TableContext;

#[allow(dead_code)]
pub struct VacuumTableInterpreter {
ctx: Arc<QueryContext>,
plan: VacuumTablePlan,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ use crate::sessions::TableContext;

const DRY_RUN_LIMIT: usize = 1000;

#[allow(dead_code)]
pub struct VacuumDropTablesInterpreter {
ctx: Arc<QueryContext>,
plan: VacuumDropTablePlan,
Expand Down
2 changes: 0 additions & 2 deletions src/query/service/src/pipelines/executor/executor_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@ struct Node {
processor: ProcessorPtr,

updated_list: Arc<UpdateList>,
#[allow(dead_code)]
inputs_port: Vec<Arc<InputPort>>,
#[allow(dead_code)]
outputs_port: Vec<Arc<OutputPort>>,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ pub struct PipelineCompleteExecutor {
}

// Use this executor when the pipeline is complete pipeline (has source and sink)
#[allow(dead_code)]
impl PipelineCompleteExecutor {
pub fn try_create(
pipeline: Pipeline,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,12 @@ impl State {
}

// Use this executor when the pipeline is pushing pipeline (exists sink but not exists source)
#[allow(dead_code)]
pub struct PipelinePushingExecutor {
state: Arc<State>,
executor: Arc<PipelineExecutor>,
sender: SyncSender<Option<DataBlock>>,
}

#[allow(dead_code)]
impl PipelinePushingExecutor {
fn wrap_pipeline(
ctx: Arc<QueryContext>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ pub struct TransformPartialAggregate<Method: HashMethodBounds> {
}

impl<Method: HashMethodBounds> TransformPartialAggregate<Method> {
#[allow(dead_code)]
pub fn try_create(
ctx: Arc<QueryContext>,
method: Method,
Expand Down
17 changes: 0 additions & 17 deletions src/query/service/src/servers/http/formats/mod.rs

This file was deleted.

1 change: 0 additions & 1 deletion src/query/service/src/servers/http/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

mod clickhouse_federated;
mod clickhouse_handler;
pub mod formats;
mod http_services;
pub mod middleware;
pub mod v1;
Expand Down
31 changes: 3 additions & 28 deletions src/query/service/src/servers/mysql/mysql_federated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
use std::collections::HashMap;
use std::sync::Arc;

use common_config::DATABEND_COMMIT_VERSION;
use common_expression::types::StringType;
use common_expression::utils::FromData;
use common_expression::DataBlock;
Expand All @@ -30,36 +29,12 @@ use regex::Regex;

use crate::servers::federated_helper::FederatedHelper;
use crate::servers::federated_helper::LazyBlockFunc;
use crate::servers::mysql::MYSQL_VERSION;

#[allow(dead_code)]
pub struct MySQLFederated {
mysql_version: String,
databend_version: String,
}
pub struct MySQLFederated {}

impl MySQLFederated {
pub fn create() -> Self {
MySQLFederated {
mysql_version: MYSQL_VERSION.to_string(),
databend_version: DATABEND_COMMIT_VERSION.to_string(),
}
}

// Build block for select @@variable.
// Format:
// |@@variable|
// |value|
#[allow(dead_code)]
fn select_variable_block(name: &str, value: &str) -> Option<(TableSchemaRef, DataBlock)> {
let schema = TableSchemaRefExt::create(vec![TableField::new(
&format!("@@{}", name),
TableDataType::String,
)]);
let block = DataBlock::new_from_columns(vec![StringType::from_data(vec![
value.as_bytes().to_vec(),
])]);
Some((schema, block))
MySQLFederated {}
}

// Build block for select function.
Expand Down Expand Up @@ -260,7 +235,7 @@ impl MySQLFederated {
(Regex::new("(?i)^(/\\*!40103 SET(.*) \\*/)$").unwrap(), None),
(Regex::new("(?i)^(/\\*!40111 SET(.*) \\*/)$").unwrap(), None),
(Regex::new("(?i)^(/\\*!40101 SET(.*) \\*/)$").unwrap(), None),
(Regex::new("(?i)^(/\\*!40014 SET(.*) \\*/)$").unwrap(), None),
(Regex::new("(?i)^(/\\*!40014 SET(.*) \\*/)$").unwrap(), None),
(Regex::new("(?i)^(/\\*!40000 SET(.*) \\*/)$").unwrap(), None),
(Regex::new("(?i)^(/\\*!40000 ALTER(.*) \\*/)$").unwrap(), None),
];
Expand Down
2 changes: 0 additions & 2 deletions src/query/service/src/test_kits/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,11 @@ pub async fn create_query_context_with_config(
Ok((guard, dummy_query_context))
}

#[allow(dead_code)]
pub struct ClusterDescriptor {
local_node_id: String,
cluster_nodes_list: Vec<Arc<NodeInfo>>,
}

#[allow(dead_code)]
impl ClusterDescriptor {
pub fn new() -> ClusterDescriptor {
ClusterDescriptor {
Expand Down
2 changes: 0 additions & 2 deletions src/query/sql/src/planner/optimizer/heuristic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@ mod decorrelate;
#[allow(clippy::module_inception)]
mod heuristic;
mod prune_unused_columns;
mod rule_list;
mod subquery_rewriter;

pub use heuristic::HeuristicOptimizer;
pub use heuristic::DEFAULT_REWRITE_RULES;
pub use heuristic::RESIDUAL_RULES;
pub use rule_list::RuleList;
pub use subquery_rewriter::SubqueryRewriter;
36 changes: 0 additions & 36 deletions src/query/sql/src/planner/optimizer/heuristic/rule_list.rs

This file was deleted.

1 change: 0 additions & 1 deletion src/query/storages/fuse/src/io/write/block_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ pub struct BloomIndexState {
pub(crate) data: Vec<u8>,
pub(crate) size: u64,
pub(crate) location: Location,
#[allow(dead_code)]
pub(crate) column_distinct_count: HashMap<FieldIndex, usize>,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,8 @@ query F
select cosine_distance([3.0, 45.0, 7.0, 2.0, 5.0, 20.0, 13.0, 12.0], [2.0, 54.0, 13.0, 15.0, 22.0, 34.0, 50.0, 1.0]) as sim
----
0.1264193

query F
select [1, 2] <-> [2, 3] as sim
----
1.4142135

0 comments on commit d4ff344

Please sign in to comment.