Skip to content

Commit

Permalink
Stop benchmarking compile time of dead code (#145590)
Browse files Browse the repository at this point in the history
Summary:
FIXES pytorch/pytorch#144775 frfr

See details on the problem: pytorch/pytorch#144775 (comment)
We fixed some silent incorrectness, but it results in less nodes DCE'd. The benchmark iteration loop had some dead code which could contain side effect ops that aren't safe to DCE. The regression is expected.

This PR removes the compile time benchmarking of the dead code, which should reduce the noise of the benchmark and aligns with the benchmarking used by performance tests

New benchmark results:
```python
dev,name,batch_size,accuracy,calls_captured,unique_graphs,graph_breaks,unique_graph_breaks,autograd_captures,autograd_compiles,cudagraph_skips,compilation_latency
cuda,BartForConditionalGeneration,1,pass,897,1,0,0,0,0,0,39.322364  # after pytorch/pytorch#144319
cuda,BartForConditionalGeneration,1,pass,897,1,0,0,0,0,0,38.972257  # before pytorch/pytorch#144319
```

X-link: pytorch/pytorch#145590
Approved by: https://github.com/jansel
ghstack dependencies: #145447

Reviewed By: ZainRizvi

Differential Revision: D68860252

fbshipit-source-id: 60371bdf3ba6e6f38766d6589690a221f8cebda4
  • Loading branch information
xmfan authored and facebook-github-bot committed Jan 30, 2025
1 parent 0e370a0 commit d9cc213
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 13 deletions.
40 changes: 29 additions & 11 deletions userbenchmark/dynamo/dynamobench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2781,11 +2781,11 @@ def batch_size_finder(self, device, model_name, initial_batch_size=1024):
batch_size = self.decay_batch_exp(batch_size)
return 1

def run_n_iterations(self, mod, inputs):
def run_n_iterations(self, mod, inputs, model_iter_fn):
n = self.args.iterations
for _ in range(n - 1):
self.model_iter_fn(mod, inputs, collect_outputs=False)
return self.model_iter_fn(mod, inputs, collect_outputs=True)
model_iter_fn(mod, inputs, collect_outputs=False)
return model_iter_fn(mod, inputs, collect_outputs=True)

@torch._disable_dynamo(recursive=True)
def optimizer_zero_grad(self, mod):
Expand Down Expand Up @@ -2953,7 +2953,9 @@ def record_status(accuracy_status, dynamo_start_stats):
clone_inputs(example_inputs),
)
self.init_optimizer(name, current_device, model_fp64.parameters())
fp64_outputs = self.run_n_iterations(model_fp64, inputs_fp64)
fp64_outputs = self.run_n_iterations(
model_fp64, inputs_fp64, self.model_iter_fn
)
fp64_outputs = tree_map(
lambda x: x.to(torch.float64)
if isinstance(x, torch.Tensor) and x.is_floating_point()
Expand Down Expand Up @@ -2986,7 +2988,7 @@ def record_status(accuracy_status, dynamo_start_stats):
model_copy = self.deepcopy_and_maybe_parallelize(model)
self.init_optimizer(name, current_device, model_copy.parameters())
correct_result = self.run_n_iterations(
model_copy, clone_inputs(example_inputs)
model_copy, clone_inputs(example_inputs), self.model_iter_fn
)
except Exception as e:
accuracy_status = (
Expand All @@ -3007,7 +3009,7 @@ def record_status(accuracy_status, dynamo_start_stats):
model_copy = self.deepcopy_and_maybe_parallelize(model)
self.init_optimizer(name, current_device, model_copy.parameters())
correct_rerun_result = self.run_n_iterations(
model_copy, clone_inputs(example_inputs)
model_copy, clone_inputs(example_inputs), self.model_iter_fn
)
except Exception as e:
accuracy_status = (
Expand Down Expand Up @@ -3066,13 +3068,15 @@ def record_status(accuracy_status, dynamo_start_stats):
)
new_result = optimized_model_iter_fn(model_copy, example_inputs)
else:
optimized_model_iter_fn = optimize_ctx(self.run_n_iterations)
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
with maybe_enable_compiled_autograd(
self.args.compiled_autograd,
fullgraph=self.args.nopython,
dynamic=self.args.dynamic_shapes,
):
new_result = optimized_model_iter_fn(model_copy, example_inputs)
new_result = self.run_n_iterations(
model_copy, example_inputs, optimized_model_iter_fn
)
except Exception as e:
log.exception("")
print(
Expand Down Expand Up @@ -3167,7 +3171,9 @@ def check_tolerance(
lambda x: x.to(base_device), example_inputs_copy
)
self.init_optimizer(name, base_device, model_copy.parameters())
correct_result = self.run_n_iterations(model_copy, example_inputs_copy)
correct_result = self.run_n_iterations(
model_copy, example_inputs_copy, self.model_iter_fn
)

# Run with Dynamo
# Sometime CI fails with random triton compilation failure which will be skipped for now
Expand All @@ -3176,8 +3182,10 @@ def check_tolerance(
torch._dynamo.reset()
try:
self.init_optimizer(name, current_device, model.parameters())
optimized_model_iter_fn = optimize_ctx(self.run_n_iterations)
new_result = optimized_model_iter_fn(model, example_inputs)
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
new_result = self.run_n_iterations(
model_copy, example_inputs, optimized_model_iter_fn
)
except Exception:
log.exception("")
print(
Expand Down Expand Up @@ -4460,6 +4468,16 @@ def run(runner, args, original_dir=None):
# Stricter check to disable fallbacks
args.suppress_errors = False

if not args.disable_cudagraphs:
runner.skip_models.update(
{
# xfail: https://github.com/pytorch/pytorch/issues/145773
"convit_base",
"llama",
"cm3leon_generate",
}
)

if args.device_index is not None:
if args.multiprocess:
print("Cannot specify both --device_index and --multiprocess")
Expand Down
16 changes: 14 additions & 2 deletions userbenchmark/dynamo/dynamobench/timm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@


try:
from .common import BenchmarkRunner, download_retry_decorator, main
from .common import BenchmarkRunner, download_retry_decorator, load_yaml_file, main
except ImportError:
from common import BenchmarkRunner, download_retry_decorator, main
from common import BenchmarkRunner, download_retry_decorator, load_yaml_file, main

import torch
from torch._dynamo.testing import collect_results, reduce_to_scalar_loss
Expand Down Expand Up @@ -218,6 +218,18 @@ def __init__(self):
super().__init__()
self.suite_name = "timm_models"

@property
def _config(self):
return load_yaml_file("timm_models.yaml")

@property
def _skip(self):
return self._config["skip"]

@property
def skip_models(self):
return self._skip["all"]

@property
def force_amp_for_fp16_bf16_models(self):
return FORCE_AMP_FOR_FP16_BF16_MODELS
Expand Down

0 comments on commit d9cc213

Please sign in to comment.