Skip to content

Commit

Permalink
feat(ingest/looker): add backpressure-aware executor (#9615)
Browse files Browse the repository at this point in the history
Co-authored-by: Tamas Nemeth <[email protected]>
  • Loading branch information
hsheth2 and treff7es authored Jan 12, 2024
1 parent dc93f2e commit 98e3da4
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 47 deletions.
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from __future__ import print_function

import datetime
import itertools
import logging
import re
from contextlib import contextmanager
from dataclasses import dataclass, field as dataclasses_field
from enum import Enum
from functools import lru_cache
from typing import (
TYPE_CHECKING,
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Expand Down Expand Up @@ -1126,6 +1126,14 @@ def report_stage_end(self, stage_name: str) -> None:
if self.stage_latency[-1].name == stage_name:
self.stage_latency[-1].end_time = datetime.datetime.now()

@contextmanager
def report_stage(self, stage_name: str) -> Iterator[None]:
try:
self.report_stage_start(stage_name)
yield
finally:
self.report_stage_end(stage_name)

def compute_stats(self) -> None:
if self.total_dashboards:
self.dashboard_process_percentage_completion = round(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,7 @@ def __init__(self, config: LookerAPIConfig) -> None:
self.client.transport.session.mount("http://", adapter)
self.client.transport.session.mount("https://", adapter)
elif self.config.max_retries > 0:
raise ConfigurationError(
"Unable to configure retries on the Looker SDK transport."
)
logger.warning("Unable to configure retries on the Looker SDK transport.")

self.transport_options = (
config.transport_options.get_transport_options()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import concurrent.futures
import datetime
import json
import logging
Expand Down Expand Up @@ -91,6 +90,7 @@
OwnershipClass,
OwnershipTypeClass,
)
from datahub.utilities.advanced_thread_executor import BackpressureAwareExecutor

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -700,28 +700,19 @@ def _make_explore_metadata_events(
explores_to_fetch = list(self.list_all_explores())
explores_to_fetch.sort()

with concurrent.futures.ThreadPoolExecutor(
max_workers=self.source_config.max_threads
) as async_executor:
self.reporter.total_explores = len(explores_to_fetch)

explore_futures = {
async_executor.submit(self.fetch_one_explore, model, explore): (
model,
explore,
)
for (model, explore) in explores_to_fetch
}

for future in concurrent.futures.wait(explore_futures).done:
events, explore_id, start_time, end_time = future.result()
del explore_futures[future]
self.reporter.explores_scanned += 1
yield from events
self.reporter.report_upstream_latency(start_time, end_time)
logger.debug(
f"Running time of fetch_one_explore for {explore_id}: {(end_time - start_time).total_seconds()}"
)
self.reporter.total_explores = len(explores_to_fetch)
for future in BackpressureAwareExecutor.map(
self.fetch_one_explore,
((model, explore) for (model, explore) in explores_to_fetch),
max_workers=self.source_config.max_threads,
):
events, explore_id, start_time, end_time = future.result()
self.reporter.explores_scanned += 1
yield from events
self.reporter.report_upstream_latency(start_time, end_time)
logger.debug(
f"Running time of fetch_one_explore for {explore_id}: {(end_time - start_time).total_seconds()}"
)

def list_all_explores(self) -> Iterable[Tuple[str, str]]:
# returns a list of (model, explore) tuples
Expand Down Expand Up @@ -1277,28 +1268,24 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
]

looker_dashboards_for_usage: List[looker_usage.LookerDashboardForUsage] = []
self.reporter.report_stage_start("dashboard_chart_metadata")

with concurrent.futures.ThreadPoolExecutor(
max_workers=self.source_config.max_threads
) as async_executor:
async_workunits = {}
for dashboard_id in dashboard_ids:
if dashboard_id is not None:
job = async_executor.submit(
self.process_dashboard, dashboard_id, fields
)
async_workunits[job] = dashboard_id

for job in concurrent.futures.as_completed(async_workunits):
with self.reporter.report_stage("dashboard_chart_metadata"):
for job in BackpressureAwareExecutor.map(
self.process_dashboard,
(
(dashboard_id, fields)
for dashboard_id in dashboard_ids
if dashboard_id is not None
),
max_workers=self.source_config.max_threads,
):
(
work_units,
dashboard_usage,
dashboard_id,
start_time,
end_time,
) = job.result()
del async_workunits[job]
logger.debug(
f"Running time of process_dashboard for {dashboard_id} = {(end_time - start_time).total_seconds()}"
)
Expand All @@ -1308,8 +1295,6 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
if dashboard_usage is not None:
looker_dashboards_for_usage.append(dashboard_usage)

self.reporter.report_stage_end("dashboard_chart_metadata")

if (
self.source_config.extract_owners
and self.reporter.resolved_user_ids > 0
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
from __future__ import annotations

import collections
import concurrent.futures
import time
from concurrent.futures import Future, ThreadPoolExecutor
from threading import BoundedSemaphore
from typing import Any, Callable, Deque, Dict, Optional, Tuple, TypeVar
from typing import (
Any,
Callable,
Deque,
Dict,
Iterable,
Iterator,
Optional,
Set,
Tuple,
TypeVar,
)

from datahub.ingestion.api.closeable import Closeable

Expand Down Expand Up @@ -130,3 +144,74 @@ def shutdown(self) -> None:

def close(self) -> None:
self.shutdown()


class BackpressureAwareExecutor:
# This couldn't be a real executor because the semantics of submit wouldn't really make sense.
# In this variant, if we blocked on submit, then we would also be blocking the thread that
# we expect to be consuming the results. As such, I made it accept the full list of args
# up front, and that way the consumer can read results at its own pace.

@classmethod
def map(
cls,
fn: Callable[..., _R],
args_list: Iterable[Tuple[Any, ...]],
max_workers: int,
max_pending: Optional[int] = None,
) -> Iterator[Future[_R]]:
"""Similar to concurrent.futures.ThreadPoolExecutor#map, except that it won't run ahead of the consumer.
The main benefit is that the ThreadPoolExecutor isn't stuck holding a ton of result
objects in memory if the consumer is slow. Instead, the consumer can read the results
at its own pace and the executor threads will idle if they need to.
Args:
fn: The function to apply to each input.
args_list: The list of inputs. In contrast to the builtin map, this is a list
of tuples, where each tuple is the arguments to fn.
max_workers: The maximum number of threads to use.
max_pending: The maximum number of pending results to keep in memory.
If not set, it will be set to 2*max_workers.
Returns:
An iterable of futures.
This differs from a traditional map because it returns futures
instead of the actual results, so that the caller is required
to handle exceptions.
Additionally, it does not maintain the order of the arguments.
If you want to know which result corresponds to which input,
the mapped function should return some form of an identifier.
"""

if max_pending is None:
max_pending = 2 * max_workers
assert max_pending >= max_workers

pending_futures: Set[Future] = set()

with ThreadPoolExecutor(max_workers=max_workers) as executor:
for args in args_list:
# If the pending list is full, wait until one is done.
if len(pending_futures) >= max_pending:
(done, _) = concurrent.futures.wait(
pending_futures, return_when=concurrent.futures.FIRST_COMPLETED
)
for future in done:
pending_futures.remove(future)

# We don't want to call result() here because we want the caller
# to handle exceptions/cancellation.
yield future

# Now that there's space in the pending list, enqueue the next task.
pending_futures.add(executor.submit(fn, *args))

# Wait for all the remaining tasks to complete.
for future in concurrent.futures.as_completed(pending_futures):
pending_futures.remove(future)
yield future

assert not pending_futures
1 change: 0 additions & 1 deletion metadata-ingestion/tests/integration/lookml/test_lookml.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,6 @@ def ingestion_test(
"client_id": "fake_client_id",
"client_secret": "fake_secret",
"base_url": "fake_account.looker.com",
"max_retries": 0,
},
"parse_table_names_from_sql": True,
"model_pattern": {"deny": ["data2"]},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import time
from concurrent.futures import Future

from datahub.utilities.advanced_thread_executor import PartitionExecutor
from datahub.utilities.advanced_thread_executor import (
BackpressureAwareExecutor,
PartitionExecutor,
)
from datahub.utilities.perf_timer import PerfTimer


Expand Down Expand Up @@ -68,3 +71,58 @@ def task(id: str) -> str:
# Wait for everything to finish.
executor.flush()
assert len(done_tasks) == 16


def test_backpressure_aware_executor_simple():
def task(i):
return i

assert set(
res.result()
for res in BackpressureAwareExecutor.map(
task, ((i,) for i in range(10)), max_workers=2
)
) == set(range(10))


def test_backpressure_aware_executor_advanced():
task_duration = 0.5
started = set()
executed = set()

def task(x, y):
assert x + 1 == y
started.add(x)
time.sleep(task_duration)
executed.add(x)
return x

args_list = [(i, i + 1) for i in range(10)]

with PerfTimer() as timer:
results = BackpressureAwareExecutor.map(
task, args_list, max_workers=2, max_pending=4
)
assert timer.elapsed_seconds() < task_duration

# No tasks should have completed yet.
assert len(executed) == 0

# Consume the first result.
first_result = next(results)
assert 0 <= first_result.result() < 4
assert timer.elapsed_seconds() > task_duration

# By now, the first four tasks should have started.
time.sleep(task_duration)
assert {0, 1, 2, 3}.issubset(started)
assert 2 <= len(executed) <= 4

# Finally, consume the rest of the results.
assert set(r.result() for r in results) == {
i for i in range(10) if i != first_result.result()
}

# Validate that the entire process took about 5-10x the task duration.
# That's because we have 2 workers and 10 tasks.
assert 5 * task_duration < timer.elapsed_seconds() < 10 * task_duration

0 comments on commit 98e3da4

Please sign in to comment.