-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
broadcast qkv_op #35780
broadcast qkv_op #35780
Conversation
Thanks for your contribution! |
@@ -233,6 +233,21 @@ __global__ void apply_scale(T *data, T scale, int n) { | |||
#endif | |||
} | |||
|
|||
inline int round_up(int seq_len, int multiple = 32) { | |||
assert(multiple); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
请使用paddle规范的错误判断方式,如PADDLE_ENFORCE_NE
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
请使用paddle规范的错误判断方式,如PADDLE_ENFORCE_NE
好的,我改一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -132,6 +132,24 @@ void TransQKVWithBias(const int batch, const int seq_len, const int head_size, | |||
} | |||
} | |||
|
|||
inline int round_up(int seq_len, int multiple = 32) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PADDLE_ENFORCE_GT( | ||
multiple, 0, | ||
platform::errors::InvalidArgument( | ||
"multiple should be a positive number,but it's (%d)", multiple)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个multiple需要标记一下吗?比如The input argument multiple
,这个报错句子直接看语法是错的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个multiple需要标记一下吗?比如The input argument
multiple
,这个报错句子直接看语法是错的
这个可以修改一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
先合入,@fengxiaoshuai,develop提个PR改一下,或者下一个PR带一下。
@@ -233,6 +233,24 @@ __global__ void apply_scale(T *data, T scale, int n) { | |||
#endif | |||
} | |||
|
|||
inline int round_up(int seq_len, int multiple = 32) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这两处代码是重复的吗?方便复用吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
考虑过,不过目前就用这两次,放到公共的头文件中发现这个函数和其他函数类型相比有点不伦不类,二者一个是trt,一个是cuda所以目前不太好放,后续如果常用或者有合适的地方会考虑重构一下
const float *input1_data = static_cast<const float *>(inputs[1]); | ||
// fit to [batch, head_num, length, length] + [batch, 1, 1, length] | ||
framework::Tensor temp_qk_bias_tensor; | ||
float *qk_bias = const_cast<float *>(static_cast<const float *>(inputs[1])); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里const_cast要使用的理由是什么,需要解释下吗,这个输入为什么需要是const void *const *
类型
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里const_cast要使用的理由是什么,需要解释下吗,这个输入为什么需要是
const void *const *
类型
这个是由于基类设置的接口的原因,目前没办法,trt这边plugin都是这么写的,具体也和秋良沟通过
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for this PR, maybe consider better design to avoid using const_cast
* broadcast qkv_op * use PADDLE_ENFORCE_GT to replace assert
PR types
Others
PR changes
Others
Describe
to support qk_bias is [batch, 1, 1, seq_len]