Skip to content

Commit

Permalink
Add support for timing C++ snippets. (pytorch#47864)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#47864

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D25199262

Pulled By: robieta

fbshipit-source-id: 1c2114628ed543fba4f403bf49c065f4d71388e2
  • Loading branch information
Taylor Robie authored and facebook-github-bot committed Dec 2, 2020
1 parent 17ea112 commit 0225d3d
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 24 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,7 @@ def print_box(msg):
'share/cmake/Gloo/*.cmake',
'share/cmake/Tensorpipe/*.cmake',
'share/cmake/Torch/*.cmake',
'utils/benchmark/utils/*.cpp',
'utils/benchmark/utils/valgrind_wrapper/*.cpp',
'utils/benchmark/utils/valgrind_wrapper/*.h',
],
Expand Down
13 changes: 12 additions & 1 deletion test/benchmark_utils/test_benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch
import torch.utils.benchmark as benchmark_utils
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS, slowTest
from torch.testing._internal.common_utils import TestCase, run_tests, IS_SANDCASTLE, IS_WINDOWS, slowTest
from torch.testing._internal import expecttest
import numpy as np

Expand Down Expand Up @@ -162,6 +162,17 @@ def test_timer(self):
).timeit(5).median
self.assertIsInstance(sample, float)

@slowTest
@unittest.skipIf(IS_SANDCASTLE, "C++ timing is OSS only.")
def test_cpp_timer(self):
timer = benchmark_utils.Timer(
"torch::Tensor y = x + 1;",
setup="torch::Tensor x = torch::empty({1});",
language=benchmark_utils.Language.CPP,
)
t = timer.timeit(10)
self.assertIsInstance(t.median, float)

class _MockTimer:
_seed = 0

Expand Down
28 changes: 25 additions & 3 deletions torch/utils/benchmark/utils/_stubs.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,33 @@
import sys
from typing import TYPE_CHECKING
from typing import Any, Callable, Dict, TYPE_CHECKING


if TYPE_CHECKING or sys.version_info >= (3, 8):
from typing import Protocol
from typing import runtime_checkable, Protocol
else:
from typing_extensions import Protocol
from typing_extensions import runtime_checkable, Protocol


class TimerClass(Protocol):
"""This is the portion of the `timeit.Timer` API used by benchmark utils."""
def __init__(
self,
stmt: str,
setup: str,
timer: Callable[[], float],
globals: Dict[str, Any]
) -> None:
...

def timeit(self, number: int) -> float:
...


@runtime_checkable
class TimeitModuleType(Protocol):
"""Modules generated from `timeit_template.cpp`."""
def timeit(self, number: int) -> float:
...


class CallgrindModuleType(Protocol):
Expand Down
2 changes: 1 addition & 1 deletion torch/utils/benchmark/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def title(self) -> str:

def setup_str(self) -> str:
return (
"" if self.setup == "pass"
"" if (self.setup == "pass" or not self.setup)
else f"setup:\n{textwrap.indent(self.setup, ' ')}" if "\n" in self.setup
else f"setup: {self.setup}"
)
Expand Down
79 changes: 77 additions & 2 deletions torch/utils/benchmark/utils/cpp_jit.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,39 @@
"""JIT C++ strings into executables."""
import atexit
import os
import re
import shutil
import textwrap
import threading
from typing import List, Optional
import uuid
from typing import Any, List, Optional

import torch
from torch.utils.benchmark.utils._stubs import CallgrindModuleType
from torch.utils.benchmark.utils._stubs import CallgrindModuleType, TimeitModuleType
from torch.utils import cpp_extension


LOCK = threading.Lock()
SOURCE_ROOT = os.path.split(os.path.abspath(__file__))[0]

# We calculate uuid once at import time so that separate processes will have
# separate build roots, but threads will share the same build root.
# `cpp_extension` uses build root as part of the cache key, so per-invocation
# uuid's (e.g. different build root per _compile_template call) would lead to
# a 0% cache hit rate and spurious recompilation. Consider the following:
# ```
# setup = "auto x = torch::ones({1024, 1024});"
# stmt = "torch::mm(x, x);"
# for num_threads in [1, 2, 4, 8]:
# print(Timer(stmt, setup, num_threads=num_threads, language="c++").blocked_autorange())
# ````
# `setup` and `stmt` do not change, so we can reuse the executable from the
# first pass through the loop.
BUILD_ROOT = os.path.join(
torch._appdirs.user_cache_dir(appname="benchmark_utils_jit"),
f"build_{uuid.uuid4()}".replace("-", "")
)

# BACK_TESTING_NOTE:
# There are two workflows where this code could be used. One is the obvious
# case where someone simply builds or installs PyTorch and uses Timer.
Expand Down Expand Up @@ -66,3 +89,55 @@ def get_compat_bindings() -> CallgrindModuleType:
extra_include_paths=EXTRA_INCLUDE_PATHS,
)
return COMPAT_CALLGRIND_BINDINGS


def _compile_template(stmt: str, setup: str, src: str, is_standalone: bool) -> Any:
for before, after, indentation in (
("// SETUP_TEMPLATE_LOCATION", setup, 4),
("// STMT_TEMPLATE_LOCATION", stmt, 8)
):
# C++ doesn't care about indentation so this code isn't load
# bearing the way it is with Python, but this makes the source
# look nicer if a human has to look at it.
src = re.sub(
before,
textwrap.indent(after, " " * indentation)[indentation:],
src
)

# We want to isolate different Timers. However `cpp_extension` will
# cache builds which will significantly reduce the cost of repeated
# invocations.
with LOCK:
if not os.path.exists(BUILD_ROOT):
os.makedirs(BUILD_ROOT)
atexit.register(shutil.rmtree, BUILD_ROOT)

name = f"timer_cpp_{abs(hash(src))}"
build_dir = os.path.join(BUILD_ROOT, name)
os.makedirs(build_dir, exist_ok=True)

src_path = os.path.join(build_dir, "timer_src.cpp")
with open(src_path, "wt") as f:
f.write(src)

# `cpp_extension` has its own locking scheme, so we don't need our lock.
return cpp_extension.load(
name=name,
sources=[src_path],
build_directory=build_dir,
extra_cflags=CXX_FLAGS,
extra_include_paths=EXTRA_INCLUDE_PATHS,
is_python_module=not is_standalone,
is_standalone=is_standalone,
)


def compile_timeit_template(stmt: str, setup: str) -> TimeitModuleType:
template_path: str = os.path.join(SOURCE_ROOT, "timeit_template.cpp")
with open(template_path, "rt") as f:
src: str = f.read()

module = _compile_template(stmt, setup, src, is_standalone=False)
assert isinstance(module, TimeitModuleType)
return module
36 changes: 36 additions & 0 deletions torch/utils/benchmark/utils/timeit_template.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/* C++ template for Timer.timeit
This template will be consumed by `cpp_jit.py`, and will replace:
`SETUP_TEMPLATE_LOCATION`
and
`STMT_TEMPLATE_LOCATION`
sections with user provided statements.
*/
#include <chrono>

#include <pybind11/pybind11.h>
#include <torch/extension.h>


double timeit(int n) {
// Setup
// SETUP_TEMPLATE_LOCATION

{
// Warmup
// STMT_TEMPLATE_LOCATION
}

// Main loop
auto start_time = std::chrono::high_resolution_clock::now();
for (int loop_idx = 0; loop_idx < n; loop_idx++) {
// STMT_TEMPLATE_LOCATION
}
auto end_time = std::chrono::high_resolution_clock::now();
return std::chrono::duration<double>(end_time - start_time).count();
}


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("timeit", &timeit);
}
89 changes: 72 additions & 17 deletions torch/utils/benchmark/utils/timer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
"""Timer class based on the timeit.Timer class, but torch aware."""

import enum
import timeit
import textwrap
from typing import Any, Callable, Dict, List, NoReturn, Optional
from typing import Any, Callable, Dict, List, NoReturn, Optional, Type, Union

import numpy as np
import torch
from torch.utils.benchmark.utils import common
from torch.utils.benchmark.utils import common, cpp_jit
from torch.utils.benchmark.utils._stubs import TimerClass, TimeitModuleType
from torch.utils.benchmark.utils.valgrind_wrapper import timer_interface as valgrind_timer_interface


__all__ = ["Timer", "timer"]
__all__ = ["Timer", "timer", "Language"]


if torch.has_cuda and torch.cuda.is_available():
Expand All @@ -21,6 +22,46 @@ def timer() -> float:
timer = timeit.default_timer


class Language(enum.Enum):
PYTHON = 0
CPP = 1


class CPPTimer:
def __init__(
self,
stmt: str,
setup: str,
timer: Callable[[], float],
globals: Dict[str, Any],
) -> None:
if timer is not timeit.default_timer:
raise NotImplementedError(
"PyTorch was built with CUDA and a GPU is present; however "
"Timer does not yet support GPU measurements. If your "
"code is CPU only, pass `timer=timeit.default_timer` to the "
"Timer's constructor to indicate this. (Note that this will "
"produce incorrect results if the GPU is in fact used, as "
"Timer will not synchronize CUDA.)"
)

if globals:
raise ValueError("C++ timing does not support globals.")

self._stmt: str = textwrap.dedent(stmt)
self._setup: str = textwrap.dedent(setup)
self._timeit_module: Optional[TimeitModuleType] = None

def timeit(self, number: int) -> float:
if self._timeit_module is None:
self._timeit_module = cpp_jit.compile_timeit_template(
self._stmt,
self._setup,
)

return self._timeit_module.timeit(number)


class Timer(object):
"""Helper class for measuring execution time of PyTorch statements.
Expand Down Expand Up @@ -122,7 +163,7 @@ class Timer(object):
threadpool size which tries to utilize all cores.
"""

_timer_cls = timeit.Timer
_timer_cls: Type[TimerClass] = timeit.Timer

def __init__(
self,
Expand All @@ -135,21 +176,32 @@ def __init__(
description: Optional[str] = None,
env: Optional[str] = None,
num_threads: int = 1,
language: Union[Language, str] = Language.PYTHON,
):
if not isinstance(stmt, str):
raise ValueError("Currently only a `str` stmt is supported.")

# We copy `globals` to prevent mutations from leaking, (for instance,
# `eval` adds the `__builtins__` key) and include `torch` if not
# specified as a convenience feature.
globals = dict(globals or {})
globals.setdefault("torch", torch)
self._globals = globals
# We copy `globals` to prevent mutations from leaking.
# (For instance, `eval` adds the `__builtins__` key)
self._globals = dict(globals or {})
if language in (Language.PYTHON, "py", "python"):
# Include `torch` if not specified as a convenience feature.
self._globals.setdefault("torch", torch)
self._language: Language = Language.PYTHON

elif language in (Language.CPP, "cpp", "c++"):
assert self._timer_cls is timeit.Timer, "_timer_cls has already been swapped."
self._timer_cls = CPPTimer
setup = ("" if setup == "pass" else setup)
self._language = Language.CPP

else:
raise ValueError(f"Invalid language `{language}`.")

# Convenience adjustment so that multi-line code snippets defined in
# functions do not IndentationError inside timeit.Timer. The leading
# newline removal is for the initial newline that appears when defining
# block strings. For instance:
# functions do not IndentationError (Python) or look odd (C++). The
# leading newline removal is for the initial newline that appears when
# defining block strings. For instance:
# textwrap.dedent("""
# print("This is a stmt")
# """)
Expand All @@ -158,15 +210,15 @@ def __init__(
# Stripping this down to 'print("This is a stmt")' doesn't change
# what gets executed, but it makes __repr__'s nicer.
stmt = textwrap.dedent(stmt)
stmt = (stmt[1:] if stmt[0] == "\n" else stmt).rstrip()
stmt = (stmt[1:] if stmt and stmt[0] == "\n" else stmt).rstrip()
setup = textwrap.dedent(setup)
setup = (setup[1:] if setup[0] == "\n" else setup).rstrip()
setup = (setup[1:] if setup and setup[0] == "\n" else setup).rstrip()

self._timer = self._timer_cls(
stmt=stmt,
setup=setup,
timer=timer,
globals=valgrind_timer_interface.CopyIfCallgrind.unwrap_all(globals),
globals=valgrind_timer_interface.CopyIfCallgrind.unwrap_all(self._globals),
)
self._task_spec = common.TaskSpec(
stmt=stmt,
Expand Down Expand Up @@ -369,6 +421,9 @@ def collect_callgrind(
if not isinstance(self._task_spec.stmt, str):
raise ValueError("`collect_callgrind` currently only supports string `stmt`")

if self._language != Language.PYTHON:
raise NotImplementedError("C++ Callgrind is later in the stack.")

# Check that the statement is valid. It doesn't guarantee success, but it's much
# simpler and quicker to raise an exception for a faulty `stmt` or `setup` in
# the parent process rather than the valgrind subprocess.
Expand Down

0 comments on commit 0225d3d

Please sign in to comment.