Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sqlsmith): Generate EquiJoin #3613

Merged
merged 19 commits into from
Jul 12, 2022
22 changes: 22 additions & 0 deletions src/frontend/src/expr/type_inference/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,28 @@ pub enum DataTypeName {
List,
}

impl DataTypeName {
pub fn is_scalar(&self) -> bool {
match self {
DataTypeName::Boolean
| DataTypeName::Int16
| DataTypeName::Int32
| DataTypeName::Int64
| DataTypeName::Decimal
| DataTypeName::Float32
| DataTypeName::Float64
| DataTypeName::Varchar
| DataTypeName::Date
| DataTypeName::Timestamp
| DataTypeName::Timestampz
| DataTypeName::Time
| DataTypeName::Interval => true,

DataTypeName::Struct | DataTypeName::List => false,
}
}
}

impl From<&DataType> for DataTypeName {
fn from(ty: &DataType) -> Self {
match ty {
Expand Down
2 changes: 1 addition & 1 deletion src/tests/sqlsmith/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ In the second mode, it will test the entire query handling end-to-end. We provid

```sh
cd risingwave
./target/debug/sqlsmith test --testdata ./src/tests/sqlsmith/tests/testdata
kwannoel marked this conversation as resolved.
Show resolved Hide resolved
./target/debug/sqlsmith --testdata ./src/tests/sqlsmith/tests/testdata
```

Additionally, in some cases where you may want to debug whether we have defined some function/operator incorrectly,
Expand Down
4 changes: 0 additions & 4 deletions src/tests/sqlsmith/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,6 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
expr.or_else(|| make_general_expr(func.func, exprs))
.unwrap_or_else(|| self.gen_simple_scalar(ret))
}

fn can_recurse(&mut self) -> bool {
self.rng.gen_bool(0.3)
}
}

fn make_unary_op(func: ExprType, expr: &Expr) -> Option<Expr> {
Expand Down
23 changes: 21 additions & 2 deletions src/tests/sqlsmith/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ use rand::Rng;
use risingwave_frontend::binder::bind_data_type;
use risingwave_frontend::expr::DataTypeName;
use risingwave_sqlparser::ast::{
ColumnDef, Expr, Ident, OrderByExpr, Query, Select, SelectItem, SetExpr, Statement,
TableWithJoins, Value, With,
BinaryOperator, ColumnDef, Expr, Ident, Join, JoinConstraint, JoinOperator, OrderByExpr, Query,
Select, SelectItem, SetExpr, Statement, TableWithJoins, Value, With,
};

mod expr;
Expand All @@ -34,6 +34,18 @@ pub struct Table {
pub columns: Vec<Column>,
}

impl Table {
pub fn get_qualified_columns(&self) -> Vec<Column> {
self.columns
.iter()
.map(|c| Column {
name: format!("{}.{}", self.name, c.name),
data_type: c.data_type,
})
.collect()
}
}

#[derive(Clone)]
pub struct Column {
name: String,
Expand All @@ -53,6 +65,8 @@ struct SqlGenerator<'a, R: Rng> {
tables: Vec<Table>,
rng: &'a mut R,

/// Relations bound in generated query.
/// We might not read from all `tables.
bound_relations: Vec<Table>,
}

Expand Down Expand Up @@ -213,6 +227,11 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
fn flip_coin(&mut self) -> bool {
self.rng.gen_bool(0.5)
}

/// Provide recursion bounds.
pub(crate) fn can_recurse(&mut self) -> bool {
self.rng.gen_bool(0.3)
}
}

/// Generate a random SQL string.
Expand Down
115 changes: 92 additions & 23 deletions src/tests/sqlsmith/src/relation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,45 +16,114 @@ use rand::prelude::SliceRandom;
use rand::Rng;
use risingwave_sqlparser::ast::{Ident, ObjectName, TableAlias, TableFactor, TableWithJoins};

use crate::{SqlGenerator, Table};
use crate::{
BinaryOperator, Column, Expr, Join, JoinConstraint, JoinOperator, SqlGenerator, Table,
};

fn create_join_on_expr(left: String, right: String) -> Expr {
kwannoel marked this conversation as resolved.
Show resolved Hide resolved
let left = Box::new(Expr::Identifier(Ident::new(left)));
let right = Box::new(Expr::Identifier(Ident::new(right)));
Expr::BinaryOp {
left,
op: BinaryOperator::Eq,
right,
}
}

impl<'a, R: Rng> SqlGenerator<'a, R> {
/// A relation specified in the FROM clause.
pub(crate) fn gen_from_relation(&mut self) -> TableWithJoins {
match self.rng.gen_range(0..=9) {
0..=9 => self.gen_simple_table(),
// TODO: unreachable, should change to 9..=9,
// but currently it will cause panic due to some wrong assertions.
10..=10 => self.gen_subquery(),
_ => unreachable!(),
let (from_relation, _) = self.gen_from_relation_with_cols();
kwannoel marked this conversation as resolved.
Show resolved Hide resolved
from_relation
}

/// A relation specified in the FROM clause.
fn gen_from_relation_with_cols(&mut self) -> (TableWithJoins, Vec<Column>) {
if self.can_recurse() {
return match self.rng.gen_range(0..=9) {
0..=8 => self.gen_simple_table(),
9..=9 => self.gen_equijoin_expr(),
// TODO: unreachable, should change to 9..=9,
// but currently it will cause panic due to some wrong assertions.
10..=10 => self.gen_subquery(),
_ => unreachable!(),
};
}

self.gen_simple_table()
}

fn gen_simple_table(&mut self) -> TableWithJoins {
fn gen_simple_table(&mut self) -> (TableWithJoins, Vec<Column>) {
let (relation, columns) = self.gen_simple_table_factor();
let simple_table = TableWithJoins {
relation,
joins: vec![],
};
(simple_table, columns)
}

fn gen_simple_table_factor(&mut self) -> (TableFactor, Vec<Column>) {
let alias = format!("t{}", self.bound_relations.len());
let mut table = self.tables.choose(&mut self.rng).unwrap().clone();
let relation = TableWithJoins {
relation: TableFactor::Table {
name: ObjectName(vec![Ident::new(table.name.clone())]),
alias: Some(TableAlias {
name: Ident::new(alias.clone()),
columns: vec![],
}),
args: vec![],
},
joins: vec![],
let table_factor = TableFactor::Table {
name: ObjectName(vec![Ident::new(table.name.clone())]),
alias: Some(TableAlias {
name: Ident::new(alias.clone()),
columns: vec![],
}),
args: vec![],
};
table.name = alias; // Rename the table.
let columns = table.get_qualified_columns();
self.bound_relations.push(table);
relation
(table_factor, columns)
}

/// Generates a table factor, and provides bound columns.
/// Generated column names should be qualified by table name.
fn gen_table_factor(&mut self) -> (TableFactor, Vec<Column>) {
// TODO: TableFactor::Derived, TableFactor::TableFunction, TableFactor::NestedJoin
self.gen_simple_table_factor()
}

fn gen_equijoin_expr(&mut self) -> (TableWithJoins, Vec<Column>) {
kwannoel marked this conversation as resolved.
Show resolved Hide resolved
let (left_factor, mut left_columns) = self.gen_table_factor();
let (right_factor, right_columns) = self.gen_table_factor();

let mut available_join_on_columns = vec![];
for left_column in &left_columns {
for right_column in &right_columns {
// NOTE: We can support some composite types if we wish to in the future.
// see: https://www.postgresql.org/docs/14/functions-comparison.html.
// For simplicity only support scalar types for now.
let left_ty = left_column.data_type;
let right_ty = right_column.data_type;
if left_ty.is_scalar() && right_ty.is_scalar() && (left_ty == right_ty) {
available_join_on_columns.push((left_column, right_column))
}
}
}
let i = self.rng.gen_range(0..available_join_on_columns.len());
let (left_column, right_column) = available_join_on_columns[i];
let join_on_expr = create_join_on_expr(left_column.name.clone(), right_column.name.clone());

let right_factor_with_join = Join {
relation: right_factor,
join_operator: JoinOperator::Inner(JoinConstraint::On(join_on_expr)),
};
let table = TableWithJoins {
relation: left_factor,
joins: vec![right_factor_with_join],
};
left_columns.extend(right_columns);
(table, left_columns)
}

fn gen_subquery(&mut self) -> TableWithJoins {
fn gen_subquery(&mut self) -> (TableWithJoins, Vec<Column>) {
let (subquery, columns) = self.gen_query();
let alias = format!("t{}", self.bound_relations.len());
let table = Table {
name: alias.clone(),
columns,
columns: columns.clone(),
};
let relation = TableWithJoins {
relation: TableFactor::Derived {
Expand All @@ -68,6 +137,6 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
joins: vec![],
};
self.bound_relations.push(table);
relation
(relation, columns)
}
}