-
Notifications
You must be signed in to change notification settings - Fork 210
/
utils.py
50 lines (39 loc) · 1.39 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import functools
import importlib
from typing import Callable
import torch
import triton
from packaging.version import Version
def ensure_contiguous(fn):
@functools.wraps(fn)
def wrapper(ctx, *args, **kwargs):
def maybe_to_contiguous(x):
return x.contiguous() if isinstance(x, torch.Tensor) else x
args = [maybe_to_contiguous(arg) for arg in args]
kwargs = {k: maybe_to_contiguous(v) for k, v in kwargs.items()}
return fn(ctx, *args, **kwargs)
return wrapper
def calculate_settings(n):
# reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43
MAX_FUSED_SIZE = 65536
BLOCK_SIZE = triton.next_power_of_2(n)
if BLOCK_SIZE > MAX_FUSED_SIZE:
raise RuntimeError(
f"Cannot launch Triton kernel since n = {n} exceeds "
f"the recommended Triton blocksize = {MAX_FUSED_SIZE}."
)
num_warps = 4
if BLOCK_SIZE >= 32768:
num_warps = 32
elif BLOCK_SIZE >= 8192:
num_warps = 16
elif BLOCK_SIZE >= 2048:
num_warps = 8
return BLOCK_SIZE, num_warps
def compare_version(package: str, operator: Callable, target: str):
try:
pkg = importlib.import_module(package)
except ImportError:
return False
pkg_version = Version(pkg.__version__)
return operator(pkg_version, Version(target))