diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/fused_moe_kernel.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/fused_moe_kernel.hip index f7288195ed..9a71f43853 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/fused_moe_kernel.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/fused_moe_kernel.hip @@ -40,7 +40,10 @@ at::Tensor fused_moe_impl( auto tokens = input.size(0); auto hidden_size = input.size(1); auto experts = gate_up_weight.size(0); - auto intermediate_size = gate_up_weight.size(1); + // Interface requires that you pass intermediate size. On |gate_only| = False, + // |gate_up_weight| might be 2 * intermediate size, so extract the size from + // |down_weight| + auto intermediate_size = down_weight.size(2); auto topk = topk_ids.size(1); auto stride = input.stride(0); @@ -81,6 +84,7 @@ at::Tensor fused_moe_impl( "fp32", // prec_sq (smooth quant) "fp32", // prec_kw (topk weight) static_cast(block_m), + 1, static_cast(gate_only), static_cast(fused_quant)};