-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Removing batch dimension from default layout maps for Gemma and Llama #2035
base: master
Are you sure you want to change the base?
Conversation
I don't think this change is needed. As I mentioned in keras-team/keras#20674 (comment), models should not be sharded on the batch dimension and this is not a bug that we don't shard the model on the batch dimension. Only data is sharded on the batch dimension. I believe here when we provide the batch dimension in the layout, we're setting up how the data is sharded so we should keep it as is. btw, this is a good resource about data distribution: https://jax.readthedocs.io/en/latest/distributed_data_loading.html |
Could you be more explicit when you say "we don't shard the model on the batch dimension". What do you mean exactly? The code of the Gemma default layout says: |
I don't think the attention weights will be sharded on the batch dimension if we set |
Would you have a pointer to the code where this behavior is implemented? I have checked the implementation of ditribute_variable, distribute_tensor as well as the documentation but I could find no reference to the special-casing of one dimension in the LayoutMap based on its name. Quick test showing that there is no special casing for model weights and the "batch" dimension: https://www.kaggle.com/code/martingorner/keras-model-sharding-test |
num_model_replicas_total is equal to batch dimension. Because data is sharded on batch dimension, we need to replicate the model on this dimension so each shard of the model has the full replication of the model. mesh_model_dim_size is the second dimension of the mesh. And dictates how the model is sharded. In JAX, we shard the data and then computation follows the data. I recommend reading through the distribute_data_input to see how the layout is used. PS: I'll look at your Kaggle notebook tomorrow. |
The notebook just shows that when a weights tensor has a layout map of |
The number of model replicas should be 2 in this case as the batch dim is 2. But there is an issue if your notebook is showing the model is sharded 8-way in this case. We need to debug to see where the disconnect happens. |
This is a matter of opinion and API design. I actually prefer the current implementation which is more direct and where a layout of |
btw, you can hit me on chat to discuss more interactively |
I see two points here:
|
Of course sharding data on the batch dimension should be supported and it is. For model weight sharding though, the current implementation maps exactly what low-level JAX APIs do which is a good thing IMHO. See my notebook:
But maybe what you have in mind are sharding hints for outputs? This is implemented in Keras3 as |
the Kaggle notebook: https://www.kaggle.com/code/martingorner/keras-model-sharding-test/ (the latest version was showing a runtime error, now fixed) |
How about this? Why do we need to specify the model sharding layout in a 2-dimensional way? |
Or are you suggesting diverging from JAX semantics here? What would the new semantics be? |
Sounds good! Thanks, Martin! |
This is to align with Keras PR 20674 which fixes data sharding in the JAX trainer but does not support the model to be sharded on the "batch" dimension.
To be clear, before Keras PR 20674, using the "batch" dimension for sharding the model was not supported either (unless the dimension was 1). Keras PR 20674 fixes use cases where the "batch" dimension is not the first dimension in the device mesh and when model and data parallelism are used at the same time. However, the data sharding expressions it uses assume that only data is sharded on the "batch" dimension, not the model. That is why this PR removes model sharding on the "batch" dimension from default layout maps.
Sharding on the "batch" dimension was added in the Gemma default layout map by Keras-hub PR 1491. The reason why this was added is unclear.