diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index d8595d2355fb3..04ef842fbdf95 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -1864,7 +1864,7 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, float pipe_size = (a_size + b_size) * pipe_stages; - float reduce_size = max(th_config.num_threads * 2 * 4 * 4, + float reduce_size = max(th_config.num_threads * 32 * 4, (tb_n / 64) * 32 * (tb_max_m / 16) * 4 * 2 * 4 * 2); TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity