-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add paddle.device.cuda.stream_guard API #35623
Conversation
Thanks for your contribution! |
c00a85d
to
4047772
Compare
4047772
to
19a17c9
Compare
@@ -104,5 +105,32 @@ def test_cuda_event_methods(self): | |||
self.assertTrue(event_query_2) | |||
|
|||
|
|||
class TestStreamGuard(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在PR上贴上验证的代码以及验证之后的效果
|
||
cur_stream = current_stream() | ||
if stream is None or id(stream) == id(cur_stream): | ||
yield |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里单测是不是要加上同样的stream
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
经讨论后不需要修改。
''' | ||
Set the current stream. | ||
|
||
Parameters: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Parameters->Args
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
询问了陈龙,Args 或者 Parameters 都可以,为了与本页面其他API 保持统一,不进行修改。
|
||
if (device == nullptr) { | ||
int curr_device_id = platform::GetCurrentDeviceId(); | ||
auto device_tmp = platform::CUDAPlace(curr_device_id); | ||
device = &device_tmp; | ||
} | ||
|
||
new (&self) paddle::platform::stream::CUDAStream(*device, prio); | ||
new (&self) paddle::platform::stream::CUDAStream(*device, prio, | ||
stream_flag); | ||
#else | ||
PADDLE_THROW(platform::errors::Unavailable( | ||
"Class CUDAStream can only be initialized on the GPU platform.")); | ||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的Stream方法,是不是可以默认non_blocking方式
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
enum class StreamFlag : uint8_t { | ||
kDefaultFlag = 0x0, | ||
kStreamNonBlocking = 0x1, | ||
kStreamPerThread = 0x2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个kStreamPerThread 可以去掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
A context manager that specifies the current stream context by the given stream. | ||
|
||
Parameters: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Paramters->Args
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
return core._set_current_stream(stream) | ||
|
||
|
||
@signature_safe_contextmanager |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dygraph_only
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
待定
if stream is None or id(stream) == id(cur_stream): | ||
yield | ||
else: | ||
pre_stream = _set_current_stream(stream) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stream 是否影响分布式环境?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
会进行线下测试,相关结果后续会贴在开头的 comment 中。
@@ -200,14 +212,16 @@ void BindCudaStream(py::module *m_ptr) { | |||
"Priority should be 1(high) or 2(normal) ")); | |||
} | |||
auto prio = paddle::platform::stream::Priority(priority); | |||
auto stream_flag = paddle::platform::stream::StreamFlag(1); | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the hard code 1
means?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 means non-blocking stream. We init CUDA Stream with default non-blocking property following pytorch implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about using paddle::platform::stream::StreamFlag::kStreamNonBlocking
instead of 1
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
if stream is None: | ||
raise ValueError("input stream should not be None.") | ||
if not isinstance(stream, paddle.device.cuda.Stream): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
下面的判断是否可以包含上面 None
的判断?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我想问,可不可以统一成 TypeError?(其实我不应该写成 ValueError,想统一改成 TypeError)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Add paddle.cuda.device.stream_guard API
PR types
New features
PR changes
APIs
Describe
This API provide a way to switch Cuda Stream flexibly.
Offline Test
Async property test
From the picture above, we can see that CUDA Kernel and CUDA Memcpy can run asynchronously.