Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HybridParallel]Add Recompute for PipeLineParallel #34607

Merged
merged 12 commits into from
Aug 12, 2021

Conversation

ForFishes
Copy link
Member

@ForFishes ForFishes commented Aug 4, 2021

PR types

New features

PR changes

Others

Describe

Add Recompute for PipeLineParallel

1、接口形式

class PipelineLayer(Layer):
    def __init__(self,
                 layers,
                 num_stages=None,
                 topology=None,
                 loss_fn=None,
                 seg_method="uniform",
                 recompute_interval=0,
                 recompute_offload=False,
                 recompute_partition=False):

2、功能支持

相比paddle原生的recompute,有以下几处不同:

  • 支持PipeLineParallel并行,通过修改recompute的输入形式进行适配。
  • 支持offload功能。
  • 支持MP模式下,checkpoint的进一步裁剪,减少显存占用。
  • 适配混合并行下,随机性控制。

3、性能对比

GPT-117M模型,V100-32G, FP32,MP=4, PP=2, mircrobatch=2, global_batch_size=128,中间卡显存

实验配置 显存占用(M) 性能速度(tokens/s)
w/o recompute 4999 54584
recompute 3531 43251
recompute + offload 3489 22529
recompute + MP切分 3473 35842
recompute + offload + MP切分 3495 28078

?? recompute + offload + MP切分的组合显存相比更大?
nvidia-smi显示的显存,可以已经释放但被paddle缓存住了。

4、精度对比
在GPT-117M,MP2_PP2下验证精度

DP2_MP2_PP2 + AMP

5、TODO

  • 减少显存碎片

@paddle-bot-old
Copy link

paddle-bot-old bot commented Aug 4, 2021

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@ForFishes ForFishes force-pushed the add_recompute_for_hybrid branch from f28cc8a to ae6ac75 Compare August 5, 2021 03:06
ctx.tensor_shapes.append(arg.shape)
partition = _split_activation(arg.detach()).clone()
# TODO(shenliang03) not use calculate stream to D2H to speed
arg = partition.cpu() if _recompute_offload else partition
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Offload in dygraph is Sooooo easy!!! lol

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the fleet.utils.recompute could do in the same way

Copy link
Member Author

@ForFishes ForFishes Aug 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, currently support hybrid_parallel first.

tensor_shapes[i])
tensors[i].stop_gradient = state
inputs[idx] = tensors[i].cuda(
device_id) if _recompute_offload else tensors[i]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should sync here ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait the H2D copy finish before conduct the following computation

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cpu() is sync operation, we don't need do this

JZ-LIANG
JZ-LIANG previously approved these changes Aug 5, 2021
Copy link
Contributor

@JZ-LIANG JZ-LIANG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ForFishes ForFishes force-pushed the add_recompute_for_hybrid branch from 88ee4ad to 5e50e53 Compare August 11, 2021 09:02
Copy link
Contributor

@JZ-LIANG JZ-LIANG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for op_function_generator

@ForFishes ForFishes merged commit 589d13c into PaddlePaddle:develop Aug 12, 2021
@ForFishes ForFishes deleted the add_recompute_for_hybrid branch August 12, 2021 03:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants