Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support max_concurrency for running Looker queries #110

Merged
merged 8 commits into from
Nov 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion spectacles/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def main():
args.api_version,
args.mode,
args.remote_reset,
args.concurrency,
)
elif args.command == "assert":
run_assert(
Expand Down Expand Up @@ -352,6 +353,13 @@ def _build_sql_subparser(
user's branch to the revision of the branch that is on the remote. \
WARNING: This will delete any uncommited changes in the user's workspace.",
)
subparser.add_argument(
"--concurrency",
default=10,
type=int,
help="Specify how many concurrent queries you want to have running \
against your data warehouse. The default is 10.",
)


def _build_assert_subparser(
Expand Down Expand Up @@ -414,6 +422,7 @@ def run_sql(
api_version,
mode,
remote_reset,
concurrency,
) -> None:
"""Runs and validates the SQL for each selected LookML dimension."""
runner = Runner(
Expand All @@ -426,7 +435,7 @@ def run_sql(
api_version,
remote_reset,
)
errors = runner.validate_sql(explores, mode)
errors = runner.validate_sql(explores, mode, concurrency)
if errors:
for error in sorted(errors, key=lambda x: x["path"]):
printer.print_sql_error(error)
Expand Down
19 changes: 7 additions & 12 deletions spectacles/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,9 @@ async def create_query_task(
logger.debug("Query %d is running under query task %s", query_id, query_task_id)
return query_task_id

def get_query_task_multi_results(self, query_task_ids: List[str]) -> JsonDict:
async def get_query_task_multi_results(
self, session: aiohttp.ClientSession, query_task_ids: List[str]
) -> JsonDict:
"""Returns query task results.

If a ClientError or TimeoutError is received, attempts to retry.
Expand All @@ -408,16 +410,9 @@ def get_query_task_multi_results(self, query_task_ids: List[str]) -> JsonDict:
"Attempting to get results for %d query tasks", len(query_task_ids)
)
url = utils.compose_url(self.api_url, path=["query_tasks", "multi_results"])
response = self.session.get(
async with session.get(
url=url, params={"query_task_ids": ",".join(query_task_ids)}
)
try:
) as response:
result = await response.json()
response.raise_for_status()
except requests.exceptions.HTTPError as error:
details = utils.details_from_http_error(response)
raise ApiConnectionError(
f"Looker API error encountered: {error}\n"
+ "Message received from Looker's API: "
f'"{details}"'
)
return response.json()
return result
6 changes: 4 additions & 2 deletions spectacles/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ def __init__(
self.client.update_session(project, branch, remote_reset)

@log_duration
def validate_sql(self, selectors: List[str], mode: str = "batch") -> List[dict]:
sql_validator = SqlValidator(self.client, self.project)
def validate_sql(
self, selectors: List[str], mode: str = "batch", concurrency: int = 10
) -> List[dict]:
sql_validator = SqlValidator(self.client, self.project, concurrency)
sql_validator.build_project(selectors)
errors = sql_validator.validate(mode)
return [vars(error) for error in errors]
Expand Down
172 changes: 95 additions & 77 deletions spectacles/validators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import List, Sequence, DefaultDict, Tuple
from typing import List, Sequence, DefaultDict
import asyncio
import time
from abc import ABC, abstractmethod
from collections import defaultdict
import aiohttp
Expand Down Expand Up @@ -81,11 +80,13 @@ class SqlValidator(Validator):

timeout = aiohttp.ClientTimeout(total=300)

def __init__(self, client: LookerClient, project: str):
def __init__(self, client: LookerClient, project: str, concurrency: int = 10):
super().__init__(client)

self.project = Project(project, models=[])
self.query_tasks: dict = {}
self.query_slots = asyncio.BoundedSemaphore(concurrency)
self.running_query_tasks: asyncio.Queue = asyncio.Queue()

@staticmethod
def parse_selectors(selectors: List[str]) -> DefaultDict[str, set]:
Expand Down Expand Up @@ -207,9 +208,10 @@ def validate(self, mode: str = "batch") -> List[SqlError]:
f"[{mode} mode]"
)

errors = self._query(mode)
loop = asyncio.get_event_loop()
errors = list(loop.run_until_complete(self._query(mode)))
if mode == "hybrid" and self.project.errored:
errors = self._query(mode)
errors = list(loop.run_until_complete(self._query(mode)))

for model in sorted(self.project.models, key=lambda x: x.name):
for explore in sorted(model.explores, key=lambda x: x.name):
Expand All @@ -221,53 +223,32 @@ def validate(self, mode: str = "batch") -> List[SqlError]:

return errors

def _query(self, mode: str = "batch") -> List[SqlError]:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
async def _query(self, mode: str = "batch") -> List[SqlError]:
session = aiohttp.ClientSession(
loop=loop, headers=self.client.session.headers, timeout=self.timeout
headers=self.client.session.headers, timeout=self.timeout
)
tasks = []

query_tasks = []
for model in self.project.models:
for explore in model.explores:
if explore.dimensions:
if mode == "batch" or (mode == "hybrid" and not explore.queried):
logger.debug("Querying one explore at at time")
task = loop.create_task(
self._query_explore(session, model, explore)
if mode == "batch" or (mode == "hybrid" and not explore.queried):
task = asyncio.create_task(
self._query_explore(session, model, explore)
)
query_tasks.append(task)
elif mode == "single" or (mode == "hybrid" and explore.errored):
for dimension in explore.dimensions:
task = asyncio.create_task(
self._query_dimension(session, model, explore, dimension)
)
tasks.append(task)
elif mode == "single" or (mode == "hybrid" and explore.errored):
logger.debug("Querying one dimension at at time")
for dimension in explore.dimensions:
task = loop.create_task(
self._query_dimension(
session, model, explore, dimension
)
)
tasks.append(task)

query_task_ids = list(loop.run_until_complete(asyncio.gather(*tasks)))
loop.run_until_complete(session.close())
loop.run_until_complete(asyncio.sleep(0.250))
loop.close()

MAX_QUERY_FETCH = 250

tasks_to_check = query_task_ids[:MAX_QUERY_FETCH]
del query_task_ids[:MAX_QUERY_FETCH]
logger.debug(f"{len(query_task_ids)} left in queue")
tasks_to_check, errors = self._get_query_results(tasks_to_check)

while tasks_to_check or query_task_ids:
number_of_tasks_to_add = MAX_QUERY_FETCH - len(tasks_to_check)
tasks_to_check.extend(query_task_ids[:number_of_tasks_to_add])
del query_task_ids[:number_of_tasks_to_add]
logger.debug(f"{len(query_task_ids)} left in queue")
tasks_to_check, more_errors = self._get_query_results(tasks_to_check)
errors.extend(more_errors)
if tasks_to_check or query_task_ids:
time.sleep(0.5)
query_tasks.append(task)

queries = asyncio.gather(*query_tasks)
query_results = asyncio.gather(self._check_for_results(session, query_tasks))
results = await asyncio.gather(queries, query_results)
errors = results[1][0] # Ignore the results from creating the queries

await session.close()

return errors

Expand Down Expand Up @@ -299,47 +280,88 @@ def _extract_error_details(query_result: dict) -> dict:

return {"message": message, "sql": sql, "line_number": line_number}

def _get_query_results(
self, query_task_ids: List[str]
) -> Tuple[List[str], List[SqlError]]:
results = self.client.get_query_task_multi_results(query_task_ids)
still_running = []
async def _run_query(
self,
session: aiohttp.ClientSession,
model: str,
explore: str,
dimensions: List[str],
) -> str:
query_id = await self.client.create_query(session, model, explore, dimensions)
await self.query_slots.acquire() # Wait for available slots before launching
query_task_id = await self.client.create_query_task(session, query_id)
await self.running_query_tasks.put(query_task_id)
return query_task_id

async def _get_query_results(
self, session: aiohttp.ClientSession
) -> List[SqlError]:
logger.debug("%d queries running", self.running_query_tasks.qsize())

# Empty the queue (up to 250) to get all running query tasks
query_task_ids: List[str] = []
while not self.running_query_tasks.empty() and len(query_task_ids) <= 250:
query_task_ids.append(await self.running_query_tasks.get())

logger.debug("Getting results for %d query tasks", len(query_task_ids))
results = await self.client.get_query_task_multi_results(
session, query_task_ids
)
pending_task_ids = []
errors = []

for query_task_id, query_result in results.items():
query_status = query_result["status"]
logger.debug("Query task %s status is %s", query_task_id, query_status)

if query_status in ("running", "added", "expired"):
still_running.append(query_task_id)
pending_task_ids.append(query_task_id)
# Put the running query tasks back in the queue
await self.running_query_tasks.put(query_task_id)
continue
elif query_status in ("complete", "error"):
# We can release a query slot for each completed query
self.query_slots.release()
lookml_object = self.query_tasks[query_task_id]
lookml_object.queried = True

if query_status == "error":
try:
details = self._extract_error_details(query_result)
except (KeyError, TypeError, IndexError) as error:
raise SpectaclesException(
"Encountered an unexpected API query result format, "
"unable to extract error details. "
f"The query result was: {query_result}"
) from error
sql_error = SqlError(
path=lookml_object.name,
url=getattr(lookml_object, "url", None),
**details,
)
lookml_object.error = sql_error
errors.append(sql_error)
else:
raise SpectaclesException(
f'Unexpected query result status "{query_status}" '
"returned by the Looker API"
)

if query_status == "error":
try:
details = self._extract_error_details(query_result)
except (KeyError, TypeError, IndexError) as error:
raise SpectaclesException(
"Encountered an unexpected API query result format, "
"unable to extract error details. "
f"The query result was: {query_result}"
) from error
sql_error = SqlError(
path=lookml_object.name,
url=getattr(lookml_object, "url", None),
**details,
)
lookml_object.error = sql_error
errors.append(sql_error)
return errors

return still_running, errors
async def _check_for_results(
self, session: aiohttp.ClientSession, query_tasks: List[asyncio.Task]
):
results = []
while (
any(not task.done() for task in query_tasks)
or not self.running_query_tasks.empty()
):
if not self.running_query_tasks.empty():
result = await self._get_query_results(session)
results.extend(result)
await asyncio.sleep(0.5)

return results

async def _query_explore(
self, session: aiohttp.ClientSession, model: Model, explore: Explore
Expand All @@ -355,11 +377,9 @@ async def _query_explore(

"""
dimensions = [dimension.name for dimension in explore.dimensions]
query_id = await self.client.create_query(
query_task_id = await self._run_query(
session, model.name, explore.name, dimensions
)
query_task_id = await self.client.create_query_task(session, query_id)

self.query_tasks[query_task_id] = explore
return query_task_id

Expand All @@ -381,11 +401,9 @@ async def _query_dimension(
str: Query task ID for the running query.

"""
query_id = await self.client.create_query(
query_task_id = await self._run_query(
session, model.name, explore.name, [dimension.name]
)
query_task_id = await self.client.create_query_task(session, query_id)

self.query_tasks[query_task_id] = dimension
return query_task_id

Expand Down
2 changes: 1 addition & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from unittest.mock import patch
import pytest
from tests.constants import ENV_VARS
from constants import ENV_VARS
from spectacles.cli import main, create_parser, handle_exceptions
from spectacles.exceptions import SpectaclesException, ValidationError

Expand Down
Loading