Skip to content

Commit

Permalink
Fix broken gpu resource override when using pod templates (#4925)
Browse files Browse the repository at this point in the history
* Fix broken gpu resource override when using pod templates

Signed-off-by: Fabio Graetz <[email protected]>

* Adapt existing tests and add test that would have caught bug

Signed-off-by: Fabio Graetz <[email protected]>

---------

Signed-off-by: Fabio Graetz <[email protected]>
  • Loading branch information
fg91 authored Apr 26, 2024
1 parent c77e2c3 commit 9853abe
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 17 deletions.
32 changes: 17 additions & 15 deletions flyteplugins/go/tasks/pluginmachinery/flytek8s/container_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,24 @@ func adjustResourceRequirement(resourceName v1.ResourceName, resourceRequirement
resourceRequirements.Limits[resourceName] = resourceValue.Limit
}

// Convert GPU resource requirements named 'gpu' the recognized 'nvidia.com/gpu' identifier.
func SanitizeGPUResourceRequirements(resources *v1.ResourceRequirements) {
gpuResourceName := config.GetK8sPluginConfig().GpuResourceName

if res, found := resources.Requests[resourceGPU]; found {
resources.Requests[gpuResourceName] = res
delete(resources.Requests, resourceGPU)
}

if res, found := resources.Limits[resourceGPU]; found {
resources.Limits[gpuResourceName] = res
delete(resources.Limits, resourceGPU)
}
}

// ApplyResourceOverrides handles resource resolution, allocation and validation. Primarily, it ensures that container
// resources do not exceed defined platformResource limits and in the case of assignIfUnset, ensures that limits and
// requests are sensibly set for resources of all types.
// Furthermore, this function handles some clean-up such as converting GPU resources to the recognized Nvidia gpu
// resource name and deleting unsupported Storage-type resources.
func ApplyResourceOverrides(resources, platformResources v1.ResourceRequirements, assignIfUnset bool) v1.ResourceRequirements {
if len(resources.Requests) == 0 {
resources.Requests = make(v1.ResourceList)
Expand Down Expand Up @@ -169,19 +182,6 @@ func ApplyResourceOverrides(resources, platformResources v1.ResourceRequirements
shouldAdjustGPU = true
}

// Override GPU
if res, found := resources.Requests[resourceGPU]; found {
resources.Requests[gpuResourceName] = res
delete(resources.Requests, resourceGPU)
shouldAdjustGPU = true
}

if res, found := resources.Limits[resourceGPU]; found {
resources.Limits[gpuResourceName] = res
delete(resources.Limits, resourceGPU)
shouldAdjustGPU = true
}

if shouldAdjustGPU {
adjustResourceRequirement(gpuResourceName, resources, platformResources, assignIfUnset)
}
Expand Down Expand Up @@ -308,6 +308,8 @@ func AddFlyteCustomizationsToContainer(ctx context.Context, parameters template.
overrideResources = &v1.ResourceRequirements{}
}

SanitizeGPUResourceRequirements(&container.Resources)

logger.Infof(ctx, "ApplyResourceOverrides with Resources [%v], Platform Resources [%v] and Container"+
" Resources [%v] with mode [%v]", overrideResources, platformResources, container.Resources, mode)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,19 +237,37 @@ func TestApplyResourceOverrides_OverrideGpu(t *testing.T) {
gpuRequest := resource.MustParse("1")
overrides := ApplyResourceOverrides(v1.ResourceRequirements{
Requests: v1.ResourceList{
resourceGPU: gpuRequest,
ResourceNvidiaGPU: gpuRequest,
},
}, v1.ResourceRequirements{}, assignIfUnset)
assert.EqualValues(t, gpuRequest, overrides.Requests[ResourceNvidiaGPU])

overrides = ApplyResourceOverrides(v1.ResourceRequirements{
Limits: v1.ResourceList{
resourceGPU: gpuRequest,
ResourceNvidiaGPU: gpuRequest,
},
}, v1.ResourceRequirements{}, assignIfUnset)
assert.EqualValues(t, gpuRequest, overrides.Limits[ResourceNvidiaGPU])
}

func TestSanitizeGPUResourceRequirements(t *testing.T) {
gpuRequest := resource.MustParse("4")
requirements := v1.ResourceRequirements{
Requests: v1.ResourceList{
resourceGPU: gpuRequest,
},
}

expectedRequirements := v1.ResourceRequirements{
Requests: v1.ResourceList{
ResourceNvidiaGPU: gpuRequest,
},
}

SanitizeGPUResourceRequirements(&requirements)
assert.EqualValues(t, expectedRequirements, requirements)
}

func TestMergeResources_EmptyIn(t *testing.T) {
requestedResourceQuantity := resource.MustParse("1")
expectedResources := v1.ResourceRequirements{
Expand Down Expand Up @@ -602,6 +620,42 @@ func TestAddFlyteCustomizationsToContainer_Resources(t *testing.T) {
assert.True(t, container.Resources.Requests.Memory().Equal(resource.MustParse("2")))
assert.True(t, container.Resources.Limits.Memory().Equal(resource.MustParse("2")))
})
t.Run("ensure gpu resource overriding works for tasks with pod templates", func(t *testing.T) {
container := &v1.Container{
Command: []string{
"{{ .Input }}",
},
Args: []string{
"{{ .OutputPrefix }}",
},
Resources: v1.ResourceRequirements{
Requests: v1.ResourceList{
resourceGPU: resource.MustParse("2"), // Tasks with pod templates request resource via the "gpu" key
},
Limits: v1.ResourceList{
resourceGPU: resource.MustParse("2"),
},
},
}

overrideRequests := v1.ResourceList{
ResourceNvidiaGPU: resource.MustParse("4"), // Resource overrides specify the "nvidia.com/gpu" key
}

overrideLimits := v1.ResourceList{
ResourceNvidiaGPU: resource.MustParse("4"),
}

templateParameters := getTemplateParametersForTest(&v1.ResourceRequirements{
Requests: overrideRequests,
Limits: overrideLimits,
}, &v1.ResourceRequirements{})

err := AddFlyteCustomizationsToContainer(context.TODO(), templateParameters, ResourceCustomizationModeMergeExistingResources, container)
assert.NoError(t, err)
assert.Equal(t, container.Resources.Requests[ResourceNvidiaGPU], overrideRequests[ResourceNvidiaGPU])
assert.Equal(t, container.Resources.Limits[ResourceNvidiaGPU], overrideLimits[ResourceNvidiaGPU])
})
}

func TestAddFlyteCustomizationsToContainer_ValidateExistingResources(t *testing.T) {
Expand Down
2 changes: 2 additions & 0 deletions flyteplugins/go/tasks/plugins/array/awsbatch/transformer.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ func FlyteTaskToBatchInput(ctx context.Context, tCtx pluginCore.TaskExecutionCon
if platformResources == nil {
platformResources = &v1.ResourceRequirements{}
}

flytek8s.SanitizeGPUResourceRequirements(res)
resources := flytek8s.ApplyResourceOverrides(*res, *platformResources, assignResources)

submitJobInput := &batch.SubmitJobInput{}
Expand Down

0 comments on commit 9853abe

Please sign in to comment.