Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Pipeline Parallelism cannot work with BFloat16 + Optimizer offload #3866

Open
SparkJiao opened this issue Jul 3, 2023 · 5 comments
Open
Labels
bug Something isn't working training

Comments

@SparkJiao
Copy link

When combining bfloat16 with optimizer offload, I get the following error:

│ ❱ 159 │   model, optimizer, _, scheduler = deepspeed.initialize(model=model,                     │                                                                                                                                                                                                                                                     │   160 │   │   │   │   │   │   │   │   │   │   │   │   │   │     model_parameters=[p for p in m   │                                                                                                                                                                                                                                                     │   161 │   │   │   │   │   │   │   │   │   │   │   │   │   │     config=ds_config)                │                                                                                                                                                                                                                                                     │   162                                                                                            │                                                                                                                                                                                                                                                     │                                                                                                  │                                                                                                                                                                                                                                                     │ /home/fangkai/anaconda3/envs/py3.9/lib/python3.9/site-packages/deepspeed/__init__.py:186 in      │                                                                                                                                                                                                                                                     │ initialize                                                                                       │                                                                                                                                                                                                                                                     │                                                                                                  │                                                                                                                                                                                                                                                     │   183 │   │   assert mpu is None, "mpu must be None with pipeline parallelism"                   │                                                                                                                                                                                                                                                     │   184 │   │   mpu = model.mpu()                                                                  │                                                                                                                                                                                                                                                     │   185 │   │   config_class = DeepSpeedConfig(config, mpu)                                        │                                                                                                                                                                                                                                                     │ ❱ 186 │   │   engine = PipelineEngine(args=args,                                                 │                                                                                                                                                                                                                                                     │   187 │   │   │   │   │   │   │   │   model=model,                                               │                                                                                                                                                                                                                                                     │   188 │   │   │   │   │   │   │   │   optimizer=optimizer,                                       │                                                                                                                                                                                                                                                     │   189 │   │   │   │   │   │   │   │   model_parameters=model_parameters,                         │                                                                                                                                                                                                                                                     │                                                                                                  │                                                                                                                                                                                                                                                     │ /home/fangkai/anaconda3/envs/py3.9/lib/python3.9/site-packages/deepspeed/runtime/pipe/engine.py: │                                                                                                                                                                                                                                                     │ 55 in __init__                                                                                   │                                                                                                                                                                                                                                                     │                                                                                                  │                                                                                                                                                                                                                                                     │     52 │   DTYPE_TO_ID = {dtype: id_ for id_, dtype in enumerate(ID_TO_DTYPE)}                   │                                                                                                                                                                                                                                                     │     53 │                                                                                         │                                                                                                                                                                                                                                                     │     54 │   def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs):              │                                                                                                                                                                                                                                                     │ ❱   55 │   │   super().__init__(*super_args, **super_kwargs)                                     │                                                                                                                                                                                                                                                     │     56 │   │   assert isinstance(self.module, PipelineModule), "model must base PipelineModule"  │                                                                                                                                                                                                                                                     │     57 │   │                                                                                     │                                                                                                                                                                                                                                                     │     58 │   │   assert self.zero_optimization_stage() < 2, "ZeRO-2 and ZeRO-3 are incompatible w  │                                                                                                                                                                                                                                                     │                                                                                                  │                                                                                                                                                                                                                                                     │ /home/fangkai/anaconda3/envs/py3.9/lib/python3.9/site-packages/deepspeed/runtime/engine.py:310   │                                                                                                                                                                                                                                                     │ in __init__                                                                                      │                                                                                                                                                                                                                                                     │                                                                                                  │                                                                                                                                                                                                                                                     │    307 │   │   │   model_parameters = list(model_parameters)                                     │                                                                                                                                                                                                                                                     │    308 │   │                                                                                     │                                                                                                                                                                                                                                                     │    309 │   │   if has_optimizer:                                                                 │                                                                                                                                                                                                                                                     │ ❱  310 │   │   │   self._configure_optimizer(optimizer, model_parameters)                        │                                                                                                                                                                                                                                                     │    311 │   │   │   self._configure_lr_scheduler(lr_scheduler)                                    │                                                                                                                                                                                                                                                     │    312 │   │   │   self._report_progress(0)                                                      │                                                                                                                                                                                                                                                     │    313 │   │   elif self.zero_optimization():                                                    │                                                                                                                                                                                                                                                     │                                                                                                  │                                                                                                                                                                                                                                                     │ /home/fangkai/anaconda3/envs/py3.9/lib/python3.9/site-packages/deepspeed/runtime/engine.py:1220  │                                                                                                                                                                                                                                                     │ in _configure_optimizer                                                                          │                                                                                                                                                                                                                                                     │                                                                                                  │                                                                                                                                                                                                                                                     │   1217 │   │   elif optimizer_wrapper == FP16:                                                   │                                                                                                                                                                                                                                                     │   1218 │   │   │   self.optimizer = self._configure_fp16_optimizer(basic_optimizer)              │                                                                                                                                                                                                                                                     │   1219 │   │   elif optimizer_wrapper == BFLOAT16:                                               │                                                                                                                                                                                                                                                     │ ❱ 1220 │   │   │   self.optimizer = self._configure_bf16_optimizer(basic_optimizer)              │                                                                                                                                                                                                                                                     │   1221 │   │   else:                                                                             │                                                                                                                                                                                                                                                     │   1222 │   │   │   self.optimizer = basic_optimizer                                              │                                                                                                                                                                                                                                                     │   1223                                                                                           │                                                                                                                                                                                                                                                     │                                                                                                  │                                                                                                                                                                                                                                                     │ /home/fangkai/anaconda3/envs/py3.9/lib/python3.9/site-packages/deepspeed/runtime/engine.py:1400  │                                                                                                                                                                                                                                                     │ in _configure_bf16_optimizer                                                                     │                                                                                                                                                                                                                                                     │                                                                                                  │                                                                                                                                                                                                                                                     │   1397 │   │   log_dist('Creating BF16 optimizer', ranks=[0])                                    │                                                                                                                                                                                                                                                     │   1398 │   │                                                                                     │                                                                                                                                                                                                                                                     │   1399 │   │   timers = self.timers if self.wall_clock_breakdown() else None                     │                                                                                                                                                                                                                                                     │ ❱ 1400 │   │   optimizer = BF16_Optimizer(optimizer,                                             │                                                                                                                                                                                                                                                     │   1401 │   │   │   │   │   │   │   │      self.param_names,                                      │                                                                                                                                                                                                                                                     │   1402 │   │   │   │   │   │   │   │      mpu=self.mpu,                                          │                                                                                                                                                                                                                                                     │   1403 │   │   │   │   │   │   │   │      clip_grad=clip_grad,                                   │                                                                                                                                                                                                                                                     │                                                                                                  │                                                                                                                                                                                                                                                     │ /home/fangkai/anaconda3/envs/py3.9/lib/python3.9/site-packages/deepspeed/runtime/bf16_optimizer. │                                                                                                                                                                                                                                                     │ py:82 in __init__                                                                                │                                                                                                                                                                                                                                                     │                                                                                                  │
│                                                                                                  │                                                                                                                                                                                                                                             [0/1966]│    79 │   │   self.group_paddings = []                                                           │
│    80 │   │                                                                                      │
│    81 │   │   if self.using_real_optimizer:                                                      │
│ ❱  82 │   │   │   self._setup_for_real_optimizer()                                               │
│    83 │   │                                                                                      │
│    84 │   │   see_memory_usage('end bf16_optimizer', force=True)                                 │
│    85                                                                                            │
│                                                                                                  │
│ /home/fangkai/anaconda3/envs/py3.9/lib/python3.9/site-packages/deepspeed/runtime/bf16_optimizer. │
│ py:157 in _setup_for_real_optimizer                                                              │
│                                                                                                  │
│   154 │   │   │   see_memory_usage(f'after initializing group {i}', force=True)                  │
│   155 │   │                                                                                      │
│   156 │   │   see_memory_usage('before initialize_optimizer', force=True)                        │
│ ❱ 157 │   │   self.initialize_optimizer_states()                                                 │
│   158 │   │   see_memory_usage('end initialize_optimizer', force=True)                           │
│   159 │   │                                                                                      │
│   160 │   │   # Need optimizer states initialized before linking lp to optimizer state           │
│                                                                                                  │
│ /home/fangkai/anaconda3/envs/py3.9/lib/python3.9/site-packages/deepspeed/runtime/bf16_optimizer. │
│ py:210 in initialize_optimizer_states                                                            │
│                                                                                                  │
│   207 │   │   │   │   │   │   │   │   │   │   │   │      self.fp32_groups_gradient_flat_partit   │
│   208 │   │   │   param_partition.grad = grad_partition                                          │
│   209 │   │                                                                                      │
│ ❱ 210 │   │   self.optimizer.step()                                                              │
│   211 │   │                                                                                      │
│   212 │   │   self.clear_hp_grads()                                                              │
│   213                                                                                            │
│                                                                                                  │
│ /home/fangkai/anaconda3/envs/py3.9/lib/python3.9/site-packages/torch/optim/optimizer.py:280 in   │
│ wrapper                                                                                          │
│                                                                                                  │
│   277 │   │   │   │   │   │   │   raise RuntimeError(f"{func} must return None or a tuple of (   │
│   278 │   │   │   │   │   │   │   │   │   │   │      f"but got {result}.")                       │
│   279 │   │   │   │                                                                              │
│ ❱ 280 │   │   │   │   out = func(*args, **kwargs)                                                │
│   281 │   │   │   │   self._optimizer_step_code()                                                │
│   282 │   │   │   │                                                                              │
│   283 │   │   │   │   # call optimizer step post hooks                                           │
│                                                                                                  │
│ /home/fangkai/anaconda3/envs/py3.9/lib/python3.9/site-packages/torch/utils/_contextlib.py:115 in │
│ decorate_context                                                                                 │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │                                                                                                                                                                                                                                                     │   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ /home/fangkai/anaconda3/envs/py3.9/lib/python3.9/site-packages/deepspeed/ops/adam/cpu_adam.py:15 │
│ 0 in step                                                                                        │
│                                                                                                  │
│   147 │   │   │   │   if p.grad is None:                                                         │
│   148 │   │   │   │   │   continue                                                               │
│   149 │   │   │   │                                                                              │
│ ❱ 150 │   │   │   │   assert p.device == device, f"CPUAdam param is on {p.device} and must be    │
│   151 │   │   │   │   │   │   "sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config   │
│   152 │   │   │   │                                                                              │
│   153 │   │   │   │   state = self.state[p]                                                      │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AssertionError: CPUAdam param is on cuda:2 and must be 'cpu', make sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config.    

The deepspeed config is as follows:

ds_cfg:
  train_micro_batch_size_per_gpu: ${per_gpu_train_batch_size}
  gradient_accumulation_steps: ${gradient_accumulation_steps}
  optimizer:
    type: AdamW
    params:
      lr: ${learning_rate}
      betas: [ 0.9, 0.95 ]
      eps: ${adam_epsilon}
      weight_decay: ${weight_decay}
  scheduler:
    type: WarmupLR
    params:
      warmup_max_lr: ${learning_rate}
      warmup_num_steps:
      warmup_type: linear
  gradient_clipping: ${max_grad_norm}
  bf16:
    enabled: ${fp16}
  data_types:
    grad_accum_dtype: "fp32"
  zero_optimization:
    stage: 1 # https://github.com/microsoft/DeepSpeed/issues/1835#issuecomment-1175836585
    contiguous_gradients: True
    overlap_comm: True
    reduce_scatter: True
    reduce_bucket_size: 5e8
    allgather_partitions: True
    allgather_bucket_size: 5e8
    offload_optimizer:
      device: cpu
      pin_memory: Truee
  steps_per_print: 1024

And when I remove offload_optimizer subconfig, the training keeps normal. Also, when using fp16 + optimizer offload, the procedure is also normal.

@SparkJiao SparkJiao added bug Something isn't working compression labels Jul 3, 2023
@tjruwase
Copy link
Contributor

tjruwase commented Jul 3, 2023

@SparkJiao, this is correct. Offloading is not enabled for pipeline parallelism.

@SparkJiao
Copy link
Author

I see. Thanks. But is there will be any unexpected behaviour when enabling optimizer offload in pipeline parallelism? Since I didn't notice anything wrong for my current training (fp16 + optimizer offload on LLaMA-65B). Also I'm sure the possible risk why optimizer offload is designed not to be support in pipeline parallel.

@tjruwase
Copy link
Contributor

@SparkJiao, sorry for the confusion. I was trying to say the bf16_optimizer.py implementation does not include offloading feature. We have not previously tested offloading and pipeline parallelism, so I don't know if there are any issues. However, your results are promising that these two can be combined. Can you share more details of your results and what you are trying to do? Do note that the bf16_optimizer.py implementation does not include offloading at all.

@SparkJiao
Copy link
Author

@tjruwase Thanks for your reply. Currently I can successfully complete the training procedure with optimizer offload and gradient checkpointing with fp16. I just wanted to use bf16 to in the above procedure. I have created a repo for my implementation of using DeepSpeed pipeline parallelism training so you may check it here.

I think offload is necessary when you have only 8 * 80G A100 to train LLaMA-65B. (I have tried 8 stages pipeline parallel with 2 group of data parallel, so totally 16 GPU cards, but failed when optimizer offload is disabled).

@tjruwase
Copy link
Contributor

@SparkJiao, can you share simple steps to repro the failure you are seeing with bf16 + offload? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

2 participants