-
Notifications
You must be signed in to change notification settings - Fork 52
/
run.py
73 lines (67 loc) · 2.68 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import hydra
from omegaconf import DictConfig
from flagscale.runner.runner_train import SSHTrainRunner, CloudTrainRunner
from flagscale.runner.runner_inference import SSHInferenceRunner
from flagscale.runner.runner_serve import SSHServeRunner
from flagscale.runner.runner_compress import SSHCompressRunner
@hydra.main(version_base=None, config_name="config")
def main(config: DictConfig) -> None:
task_type = config.experiment.task.get("type", "train")
if task_type == "train":
if config.action == "auto_tune":
from flagscale.auto_tuner import AutoTuner
# For MPIRUN scene, just one autotuner process.
# NOTE: This is a temporary solution and will be updated with cloud runner.
from flagscale.auto_tuner.utils import is_master
if is_master(config):
tuner = AutoTuner(config)
tuner.tune()
else:
if config.experiment.runner.get("type", "ssh") == "ssh":
runner = SSHTrainRunner(config)
elif config.experiment.runner.get("type") == "cloud":
runner = CloudTrainRunner(config)
else:
raise ValueError(f"Unknown runner type {config.runner.type}")
if config.action == "run":
runner.run()
elif config.action == "dryrun":
runner.run(dryrun=True)
elif config.action == "test":
runner.run(with_test=True)
elif config.action == "stop":
runner.stop()
elif config.action == "query":
runner.query()
else:
raise ValueError(f"Unknown action {config.action}")
elif task_type == "inference":
runner = SSHInferenceRunner(config)
if config.action == "run":
runner.run()
elif config.action == "dryrun":
runner.run(dryrun=True)
elif config.action == "stop":
runner.stop()
else:
raise ValueError(f"Unknown action {config.action}")
elif task_type == "serve":
runner = SSHServeRunner(config)
if config.action == "run":
runner.run()
elif config.action == "test":
runner.run(with_test=True)
elif task_type == "compress":
runner = SSHCompressRunner(config)
if config.action == "run":
runner.run()
elif config.action == "dryrun":
runner.run(dryrun=True)
elif config.action == "stop":
runner.stop()
else:
raise ValueError(f"Unknown action {config.action}")
else:
raise ValueError(f"Unknown task type {task_type}")
if __name__ == "__main__":
main()