Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

broadcast qkv_op #35780

Merged
merged 2 commits into from
Sep 17, 2021
Merged

broadcast qkv_op #35780

merged 2 commits into from
Sep 17, 2021

Conversation

fengxiaoshuai
Copy link
Contributor

PR types

Others

PR changes

Others

Describe

to support qk_bias is [batch, 1, 1, seq_len]

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@@ -233,6 +233,21 @@ __global__ void apply_scale(T *data, T scale, int n) {
#endif
}

inline int round_up(int seq_len, int multiple = 32) {
assert(multiple);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请使用paddle规范的错误判断方式,如PADDLE_ENFORCE_NE

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请使用paddle规范的错误判断方式,如PADDLE_ENFORCE_NE

好的,我改一下

Copy link
Member

@shangzhizhou shangzhizhou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -132,6 +132,24 @@ void TransQKVWithBias(const int batch, const int seq_len, const int head_size,
}
}

inline int round_up(int seq_len, int multiple = 32) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么不使用驼峰式命名,其他地方也一样
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么不使用驼峰式命名,其他地方也一样
image

我看这个文件里面很多地方都用下划线的方式,为了风格统一就延续了这种风格

PADDLE_ENFORCE_GT(
multiple, 0,
platform::errors::InvalidArgument(
"multiple should be a positive number,but it's (%d)", multiple));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个multiple需要标记一下吗?比如The input argument multiple,这个报错句子直接看语法是错的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个multiple需要标记一下吗?比如The input argument multiple,这个报错句子直接看语法是错的

这个可以修改一下

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

先合入,@fengxiaoshuai,develop提个PR改一下,或者下一个PR带一下。

@@ -233,6 +233,24 @@ __global__ void apply_scale(T *data, T scale, int n) {
#endif
}

inline int round_up(int seq_len, int multiple = 32) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两处代码是重复的吗?方便复用吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

考虑过,不过目前就用这两次,放到公共的头文件中发现这个函数和其他函数类型相比有点不伦不类,二者一个是trt,一个是cuda所以目前不太好放,后续如果常用或者有合适的地方会考虑重构一下

const float *input1_data = static_cast<const float *>(inputs[1]);
// fit to [batch, head_num, length, length] + [batch, 1, 1, length]
framework::Tensor temp_qk_bias_tensor;
float *qk_bias = const_cast<float *>(static_cast<const float *>(inputs[1]));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里const_cast要使用的理由是什么,需要解释下吗,这个输入为什么需要是const void *const *类型

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里const_cast要使用的理由是什么,需要解释下吗,这个输入为什么需要是const void *const *类型
这个是由于基类设置的接口的原因,目前没办法,trt这边plugin都是这么写的,具体也和秋良沟通过

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for this PR, maybe consider better design to avoid using const_cast

@shangzhizhou shangzhizhou merged commit cf9eae4 into PaddlePaddle:develop Sep 17, 2021
AnnaTrainingG pushed a commit to AnnaTrainingG/Paddle that referenced this pull request Sep 29, 2021
* broadcast qkv_op

* use PADDLE_ENFORCE_GT to replace assert
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants