-
Notifications
You must be signed in to change notification settings - Fork 9
/
util.py
executable file
·73 lines (65 loc) · 2.44 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import torch.nn as nn
import threading
from torch._utils import ExceptionWrapper
import logging
def get_a_var(obj):
if isinstance(obj, torch.Tensor):
return obj
if isinstance(obj, list) or isinstance(obj, tuple):
for result in map(get_a_var, obj):
if isinstance(result, torch.Tensor):
return result
if isinstance(obj, dict):
for result in map(get_a_var, obj.items()):
if isinstance(result, torch.Tensor):
return result
return None
def parallel_apply(fct, model, inputs, device_ids):
modules = nn.parallel.replicate(model, device_ids)
assert len(modules) == len(inputs)
lock = threading.Lock()
results = {}
grad_enabled = torch.is_grad_enabled()
def _worker(i, module, input):
torch.set_grad_enabled(grad_enabled)
device = get_a_var(input).get_device()
try:
with torch.cuda.device(device):
# this also avoids accidental slicing of `input` if it is a Tensor
if not isinstance(input, (list, tuple)):
input = (input,)
output = fct(module, *input)
with lock:
results[i] = output
except Exception:
with lock:
results[i] = ExceptionWrapper(where="in replica {} on device {}".format(i, device))
if len(modules) > 1:
threads = [threading.Thread(target=_worker, args=(i, module, input))
for i, (module, input) in enumerate(zip(modules, inputs))]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0])
outputs = []
for i in range(len(inputs)):
output = results[i]
if isinstance(output, ExceptionWrapper):
output.reraise()
outputs.append(output)
return outputs
def get_logger(filename=None):
logger = logging.getLogger('logger')
logger.setLevel(logging.DEBUG)
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
if filename is not None:
handler = logging.FileHandler(filename)
handler.setLevel(logging.DEBUG)
handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
logging.getLogger().addHandler(handler)
return logger