diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/container_helper.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/container_helper.go index d0cbe62b07..fdaec7256c 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/container_helper.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/container_helper.go @@ -129,11 +129,24 @@ func adjustResourceRequirement(resourceName v1.ResourceName, resourceRequirement resourceRequirements.Limits[resourceName] = resourceValue.Limit } +// Convert GPU resource requirements named 'gpu' the recognized 'nvidia.com/gpu' identifier. +func SanitizeGPUResourceRequirements(resources *v1.ResourceRequirements) { + gpuResourceName := config.GetK8sPluginConfig().GpuResourceName + + if res, found := resources.Requests[resourceGPU]; found { + resources.Requests[gpuResourceName] = res + delete(resources.Requests, resourceGPU) + } + + if res, found := resources.Limits[resourceGPU]; found { + resources.Limits[gpuResourceName] = res + delete(resources.Limits, resourceGPU) + } +} + // ApplyResourceOverrides handles resource resolution, allocation and validation. Primarily, it ensures that container // resources do not exceed defined platformResource limits and in the case of assignIfUnset, ensures that limits and // requests are sensibly set for resources of all types. -// Furthermore, this function handles some clean-up such as converting GPU resources to the recognized Nvidia gpu -// resource name and deleting unsupported Storage-type resources. func ApplyResourceOverrides(resources, platformResources v1.ResourceRequirements, assignIfUnset bool) v1.ResourceRequirements { if len(resources.Requests) == 0 { resources.Requests = make(v1.ResourceList) @@ -169,19 +182,6 @@ func ApplyResourceOverrides(resources, platformResources v1.ResourceRequirements shouldAdjustGPU = true } - // Override GPU - if res, found := resources.Requests[resourceGPU]; found { - resources.Requests[gpuResourceName] = res - delete(resources.Requests, resourceGPU) - shouldAdjustGPU = true - } - - if res, found := resources.Limits[resourceGPU]; found { - resources.Limits[gpuResourceName] = res - delete(resources.Limits, resourceGPU) - shouldAdjustGPU = true - } - if shouldAdjustGPU { adjustResourceRequirement(gpuResourceName, resources, platformResources, assignIfUnset) } @@ -308,6 +308,8 @@ func AddFlyteCustomizationsToContainer(ctx context.Context, parameters template. overrideResources = &v1.ResourceRequirements{} } + SanitizeGPUResourceRequirements(&container.Resources) + logger.Infof(ctx, "ApplyResourceOverrides with Resources [%v], Platform Resources [%v] and Container"+ " Resources [%v] with mode [%v]", overrideResources, platformResources, container.Resources, mode) diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/container_helper_test.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/container_helper_test.go index d48b99ab39..ece0f724c4 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/container_helper_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/container_helper_test.go @@ -237,19 +237,37 @@ func TestApplyResourceOverrides_OverrideGpu(t *testing.T) { gpuRequest := resource.MustParse("1") overrides := ApplyResourceOverrides(v1.ResourceRequirements{ Requests: v1.ResourceList{ - resourceGPU: gpuRequest, + ResourceNvidiaGPU: gpuRequest, }, }, v1.ResourceRequirements{}, assignIfUnset) assert.EqualValues(t, gpuRequest, overrides.Requests[ResourceNvidiaGPU]) overrides = ApplyResourceOverrides(v1.ResourceRequirements{ Limits: v1.ResourceList{ - resourceGPU: gpuRequest, + ResourceNvidiaGPU: gpuRequest, }, }, v1.ResourceRequirements{}, assignIfUnset) assert.EqualValues(t, gpuRequest, overrides.Limits[ResourceNvidiaGPU]) } +func TestSanitizeGPUResourceRequirements(t *testing.T) { + gpuRequest := resource.MustParse("4") + requirements := v1.ResourceRequirements{ + Requests: v1.ResourceList{ + resourceGPU: gpuRequest, + }, + } + + expectedRequirements := v1.ResourceRequirements{ + Requests: v1.ResourceList{ + ResourceNvidiaGPU: gpuRequest, + }, + } + + SanitizeGPUResourceRequirements(&requirements) + assert.EqualValues(t, expectedRequirements, requirements) +} + func TestMergeResources_EmptyIn(t *testing.T) { requestedResourceQuantity := resource.MustParse("1") expectedResources := v1.ResourceRequirements{ @@ -602,6 +620,42 @@ func TestAddFlyteCustomizationsToContainer_Resources(t *testing.T) { assert.True(t, container.Resources.Requests.Memory().Equal(resource.MustParse("2"))) assert.True(t, container.Resources.Limits.Memory().Equal(resource.MustParse("2"))) }) + t.Run("ensure gpu resource overriding works for tasks with pod templates", func(t *testing.T) { + container := &v1.Container{ + Command: []string{ + "{{ .Input }}", + }, + Args: []string{ + "{{ .OutputPrefix }}", + }, + Resources: v1.ResourceRequirements{ + Requests: v1.ResourceList{ + resourceGPU: resource.MustParse("2"), // Tasks with pod templates request resource via the "gpu" key + }, + Limits: v1.ResourceList{ + resourceGPU: resource.MustParse("2"), + }, + }, + } + + overrideRequests := v1.ResourceList{ + ResourceNvidiaGPU: resource.MustParse("4"), // Resource overrides specify the "nvidia.com/gpu" key + } + + overrideLimits := v1.ResourceList{ + ResourceNvidiaGPU: resource.MustParse("4"), + } + + templateParameters := getTemplateParametersForTest(&v1.ResourceRequirements{ + Requests: overrideRequests, + Limits: overrideLimits, + }, &v1.ResourceRequirements{}) + + err := AddFlyteCustomizationsToContainer(context.TODO(), templateParameters, ResourceCustomizationModeMergeExistingResources, container) + assert.NoError(t, err) + assert.Equal(t, container.Resources.Requests[ResourceNvidiaGPU], overrideRequests[ResourceNvidiaGPU]) + assert.Equal(t, container.Resources.Limits[ResourceNvidiaGPU], overrideLimits[ResourceNvidiaGPU]) + }) } func TestAddFlyteCustomizationsToContainer_ValidateExistingResources(t *testing.T) { diff --git a/flyteplugins/go/tasks/plugins/array/awsbatch/transformer.go b/flyteplugins/go/tasks/plugins/array/awsbatch/transformer.go index c75ca61c42..68dc88c883 100644 --- a/flyteplugins/go/tasks/plugins/array/awsbatch/transformer.go +++ b/flyteplugins/go/tasks/plugins/array/awsbatch/transformer.go @@ -95,6 +95,8 @@ func FlyteTaskToBatchInput(ctx context.Context, tCtx pluginCore.TaskExecutionCon if platformResources == nil { platformResources = &v1.ResourceRequirements{} } + + flytek8s.SanitizeGPUResourceRequirements(res) resources := flytek8s.ApplyResourceOverrides(*res, *platformResources, assignResources) submitJobInput := &batch.SubmitJobInput{}