diff --git a/spectacles/cli.py b/spectacles/cli.py index c2f76c19..c150f543 100644 --- a/spectacles/cli.py +++ b/spectacles/cli.py @@ -175,6 +175,7 @@ def main(): args.api_version, args.mode, args.remote_reset, + args.concurrency, ) elif args.command == "assert": run_assert( @@ -352,6 +353,13 @@ def _build_sql_subparser( user's branch to the revision of the branch that is on the remote. \ WARNING: This will delete any uncommited changes in the user's workspace.", ) + subparser.add_argument( + "--concurrency", + default=10, + type=int, + help="Specify how many concurrent queries you want to have running \ + against your data warehouse. The default is 10.", + ) def _build_assert_subparser( @@ -414,6 +422,7 @@ def run_sql( api_version, mode, remote_reset, + concurrency, ) -> None: """Runs and validates the SQL for each selected LookML dimension.""" runner = Runner( @@ -426,7 +435,7 @@ def run_sql( api_version, remote_reset, ) - errors = runner.validate_sql(explores, mode) + errors = runner.validate_sql(explores, mode, concurrency) if errors: for error in sorted(errors, key=lambda x: x["path"]): printer.print_sql_error(error) diff --git a/spectacles/client.py b/spectacles/client.py index 6bbd0600..5954a92e 100644 --- a/spectacles/client.py +++ b/spectacles/client.py @@ -391,7 +391,9 @@ async def create_query_task( logger.debug("Query %d is running under query task %s", query_id, query_task_id) return query_task_id - def get_query_task_multi_results(self, query_task_ids: List[str]) -> JsonDict: + async def get_query_task_multi_results( + self, session: aiohttp.ClientSession, query_task_ids: List[str] + ) -> JsonDict: """Returns query task results. If a ClientError or TimeoutError is received, attempts to retry. @@ -408,16 +410,9 @@ def get_query_task_multi_results(self, query_task_ids: List[str]) -> JsonDict: "Attempting to get results for %d query tasks", len(query_task_ids) ) url = utils.compose_url(self.api_url, path=["query_tasks", "multi_results"]) - response = self.session.get( + async with session.get( url=url, params={"query_task_ids": ",".join(query_task_ids)} - ) - try: + ) as response: + result = await response.json() response.raise_for_status() - except requests.exceptions.HTTPError as error: - details = utils.details_from_http_error(response) - raise ApiConnectionError( - f"Looker API error encountered: {error}\n" - + "Message received from Looker's API: " - f'"{details}"' - ) - return response.json() + return result diff --git a/spectacles/runner.py b/spectacles/runner.py index 714230d7..915c3977 100644 --- a/spectacles/runner.py +++ b/spectacles/runner.py @@ -39,8 +39,10 @@ def __init__( self.client.update_session(project, branch, remote_reset) @log_duration - def validate_sql(self, selectors: List[str], mode: str = "batch") -> List[dict]: - sql_validator = SqlValidator(self.client, self.project) + def validate_sql( + self, selectors: List[str], mode: str = "batch", concurrency: int = 10 + ) -> List[dict]: + sql_validator = SqlValidator(self.client, self.project, concurrency) sql_validator.build_project(selectors) errors = sql_validator.validate(mode) return [vars(error) for error in errors] diff --git a/spectacles/validators.py b/spectacles/validators.py index 531df8cd..c4681d42 100644 --- a/spectacles/validators.py +++ b/spectacles/validators.py @@ -1,6 +1,5 @@ -from typing import List, Sequence, DefaultDict, Tuple +from typing import List, Sequence, DefaultDict import asyncio -import time from abc import ABC, abstractmethod from collections import defaultdict import aiohttp @@ -81,11 +80,13 @@ class SqlValidator(Validator): timeout = aiohttp.ClientTimeout(total=300) - def __init__(self, client: LookerClient, project: str): + def __init__(self, client: LookerClient, project: str, concurrency: int = 10): super().__init__(client) self.project = Project(project, models=[]) self.query_tasks: dict = {} + self.query_slots = asyncio.BoundedSemaphore(concurrency) + self.running_query_tasks: asyncio.Queue = asyncio.Queue() @staticmethod def parse_selectors(selectors: List[str]) -> DefaultDict[str, set]: @@ -207,9 +208,10 @@ def validate(self, mode: str = "batch") -> List[SqlError]: f"[{mode} mode]" ) - errors = self._query(mode) + loop = asyncio.get_event_loop() + errors = list(loop.run_until_complete(self._query(mode))) if mode == "hybrid" and self.project.errored: - errors = self._query(mode) + errors = list(loop.run_until_complete(self._query(mode))) for model in sorted(self.project.models, key=lambda x: x.name): for explore in sorted(model.explores, key=lambda x: x.name): @@ -221,53 +223,32 @@ def validate(self, mode: str = "batch") -> List[SqlError]: return errors - def _query(self, mode: str = "batch") -> List[SqlError]: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + async def _query(self, mode: str = "batch") -> List[SqlError]: session = aiohttp.ClientSession( - loop=loop, headers=self.client.session.headers, timeout=self.timeout + headers=self.client.session.headers, timeout=self.timeout ) - tasks = [] + + query_tasks = [] for model in self.project.models: for explore in model.explores: - if explore.dimensions: - if mode == "batch" or (mode == "hybrid" and not explore.queried): - logger.debug("Querying one explore at at time") - task = loop.create_task( - self._query_explore(session, model, explore) + if mode == "batch" or (mode == "hybrid" and not explore.queried): + task = asyncio.create_task( + self._query_explore(session, model, explore) + ) + query_tasks.append(task) + elif mode == "single" or (mode == "hybrid" and explore.errored): + for dimension in explore.dimensions: + task = asyncio.create_task( + self._query_dimension(session, model, explore, dimension) ) - tasks.append(task) - elif mode == "single" or (mode == "hybrid" and explore.errored): - logger.debug("Querying one dimension at at time") - for dimension in explore.dimensions: - task = loop.create_task( - self._query_dimension( - session, model, explore, dimension - ) - ) - tasks.append(task) - - query_task_ids = list(loop.run_until_complete(asyncio.gather(*tasks))) - loop.run_until_complete(session.close()) - loop.run_until_complete(asyncio.sleep(0.250)) - loop.close() - - MAX_QUERY_FETCH = 250 - - tasks_to_check = query_task_ids[:MAX_QUERY_FETCH] - del query_task_ids[:MAX_QUERY_FETCH] - logger.debug(f"{len(query_task_ids)} left in queue") - tasks_to_check, errors = self._get_query_results(tasks_to_check) - - while tasks_to_check or query_task_ids: - number_of_tasks_to_add = MAX_QUERY_FETCH - len(tasks_to_check) - tasks_to_check.extend(query_task_ids[:number_of_tasks_to_add]) - del query_task_ids[:number_of_tasks_to_add] - logger.debug(f"{len(query_task_ids)} left in queue") - tasks_to_check, more_errors = self._get_query_results(tasks_to_check) - errors.extend(more_errors) - if tasks_to_check or query_task_ids: - time.sleep(0.5) + query_tasks.append(task) + + queries = asyncio.gather(*query_tasks) + query_results = asyncio.gather(self._check_for_results(session, query_tasks)) + results = await asyncio.gather(queries, query_results) + errors = results[1][0] # Ignore the results from creating the queries + + await session.close() return errors @@ -299,47 +280,88 @@ def _extract_error_details(query_result: dict) -> dict: return {"message": message, "sql": sql, "line_number": line_number} - def _get_query_results( - self, query_task_ids: List[str] - ) -> Tuple[List[str], List[SqlError]]: - results = self.client.get_query_task_multi_results(query_task_ids) - still_running = [] + async def _run_query( + self, + session: aiohttp.ClientSession, + model: str, + explore: str, + dimensions: List[str], + ) -> str: + query_id = await self.client.create_query(session, model, explore, dimensions) + await self.query_slots.acquire() # Wait for available slots before launching + query_task_id = await self.client.create_query_task(session, query_id) + await self.running_query_tasks.put(query_task_id) + return query_task_id + + async def _get_query_results( + self, session: aiohttp.ClientSession + ) -> List[SqlError]: + logger.debug("%d queries running", self.running_query_tasks.qsize()) + + # Empty the queue (up to 250) to get all running query tasks + query_task_ids: List[str] = [] + while not self.running_query_tasks.empty() and len(query_task_ids) <= 250: + query_task_ids.append(await self.running_query_tasks.get()) + + logger.debug("Getting results for %d query tasks", len(query_task_ids)) + results = await self.client.get_query_task_multi_results( + session, query_task_ids + ) + pending_task_ids = [] errors = [] for query_task_id, query_result in results.items(): query_status = query_result["status"] logger.debug("Query task %s status is %s", query_task_id, query_status) - if query_status in ("running", "added", "expired"): - still_running.append(query_task_id) + pending_task_ids.append(query_task_id) + # Put the running query tasks back in the queue + await self.running_query_tasks.put(query_task_id) continue elif query_status in ("complete", "error"): + # We can release a query slot for each completed query + self.query_slots.release() lookml_object = self.query_tasks[query_task_id] lookml_object.queried = True + + if query_status == "error": + try: + details = self._extract_error_details(query_result) + except (KeyError, TypeError, IndexError) as error: + raise SpectaclesException( + "Encountered an unexpected API query result format, " + "unable to extract error details. " + f"The query result was: {query_result}" + ) from error + sql_error = SqlError( + path=lookml_object.name, + url=getattr(lookml_object, "url", None), + **details, + ) + lookml_object.error = sql_error + errors.append(sql_error) else: raise SpectaclesException( f'Unexpected query result status "{query_status}" ' "returned by the Looker API" ) - if query_status == "error": - try: - details = self._extract_error_details(query_result) - except (KeyError, TypeError, IndexError) as error: - raise SpectaclesException( - "Encountered an unexpected API query result format, " - "unable to extract error details. " - f"The query result was: {query_result}" - ) from error - sql_error = SqlError( - path=lookml_object.name, - url=getattr(lookml_object, "url", None), - **details, - ) - lookml_object.error = sql_error - errors.append(sql_error) + return errors - return still_running, errors + async def _check_for_results( + self, session: aiohttp.ClientSession, query_tasks: List[asyncio.Task] + ): + results = [] + while ( + any(not task.done() for task in query_tasks) + or not self.running_query_tasks.empty() + ): + if not self.running_query_tasks.empty(): + result = await self._get_query_results(session) + results.extend(result) + await asyncio.sleep(0.5) + + return results async def _query_explore( self, session: aiohttp.ClientSession, model: Model, explore: Explore @@ -355,11 +377,9 @@ async def _query_explore( """ dimensions = [dimension.name for dimension in explore.dimensions] - query_id = await self.client.create_query( + query_task_id = await self._run_query( session, model.name, explore.name, dimensions ) - query_task_id = await self.client.create_query_task(session, query_id) - self.query_tasks[query_task_id] = explore return query_task_id @@ -381,11 +401,9 @@ async def _query_dimension( str: Query task ID for the running query. """ - query_id = await self.client.create_query( + query_task_id = await self._run_query( session, model.name, explore.name, [dimension.name] ) - query_task_id = await self.client.create_query_task(session, query_id) - self.query_tasks[query_task_id] = dimension return query_task_id diff --git a/tests/test_cli.py b/tests/test_cli.py index c8589cf3..a6e6d7cf 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,6 +1,6 @@ from unittest.mock import patch import pytest -from tests.constants import ENV_VARS +from constants import ENV_VARS from spectacles.cli import main, create_parser, handle_exceptions from spectacles.exceptions import SpectaclesException, ValidationError diff --git a/tests/test_sql_validator.py b/tests/test_sql_validator.py index fe4b23a5..0d35ca08 100644 --- a/tests/test_sql_validator.py +++ b/tests/test_sql_validator.py @@ -2,10 +2,10 @@ import json from unittest.mock import patch, Mock import pytest +import asynctest from spectacles.lookml import Project, Model, Explore, Dimension from spectacles.client import LookerClient from spectacles.validators import SqlValidator -from spectacles.exceptions import SpectaclesException TEST_BASE_URL = "https://test.looker.com" TEST_CLIENT_ID = "test_client_id" @@ -72,135 +72,95 @@ def test_build_project(mock_get_models, mock_get_dimensions, project, validator) assert validator.project == project -@patch("spectacles.client.LookerClient.get_query_task_multi_results") -def test_get_query_results_task_running(mock_get_query_task_multi_results, validator): +@pytest.mark.asyncio +@asynctest.patch("spectacles.client.LookerClient.get_query_task_multi_results") +async def test_get_query_results_task_running( + mock_get_query_task_multi_results, validator +): + await validator.query_slots.acquire() mock_response = {"status": "running"} mock_get_query_task_multi_results.return_value = {"query_task_a": mock_response} - still_running, errors = validator._get_query_results(["query_task_a"]) - assert not errors - assert still_running == ["query_task_a"] + errors = validator._get_query_results(["query_task_a"]) + assert not await errors -@patch("spectacles.client.LookerClient.get_query_task_multi_results") -def test_get_query_results_task_complete( +@pytest.mark.asyncio +@asynctest.patch("spectacles.client.LookerClient.get_query_task_multi_results") +async def test_get_query_results_task_complete( mock_get_query_task_multi_results, validator, project ): + await validator.query_slots.acquire() lookml_object = project.models[0].explores[0] validator.query_tasks = {"query_task_a": lookml_object} mock_response = {"status": "complete"} mock_get_query_task_multi_results.return_value = {"query_task_a": mock_response} - still_running, errors = validator._get_query_results(["query_task_a"]) - assert not errors - assert not still_running + errors = validator._get_query_results(["query_task_a"]) + assert not await errors -@patch("spectacles.client.LookerClient.get_query_task_multi_results") -def test_get_query_results_task_error_dict( - mock_get_query_task_multi_results, validator, project -): - lookml_object = project.models[0].explores[0] - validator.query_tasks = {"query_task_a": lookml_object} - mock_message = "An error message." - mock_details = "Shocking details." - mock_sql = "SELECT * FROM orders" - mock_response = { +def test_extract_error_details_error_dict(validator, project): + message = "An error message." + message_details = "Shocking details." + sql = "SELECT * FROM orders" + query_result = { "status": "error", "data": { - "errors": [{"message": mock_message, "message_details": mock_details}], - "sql": mock_sql, + "errors": [{"message": message, "message_details": message_details}], + "sql": sql, }, } - mock_get_query_task_multi_results.return_value = {"query_task_a": mock_response} - still_running, errors = validator._get_query_results(["query_task_a"]) - assert errors[0].path == lookml_object.name - assert errors[0].message == f"{mock_message} {mock_details}" - assert errors[0].sql == mock_sql - assert not still_running + extracted = validator._extract_error_details(query_result) + assert extracted["message"] == f"{message} {message_details}" + assert extracted["sql"] == sql -@patch("spectacles.client.LookerClient.get_query_task_multi_results") -def test_get_query_results_task_error_list( - mock_get_query_task_multi_results, validator, project -): - lookml_object = project.models[0].explores[0] - validator.query_tasks = {"query_task_a": lookml_object} - mock_message = "An error message." - mock_response = {"status": "error", "data": [mock_message]} - mock_get_query_task_multi_results.return_value = {"query_task_a": mock_response} - still_running, errors = validator._get_query_results(["query_task_a"]) - assert errors[0].path == lookml_object.name - assert errors[0].message == mock_message - assert errors[0].sql is None - assert not still_running +def test_extract_error_details_error_list(validator, project): + message = "An error message." + query_result = {"status": "error", "data": [message]} + extracted = validator._extract_error_details(query_result) + assert extracted["message"] == message + assert extracted["sql"] is None -@patch("spectacles.client.LookerClient.get_query_task_multi_results") -def test_get_query_results_task_error_other( - mock_get_query_task_multi_results, validator, project -): - lookml_object = project.models[0].explores[0] - validator.query_tasks = {"query_task_a": lookml_object} - mock_response = {"status": "error", "data": "some string"} - mock_get_query_task_multi_results.return_value = {"query_task_a": mock_response} - with pytest.raises(SpectaclesException): - still_running, errors = validator._get_query_results(["query_task_a"]) +def test_extract_error_details_error_other(validator, project): + query_result = {"status": "error", "data": "some string"} + with pytest.raises(TypeError): + validator._extract_error_details(query_result) -@patch("spectacles.client.LookerClient.get_query_task_multi_results") -def test_get_query_results_non_str_message_details( - mock_get_query_task_multi_results, validator, project -): - lookml_object = project.models[0].explores[0] - validator.query_tasks = {"query_task_a": lookml_object} - mock_message = {"message": "An error messsage.", "details": "More details."} - mock_sql = "SELECT * FROM orders" - mock_response = { +def test_extract_error_details_error_non_str_message_details(validator, project): + message = {"message": "An error messsage.", "details": "More details."} + sql = "SELECT * FROM orders" + query_result = { "status": "error", - "data": {"errors": [{"message_details": mock_message}], "sql": mock_sql}, + "data": {"errors": [{"message_details": message}], "sql": sql}, } - mock_get_query_task_multi_results.return_value = {"query_task_a": mock_response} - with pytest.raises(SpectaclesException): - still_running, errors = validator._get_query_results(["query_task_a"]) + with pytest.raises(TypeError): + validator._extract_error_details(query_result) -@patch("spectacles.client.LookerClient.get_query_task_multi_results") -def test_get_query_results_task_error_loc_wo_msg_details( - mock_get_query_task_multi_results, validator, project -): - lookml_object = project.models[0].explores[0] - validator.query_tasks = {"query_task_a": lookml_object} - mock_message = "An error message." - mock_sql = "SELECT * FROM orders" - mock_response = { +def test_extract_error_details_error_loc_wo_msg_details(validator, project): + message = "An error message." + sql = "SELECT * FROM orders" + query_result = { "status": "error", - "data": {"errors": [{"message": mock_message}], "sql": mock_sql}, + "data": {"errors": [{"message": message}], "sql": sql}, } - mock_get_query_task_multi_results.return_value = {"query_task_a": mock_response} - still_running, errors = validator._get_query_results(["query_task_a"]) - assert errors[0].path == lookml_object.name - assert errors[0].message == mock_message - assert errors[0].sql == mock_sql - assert not still_running + extracted = validator._extract_error_details(query_result) + assert extracted["message"] == message + assert extracted["sql"] == sql -@patch("spectacles.client.LookerClient.get_query_task_multi_results") -def test_get_query_results_task_error_loc_wo_line( - mock_get_query_task_multi_results, validator, project -): - lookml_object = project.models[0].explores[0] - validator.query_tasks = {"query_task_a": lookml_object} - mock_message = "An error message." - mock_sql = "SELECT x FROM orders" - mock_response = { +def test_extract_error_details_error_loc_wo_line(validator, project): + message = "An error message." + sql = "SELECT x FROM orders" + query_result = { "status": "error", "data": { - "errors": [{"message": mock_message, "sql_error_loc": {"character": 8}}], - "sql": mock_sql, + "errors": [{"message": message, "sql_error_loc": {"character": 8}}], + "sql": sql, }, } - mock_get_query_task_multi_results.return_value = {"query_task_a": mock_response} - still_running, errors = validator._get_query_results(["query_task_a"]) - assert errors[0].path == lookml_object.name - assert errors[0].message == mock_message - assert errors[0].sql == mock_sql - assert not still_running + extracted = validator._extract_error_details(query_result) + assert extracted["message"] == message + assert extracted["sql"] == sql