Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A question about parameter server training #1

Open
SearchVera opened this issue Nov 4, 2022 · 1 comment
Open

A question about parameter server training #1

SearchVera opened this issue Nov 4, 2022 · 1 comment

Comments

@SearchVera
Copy link

SearchVera commented Nov 4, 2022

Hi, your code really helps! I have one question:
In the coordinator(train_dataset_fn), you use shard to split data to each worker, but the input param(input_context.input_pipeline_id) indicates which worker index is, so I think every worker should call the function(train_dataset_fn) to get his part of data. But your code show that only the coordinator use the train_dataset_fn function.
Can you explain to me how this param(input_context.input_pipeline_id) works
thx!!!

@18520339
Copy link
Owner

18520339 commented Mar 4, 2023

Hi, sorry for replying late.

Users of ParameterServerStrategy with the Model.fit API need to use a DatasetCreator as the input. An instance of this class will be passed to fit when using a callable (with a input_context argument) that returns a tf.data.Dataset. According to TensorFlow's document:

If you instead create your dataset with tf.keras.utils.experimental.DatasetCreator, the code in dataset_fn will be invoked on the input device, which is usually the CPU, on each of the worker machines.

So Model.fit usage with DatasetCreator is intended to work across all tf.distribute.Strategy, as long as Strategy.scope is used at model creation. tf.distribute will call the input function on the CPU device of each of the workers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants