From 9114d86ecd0dbd0ad20c24d56f1fd29604de2047 Mon Sep 17 00:00:00 2001 From: Bogdan Date: Mon, 6 Mar 2017 16:20:55 -0800 Subject: [PATCH] Add hive to superset + monkey patch the pyhive (#2134) * Initial hive implementation * Fix select star query for hive. * Exclude generated code. * Address code coverage and linting. * Exclude generated code from coveralls. * Fix lint errors * Move TCLIService to it's own repo. * Address comments * Implement special postgres case, --- superset/db_engine_specs.py | 334 +++++++++++++++++++++++++++++--- superset/db_engines/__init__.py | 0 superset/db_engines/hive.py | 41 ++++ superset/jinja_context.py | 112 +---------- superset/models.py | 17 +- superset/sql_lab.py | 54 +++--- superset/views.py | 13 +- tests/celery_tests.py | 64 +++--- tests/db_engine_specs_test.py | 87 +++++++++ 9 files changed, 504 insertions(+), 218 deletions(-) create mode 100644 superset/db_engines/__init__.py create mode 100644 superset/db_engines/hive.py create mode 100644 tests/db_engine_specs_test.py diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 2bb7159f41e99..81e5c3b550db2 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -17,14 +17,19 @@ from __future__ import unicode_literals from collections import namedtuple, defaultdict -from flask_babel import lazy_gettext as _ from superset import utils import inspect +import re +import sqlparse import textwrap import time from superset import cache_util +from sqlalchemy import select +from sqlalchemy.sql import text +from superset.utils import SupersetTemplateException +from flask_babel import lazy_gettext as _ Grain = namedtuple('Grain', 'name label function') @@ -37,9 +42,16 @@ class LimitMethod(object): class BaseEngineSpec(object): engine = 'base' # str as defined in sqlalchemy.engine.engine + cursor_execute_kwargs = {} time_grains = tuple() limit_method = LimitMethod.FETCH_MANY + @classmethod + def fetch_data(cls, cursor, limit): + if cls.limit_method == LimitMethod.FETCH_MANY: + return cursor.fetchmany(limit) + return cursor.fetchall() + @classmethod def epoch_to_dttm(cls): raise NotImplementedError() @@ -106,6 +118,40 @@ def sql_preprocessor(cls, sql): """ return sql + @classmethod + def patch(cls): + pass + + @classmethod + def where_latest_partition( + cls, table_name, schema, database, qry, columns=None): + return False + + @classmethod + def select_star(cls, my_db, table_name, schema=None, limit=100, + show_cols=False, indent=True): + fields = '*' + table = my_db.get_table(table_name, schema=schema) + if show_cols: + fields = [my_db.get_quoter()(c.name) for c in table.columns] + full_table_name = table_name + if schema: + full_table_name = schema + '.' + table_name + qry = select(fields) + if limit: + qry = qry.limit(limit) + partition_query = cls.where_latest_partition( + table_name, schema, my_db, qry, columns=table.columns) + # if not partition_query condition fails. + if partition_query == False: # noqa + qry = qry.select_from(text(full_table_name)) + else: + qry = partition_query + sql = my_db.compile_sqla_query(qry) + if indent: + sql = sqlparse.format(sql, reindent=True) + return sql + class PostgresEngineSpec(BaseEngineSpec): engine = 'postgresql' @@ -122,6 +168,14 @@ class PostgresEngineSpec(BaseEngineSpec): Grain("year", _('year'), "DATE_TRUNC('year', {col})"), ) + @classmethod + def fetch_data(cls, cursor, limit): + if not cursor.description: + return [] + if cls.limit_method == LimitMethod.FETCH_MANY: + return cursor.fetchmany(limit) + return cursor.fetchall() + @classmethod def epoch_to_dttm(cls): return "(timestamp 'epoch' + {col} * interval '1 second')" @@ -235,27 +289,6 @@ def convert_dttm(cls, target_type, dttm): def epoch_to_dttm(cls): return "from_unixtime({col})" - @staticmethod - def show_partition_pql( - table_name, schema_name=None, order_by=None, limit=100): - if schema_name: - table_name = schema_name + '.' + table_name - order_by = order_by or [] - order_by_clause = '' - if order_by: - order_by_clause = "ORDER BY " + ', '.join(order_by) + " DESC" - - limit_clause = '' - if limit: - limit_clause = "LIMIT {}".format(limit) - - return textwrap.dedent("""\ - SHOW PARTITIONS - FROM {table_name} - {order_by_clause} - {limit_clause} - """).format(**locals()) - @classmethod @cache_util.memoized_func( timeout=600, @@ -284,16 +317,14 @@ def extra_table_metadata(cls, database, table_name, schema_name): if not indexes: return {} cols = indexes[0].get('column_names', []) - pql = cls.show_partition_pql(table_name, schema_name, cols) - df = database.get_df(pql, schema_name) - latest_part = df.to_dict(orient='records')[0] if not df.empty else None - - partition_query = cls.show_partition_pql(table_name, schema_name, cols) + pql = cls._partition_query(table_name, schema_name, cols) + col_name, latest_part = cls.latest_partition( + table_name, schema_name, database) return { 'partitions': { 'cols': cols, - 'latest': latest_part, - 'partitionQuery': partition_query, + 'latest': {col_name: latest_part}, + 'partitionQuery': pql, } } @@ -332,6 +363,251 @@ def extract_error_message(cls, e): ) return utils.error_msg_from_exception(e) + @classmethod + def _partition_query( + cls, table_name, limit=0, order_by=None, filters=None): + """Returns a partition query + + :param table_name: the name of the table to get partitions from + :type table_name: str + :param limit: the number of partitions to be returned + :type limit: int + :param order_by: a list of tuples of field name and a boolean + that determines if that field should be sorted in descending + order + :type order_by: list of (str, bool) tuples + :param filters: a list of filters to apply + :param filters: dict of field name and filter value combinations + """ + limit_clause = "LIMIT {}".format(limit) if limit else '' + order_by_clause = '' + if order_by: + l = [] + for field, desc in order_by: + l.append(field + ' DESC' if desc else '') + order_by_clause = 'ORDER BY ' + ', '.join(l) + + where_clause = '' + if filters: + l = [] + for field, value in filters.items(): + l.append("{field} = '{value}'".format(**locals())) + where_clause = 'WHERE ' + ' AND '.join(l) + + sql = textwrap.dedent("""\ + SHOW PARTITIONS FROM {table_name} + {where_clause} + {order_by_clause} + {limit_clause} + """).format(**locals()) + return sql + + @classmethod + def _latest_partition_from_df(cls, df): + return df.to_records(index=False)[0][0] + + @classmethod + def latest_partition(cls, table_name, schema, database): + """Returns col name and the latest (max) partition value for a table + + :param table_name: the name of the table + :type table_name: str + :param schema: schema / database / namespace + :type schema: str + :param database: database query will be run against + :type database: models.Database + + >>> latest_partition('foo_table') + '2018-01-01' + """ + indexes = database.get_indexes(table_name, schema) + if len(indexes[0]['column_names']) < 1: + raise SupersetTemplateException( + "The table should have one partitioned field") + elif len(indexes[0]['column_names']) > 1: + raise SupersetTemplateException( + "The table should have a single partitioned field " + "to use this function. You may want to use " + "`presto.latest_sub_partition`") + part_field = indexes[0]['column_names'][0] + sql = cls._partition_query(table_name, 1, [(part_field, True)]) + df = database.get_df(sql, schema) + return part_field, cls._latest_partition_from_df(df) + + @classmethod + def latest_sub_partition(cls, table_name, schema, database, **kwargs): + """Returns the latest (max) partition value for a table + + A filtering criteria should be passed for all fields that are + partitioned except for the field to be returned. For example, + if a table is partitioned by (``ds``, ``event_type`` and + ``event_category``) and you want the latest ``ds``, you'll want + to provide a filter as keyword arguments for both + ``event_type`` and ``event_category`` as in + ``latest_sub_partition('my_table', + event_category='page', event_type='click')`` + + :param table_name: the name of the table, can be just the table + name or a fully qualified table name as ``schema_name.table_name`` + :type table_name: str + :param schema: schema / database / namespace + :type schema: str + :param database: database query will be run against + :type database: models.Database + + :param kwargs: keyword arguments define the filtering criteria + on the partition list. There can be many of these. + :type kwargs: str + >>> latest_sub_partition('sub_partition_table', event_type='click') + '2018-01-01' + """ + indexes = database.get_indexes(table_name, schema) + part_fields = indexes[0]['column_names'] + for k in kwargs.keys(): + if k not in k in part_fields: + msg = "Field [{k}] is not part of the portioning key" + raise SupersetTemplateException(msg) + if len(kwargs.keys()) != len(part_fields) - 1: + msg = ( + "A filter needs to be specified for {} out of the " + "{} fields." + ).format(len(part_fields)-1, len(part_fields)) + raise SupersetTemplateException(msg) + + for field in part_fields: + if field not in kwargs.keys(): + field_to_return = field + + sql = cls._partition_query( + table_name, 1, [(field_to_return, True)], kwargs) + df = database.get_df(sql, schema) + if df.empty: + return '' + return df.to_dict()[field_to_return][0] + + +class HiveEngineSpec(PrestoEngineSpec): + + """Reuses PrestoEngineSpec functionality.""" + + engine = 'hive' + cursor_execute_kwargs = {'async': True} + + @classmethod + def patch(cls): + from pyhive import hive + from superset.db_engines import hive as patched_hive + from pythrifthiveapi.TCLIService import ( + constants as patched_constants, + ttypes as patched_ttypes, + TCLIService as patched_TCLIService) + + hive.TCLIService = patched_TCLIService + hive.constants = patched_constants + hive.ttypes = patched_ttypes + hive.Cursor.fetch_logs = patched_hive.fetch_logs + + @classmethod + @cache_util.memoized_func( + timeout=600, + key=lambda *args, **kwargs: 'db:{}:{}'.format(args[0].id, args[1])) + def fetch_result_sets(cls, db, datasource_type, force=False): + return BaseEngineSpec.fetch_result_sets( + db, datasource_type, force=force) + + @classmethod + def progress(cls, logs): + # 17/02/07 19:36:38 INFO ql.Driver: Total jobs = 5 + jobs_stats_r = re.compile( + r'.*INFO.*Total jobs = (?P[0-9]+)') + # 17/02/07 19:37:08 INFO ql.Driver: Launching Job 2 out of 5 + launching_job_r = re.compile( + '.*INFO.*Launching Job (?P[0-9]+) out of ' + '(?P[0-9]+)') + # 17/02/07 19:36:58 INFO exec.Task: 2017-02-07 19:36:58,152 Stage-18 + # map = 0%, reduce = 0% + stage_progress = re.compile( + r'.*INFO.*Stage-(?P[0-9]+).*' + r'map = (?P[0-9]+)%.*' + r'reduce = (?P[0-9]+)%.*') + total_jobs = None + current_job = None + stages = {} + lines = logs.splitlines() + for line in lines: + match = jobs_stats_r.match(line) + if match: + total_jobs = int(match.groupdict()['max_jobs']) + match = launching_job_r.match(line) + if match: + current_job = int(match.groupdict()['job_number']) + stages = {} + match = stage_progress.match(line) + if match: + stage_number = int(match.groupdict()['stage_number']) + map_progress = int(match.groupdict()['map_progress']) + reduce_progress = int(match.groupdict()['reduce_progress']) + stages[stage_number] = (map_progress + reduce_progress) / 2 + + if not total_jobs or not current_job: + return 0 + stage_progress = sum( + stages.values()) / len(stages.values()) if stages else 0 + + progress = ( + 100 * (current_job - 1) / total_jobs + stage_progress / total_jobs + ) + return int(progress) + + @classmethod + def handle_cursor(cls, cursor, query, session): + """Updates progress information""" + from pyhive import hive + print("PATCHED TCLIService {}".format(hive.TCLIService.__file__)) + unfinished_states = ( + hive.ttypes.TOperationState.INITIALIZED_STATE, + hive.ttypes.TOperationState.RUNNING_STATE, + ) + polled = cursor.poll() + while polled.operationState in unfinished_states: + resp = cursor.fetch_logs() + if resp and resp.log: + progress = cls.progress(resp.log) + if progress > query.progress: + query.progress = progress + session.commit() + time.sleep(5) + polled = cursor.poll() + + @classmethod + def where_latest_partition( + cls, table_name, schema, database, qry, columns=None): + try: + col_name, value = cls.latest_partition( + table_name, schema, database) + except Exception: + # table is not partitioned + return False + for c in columns: + if str(c.name) == str(col_name): + return qry.where(c == str(value)) + return False + + @classmethod + def latest_sub_partition(cls, table_name, **kwargs): + # TODO(bogdan): implement` + pass + + @classmethod + def _latest_partition_from_df(cls, df): + """Hive partitions look like ds={partition name}""" + return df.ix[:, 0].max().split('=')[1] + + @classmethod + def _partition_query( + cls, table_name, limit=0, order_by=None, filters=None): + return "SHOW PARTITIONS {table_name}".format(**locals()) + class MssqlEngineSpec(BaseEngineSpec): engine = 'mssql' diff --git a/superset/db_engines/__init__.py b/superset/db_engines/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/superset/db_engines/hive.py b/superset/db_engines/hive.py new file mode 100644 index 0000000000000..d3244feac62b0 --- /dev/null +++ b/superset/db_engines/hive.py @@ -0,0 +1,41 @@ +from pyhive import hive +from pythrifthiveapi.TCLIService import ttypes + + +# TODO: contribute back to pyhive. +def fetch_logs(self, max_rows=1024, + orientation=ttypes.TFetchOrientation.FETCH_NEXT): + """Mocked. Retrieve the logs produced by the execution of the query. + Can be called multiple times to fetch the logs produced after + the previous call. + :returns: list + :raises: ``ProgrammingError`` when no query has been started + .. note:: + This is not a part of DB-API. + """ + try: + req = ttypes.TGetLogReq(operationHandle=self._operationHandle) + logs = self._connection.client.GetLog(req) + return logs + except ttypes.TApplicationException as e: # raised if Hive is used + if self._state == self._STATE_NONE: + raise hive.ProgrammingError("No query yet") + logs = [] + while True: + req = ttypes.TFetchResultsReq( + operationHandle=self._operationHandle, + orientation=ttypes.TFetchOrientation.FETCH_NEXT, + maxRows=self.arraysize, + fetchType=1 # 0: results, 1: logs + ) + response = self._connection.client.FetchResults(req) + hive._check_status(response) + assert not ( + response.results.rows, 'expected data in columnar format' + ) + assert len(response.results.columns) == 1, response.results.columns + new_logs = hive._unwrap_column(response.results.columns[0]) + logs += new_logs + if not new_logs: + break + return logs diff --git a/superset/jinja_context.py b/superset/jinja_context.py index c861d9ac20643..51948dee418a6 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -10,12 +10,10 @@ from datetime import datetime, timedelta from dateutil.relativedelta import relativedelta import time -import textwrap import uuid import random from superset import app -from superset.utils import SupersetTemplateException config = app.config BASE_CONTEXT = { @@ -79,44 +77,6 @@ class PrestoTemplateProcessor(BaseTemplateProcessor): """ engine = 'presto' - @staticmethod - def _partition_query(table_name, limit=0, order_by=None, filters=None): - """Returns a partition query - - :param table_name: the name of the table to get partitions from - :type table_name: str - :param limit: the number of partitions to be returned - :type limit: int - :param order_by: a list of tuples of field name and a boolean - that determines if that field should be sorted in descending - order - :type order_by: list of (str, bool) tuples - :param filters: a list of filters to apply - :param filters: dict of field name and filter value combinations - """ - limit_clause = "LIMIT {}".format(limit) if limit else '' - order_by_clause = '' - if order_by: - l = [] - for field, desc in order_by: - l.append(field + ' DESC' if desc else '') - order_by_clause = 'ORDER BY ' + ', '.join(l) - - where_clause = '' - if filters: - l = [] - for field, value in filters.items(): - l.append("{field} = '{value}'".format(**locals())) - where_clause = 'WHERE ' + ' AND '.join(l) - - sql = textwrap.dedent("""\ - SHOW PARTITIONS FROM {table_name} - {where_clause} - {order_by_clause} - {limit_clause} - """).format(**locals()) - return sql - @staticmethod def _schema_table(table_name, schema): if '.' in table_name: @@ -124,74 +84,18 @@ def _schema_table(table_name, schema): return table_name, schema def latest_partition(self, table_name): - """Returns the latest (max) partition value for a table - - :param table_name: the name of the table, can be just the table - name or a fully qualified table name as ``schema_name.table_name`` - :type table_name: str - >>> latest_partition('foo_table') - '2018-01-01' - """ table_name, schema = self._schema_table(table_name, self.schema) - indexes = self.database.get_indexes(table_name, schema) - if len(indexes[0]['column_names']) < 1: - raise SupersetTemplateException( - "The table should have one partitioned field") - elif len(indexes[0]['column_names']) > 1: - raise SupersetTemplateException( - "The table should have a single partitioned field " - "to use this function. You may want to use " - "`presto.latest_sub_partition`") - part_field = indexes[0]['column_names'][0] - sql = self._partition_query(table_name, 1, [(part_field, True)]) - df = self.database.get_df(sql, schema) - return df.to_records(index=False)[0][0] + return self.database.db_engine_spec.latest_partition( + table_name, schema, self.database)[1] def latest_sub_partition(self, table_name, **kwargs): - """Returns the latest (max) partition value for a table - - A filtering criteria should be passed for all fields that are - partitioned except for the field to be returned. For example, - if a table is partitioned by (``ds``, ``event_type`` and - ``event_category``) and you want the latest ``ds``, you'll want - to provide a filter as keyword arguments for both - ``event_type`` and ``event_category`` as in - ``latest_sub_partition('my_table', - event_category='page', event_type='click')`` - - :param table_name: the name of the table, can be just the table - name or a fully qualified table name as ``schema_name.table_name`` - :type table_name: str - :param kwargs: keyword arguments define the filtering criteria - on the partition list. There can be many of these. - :type kwargs: str - >>> latest_sub_partition('sub_partition_table', event_type='click') - '2018-01-01' - """ table_name, schema = self._schema_table(table_name, self.schema) - indexes = self.database.get_indexes(table_name, schema) - part_fields = indexes[0]['column_names'] - for k in kwargs.keys(): - if k not in k in part_fields: - msg = "Field [{k}] is not part of the portioning key" - raise SupersetTemplateException(msg) - if len(kwargs.keys()) != len(part_fields) - 1: - msg = ( - "A filter needs to be specified for {} out of the " - "{} fields." - ).format(len(part_fields)-1, len(part_fields)) - raise SupersetTemplateException(msg) - - for field in part_fields: - if field not in kwargs.keys(): - field_to_return = field - - sql = self._partition_query( - table_name, 1, [(field_to_return, True)], kwargs) - df = self.database.get_df(sql, schema) - if df.empty: - return '' - return df.to_dict()[field_to_return][0] + return self.database.db_engine_spec.latest_sub_partition( + table_name, schema, self.database, kwargs) + + +class HiveTemplateProcessor(PrestoTemplateProcessor): + engine = 'hive' template_processors = {} diff --git a/superset/models.py b/superset/models.py index 7a8a83680eee2..cd732fb4832ef 100644 --- a/superset/models.py +++ b/superset/models.py @@ -830,20 +830,9 @@ def select_star( self, table_name, schema=None, limit=100, show_cols=False, indent=True): """Generates a ``select *`` statement in the proper dialect""" - quote = self.get_quoter() - fields = '*' - table = self.get_table(table_name, schema=schema) - if show_cols: - fields = [quote(c.name) for c in table.columns] - if schema: - table_name = schema + '.' + table_name - qry = select(fields).select_from(text(table_name)) - if limit: - qry = qry.limit(limit) - sql = self.compile_sqla_query(qry) - if indent: - sql = sqlparse.format(sql, reindent=True) - return sql + return self.db_engine_spec.select_star( + self, table_name, schema=schema, limit=limit, show_cols=show_cols, + indent=indent) def wrap_sql_limit(self, sql, limit=1000): qry = ( diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 8b57901f9a445..13a787145d082 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -2,6 +2,7 @@ from datetime import datetime import json import logging +import numpy as np import pandas as pd import sqlalchemy import uuid @@ -56,6 +57,7 @@ def get_sql_results(self, query_id, return_results=True, store_results=False): query = session.query(models.Query).filter_by(id=query_id).one() database = query.database db_engine_spec = database.db_engine_spec + db_engine_spec.patch() def handle_error(msg): """Local method handling error while processing the SQL""" @@ -91,7 +93,6 @@ def handle_error(msg): db_engine_spec.limit_method == LimitMethod.WRAP_SQL): executed_sql = database.wrap_sql_limit(executed_sql, query.limit) query.limit_used = True - engine = database.get_sqla_engine(schema=query.schema) try: template_processor = get_template_processor( database=database, query=query) @@ -104,34 +105,42 @@ def handle_error(msg): query.executed_sql = executed_sql logging.info("Running query: \n{}".format(executed_sql)) + engine = database.get_sqla_engine(schema=query.schema) + conn = engine.raw_connection() + cursor = conn.cursor() try: - result_proxy = engine.execute(query.executed_sql, schema=query.schema) + cursor.execute( + query.executed_sql, **db_engine_spec.cursor_execute_kwargs) except Exception as e: logging.exception(e) + conn.close() handle_error(db_engine_spec.extract_error_message(e)) - cursor = result_proxy.cursor query.status = QueryStatus.RUNNING session.flush() - db_engine_spec.handle_cursor(cursor, query, session) - - cdf = None - if result_proxy.cursor: - column_names = [col[0] for col in result_proxy.cursor.description] - column_names = dedup(column_names) - if db_engine_spec.limit_method == LimitMethod.FETCH_MANY: - data = result_proxy.fetchmany(query.limit) - else: - data = result_proxy.fetchall() - cdf = dataframe.SupersetDataFrame( - pd.DataFrame(data, columns=column_names)) + try: + logging.info("Handling cursor") + db_engine_spec.handle_cursor(cursor, query, session) + logging.info("Fetching data: {}".format(query.to_dict())) + data = db_engine_spec.fetch_data(cursor, query.limit) + except Exception as e: + logging.exception(e) + conn.close() + handle_error(db_engine_spec.extract_error_message(e)) + + conn.commit() + conn.close() + + column_names = ( + [col[0] for col in cursor.description] if cursor.description else []) + column_names = dedup(column_names) + df_data = np.array(data) if data else [] + cdf = dataframe.SupersetDataFrame(pd.DataFrame( + df_data, columns=column_names)) - query.rows = result_proxy.rowcount + query.rows = cdf.size query.progress = 100 query.status = QueryStatus.SUCCESS - if query.rows == -1 and cdf: - # Presto doesn't provide result_proxy.row_count - query.rows = cdf.size if query.select_as_cta: query.select_sql = '{}'.format(database.select_star( query.tmp_table_name, @@ -144,11 +153,10 @@ def handle_error(msg): payload = { 'query_id': query.id, 'status': query.status, - 'data': [], + 'data': cdf.data if cdf.data else [], + 'columns': cdf.columns_dict if cdf.columns_dict else {}, + 'query': query.to_dict(), } - payload['data'] = cdf.data if cdf else [] - payload['columns'] = cdf.columns_dict if cdf else [] - payload['query'] = query.to_dict() payload = json.dumps(payload, default=utils.json_iso_dttm_ser) if store_results: diff --git a/superset/views.py b/superset/views.py index 8882ba041f71d..6e3eab0a98197 100755 --- a/superset/views.py +++ b/superset/views.py @@ -2477,20 +2477,9 @@ def extra_table_metadata(self, database_id, table_name, schema): def select_star(self, database_id, table_name): mydb = db.session.query( models.Database).filter_by(id=database_id).first() - quote = mydb.get_quoter() - t = mydb.get_table(table_name) - - # Prevent exposing column fields to users that cannot access DB. - if not self.datasource_access(t.perm): - flash(get_datasource_access_error_msg(t.name), 'danger') - return redirect("/tablemodelview/list/") - - fields = ", ".join( - [quote(c.name) for c in t.columns] or "*") - s = "SELECT\n{}\nFROM {}".format(fields, table_name) return self.render_template( "superset/ajah.html", - content=s + content=mydb.select_star(table_name, show_cols=True) ) @expose("/theme/") diff --git a/tests/celery_tests.py b/tests/celery_tests.py index f172e6b1fe544..a4e48f4b191ea 100644 --- a/tests/celery_tests.py +++ b/tests/celery_tests.py @@ -170,27 +170,25 @@ def test_add_limit_to_the_query(self): ' '.join(updated_multi_line_query.split()) ) - def test_run_sync_query(self): + def test_run_sync_query_dont_exist(self): main_db = self.get_main_database(db.session) - eng = main_db.get_sqla_engine() - perm_name = 'can_sql_json' - db_id = main_db.id - # Case 1. - # Table doesn't exist. sql_dont_exist = 'SELECT name FROM table_dont_exist' result1 = self.run_sql(db_id, sql_dont_exist, "1", cta='true') self.assertTrue('error' in result1) - # Case 2. - # Table and DB exists, CTA call to the backend. + def test_run_sync_query_cta(self): + main_db = self.get_main_database(db.session) + db_id = main_db.id + eng = main_db.get_sqla_engine() + perm_name = 'can_sql_json' sql_where = ( "SELECT name FROM ab_permission WHERE name='{}'".format(perm_name)) result2 = self.run_sql( db_id, sql_where, "2", tmp_table='tmp_table_2', cta='true') self.assertEqual(QueryStatus.SUCCESS, result2['query']['state']) self.assertEqual([], result2['data']) - self.assertEqual([], result2['columns']) + self.assertEqual({}, result2['columns']) query2 = self.get_query_by_id(result2['query']['serverId']) # Check the data in the tmp table. @@ -198,14 +196,15 @@ def test_run_sync_query(self): data2 = df2.to_dict(orient='records') self.assertEqual([{'name': perm_name}], data2) - # Case 3. - # Table and DB exists, CTA call to the backend, no data. + def test_run_sync_query_cta_no_data(self): + main_db = self.get_main_database(db.session) + db_id = main_db.id sql_empty_result = 'SELECT * FROM ab_user WHERE id=666' result3 = self.run_sql( - db_id, sql_empty_result, "3", tmp_table='tmp_table_3', cta='true',) + db_id, sql_empty_result, "3", tmp_table='tmp_table_3', cta='true') self.assertEqual(QueryStatus.SUCCESS, result3['query']['state']) self.assertEqual([], result3['data']) - self.assertEqual([], result3['columns']) + self.assertEqual({}, result3['columns']) query3 = self.get_query_by_id(result3['query']['serverId']) self.assertEqual(QueryStatus.SUCCESS, query3.status) @@ -213,38 +212,31 @@ def test_run_sync_query(self): def test_run_async_query(self): main_db = self.get_main_database(db.session) eng = main_db.get_sqla_engine() - - # Schedule queries - - # Case 1. - # Table and DB exists, async CTA call to the backend. sql_where = "SELECT name FROM ab_role WHERE name='Admin'" - result1 = self.run_sql( + result = self.run_sql( main_db.id, sql_where, "4", async='true', tmp_table='tmp_async_1', cta='true') - assert result1['query']['state'] in ( + assert result['query']['state'] in ( QueryStatus.PENDING, QueryStatus.RUNNING, QueryStatus.SUCCESS) time.sleep(1) - # Case 1. - query1 = self.get_query_by_id(result1['query']['serverId']) - df1 = pd.read_sql_query(query1.select_sql, con=eng) - self.assertEqual(QueryStatus.SUCCESS, query1.status) - self.assertEqual([{'name': 'Admin'}], df1.to_dict(orient='records')) - self.assertEqual(QueryStatus.SUCCESS, query1.status) - self.assertTrue("FROM tmp_async_1" in query1.select_sql) - self.assertTrue("LIMIT 666" in query1.select_sql) + query = self.get_query_by_id(result['query']['serverId']) + df = pd.read_sql_query(query.select_sql, con=eng) + self.assertEqual(QueryStatus.SUCCESS, query.status) + self.assertEqual([{'name': 'Admin'}], df.to_dict(orient='records')) + self.assertEqual(QueryStatus.SUCCESS, query.status) + self.assertTrue("FROM tmp_async_1" in query.select_sql) + self.assertTrue("LIMIT 666" in query.select_sql) self.assertEqual( "CREATE TABLE tmp_async_1 AS \nSELECT name FROM ab_role " - "WHERE name='Admin'", query1.executed_sql) - self.assertEqual(sql_where, query1.sql) - if eng.name != 'sqlite': - self.assertEqual(1, query1.rows) - self.assertEqual(666, query1.limit) - self.assertEqual(False, query1.limit_used) - self.assertEqual(True, query1.select_as_cta) - self.assertEqual(True, query1.select_as_cta_used) + "WHERE name='Admin'", query.executed_sql) + self.assertEqual(sql_where, query.sql) + self.assertEqual(0, query.rows) + self.assertEqual(666, query.limit) + self.assertEqual(False, query.limit_used) + self.assertEqual(True, query.select_as_cta) + self.assertEqual(True, query.select_as_cta_used) def test_get_columns_dict(self): main_db = self.get_main_database(db.session) diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py new file mode 100644 index 0000000000000..626a97bb3f9c3 --- /dev/null +++ b/tests/db_engine_specs_test.py @@ -0,0 +1,87 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import unittest + +from superset import db_engine_specs + + +class DbEngineSpecsTestCase(unittest.TestCase): + def test_0_progress(self): + log = """ + 17/02/07 18:26:27 INFO log.PerfLogger: + 17/02/07 18:26:27 INFO log.PerfLogger: + """ + self.assertEquals(0, db_engine_specs.HiveEngineSpec.progress(log)) + + def test_0_progress(self): + log = """ + 17/02/07 18:26:27 INFO log.PerfLogger: + 17/02/07 18:26:27 INFO log.PerfLogger: + """ + self.assertEquals(0, db_engine_specs.HiveEngineSpec.progress(log)) + + def test_number_of_jobs_progress(self): + log = """ + 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 + """ + self.assertEquals(0, db_engine_specs.HiveEngineSpec.progress(log)) + + def test_job_1_launched_progress(self): + log = """ + 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 + 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2 + """ + self.assertEquals(0, db_engine_specs.HiveEngineSpec.progress(log)) + + def test_job_1_launched_stage_1_0_progress(self): + log = """ + 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 + 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2 + 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0% + """ + self.assertEquals(0, db_engine_specs.HiveEngineSpec.progress(log)) + + def test_job_1_launched_stage_1_map_40_progress(self): + log = """ + 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 + 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2 + 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0% + 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0% + """ + self.assertEquals(10, db_engine_specs.HiveEngineSpec.progress(log)) + + def test_job_1_launched_stage_1_map_80_reduce_40_progress(self): + log = """ + 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 + 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2 + 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0% + 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0% + 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 80%, reduce = 40% + """ + self.assertEquals(30, db_engine_specs.HiveEngineSpec.progress(log)) + + def test_job_1_launched_stage_2_stages_progress(self): + log = """ + 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 + 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2 + 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0% + 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0% + 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 80%, reduce = 40% + 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-2 map = 0%, reduce = 0% + 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 100%, reduce = 0% + """ + self.assertEquals(12, db_engine_specs.HiveEngineSpec.progress(log)) + + def test_job_2_launched_stage_2_stages_progress(self): + log = """ + 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 + 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2 + 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 100%, reduce = 0% + 17/02/07 19:15:55 INFO ql.Driver: Launching Job 2 out of 2 + 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0% + 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0% + """ + self.assertEquals(60, db_engine_specs.HiveEngineSpec.progress(log))