diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h index 853e1a710cb24..6fa11200fd5be 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h @@ -214,13 +214,13 @@ Status CheckInputs(const Tensor* query, "head_size shall be a multiple of 16. Got head_size % 16 == ", head_size % 16); } - if (cos_dims[0] != present_sequence_length) { + if (cos_dims[0] < present_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache dimension 0 must be of present_sequence_length."); + "cos_cache dimension 0 should be of max_sequence_length."); } - if (sin_dims[0] != present_sequence_length) { + if (sin_dims[0] < present_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sin_cache dimension 0 must be of present_sequence_length."); + "sin_cache dimension 0 should be of max_sequence_length."); } if (cos_dims[1] != (head_size / 16) * 8) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,