diff --git a/flyteplugins/go/tasks/plugins/array/awsbatch/transformer.go b/flyteplugins/go/tasks/plugins/array/awsbatch/transformer.go index f978d92ef5..6b6da84fb9 100644 --- a/flyteplugins/go/tasks/plugins/array/awsbatch/transformer.go +++ b/flyteplugins/go/tasks/plugins/array/awsbatch/transformer.go @@ -25,6 +25,7 @@ import ( const ( ArrayJobIndex = "BATCH_JOB_ARRAY_INDEX_VAR_NAME" arrayJobIDFormatter = "%v:%v" + failOnError = "FLYTE_FAIL_ON_ERROR" ) const assignResources = true @@ -118,7 +119,6 @@ func UpdateBatchInputForArray(_ context.Context, batchInput *batch.SubmitJobInpu envVars = append(envVars, &batch.KeyValuePair{Name: refStr(ArrayJobIndex), Value: refStr("FAKE_JOB_ARRAY_INDEX")}, &batch.KeyValuePair{Name: refStr("FAKE_JOB_ARRAY_INDEX"), Value: refStr("0")}) } - batchInput.ArrayProperties = arrayProps batchInput.ContainerOverrides.Environment = envVars @@ -136,7 +136,7 @@ func getEnvVarsForTask(ctx context.Context, execID pluginCore.TaskExecutionID, c for key, value := range defaultEnvVars { m[key] = value } - + m[failOnError] = "true" finalEnvVars := make([]v1.EnvVar, 0, len(m)) for key, val := range m { finalEnvVars = append(finalEnvVars, v1.EnvVar{ @@ -144,7 +144,6 @@ func getEnvVarsForTask(ctx context.Context, execID pluginCore.TaskExecutionID, c Value: val, }) } - return finalEnvVars } diff --git a/flyteplugins/go/tasks/plugins/array/awsbatch/transformer_test.go b/flyteplugins/go/tasks/plugins/array/awsbatch/transformer_test.go index 29fc8022cc..473fd3bad0 100644 --- a/flyteplugins/go/tasks/plugins/array/awsbatch/transformer_test.go +++ b/flyteplugins/go/tasks/plugins/array/awsbatch/transformer_test.go @@ -130,6 +130,7 @@ func TestArrayJobToBatchInput(t *testing.T) { ContainerOverrides: &batch.ContainerOverrides{ Command: []*string{ref("cmd"), ref("/inputs/prefix")}, Environment: []*batch.KeyValuePair{ + {Name: ref(failOnError), Value: refStr("true")}, {Name: refStr("BATCH_JOB_ARRAY_INDEX_VAR_NAME"), Value: refStr("AWS_BATCH_JOB_ARRAY_INDEX")}, }, Memory: refInt(1074), @@ -237,5 +238,9 @@ func Test_getEnvVarsForTask(t *testing.T) { Name: "MyKey", Value: "MyVal", }, + { + Name: "FLYTE_FAIL_ON_ERROR", + Value: "true", + }, }, envVars) }