Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: builder methods for basic ops #1516

Merged
merged 1 commit into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sqlglot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Expression as Expression,
alias_ as alias,
and_ as and_,
coalesce as coalesce,
column as column,
condition as condition,
except_ as except_,
Expand Down
176 changes: 152 additions & 24 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,119 @@ def not_(self):
"""
return not_(self)

def _binop(self, klass: t.Type[E], other: ExpOrStr, reverse=False) -> E:
this = self
other = convert(other)
if not isinstance(this, klass) and not isinstance(other, klass):
this = _wrap(this, Binary)
other = _wrap(other, Binary)
if reverse:
return klass(this=other, expression=this)
return klass(this=this, expression=other)

def __getitem__(self, other: ExpOrStr | slice | t.Tuple[ExpOrStr]):
if isinstance(other, slice):
return Between(
this=self,
low=convert(other.start),
high=convert(other.stop),
)
return Bracket(this=self, expressions=[convert(e) for e in ensure_list(other)])

def isin(self, *expressions: ExpOrStr, query: t.Optional[ExpOrStr] = None, **opts) -> In:
return In(
this=self,
expressions=[convert(e) for e in expressions],
query=maybe_parse(query, **opts) if query else None,
)

def like(self, other: ExpOrStr) -> Like:
return self._binop(Like, other)

def ilike(self, other: ExpOrStr) -> ILike:
return self._binop(ILike, other)

def eq(self, other: ExpOrStr) -> EQ:
return self._binop(EQ, other)

def neq(self, other: ExpOrStr) -> NEQ:
return self._binop(NEQ, other)

def rlike(self, other: ExpOrStr) -> RegexpLike:
return self._binop(RegexpLike, other)

def __lt__(self, other: ExpOrStr) -> LT:
return self._binop(LT, other)

def __le__(self, other: ExpOrStr) -> LTE:
return self._binop(LTE, other)

def __gt__(self, other: ExpOrStr) -> GT:
return self._binop(GT, other)

def __ge__(self, other: ExpOrStr) -> GTE:
return self._binop(GTE, other)

def __add__(self, other: ExpOrStr) -> Add:
return self._binop(Add, other)

def __radd__(self, other: ExpOrStr) -> Add:
return self._binop(Add, other, reverse=True)

def __sub__(self, other: ExpOrStr) -> Sub:
return self._binop(Sub, other)

def __rsub__(self, other: ExpOrStr) -> Sub:
return self._binop(Sub, other, reverse=True)

def __mul__(self, other: ExpOrStr) -> Mul:
return self._binop(Mul, other)

def __rmul__(self, other: ExpOrStr) -> Mul:
return self._binop(Mul, other, reverse=True)

def __truediv__(self, other: ExpOrStr) -> Div:
return self._binop(Div, other)

def __rtruediv__(self, other: ExpOrStr) -> Div:
return self._binop(Div, other, reverse=True)

def __floordiv__(self, other: ExpOrStr) -> IntDiv:
return self._binop(IntDiv, other)

def __rfloordiv__(self, other: ExpOrStr) -> IntDiv:
return self._binop(IntDiv, other, reverse=True)

def __mod__(self, other: ExpOrStr) -> Mod:
return self._binop(Mod, other)

def __rmod__(self, other: ExpOrStr) -> Mod:
return self._binop(Mod, other, reverse=True)

def __pow__(self, other: ExpOrStr) -> Pow:
return self._binop(Pow, other)

def __rpow__(self, other: ExpOrStr) -> Pow:
return self._binop(Pow, other, reverse=True)

def __and__(self, other: ExpOrStr) -> And:
return self._binop(And, other)

def __rand__(self, other: ExpOrStr) -> And:
return self._binop(And, other, reverse=True)

def __or__(self, other: ExpOrStr) -> Or:
return self._binop(Or, other)

def __ror__(self, other: ExpOrStr) -> Or:
return self._binop(Or, other, reverse=True)

def __neg__(self) -> Neg:
return Neg(this=_wrap(self, Binary))

def __invert__(self) -> Not:
return not_(self)


class Predicate(Condition):
"""Relationships like x = y, x > 1, x >= y."""
Expand Down Expand Up @@ -3006,7 +3119,7 @@ class DropPartition(Expression):


# Binary expressions like (ADD a b)
class Binary(Expression):
class Binary(Condition):
arg_types = {"this": True, "expression": True}

@property
Expand All @@ -3022,7 +3135,7 @@ class Add(Binary):
pass


class Connector(Binary, Condition):
class Connector(Binary):
pass


Expand Down Expand Up @@ -3184,19 +3297,19 @@ class ArrayOverlaps(Binary):

# Unary Expressions
# (NOT a)
class Unary(Expression):
class Unary(Condition):
pass


class BitwiseNot(Unary):
pass


class Not(Unary, Condition):
class Not(Unary):
pass


class Paren(Unary, Condition):
class Paren(Unary):
arg_types = {"this": True, "with": False}


Expand Down Expand Up @@ -4290,15 +4403,15 @@ def _combine(expressions, operator, dialect=None, **opts):
expressions = [condition(expression, dialect=dialect, **opts) for expression in expressions]
this = expressions[0]
if expressions[1:]:
this = _wrap_operator(this)
this = _wrap(this, Connector)
for expression in expressions[1:]:
this = operator(this=this, expression=_wrap_operator(expression))
this = operator(this=this, expression=_wrap(expression, Connector))
return this


def _wrap_operator(expression):
if isinstance(expression, (And, Or, Not)):
expression = Paren(this=expression)
def _wrap(expression: E, kind: t.Type[Expression]) -> E | Paren:
if isinstance(expression, kind):
return Paren(this=expression)
return expression


Expand Down Expand Up @@ -4596,7 +4709,7 @@ def not_(expression, dialect=None, **opts) -> Not:
dialect=dialect,
**opts,
)
return Not(this=_wrap_operator(this))
return Not(this=_wrap(this, Connector))


def paren(expression) -> Paren:
Expand Down Expand Up @@ -4838,6 +4951,23 @@ def cast(expression: ExpOrStr, to: str | DataType | DataType.Type, **opts) -> Ca
return Cast(this=expression, to=DataType.build(to, **opts))


def coalesce(*expressions: ExpOrStr, dialect: DialectType = None, **opts) -> Coalesce:
"""Create a coalesce node.

Example:
>>> coalesce('x + 1', '0').sql()
'COALESCE(x + 1, 0)'

Args:
expressions: The expressions to coalesce.

Returns:
A coalesce node.
"""
this, *exprs = [maybe_parse(e, **opts) for e in expressions]
return Coalesce(this=this, expressions=exprs)


def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table:
"""Build a Table.

Expand Down Expand Up @@ -4956,16 +5086,22 @@ def convert(value) -> Expression:
"""
if isinstance(value, Expression):
return value
if value is None:
return NULL
if isinstance(value, bool):
return Boolean(this=value)
if isinstance(value, str):
return Literal.string(value)
if isinstance(value, float) and math.isnan(value):
if isinstance(value, bool):
return Boolean(this=value)
if value is None or (isinstance(value, float) and math.isnan(value)):
return NULL
if isinstance(value, numbers.Number):
return Literal.number(value)
if isinstance(value, datetime.datetime):
datetime_literal = Literal.string(
(value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat()
)
return TimeStrToTime(this=datetime_literal)
if isinstance(value, datetime.date):
date_literal = Literal.string(value.strftime("%Y-%m-%d"))
return DateStrToDate(this=date_literal)
if isinstance(value, tuple):
return Tuple(expressions=[convert(v) for v in value])
if isinstance(value, list):
Expand All @@ -4975,14 +5111,6 @@ def convert(value) -> Expression:
keys=[convert(k) for k in value],
values=[convert(v) for v in value.values()],
)
if isinstance(value, datetime.datetime):
datetime_literal = Literal.string(
(value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat()
)
return TimeStrToTime(this=datetime_literal)
if isinstance(value, datetime.date):
date_literal = Literal.string(value.strftime("%Y-%m-%d"))
return DateStrToDate(this=date_literal)
raise ValueError(f"Cannot convert {value}")


Expand Down
4 changes: 2 additions & 2 deletions sqlglot/optimizer/annotate_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def _annotate_binary(self, expression):
left_type = expression.left.type.this
right_type = expression.right.type.this

if isinstance(expression, (exp.And, exp.Or)):
if isinstance(expression, exp.Connector):
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
expression.type = exp.DataType.Type.NULL
elif exp.DataType.Type.NULL in (left_type, right_type):
Expand All @@ -347,7 +347,7 @@ def _annotate_binary(self, expression):
)
else:
expression.type = exp.DataType.Type.BOOLEAN
elif isinstance(expression, (exp.Condition, exp.Predicate)):
elif isinstance(expression, exp.Predicate):
expression.type = exp.DataType.Type.BOOLEAN
else:
expression.type = self._maybe_coerce(left_type, right_type)
Expand Down
49 changes: 49 additions & 0 deletions tests/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from sqlglot import (
alias,
and_,
coalesce,
condition,
except_,
exp,
Expand All @@ -18,7 +19,55 @@

class TestBuild(unittest.TestCase):
def test_build(self):
x = condition("x")

for expression, sql, *dialect in [
(lambda: x + 1, "x + 1"),
(lambda: 1 + x, "1 + x"),
(lambda: x - 1, "x - 1"),
(lambda: 1 - x, "1 - x"),
(lambda: x * 1, "x * 1"),
(lambda: 1 * x, "1 * x"),
(lambda: x / 1, "x / 1"),
(lambda: 1 / x, "1 / x"),
(lambda: x // 1, "CAST(x / 1 AS INT)"),
(lambda: 1 // x, "CAST(1 / x AS INT)"),
(lambda: x % 1, "x % 1"),
(lambda: 1 % x, "1 % x"),
(lambda: x**1, "POWER(x, 1)"),
(lambda: 1**x, "POWER(1, x)"),
(lambda: x & 1, "x AND 1"),
(lambda: 1 & x, "1 AND x"),
(lambda: x | 1, "x OR 1"),
(lambda: 1 | x, "1 OR x"),
(lambda: x < 1, "x < 1"),
(lambda: 1 < x, "x > 1"),
(lambda: x <= 1, "x <= 1"),
(lambda: 1 <= x, "x >= 1"),
(lambda: x > 1, "x > 1"),
(lambda: 1 > x, "x < 1"),
(lambda: x >= 1, "x >= 1"),
(lambda: 1 >= x, "x <= 1"),
(lambda: x.eq(1), "x = 1"),
(lambda: x.neq(1), "x <> 1"),
(lambda: x.isin(1, "2"), "x IN (1, '2')"),
(lambda: x.isin(query="select 1"), "x IN (SELECT 1)"),
(lambda: 1 + x + 2 + 3, "1 + x + 2 + 3"),
(lambda: 1 + x * 2 + 3, "1 + (x * 2) + 3"),
(lambda: x * 1 * 2 + 3, "(x * 1 * 2) + 3"),
(lambda: 1 + (x * 2) / 3, "1 + ((x * 2) / 3)"),
(lambda: x & "y", "x AND 'y'"),
(lambda: x | "y", "x OR 'y'"),
(lambda: -x, "-x"),
(lambda: ~x, "NOT x"),
(lambda: x[1], "x[1]"),
(lambda: x[1, 2], "x[1, 2]"),
(lambda: x["y"] + 1, "x['y'] + 1"),
(lambda: x.like("y"), "x LIKE 'y'"),
(lambda: x.ilike("y"), "x ILIKE 'y'"),
(lambda: x.rlike("y"), "REGEXP_LIKE(x, 'y')"),
(lambda: coalesce("x", 1), "COALESCE(x, 1)"),
(lambda: select("x"), "SELECT x"),
(lambda: select("x"), "SELECT x"),
(lambda: select("x", "y"), "SELECT x, y"),
(lambda: select("x").from_("tbl"), "SELECT x FROM tbl"),
Expand Down