Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sqllab] Fix sqllab limit regex issue with sqlparse #5295

Merged
merged 2 commits into from
Jul 16, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 10 additions & 23 deletions superset/db_engine_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from tableschema import Table
from werkzeug.utils import secure_filename

from superset import app, cache_util, conf, db, utils
from superset import app, cache_util, conf, db, sql_parse, utils
from superset.exceptions import SupersetTemplateException
from superset.utils import QueryStatus

Expand Down Expand Up @@ -110,32 +110,19 @@ def apply_limit_to_sql(cls, sql, limit, database):
)
return database.compile_sqla_query(qry)
elif LimitMethod.FORCE_LIMIT:
sql_without_limit = cls.get_query_without_limit(sql)
return '{sql_without_limit} LIMIT {limit}'.format(**locals())
parsed_query = sql_parse.SupersetQuery(sql)
sql = parsed_query.get_query_with_new_limit(limit)
return sql

@classmethod
def get_limit_from_sql(cls, sql):
limit_pattern = re.compile(r"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the get_limit_from_sql method now obsolete?

(?ix) # case insensitive, verbose
\s+ # whitespace
LIMIT\s+(\d+) # LIMIT $ROWS
;? # optional semi-colon
(\s|;)*$ # remove trailing spaces tabs or semicolons
""")
matches = limit_pattern.findall(sql)
if matches:
return int(matches[0][0])

@classmethod
def get_query_without_limit(cls, sql):
return re.sub(r"""
(?ix) # case insensitive, verbose
\s+ # whitespace
LIMIT\s+\d+ # LIMIT $ROWS
;? # optional semi-colon
(\s|;)*$ # remove trailing spaces tabs or semicolons
""", '', sql)
parsed_query = sql_parse.SupersetQuery(sql)
return parsed_query.limit

@classmethod
def get_query_with_new_limit(cls, sql, limit):
parsed_query = sql_parse.SupersetQuery(sql)
return parsed_query.get_query_with_new_limit(limit)

@staticmethod
def csv_to_df(**kwargs):
Expand Down
44 changes: 44 additions & 0 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,24 @@ def __init__(self, sql_statement):
self.sql = sql_statement
self._table_names = set()
self._alias_names = set()
self._limit = None
# TODO: multistatement support

logging.info('Parsing with sqlparse statement {}'.format(self.sql))
self._parsed = sqlparse.parse(self.sql)
for statement in self._parsed:
self.__extract_from_token(statement)
self._limit = self._extract_limit_from_query(statement)
self._table_names = self._table_names - self._alias_names

@property
def tables(self):
return self._table_names

@property
def limit(self):
return self._limit

def is_select(self):
return self._parsed[0].get_type() == 'SELECT'

Expand Down Expand Up @@ -128,3 +134,41 @@ def __extract_from_token(self, token):
for token in item.tokens:
if self.__is_identifier(token):
self.__process_identifier(token)

def _get_limit_from_token(self, token):
if token.ttype == sqlparse.tokens.Literal.Number.Integer:
return int(token.value)
elif token.is_group:
return int(token.get_token_at_offset(1).value)

def _extract_limit_from_query(self, statement):
limit_token = None
for pos, item in enumerate(statement.tokens):
if item.ttype in Keyword and item.value.lower() == 'limit':
limit_token = statement.tokens[pos + 2]
return self._get_limit_from_token(limit_token)

def get_query_with_new_limit(self, new_limit):
"""returns the query with the specified limit"""
"""does not change the underlying query"""
if not self._limit:
return self.sql + ' LIMIT ' + str(new_limit)
limit_pos = None
tokens = self._parsed[0].tokens
# Add all items to before_str until there is a limit
for pos, item in enumerate(tokens):
if item.ttype in Keyword and item.value.lower() == 'limit':
limit_pos = pos
break
limit = tokens[limit_pos + 2]
if limit.ttype == sqlparse.tokens.Literal.Number.Integer:
tokens[limit_pos + 2].value = new_limit
elif limit.is_group:
tokens[limit_pos + 2].value = (
'{}, {}'.format(next(limit.get_identifiers()), new_limit)
)

str_res = ''
for i in tokens:
str_res += str(i.value)
return str_res
91 changes: 63 additions & 28 deletions tests/db_engine_specs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from __future__ import print_function
from __future__ import unicode_literals

import textwrap

from superset.db_engine_specs import (
BaseEngineSpec, HiveEngineSpec, MssqlEngineSpec,
MySQLEngineSpec, PrestoEngineSpec,
Expand Down Expand Up @@ -143,18 +141,6 @@ def test_modify_limit_query(self):
'SELECT * FROM a LIMIT 1000',
)

def test_modify_newline_query(self):
self.sql_limit_regex(
'SELECT * FROM a\nLIMIT 9999',
'SELECT * FROM a LIMIT 1000',
)

def test_modify_lcase_limit_query(self):
self.sql_limit_regex(
'SELECT * FROM a\tlimit 9999',
'SELECT * FROM a LIMIT 1000',
)

def test_limit_query_with_limit_subquery(self):
self.sql_limit_regex(
'SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999',
Expand All @@ -163,41 +149,90 @@ def test_limit_query_with_limit_subquery(self):

def test_limit_with_expr(self):
self.sql_limit_regex(
textwrap.dedent("""\
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT
99990"""),
textwrap.dedent("""\
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 99990""",
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table LIMIT 1000"""),
table
LIMIT 1000""",
)

def test_limit_expr_and_semicolon(self):
self.sql_limit_regex(
textwrap.dedent("""\
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 99990 ;"""),
textwrap.dedent("""\
LIMIT 99990 ;""",
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table LIMIT 1000"""),
table
LIMIT 1000 ;""",
)

def test_get_datatype(self):
self.assertEquals('STRING', PrestoEngineSpec.get_datatype('string'))
self.assertEquals('TINY', MySQLEngineSpec.get_datatype(1))
self.assertEquals('VARCHAR', MySQLEngineSpec.get_datatype(15))
self.assertEquals('VARCHAR', BaseEngineSpec.get_datatype('VARCHAR'))

def test_limit_with_implicit_offset(self):
self.sql_limit_regex(
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 99990, 999999""",
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 99990, 1000""",
)

def test_limit_with_explicit_offset(self):
self.sql_limit_regex(
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 99990
OFFSET 999999""",
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 1000
OFFSET 999999""",
)

def test_limit_with_non_token_limit(self):
self.sql_limit_regex(
"""
SELECT
'LIMIT 777'""",
"""
SELECT
'LIMIT 777' LIMIT 1000""",
)