You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, your code really helps! I have one question:
In the coordinator(train_dataset_fn), you use shard to split data to each worker, but the input param(input_context.input_pipeline_id) indicates which worker index is, so I think every worker should call the function(train_dataset_fn) to get his part of data. But your code show that only the coordinator use the train_dataset_fn function.
Can you explain to me how this param(input_context.input_pipeline_id) works
thx!!!
The text was updated successfully, but these errors were encountered:
Users of ParameterServerStrategy with the Model.fit API need to use a DatasetCreator as the input. An instance of this class will be passed to fit when using a callable (with a input_context argument) that returns a tf.data.Dataset. According to TensorFlow's document:
If you instead create your dataset with tf.keras.utils.experimental.DatasetCreator, the code in dataset_fn will be invoked on the input device, which is usually the CPU, on each of the worker machines.
So Model.fit usage with DatasetCreator is intended to work across all tf.distribute.Strategy, as long as Strategy.scope is used at model creation. tf.distribute will call the input function on the CPU device of each of the workers.
Hi, your code really helps! I have one question:
In the coordinator(train_dataset_fn), you use shard to split data to each worker, but the input param(input_context.input_pipeline_id) indicates which worker index is, so I think every worker should call the function(train_dataset_fn) to get his part of data. But your code show that only the coordinator use the train_dataset_fn function.
Can you explain to me how this param(input_context.input_pipeline_id) works
thx!!!
The text was updated successfully, but these errors were encountered: