Skip to content

Commit

Permalink
Optimize caclulation of shared memory size for reduction
Browse files Browse the repository at this point in the history
Signed-off-by: wchen61 <[email protected]>
  • Loading branch information
wchen61 committed Jan 2, 2025
1 parent aa2f07a commit 4762127
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion csrc/quantization/gptq_marlin/gptq_marlin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1864,7 +1864,7 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,

float pipe_size = (a_size + b_size) * pipe_stages;

float reduce_size = max(th_config.num_threads * 2 * 4 * 4,
float reduce_size = max(th_config.num_threads * 32 * 4,
(tb_n / 64) * 32 * (tb_max_m / 16) * 4 * 2 * 4 * 2);

TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
Expand Down

0 comments on commit 4762127

Please sign in to comment.