Skip to content

Commit

Permalink
More dask features (#959)
Browse files Browse the repository at this point in the history
* upgrades to mypy 1.6

* pr feedback

* changelog

* adds sigint, mem and cpus support

* changelog

* weakref handle_kill

* test dask in CI

* typing

* ci

* ci

* fix tests
  • Loading branch information
normanrz authored Nov 16, 2023
1 parent 2662300 commit 336e2b6
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 14 deletions.
10 changes: 8 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
strategy:
max-parallel: 4
matrix:
executors: [multiprocessing, slurm, kubernetes]
executors: [multiprocessing, slurm, kubernetes, dask]
python-version: ["3.11", "3.10", "3.9", "3.8"]
defaults:
run:
Expand Down Expand Up @@ -88,7 +88,7 @@ jobs:
./kind load docker-image scalableminds/cluster-tools:latest
- name: Install dependencies (without docker)
if: ${{ matrix.executors == 'multiprocessing' || matrix.executors == 'kubernetes' }}
if: ${{ matrix.executors != 'slurm' }}
run: |
pip install -r ../requirements.txt
poetry install
Expand Down Expand Up @@ -130,6 +130,12 @@ jobs:
cd tests
PYTEST_EXECUTORS=kubernetes poetry run python -m pytest -sv test_all.py test_kubernetes.py
- name: Run dask tests
if: ${{ matrix.executors == 'dask' && matrix.python-version != '3.8' }}
run: |
cd tests
PYTEST_EXECUTORS=dask poetry run python -m pytest -sv test_all.py
webknossos_linux:
needs: changes
if: |
Expand Down
4 changes: 4 additions & 0 deletions cluster_tools/Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@ For upgrade instructions, please check the respective *Breaking Changes* section
### Breaking Changes

### Added
- Added SIGINT handling to `DaskExecutor`. [#959](https://github.com/scalableminds/webknossos-libs/pull/959)
- Added support for resources (e.g. mem, cpus) to `DaskExecutor`. [#959](https://github.com/scalableminds/webknossos-libs/pull/959)
- The cluster address for the `DaskExecutor` can be configured via the `DASK_ADDRESS` env var. [#959](https://github.com/scalableminds/webknossos-libs/pull/959)

### Changed
- Tasks using the `DaskExecutor` are run in their own process. This is required to not block the GIL for the dask worker to communicate with the scheduler. Env variables are propagated to the task processes. [#959](https://github.com/scalableminds/webknossos-libs/pull/959)

### Fixed

Expand Down
151 changes: 145 additions & 6 deletions cluster_tools/cluster_tools/executors/dask.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import os
import re
import signal
import traceback
from concurrent import futures
from concurrent.futures import Future
from functools import partial
from multiprocessing import Queue, get_context
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -11,9 +15,11 @@
Iterator,
List,
Optional,
Set,
TypeVar,
cast,
)
from weakref import ReferenceType, ref

from typing_extensions import ParamSpec

Expand All @@ -28,23 +34,119 @@
_S = TypeVar("_S")


def _run_in_nanny(
queue: Queue, __fn: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> None:
try:
__env = cast(Dict[str, str], kwargs.pop("__env"))
for key, value in __env.items():
os.environ[key] = value

ret = __fn(*args, **kwargs)
queue.put({"value": ret})
except Exception as exc:
queue.put({"exception": exc})


def _run_with_nanny(
__fn: Callable[_P, _T],
*args: _P.args,
**kwargs: _P.kwargs,
) -> _T:
mp_context = get_context("spawn")
q = mp_context.Queue()
p = mp_context.Process(target=_run_in_nanny, args=(q, __fn) + args, kwargs=kwargs)
p.start()
p.join()
ret = q.get(timeout=0.1)
if "exception" in ret:
raise ret["exception"]
else:
return ret["value"]


def _parse_mem(size: str) -> int:
units = {"": 1, "K": 2**10, "M": 2**20, "G": 2**30, "T": 2**40}
m = re.match(r"^([\d\.]+)\s*([kmgtKMGT]{0,1})$", str(size).strip())
assert m is not None, f"Could not parse {size}"
number, unit = float(m.group(1)), m.group(2).upper()
assert unit in units
return int(number * units[unit])


def _handle_kill_through_weakref(
executor_ref: "ReferenceType[DaskExecutor]",
existing_sigint_handler: Any,
signum: Optional[int],
frame: Any,
) -> None:
executor = executor_ref()
if executor is None:
return
executor.handle_kill(existing_sigint_handler, signum, frame)


class DaskExecutor(futures.Executor):
"""
The `DaskExecutor` allows to run workloads on a dask cluster.
The executor can be constructed with an existing dask `Client` or
from a declarative configuration. The address of the dask scheduler
can be part of the configuration or supplied as environment variable
`DASK_ADDRESS`.
There is support for resource-based scheduling. As default, `mem` and
`cpus-per-task` are supported. To make use of them, the dask workers
should be started with:
`python -m dask worker --no-nanny --nthreads 6 tcp://... --resources "mem=1073741824 cpus=8"`
"""

client: "Client"
pending_futures: Set[Future]
job_resources: Optional[Dict[str, Any]]
is_shutting_down = False

def __init__(
self,
client: "Client",
self, client: "Client", job_resources: Optional[Dict[str, Any]] = None
) -> None:
self.client = client
self.pending_futures = set()
self.job_resources = job_resources

if self.job_resources is not None:
# `mem` needs to be a number for dask, so we need to parse it
if "mem" in self.job_resources:
self.job_resources["mem"] = _parse_mem(self.job_resources["mem"])
if "cpus-per-task" in self.job_resources:
self.job_resources["cpus"] = int(
self.job_resources.pop("cpus-per-task")
)

# Clean up if a SIGINT signal is received. However, do not interfere with the
# existing signal handler of the process or the
# shutdown of the main process which sends SIGTERM signals to terminate all
# child processes.
existing_sigint_handler = signal.getsignal(signal.SIGINT)
signal.signal(
signal.SIGINT,
partial(_handle_kill_through_weakref, ref(self), existing_sigint_handler),
)

@classmethod
def from_config(
cls,
job_resources: Dict[str, Any],
job_resources: Dict[str, str],
**_kwargs: Any,
) -> "DaskExecutor":
from distributed import Client

return cls(Client(**job_resources))
job_resources = job_resources.copy()
address = job_resources.pop("address", None)
if address is None:
address = os.environ.get("DASK_ADDRESS", None)

client = Client(address=address)
return cls(client, job_resources=job_resources)

@classmethod
def as_completed(cls, futures: List["Future[_T]"]) -> Iterator["Future[_T]"]:
Expand Down Expand Up @@ -72,7 +174,20 @@ def submit( # type: ignore[override]
__fn,
),
)
fut = self.client.submit(partial(__fn, *args, **kwargs))

kwargs["__env"] = os.environ.copy()

# We run the functions in dask as a separate process to not hold the
# GIL for too long, because dask workers need to be able to communicate
# with the scheduler regularly.
__fn = partial(_run_with_nanny, __fn)

fut = self.client.submit(
partial(__fn, *args, **kwargs), pure=False, resources=self.job_resources
)

self.pending_futures.add(fut)
fut.add_done_callback(self.pending_futures.remove)

enrich_future_with_uncaught_warning(fut)
return fut
Expand Down Expand Up @@ -125,8 +240,32 @@ def map( # type: ignore[override]
def forward_log(self, fut: "Future[_T]") -> _T:
return fut.result()

def handle_kill(
self,
existing_sigint_handler: Any,
signum: Optional[int],
frame: Any,
) -> None:
if self.is_shutting_down:
return

self.is_shutting_down = True

self.client.cancel(list(self.pending_futures))

if (
existing_sigint_handler # pylint: disable=comparison-with-callable
!= signal.default_int_handler
and callable(existing_sigint_handler) # Could also be signal.SIG_IGN
):
existing_sigint_handler(signum, frame)

def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None:
print(f"{wait=} {cancel_futures=}")
traceback.print_stack()
if wait:
self.client.close(timeout=60 * 60 * 24)
for fut in list(self.pending_futures):
fut.result()
self.client.close(timeout=60 * 60) # 1 hour
else:
self.client.close()
18 changes: 17 additions & 1 deletion cluster_tools/cluster_tools/schedulers/cluster_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Union,
cast,
)
from weakref import ReferenceType, ref

from typing_extensions import ParamSpec

Expand All @@ -45,6 +46,18 @@
_S = TypeVar("_S")


def _handle_kill_through_weakref(
executor_ref: "ReferenceType[ClusterExecutor]",
existing_sigint_handler: Any,
signum: Optional[int],
frame: Any,
) -> None:
executor = executor_ref()
if executor is None:
return
executor.handle_kill(existing_sigint_handler, signum, frame)


def join_messages(strings: List[str]) -> str:
return " ".join(x.strip() for x in strings if x.strip())

Expand Down Expand Up @@ -130,7 +143,10 @@ def __init__(
# shutdown of the main process which sends SIGTERM signals to terminate all
# child processes.
existing_sigint_handler = signal.getsignal(signal.SIGINT)
signal.signal(signal.SIGINT, partial(self.handle_kill, existing_sigint_handler))
signal.signal(
signal.SIGINT,
partial(_handle_kill_through_weakref, ref(self), existing_sigint_handler),
)

self.meta_data = {}
assert not (
Expand Down
14 changes: 9 additions & 5 deletions cluster_tools/tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from distributed import LocalCluster

import cluster_tools
from cluster_tools.executors.dask import DaskExecutor


# "Worker" functions.
Expand Down Expand Up @@ -79,10 +78,14 @@ def get_executors(with_debug_sequential: bool = False) -> List[cluster_tools.Exe
executors.append(cluster_tools.get_executor("sequential"))
if "dask" in executor_keys:
if not _dask_cluster:
from distributed import LocalCluster
from distributed import LocalCluster, Worker

_dask_cluster = LocalCluster()
executors.append(cluster_tools.get_executor("dask", address=_dask_cluster))
_dask_cluster = LocalCluster(
worker_class=Worker, resources={"mem": 20e9, "cpus": 4}, nthreads=6
)
executors.append(
cluster_tools.get_executor("dask", job_resources={"address": _dask_cluster})
)
if "test_pickling" in executor_keys:
executors.append(cluster_tools.get_executor("test_pickling"))
if "pbs" in executor_keys:
Expand Down Expand Up @@ -328,7 +331,8 @@ def run_map(executor: cluster_tools.Executor) -> None:
assert list(result) == [4, 9, 16]

for exc in get_executors():
run_map(exc)
if not isinstance(exc, cluster_tools.DaskExecutor):
run_map(exc)


def test_executor_args() -> None:
Expand Down

0 comments on commit 336e2b6

Please sign in to comment.