Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removing batch dimension from default layout maps for Gemma and Llama #2035

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

martin-gorner
Copy link
Contributor

This is to align with Keras PR 20674 which fixes data sharding in the JAX trainer but does not support the model to be sharded on the "batch" dimension.

To be clear, before Keras PR 20674, using the "batch" dimension for sharding the model was not supported either (unless the dimension was 1). Keras PR 20674 fixes use cases where the "batch" dimension is not the first dimension in the device mesh and when model and data parallelism are used at the same time. However, the data sharding expressions it uses assume that only data is sharded on the "batch" dimension, not the model. That is why this PR removes model sharding on the "batch" dimension from default layout maps.

Sharding on the "batch" dimension was added in the Gemma default layout map by Keras-hub PR 1491. The reason why this was added is unclear.

@github-actions github-actions bot added the Gemma Gemma model specific issues label Jan 6, 2025
@SamanehSaadat
Copy link
Member

I don't think this change is needed. As I mentioned in keras-team/keras#20674 (comment), models should not be sharded on the batch dimension and this is not a bug that we don't shard the model on the batch dimension. Only data is sharded on the batch dimension. I believe here when we provide the batch dimension in the layout, we're setting up how the data is sharded so we should keep it as is. btw, this is a good resource about data distribution: https://jax.readthedocs.io/en/latest/distributed_data_loading.html

@martin-gorner
Copy link
Contributor Author

Could you be more explicit when you say "we don't shard the model on the batch dimension". What do you mean exactly? The code of the Gemma default layout says:
layout_map["decoder_block.*attention.*(query|key|value).kernel"] = ('model', 'batch', None)
which will result in the attention weights being sharded on both the 'model' and 'batch' dims when the mesh is keras.distribution.DeviceMesh((len(devices)//2, 2), ["model", "batch"], devices), will it not?

@SamanehSaadat
Copy link
Member

Could you be more explicit when you say "we don't shard the model on the batch dimension". What do you mean exactly? The code of the Gemma default layout says: layout_map["decoder_block.*attention.*(query|key|value).kernel"] = ('model', 'batch', None) which will result in the attention weights being sharded on both the 'model' and 'batch' dims when the mesh is keras.distribution.DeviceMesh((len(devices)//2, 2), ["model", "batch"], devices), will it not?

I don't think the attention weights will be sharded on the batch dimension if we set layout_map["decoder_block.*attention.*(query|key|value).kernel"] = ('model', 'batch', None). I think attention weights will be sharded on the model dimension only and the data input to the this layer has been sharded on the batch dimension so the attention layer doesn't get all the data but only a portion of the data.

@martin-gorner
Copy link
Contributor Author

martin-gorner commented Jan 8, 2025

Would you have a pointer to the code where this behavior is implemented? I have checked the implementation of ditribute_variable, distribute_tensor as well as the documentation but I could find no reference to the special-casing of one dimension in the LayoutMap based on its name.

Quick test showing that there is no special casing for model weights and the "batch" dimension: https://www.kaggle.com/code/martingorner/keras-model-sharding-test

@SamanehSaadat
Copy link
Member

num_model_replicas_total is equal to batch dimension. Because data is sharded on batch dimension, we need to replicate the model on this dimension so each shard of the model has the full replication of the model.

mesh_model_dim_size is the second dimension of the mesh. And dictates how the model is sharded.

In JAX, we shard the data and then computation follows the data. I recommend reading through the distribute_data_input to see how the layout is used.

PS: I'll look at your Kaggle notebook tomorrow.

@martin-gorner
Copy link
Contributor Author

The notebook just shows that when a weights tensor has a layout map of ("batch", "model") on a mesh like ((2,4), ("batch", "model")), it does get split up into 8 pieces. So num_model_deplicas in that case is 1 (full sharding, no replication) but the Keras code (even with my fix) computes 2.

@SamanehSaadat
Copy link
Member

The notebook just shows that when a weights tensor has a layout map of ("batch", "model") on a mesh like ((2,4), ("batch", "model")), it does get split up into 8 pieces. So num_model_deplicas in that case is 1 (full sharding, no replication) but the Keras code (even with my fix) computes 2.

The number of model replicas should be 2 in this case as the batch dim is 2. But there is an issue if your notebook is showing the model is sharded 8-way in this case. We need to debug to see where the disconnect happens. ((2,4), ("batch", "model")) should shard the data 2-way and the model 4-way.

@martin-gorner
Copy link
Contributor Author

This is a matter of opinion and API design. I actually prefer the current implementation which is more direct and where a layout of ((2,4), ("a", "b")), applied to a specific weights tensor, shards that tensor 8 ways, no matter the names of the the mesh axes. We can add an error or warning if the users asks the model to be sharded on the batch_dim_name dimension, if we believe there is no use case for that.

@martin-gorner
Copy link
Contributor Author

btw, you can hit me on chat to discuss more interactively

@SamanehSaadat
Copy link
Member

This is a matter of opinion and API design. I actually prefer the current implementation which is more direct and where a layout of ((2,4), ("a", "b")), applied to a specific weights tensor, shards that tensor 8 ways, no matter the names of the the mesh axes. We can add an error or warning if the users asks the model to be sharded on the batch_dim_name dimension, if we believe there is no use case for that.

I see two points here:

  1. If we want to shard the model 8-way, why do it in a 2-dimensional logical mesh? Assume we have a (2, 4) or (4, 2) or (2, 2, 2) physical mesh then we just input the ((1, 8), ("batch", "model")) and the model will be sharded 8-way on any of those physical meshes. This way, we abstract away the complexities of the physical mesh from the user.
  2. I do believe allowing users to shard the data on the batch dimension should be supported and I think allowing the user to provide their desired data and model parallelism through ("batch", "model") is a nice API design (given that model doesn't need to be sharded on a 2-dimensional mesh).

@martin-gorner
Copy link
Contributor Author

Of course sharding data on the batch dimension should be supported and it is.

For model weight sharding though, the current implementation maps exactly what low-level JAX APIs do which is a good thing IMHO. See my notebook:

  • A Keras layout spec of layout_map["dense/kernel"] = ("a", "b")
    translates exactly into JAX as:
  • jax.device_put(x, jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("a", "b")

But maybe what you have in mind are sharding hints for outputs? This is implemented in Keras3 as layout_map["dense/output"] and yes, in this case it makes sense for the user to set it to something like ("batch", None)

@martin-gorner
Copy link
Contributor Author

the Kaggle notebook: https://www.kaggle.com/code/martingorner/keras-model-sharding-test/ (the latest version was showing a runtime error, now fixed)

@SamanehSaadat
Copy link
Member

  1. If we want to shard the model 8-way, why do it in a 2-dimensional logical mesh? Assume we have a (2, 4) or (4, 2) or (2, 2, 2) physical mesh then we just input the ((1, 8), ("batch", "model")) and the model will be sharded 8-way on any of those physical meshes. This way, we abstract away the complexities of the physical mesh from the user.

How about this? Why do we need to specify the model sharding layout in a 2-dimensional way?

@martin-gorner
Copy link
Contributor Author

  1. Not sure what the layout map is for Gemma's attention weights in your proposal. Assuming it is ('model', 'batch'), then yes, when the "batch" dimension is 1, using 'batch' in the layout map does not matter. But if people want to do data and model parallelism by setting a logical mesh of ((2,4), ('batch', 'model')), then, with a layout of ('model', 'batch') the current implementation will shard the attention weights 8-way (which is what was specified according to JAX PartitionSpec and NamedSharding semantics, but obviously not a correct outcome).
  2. Also, mesh=((1, 8), ("batch", "model")) and attention weights sharding spec ('model', 'batch') is one way of sharding a model 8-way. In the JAX API, it is not the only way. mesh=((2, 4), ('a', 'b')) and a sharding spec of ('a', 'b') is another way of expressing 8-way weights sharding. And since the Keras layout map implementation follows the JAX PartitionSpec and NamedSharding exacly, these settings should also make sense and mean the same thing in Keras.

Or are you suggesting diverging from JAX semantics here? What would the new semantics be?

@SamanehSaadat
Copy link
Member

Sounds good! Thanks, Martin!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants