Authors:
- @bgedik
- @eapolinario
- @fg91
Flyte can schedule distributed training jobs leverging e.g. the kubeflow training operator and its PyTorchJob
, TFJob
, MPIJob
, ...
For these distributed jobs, multiple Kubernetes pods are launched. Any of these worker pods can crash, causing all other worker pods in the distributed job to fail subsequently because one worker disappeared.
Error propagation, in Flyte, happens by the pod entrypoint uploading a file called error.pb
to blob storage which contains (among other things) the error message and the information whether the error is retriable.
In a distributed training job, all worker pods currently try to create the same error.pb
file in blob storage - leading to a race condition. It is not guaranteed that the root-cause error is the one being reported to the user and used to determine whether the task can be retried. In fact, the current behavior typically results in the worst outcome, as the latter errors override the former ones, which is the exact opposite of the desired behavior of identifying the first error as the root cause.
- As a Flyte user trying to understand why a distributed training task failed, I currently cannot rely on the error reported in the Flyte Console (UI) being the root cause error.
- Instead, I have to search the logs of each worker pod. For distributed training jobs with dozens or even hundreds of worker pods, this can be tedious.
- (Current remedies include combining all worker pods in stackdriver logs using a wildcard in the pod name and then filtering by severity.)
- As a Flyte user marking specific errors that can occur in distributed training jobs as retriable (using a
FlyteRecoverableException
), I want Flyte to deterministically determine the root cause error so that the retry behaviour does not suffer from a race condition.
For distributed training tasks, the pod entrypoint pyflyte-execute
must not upload a file called error.pb
(which is the same for all worker pods) but instead choose a file name which differs for each worker pod. We propose to simply include the pod name in the error-<pod name>.pb
. This prevents the race condition and has the added benefit that with this information displayed in the error message in the UI, it is easy to inspect the problematic pod.
Open questions:
- How does the pod entrypoint determine that a specific task plugin requires separate error files for each worker pod?
- One of the task base classes, e.g.
PythonFunctionTask
(at which level should we do this?) could have an attributeis_distributed
which is set toFalse
. Distributed training tasks which inherit from this base class would overwrite this attribute to true. The entrypoint would check this attribute to determine the correct naming of the error file.
- One of the task base classes, e.g.
- Should we not configure this in the task classes in flytekit but inject this information e.g. via an env var in
flyteplugins
(backend)?
When a distributed training job dies, one of the worker pods often dies due to a certain root-cause error. The other worker pods subsequently crash because one of the workers disappeared. We are interested in the root-cause error, not the error that one worker disappeared.
We propose to use the timestamp of the exception as a proxy to determine the root-cause. Pytorch distributed, for instance, raises a ChildFailedError exception which contains a so-called ProcessFailure which contains the exception timestamp. The flytekit pytorch elastic plugin catches ChildFailedError
s here, would extract the timestamp, and re-raise it as a Flyte exception which contains a timestamp.
We propose to make it possible to include timestamps in the error proto message reported by flytekit to the backend.
Open questions:
- Where exactly do we include the time stamp?
message Error
message ErrorDocument
message ContainerError
(I tend to include the timestamp here.)
- At which level in the flytekit exceptions do we include the timestamp? Do we need system-scoped and user-scoped exceptions with timestamps? Do we need recoverable and non-recoverable exceptions with a timestamp?
Currently, here in the plugin manager, upon completion of a node execution, a new RemoteFileOutputReader
is constructed which is responsible for reading the error file uploaded to blob storage. This RemoteFileOutputReader
implements the OutputReader
interface.
We propose to implement a new MultiErrorFileRemoteFileOutputReader
which (for future flexibility) can be configured with different policies the determine which of multiple errors to report downstream. Initially, the only available policy is "earliest".
Open questions:
-
How do we configure for distributed plugins to use this new
MultiErrorFileRemoteFileOutputReader
reader instead of the default one?- We could add a
MultipleErrorFiles
property toPluginProperties
(see ). The PyTorch plugin, for instance, would then passtrue
forMultipleErrorFiles
here.
Currently, here in the plugin manager, where we call
NewRemoteFileOutputReader
, we do have access toe.plugin
, and thus toPluginProperties
and could make use of that information to instantiate another output reader.- Could we alternatively add an
OutputReader
to thePluginContext
? Where would we customize this plugin context for e.g. the kubeflow plugins?
- We could add a
What are the main metrics we should be measuring? For example, when interacting with an external system, it might be the external system latency. When adding a new table, how fast would it fill up?
We don't see any drawbacks to making the error handling of distributed training tasks deterministic.
A poor man's version would be to not override the error file if it already exists. While this is a worse solution than proposed above as there still is a race condition, this would still better than the current behavior because at least we would favor earlier errors instead of later ones.
The authors of this RFC have experience with pytorch (elastic and non-elastic) distributed training jobs. Are there any community members which have experience with mpi jobs or tenserflow jobs which we can include in the discussions?
Are there any problems regarding backwards compatibility? What happens when the flytekit and distributed task plugin version do not upload multiple error files but the backend expects multiple ones (and vice versa)?
With ML models getting bigger and bigger, distributed training jobs become increasingly important to the Flyte community. Removing the race condition outlined above from Flyte's error handling for such jobs will significantly improve the UX because we will be able to determine recoverability and report the root-cause error in the Flyte UI in a deterministic way.