diff --git a/agent/api/container/container.go b/agent/api/container/container.go index cbd4a312374..e07f0a174cf 100644 --- a/agent/api/container/container.go +++ b/agent/api/container/container.go @@ -15,8 +15,10 @@ package container import ( "encoding/json" + "errors" "fmt" "strconv" + "strings" "sync" "time" @@ -1325,6 +1327,44 @@ func (c *Container) UpdateManagedAgentSentStatus(agentName string, status apicon return false } +// RequiresCredentialSpec checks if container needs a credentialspec resource +func (c *Container) RequiresCredentialSpec() bool { + credSpec, err := c.getCredentialSpec() + if err != nil || credSpec == "" { + return false + } + + return true +} + +// GetCredentialSpec is used to retrieve the current credentialspec resource +func (c *Container) GetCredentialSpec() (string, error) { + return c.getCredentialSpec() +} + +func (c *Container) getCredentialSpec() (string, error) { + c.lock.RLock() + defer c.lock.RUnlock() + + if c.DockerConfig.HostConfig == nil { + return "", errors.New("empty container hostConfig") + } + + hostConfig := &dockercontainer.HostConfig{} + err := json.Unmarshal([]byte(*c.DockerConfig.HostConfig), hostConfig) + if err != nil || len(hostConfig.SecurityOpt) == 0 { + return "", errors.New("unable to obtain security options from container hostConfig") + } + + for _, opt := range hostConfig.SecurityOpt { + if strings.HasPrefix(opt, "credentialspec") { + return opt, nil + } + } + + return "", errors.New("unable to obtain credentialspec") +} + func (c *Container) GetManagedAgentStatus(agentName string) apicontainerstatus.ManagedAgentStatus { c.lock.RLock() defer c.lock.RUnlock() diff --git a/agent/api/container/container_test.go b/agent/api/container/container_test.go index 041532883ec..bfbfcdcf9ea 100644 --- a/agent/api/container/container_test.go +++ b/agent/api/container/container_test.go @@ -970,3 +970,113 @@ func TestUpdateManagedAgentSentStatus(t *testing.T) { }) } } + +func TestRequiresCredentialSpec(t *testing.T) { + testCases := []struct { + name string + container *Container + expectedOutput bool + }{ + { + name: "hostconfig_nil", + container: &Container{}, + expectedOutput: false, + }, + { + name: "invalid_case", + container: getContainer("invalid"), + expectedOutput: false, + }, + { + name: "empty_sec_opt", + container: getContainer("{\"NetworkMode\":\"bridge\"}"), + expectedOutput: false, + }, + { + name: "missing_credentialspec", + container: getContainer("{\"SecurityOpt\": [\"invalid-sec-opt\"]}"), + expectedOutput: false, + }, + { + name: "valid_credentialspec_file", + container: getContainer("{\"SecurityOpt\": [\"credentialspec:file://gmsa_gmsa-acct.json\"]}"), + expectedOutput: true, + }, + { + name: "valid_credentialspec_s3", + container: getContainer("{\"SecurityOpt\": [\"credentialspec:arn:aws:s3:::${BucketName}/${ObjectName}\"]}"), + expectedOutput: true, + }, + { + name: "valid_credentialspec_ssm", + container: getContainer("{\"SecurityOpt\": [\"credentialspec:arn:aws:ssm:region:aws_account_id:parameter/parameter_name\"]}"), + expectedOutput: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expectedOutput, tc.container.RequiresCredentialSpec()) + }) + } +} + +func TestGetCredentialSpecErr(t *testing.T) { + testCases := []struct { + name string + container *Container + expectedOutputString string + expectedErrorString string + }{ + { + name: "hostconfig_nil", + container: &Container{}, + expectedOutputString: "", + expectedErrorString: "empty container hostConfig", + }, + { + name: "invalid_case", + container: getContainer("invalid"), + expectedOutputString: "", + expectedErrorString: "unable to obtain security options from container hostConfig", + }, + { + name: "empty_sec_opt", + container: getContainer("{\"NetworkMode\":\"bridge\"}"), + expectedOutputString: "", + expectedErrorString: "unable to obtain security options from container hostConfig", + }, + { + name: "missing_credentialspec", + container: getContainer("{\"SecurityOpt\": [\"invalid-sec-opt\"]}"), + expectedOutputString: "", + expectedErrorString: "unable to obtain credentialspec", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + expectedOutputStr, err := tc.container.GetCredentialSpec() + assert.Equal(t, tc.expectedOutputString, expectedOutputStr) + assert.EqualError(t, err, tc.expectedErrorString) + }) + } +} + +func TestGetCredentialSpecHappyPath(t *testing.T) { + c := getContainer("{\"SecurityOpt\": [\"credentialspec:file://gmsa_gmsa-acct.json\"]}") + + expectedCredentialSpec := "credentialspec:file://gmsa_gmsa-acct.json" + + credentialspec, err := c.GetCredentialSpec() + assert.NoError(t, err) + assert.EqualValues(t, expectedCredentialSpec, credentialspec) +} + +func getContainer(hostConfig string) *Container { + c := &Container{ + Name: "c", + } + c.DockerConfig.HostConfig = &hostConfig + return c +} diff --git a/agent/api/container/container_unix.go b/agent/api/container/container_unix.go index defdaf30970..3bbd4a40815 100644 --- a/agent/api/container/container_unix.go +++ b/agent/api/container/container_unix.go @@ -16,22 +16,8 @@ package container -import ( - "github.com/pkg/errors" -) - const ( // DockerContainerMinimumMemoryInBytes is the minimum amount of // memory to be allocated to a docker container DockerContainerMinimumMemoryInBytes = 4 * 1024 * 1024 // 4MB ) - -// RequiresCredentialSpec checks if container needs a credentialspec resource -func (c *Container) RequiresCredentialSpec() bool { - return false -} - -// GetCredentialSpec is used to retrieve the current credentialspec resource -func (c *Container) GetCredentialSpec() (string, error) { - return "", errors.New("unsupported platform") -} diff --git a/agent/api/container/container_windows.go b/agent/api/container/container_windows.go index f481aaabeab..49b3eeb36a9 100644 --- a/agent/api/container/container_windows.go +++ b/agent/api/container/container_windows.go @@ -16,54 +16,8 @@ package container -import ( - "encoding/json" - "strings" - - dockercontainer "github.com/docker/docker/api/types/container" - "github.com/pkg/errors" -) - const ( // DockerContainerMinimumMemoryInBytes is the minimum amount of // memory to be allocated to a docker container DockerContainerMinimumMemoryInBytes = 256 * 1024 * 1024 // 256MB ) - -// RequiresCredentialSpec checks if container needs a credentialspec resource -func (c *Container) RequiresCredentialSpec() bool { - credSpec, err := c.getCredentialSpec() - if err != nil || credSpec == "" { - return false - } - - return true -} - -// GetCredentialSpec is used to retrieve the current credentialspec resource -func (c *Container) GetCredentialSpec() (string, error) { - return c.getCredentialSpec() -} - -func (c *Container) getCredentialSpec() (string, error) { - c.lock.RLock() - defer c.lock.RUnlock() - - if c.DockerConfig.HostConfig == nil { - return "", errors.New("empty container hostConfig") - } - - hostConfig := &dockercontainer.HostConfig{} - err := json.Unmarshal([]byte(*c.DockerConfig.HostConfig), hostConfig) - if err != nil || len(hostConfig.SecurityOpt) == 0 { - return "", errors.New("unable to obtain security options from container hostConfig") - } - - for _, opt := range hostConfig.SecurityOpt { - if strings.HasPrefix(opt, "credentialspec") { - return opt, nil - } - } - - return "", errors.New("unable to obtain credentialspec") -} diff --git a/agent/api/container/container_windows_test.go b/agent/api/container/container_windows_test.go deleted file mode 100644 index 669019ad850..00000000000 --- a/agent/api/container/container_windows_test.go +++ /dev/null @@ -1,133 +0,0 @@ -//go:build windows && unit -// +build windows,unit - -// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may -// not use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -// express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -package container - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestRequiresCredentialSpec(t *testing.T) { - testCases := []struct { - name string - container *Container - expectedOutput bool - }{ - { - name: "hostconfig_nil", - container: &Container{}, - expectedOutput: false, - }, - { - name: "invalid_case", - container: getContainer("invalid"), - expectedOutput: false, - }, - { - name: "empty_sec_opt", - container: getContainer("{\"NetworkMode\":\"bridge\"}"), - expectedOutput: false, - }, - { - name: "missing_credentialspec", - container: getContainer("{\"SecurityOpt\": [\"invalid-sec-opt\"]}"), - expectedOutput: false, - }, - { - name: "valid_credentialspec_file", - container: getContainer("{\"SecurityOpt\": [\"credentialspec:file://gmsa_gmsa-acct.json\"]}"), - expectedOutput: true, - }, - { - name: "valid_credentialspec_s3", - container: getContainer("{\"SecurityOpt\": [\"credentialspec:arn:aws:s3:::${BucketName}/${ObjectName}\"]}"), - expectedOutput: true, - }, - { - name: "valid_credentialspec_ssm", - container: getContainer("{\"SecurityOpt\": [\"credentialspec:arn:aws:ssm:region:aws_account_id:parameter/parameter_name\"]}"), - expectedOutput: true, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - assert.Equal(t, tc.expectedOutput, tc.container.RequiresCredentialSpec()) - }) - } -} - -func TestGetCredentialSpecErr(t *testing.T) { - testCases := []struct { - name string - container *Container - expectedOutputString string - expectedErrorString string - }{ - { - name: "hostconfig_nil", - container: &Container{}, - expectedOutputString: "", - expectedErrorString: "empty container hostConfig", - }, - { - name: "invalid_case", - container: getContainer("invalid"), - expectedOutputString: "", - expectedErrorString: "unable to obtain security options from container hostConfig", - }, - { - name: "empty_sec_opt", - container: getContainer("{\"NetworkMode\":\"bridge\"}"), - expectedOutputString: "", - expectedErrorString: "unable to obtain security options from container hostConfig", - }, - { - name: "missing_credentialspec", - container: getContainer("{\"SecurityOpt\": [\"invalid-sec-opt\"]}"), - expectedOutputString: "", - expectedErrorString: "unable to obtain credentialspec", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - expectedOutputStr, err := tc.container.GetCredentialSpec() - assert.Equal(t, tc.expectedOutputString, expectedOutputStr) - assert.EqualError(t, err, tc.expectedErrorString) - }) - } -} - -func TestGetCredentialSpecHappyPath(t *testing.T) { - c := getContainer("{\"SecurityOpt\": [\"credentialspec:file://gmsa_gmsa-acct.json\"]}") - - expectedCredentialSpec := "credentialspec:file://gmsa_gmsa-acct.json" - - credentialspec, err := c.GetCredentialSpec() - assert.NoError(t, err) - assert.EqualValues(t, expectedCredentialSpec, credentialspec) -} - -func getContainer(hostConfig string) *Container { - c := &Container{ - Name: "c", - } - c.DockerConfig.HostConfig = &hostConfig - return c -} diff --git a/agent/api/task/task.go b/agent/api/task/task.go index a0a026f8732..7b6bbdf2cd0 100644 --- a/agent/api/task/task.go +++ b/agent/api/task/task.go @@ -25,9 +25,9 @@ import ( "time" "github.com/aws/amazon-ecs-agent/agent/api/serviceconnect" - "github.com/aws/amazon-ecs-agent/agent/logger" "github.com/aws/amazon-ecs-agent/agent/logger/field" + "github.com/aws/amazon-ecs-agent/agent/taskresource/credentialspec" "github.com/aws/amazon-ecs-agent/agent/utils/ttime" "github.com/aws/aws-sdk-go/aws" "github.com/docker/docker/api/types" @@ -2784,6 +2784,38 @@ func (task *Task) AddResource(resourceType string, resource taskresource.TaskRes task.ResourcesMapUnsafe[resourceType] = append(task.ResourcesMapUnsafe[resourceType], resource) } +// requiresCredentialSpecResource returns true if at least one container in the task +// needs a valid credentialspec resource +func (task *Task) requiresCredentialSpecResource() bool { + for _, container := range task.Containers { + if container.RequiresCredentialSpec() { + return true + } + } + return false +} + +// GetCredentialSpecResource retrieves credentialspec resource from resource map +func (task *Task) GetCredentialSpecResource() ([]taskresource.TaskResource, bool) { + task.lock.RLock() + defer task.lock.RUnlock() + + res, ok := task.ResourcesMapUnsafe[credentialspec.ResourceName] + return res, ok +} + +// getAllCredentialSpecRequirements is used to build all the credential spec requirements for the task +func (task *Task) getAllCredentialSpecRequirements() map[string]string { + reqsContainerMap := make(map[string]string) + for _, container := range task.Containers { + credentialSpec, err := container.GetCredentialSpec() + if err == nil && credentialSpec != "" { + reqsContainerMap[credentialSpec] = container.Name + } + } + return reqsContainerMap +} + // SetTerminalReason sets the terminalReason string and this can only be set // once per the task's lifecycle. This field does not accept updates. func (task *Task) SetTerminalReason(reason string) { diff --git a/agent/api/task/task_linux.go b/agent/api/task/task_linux.go index 7113f504914..ce1167dec3f 100644 --- a/agent/api/task/task_linux.go +++ b/agent/api/task/task_linux.go @@ -239,23 +239,15 @@ func (task *Task) dockerCPUShares(containerCPU uint) int64 { return int64(containerCPU) } -// requiresCredentialSpecResource returns true if at least one container in the task -// needs a valid credentialspec resource -func (task *Task) requiresCredentialSpecResource() bool { - return false -} - // initializeCredentialSpecResource builds the resource dependency map for the credentialspec resource func (task *Task) initializeCredentialSpecResource(config *config.Config, credentialsManager credentials.Manager, resourceFields *taskresource.ResourceFields) error { + //TBD: Add code to support gMSA on linux + credspecContainerMapping := task.getAllCredentialSpecRequirements() + seelog.Info(credspecContainerMapping) return errors.New("task credentialspec is only supported on windows") } -// GetCredentialSpecResource retrieves credentialspec resource from resource map -func (task *Task) GetCredentialSpecResource() ([]taskresource.TaskResource, bool) { - return []taskresource.TaskResource{}, false -} - func enableIPv6SysctlSetting(hostConfig *dockercontainer.HostConfig) { if hostConfig.Sysctls == nil { hostConfig.Sysctls = make(map[string]string) diff --git a/agent/api/task/task_test.go b/agent/api/task/task_test.go index edd70f84e71..9d56c13c8b4 100644 --- a/agent/api/task/task_test.go +++ b/agent/api/task/task_test.go @@ -27,6 +27,7 @@ import ( "time" "github.com/aws/amazon-ecs-agent/agent/api/serviceconnect" + "github.com/aws/amazon-ecs-agent/agent/taskresource/credentialspec" "github.com/docker/go-connections/nat" @@ -4374,3 +4375,123 @@ func TestTaskWithoutServiceConnectAttachment(t *testing.T) { assert.Equal(t, BridgeNetworkMode, task.NetworkMode) assert.Nil(t, task.ServiceConnectConfig, "Should be no service connect config") } + +func TestRequiresCredentialSpecResource(t *testing.T) { + container1 := &apicontainer.Container{} + task1 := &Task{ + Arn: "test", + Containers: []*apicontainer.Container{container1}, + } + + hostConfig := "{\"SecurityOpt\": [\"credentialspec:file://gmsa_gmsa-acct.json\"]}" + container2 := &apicontainer.Container{} + container2.DockerConfig.HostConfig = &hostConfig + task2 := &Task{ + Arn: "test", + Containers: []*apicontainer.Container{container2}, + } + + testCases := []struct { + name string + task *Task + expectedOutput bool + }{ + { + name: "missing_credentialspec", + task: task1, + expectedOutput: false, + }, + { + name: "valid_credentialspec", + task: task2, + expectedOutput: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expectedOutput, tc.task.requiresCredentialSpecResource()) + }) + } + +} + +func TestGetAllCredentialSpecRequirements(t *testing.T) { + hostConfig := "{\"SecurityOpt\": [\"credentialspec:file://gmsa_gmsa-acct.json\"]}" + container := &apicontainer.Container{Name: "webapp1"} + container.DockerConfig.HostConfig = &hostConfig + + task := &Task{ + Arn: "test", + Containers: []*apicontainer.Container{container}, + } + + credentialSpecContainerMap := task.getAllCredentialSpecRequirements() + + credentialspecFileLocation := "credentialspec:file://gmsa_gmsa-acct.json" + expectedCredentialSpecContainerMap := map[string]string{credentialspecFileLocation: "webapp1"} + + assert.True(t, reflect.DeepEqual(expectedCredentialSpecContainerMap, credentialSpecContainerMap)) +} + +func TestGetAllCredentialSpecRequirementsWithMultipleContainersUsingSameSpec(t *testing.T) { + hostConfig := "{\"SecurityOpt\": [\"credentialspec:file://gmsa_gmsa-acct.json\"]}" + c1 := &apicontainer.Container{Name: "webapp1"} + c1.DockerConfig.HostConfig = &hostConfig + + c2 := &apicontainer.Container{Name: "webapp2"} + c2.DockerConfig.HostConfig = &hostConfig + + task := &Task{ + Arn: "test", + Containers: []*apicontainer.Container{c1, c2}, + } + + credentialSpecContainerMap := task.getAllCredentialSpecRequirements() + + credentialspecFileLocation := "credentialspec:file://gmsa_gmsa-acct.json" + expectedCredentialSpecContainerMap := map[string]string{credentialspecFileLocation: "webapp2"} + + assert.Equal(t, len(expectedCredentialSpecContainerMap), len(credentialSpecContainerMap)) + assert.True(t, reflect.DeepEqual(expectedCredentialSpecContainerMap, credentialSpecContainerMap)) +} + +func TestGetAllCredentialSpecRequirementsWithMultipleContainers(t *testing.T) { + hostConfig1 := "{\"SecurityOpt\": [\"credentialspec:file://gmsa_gmsa-acct-1.json\"]}" + hostConfig2 := "{\"SecurityOpt\": [\"credentialspec:file://gmsa_gmsa-acct-2.json\"]}" + + c1 := &apicontainer.Container{Name: "webapp1"} + c1.DockerConfig.HostConfig = &hostConfig1 + + c2 := &apicontainer.Container{Name: "webapp2"} + c2.DockerConfig.HostConfig = &hostConfig1 + + c3 := &apicontainer.Container{Name: "webapp3"} + c3.DockerConfig.HostConfig = &hostConfig2 + + task := &Task{ + Arn: "test", + Containers: []*apicontainer.Container{c1, c2, c3}, + } + + credentialSpecContainerMap := task.getAllCredentialSpecRequirements() + + credentialspec1 := "credentialspec:file://gmsa_gmsa-acct-1.json" + credentialspec2 := "credentialspec:file://gmsa_gmsa-acct-2.json" + + expectedCredentialSpecContainerMap := map[string]string{credentialspec1: "webapp2", credentialspec2: "webapp3"} + + assert.True(t, reflect.DeepEqual(expectedCredentialSpecContainerMap, credentialSpecContainerMap)) +} + +func TestGetCredentialSpecResource(t *testing.T) { + credentialspecResource := &credentialspec.CredentialSpecResource{} + task := &Task{ + ResourcesMapUnsafe: make(map[string][]taskresource.TaskResource), + } + task.AddResource(credentialspec.ResourceName, credentialspecResource) + + credentialspecTaskResource, ok := task.GetCredentialSpecResource() + assert.True(t, ok) + assert.NotEmpty(t, credentialspecTaskResource) +} diff --git a/agent/api/task/task_windows.go b/agent/api/task/task_windows.go index 04b11f760fb..d5fa0e2bc63 100644 --- a/agent/api/task/task_windows.go +++ b/agent/api/task/task_windows.go @@ -139,17 +139,6 @@ func (task *Task) initializeCgroupResourceSpec(cgroupPath string, cGroupCPUPerio return errors.New("unsupported platform") } -// requiresCredentialSpecResource returns true if at least one container in the task -// needs a valid credentialspec resource -func (task *Task) requiresCredentialSpecResource() bool { - for _, container := range task.Containers { - if container.RequiresCredentialSpec() { - return true - } - } - return false -} - // initializeCredentialSpecResource builds the resource dependency map for the credentialspec resource func (task *Task) initializeCredentialSpecResource(config *config.Config, credentialsManager credentials.Manager, resourceFields *taskresource.ResourceFields) error { @@ -174,27 +163,6 @@ func (task *Task) initializeCredentialSpecResource(config *config.Config, creden return nil } -// getAllCredentialSpecRequirements is used to build all the credential spec requirements for the task -func (task *Task) getAllCredentialSpecRequirements() map[string]string { - reqsContainerMap := make(map[string]string) - for _, container := range task.Containers { - credentialSpec, err := container.GetCredentialSpec() - if err == nil && credentialSpec != "" { - reqsContainerMap[credentialSpec] = container.Name - } - } - return reqsContainerMap -} - -// GetCredentialSpecResource retrieves credentialspec resource from resource map -func (task *Task) GetCredentialSpecResource() ([]taskresource.TaskResource, bool) { - task.lock.RLock() - defer task.lock.RUnlock() - - res, ok := task.ResourcesMapUnsafe[credentialspec.ResourceName] - return res, ok -} - func enableIPv6SysctlSetting(hostConfig *dockercontainer.HostConfig) { return } diff --git a/agent/api/task/task_windows_test.go b/agent/api/task/task_windows_test.go index fb22d12420c..4df599902b7 100644 --- a/agent/api/task/task_windows_test.go +++ b/agent/api/task/task_windows_test.go @@ -19,7 +19,6 @@ package task import ( "encoding/json" "fmt" - "reflect" "runtime" "testing" @@ -356,114 +355,6 @@ func TestGetCanonicalPath(t *testing.T) { } } -func TestRequiresCredentialSpecResource(t *testing.T) { - container1 := &apicontainer.Container{} - task1 := &Task{ - Arn: "test", - Containers: []*apicontainer.Container{container1}, - } - - hostConfig := "{\"SecurityOpt\": [\"credentialspec:file://gmsa_gmsa-acct.json\"]}" - container2 := &apicontainer.Container{} - container2.DockerConfig.HostConfig = &hostConfig - task2 := &Task{ - Arn: "test", - Containers: []*apicontainer.Container{container2}, - } - - testCases := []struct { - name string - task *Task - expectedOutput bool - }{ - { - name: "missing_credentialspec", - task: task1, - expectedOutput: false, - }, - { - name: "valid_credentialspec", - task: task2, - expectedOutput: true, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - assert.Equal(t, tc.expectedOutput, tc.task.requiresCredentialSpecResource()) - }) - } - -} - -func TestGetAllCredentialSpecRequirements(t *testing.T) { - hostConfig := "{\"SecurityOpt\": [\"credentialspec:file://gmsa_gmsa-acct.json\"]}" - container := &apicontainer.Container{Name: "webapp1"} - container.DockerConfig.HostConfig = &hostConfig - - task := &Task{ - Arn: "test", - Containers: []*apicontainer.Container{container}, - } - - credentialSpecContainerMap := task.getAllCredentialSpecRequirements() - - credentialspecFileLocation := "credentialspec:file://gmsa_gmsa-acct.json" - expectedCredentialSpecContainerMap := map[string]string{credentialspecFileLocation: "webapp1"} - - assert.True(t, reflect.DeepEqual(expectedCredentialSpecContainerMap, credentialSpecContainerMap)) -} - -func TestGetAllCredentialSpecRequirementsWithMultipleContainersUsingSameSpec(t *testing.T) { - hostConfig := "{\"SecurityOpt\": [\"credentialspec:file://gmsa_gmsa-acct.json\"]}" - c1 := &apicontainer.Container{Name: "webapp1"} - c1.DockerConfig.HostConfig = &hostConfig - - c2 := &apicontainer.Container{Name: "webapp2"} - c2.DockerConfig.HostConfig = &hostConfig - - task := &Task{ - Arn: "test", - Containers: []*apicontainer.Container{c1, c2}, - } - - credentialSpecContainerMap := task.getAllCredentialSpecRequirements() - - credentialspecFileLocation := "credentialspec:file://gmsa_gmsa-acct.json" - expectedCredentialSpecContainerMap := map[string]string{credentialspecFileLocation: "webapp2"} - - assert.Equal(t, len(expectedCredentialSpecContainerMap), len(credentialSpecContainerMap)) - assert.True(t, reflect.DeepEqual(expectedCredentialSpecContainerMap, credentialSpecContainerMap)) -} - -func TestGetAllCredentialSpecRequirementsWithMultipleContainers(t *testing.T) { - hostConfig1 := "{\"SecurityOpt\": [\"credentialspec:file://gmsa_gmsa-acct-1.json\"]}" - hostConfig2 := "{\"SecurityOpt\": [\"credentialspec:file://gmsa_gmsa-acct-2.json\"]}" - - c1 := &apicontainer.Container{Name: "webapp1"} - c1.DockerConfig.HostConfig = &hostConfig1 - - c2 := &apicontainer.Container{Name: "webapp2"} - c2.DockerConfig.HostConfig = &hostConfig1 - - c3 := &apicontainer.Container{Name: "webapp3"} - c3.DockerConfig.HostConfig = &hostConfig2 - - task := &Task{ - Arn: "test", - Containers: []*apicontainer.Container{c1, c2, c3}, - } - - credentialSpecContainerMap := task.getAllCredentialSpecRequirements() - - credentialspec1 := "credentialspec:file://gmsa_gmsa-acct-1.json" - credentialspec2 := "credentialspec:file://gmsa_gmsa-acct-2.json" - - expectedCredentialSpecContainerMap := map[string]string{credentialspec1: "webapp2", credentialspec2: "webapp3"} - - assert.True(t, reflect.DeepEqual(expectedCredentialSpecContainerMap, credentialSpecContainerMap)) -} - func TestInitializeAndGetCredentialSpecResource(t *testing.T) { hostConfig := "{\"SecurityOpt\": [\"credentialspec:file://gmsa_gmsa-acct.json\"]}" container := &apicontainer.Container{ @@ -510,18 +401,6 @@ func TestInitializeAndGetCredentialSpecResource(t *testing.T) { assert.True(t, ok) } -func TestGetCredentialSpecResource(t *testing.T) { - credentialspecResource := &credentialspec.CredentialSpecResource{} - task := &Task{ - ResourcesMapUnsafe: make(map[string][]taskresource.TaskResource), - } - task.AddResource(credentialspec.ResourceName, credentialspecResource) - - credentialspecTaskResource, ok := task.GetCredentialSpecResource() - assert.True(t, ok) - assert.NotEmpty(t, credentialspecTaskResource) -} - func TestRequiresFSxWindowsFileServerResource(t *testing.T) { task1 := &Task{ Arn: "test1", diff --git a/agent/s3/factory/factory.go b/agent/s3/factory/factory.go index b6ae683860c..3c7f42670c9 100644 --- a/agent/s3/factory/factory.go +++ b/agent/s3/factory/factory.go @@ -33,35 +33,48 @@ const ( ) type S3ClientCreator interface { - NewS3ClientForBucket(bucket, region string, creds credentials.IAMRoleCredentials) (s3client.S3Client, error) + NewS3ManagerClient(bucket, region string, creds credentials.IAMRoleCredentials) (s3client.S3ManagerClient, error) + NewS3Client(region string, creds credentials.IAMRoleCredentials) s3client.S3Client } +// NewS3ClientCreator provide 2 implementations +// NewS3ManagerClient implements methods from aws-sdk-go/service/s3manager. +// NewS3Client implements methods from aws-sdk-go/service/s3. func NewS3ClientCreator() S3ClientCreator { return &s3ClientCreator{} } type s3ClientCreator struct{} -// NewS3Client returns a new S3 client based on the region of the bucket. -func (*s3ClientCreator) NewS3ClientForBucket(bucket, region string, - creds credentials.IAMRoleCredentials) (s3client.S3Client, error) { +// NewS3ManagerClient returns a new S3 client based on the region of the bucket. +func (*s3ClientCreator) NewS3ManagerClient(bucket, region string, + creds credentials.IAMRoleCredentials) (s3client.S3ManagerClient, error) { cfg := aws.NewConfig(). WithHTTPClient(httpclient.New(roundtripTimeout, false)). WithCredentials( awscreds.NewStaticCredentials(creds.AccessKeyID, creds.SecretAccessKey, creds.SessionToken)).WithRegion(region) sess := session.Must(session.NewSession(cfg)) - svc := s3.New(sess) bucketRegion, err := getRegionFromBucket(svc, bucket) if err != nil { return nil, err } - sessWithRegion := session.Must(session.NewSession(cfg.WithRegion(bucketRegion))) return s3manager.NewDownloaderWithClient(s3.New(sessWithRegion)), nil } +// NewS3Client returns a new S3 client to support s3 operations which are not provided by s3manager. +func (*s3ClientCreator) NewS3Client(region string, + creds credentials.IAMRoleCredentials) s3client.S3Client { + cfg := aws.NewConfig(). + WithHTTPClient(httpclient.New(roundtripTimeout, false)). + WithCredentials( + awscreds.NewStaticCredentials(creds.AccessKeyID, creds.SecretAccessKey, + creds.SessionToken)).WithRegion(region) + sess := session.Must(session.NewSession(cfg)) + return s3.New(sess) +} func getRegionFromBucket(svc *s3.S3, bucket string) (string, error) { input := &s3.GetBucketLocationInput{ Bucket: aws.String(bucket), @@ -73,6 +86,5 @@ func getRegionFromBucket(svc *s3.S3, bucket string) (string, error) { if result.LocationConstraint == nil { // GetBucketLocation returns nil for bucket in us-east-1. return bucketLocationDefault, nil } - return aws.StringValue(result.LocationConstraint), nil } diff --git a/agent/s3/factory/mocks/factory_mocks.go b/agent/s3/factory/mocks/factory_mocks.go index 1b675987b2d..d19603715a8 100644 --- a/agent/s3/factory/mocks/factory_mocks.go +++ b/agent/s3/factory/mocks/factory_mocks.go @@ -49,17 +49,31 @@ func (m *MockS3ClientCreator) EXPECT() *MockS3ClientCreatorMockRecorder { return m.recorder } -// NewS3ClientForBucket mocks base method -func (m *MockS3ClientCreator) NewS3ClientForBucket(arg0, arg1 string, arg2 credentials.IAMRoleCredentials) (s3.S3Client, error) { +// NewS3Client mocks base method +func (m *MockS3ClientCreator) NewS3Client(arg0 string, arg1 credentials.IAMRoleCredentials) s3.S3Client { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NewS3ClientForBucket", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "NewS3Client", arg0, arg1) ret0, _ := ret[0].(s3.S3Client) + return ret0 +} + +// NewS3Client indicates an expected call of NewS3Client +func (mr *MockS3ClientCreatorMockRecorder) NewS3Client(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewS3Client", reflect.TypeOf((*MockS3ClientCreator)(nil).NewS3Client), arg0, arg1) +} + +// NewS3ManagerClient mocks base method +func (m *MockS3ClientCreator) NewS3ManagerClient(arg0, arg1 string, arg2 credentials.IAMRoleCredentials) (s3.S3ManagerClient, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewS3ManagerClient", arg0, arg1, arg2) + ret0, _ := ret[0].(s3.S3ManagerClient) ret1, _ := ret[1].(error) return ret0, ret1 } -// NewS3ClientForBucket indicates an expected call of NewS3ClientForBucket -func (mr *MockS3ClientCreatorMockRecorder) NewS3ClientForBucket(arg0, arg1, arg2 interface{}) *gomock.Call { +// NewS3ManagerClient indicates an expected call of NewS3ManagerClient +func (mr *MockS3ClientCreatorMockRecorder) NewS3ManagerClient(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewS3ClientForBucket", reflect.TypeOf((*MockS3ClientCreator)(nil).NewS3ClientForBucket), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewS3ManagerClient", reflect.TypeOf((*MockS3ClientCreator)(nil).NewS3ManagerClient), arg0, arg1, arg2) } diff --git a/agent/s3/generate_mocks.go b/agent/s3/generate_mocks.go index ba2fccdbe03..1a540728623 100644 --- a/agent/s3/generate_mocks.go +++ b/agent/s3/generate_mocks.go @@ -13,4 +13,5 @@ package s3 +//go:generate mockgen -destination=mocks/s3manager/s3_mocks.go -copyright_file=../../scripts/copyright_file github.com/aws/amazon-ecs-agent/agent/s3 S3ManagerClient //go:generate mockgen -destination=mocks/s3_mocks.go -copyright_file=../../scripts/copyright_file github.com/aws/amazon-ecs-agent/agent/s3 S3Client diff --git a/agent/s3/interface.go b/agent/s3/interface.go index f68807e79b6..16eab56df2a 100644 --- a/agent/s3/interface.go +++ b/agent/s3/interface.go @@ -21,7 +21,14 @@ import ( "github.com/aws/aws-sdk-go/service/s3/s3manager" ) -// S3Client interface wraps the S3 API. -type S3Client interface { +// S3ManagerClient interface wraps the S3Manager APIs. +// Any method that belongs aws-sdk-go/service/s3manager goes here. +type S3ManagerClient interface { DownloadWithContext(ctx aws.Context, w io.WriterAt, input *s3.GetObjectInput, options ...func(*s3manager.Downloader)) (n int64, err error) } + +// S3Client interface wraps the generic S3 APIs. +// Any method that belongs to aws-sdk-go/service/s3 goes here. +type S3Client interface { + GetObject(*s3.GetObjectInput) (*s3.GetObjectOutput, error) +} diff --git a/agent/s3/mocks/s3_mocks.go b/agent/s3/mocks/s3_mocks.go index 97f6de8c006..6bab27370e6 100644 --- a/agent/s3/mocks/s3_mocks.go +++ b/agent/s3/mocks/s3_mocks.go @@ -19,12 +19,9 @@ package mock_s3 import ( - context "context" - io "io" reflect "reflect" s3 "github.com/aws/aws-sdk-go/service/s3" - s3manager "github.com/aws/aws-sdk-go/service/s3/s3manager" gomock "github.com/golang/mock/gomock" ) @@ -51,22 +48,17 @@ func (m *MockS3Client) EXPECT() *MockS3ClientMockRecorder { return m.recorder } -// DownloadWithContext mocks base method -func (m *MockS3Client) DownloadWithContext(arg0 context.Context, arg1 io.WriterAt, arg2 *s3.GetObjectInput, arg3 ...func(*s3manager.Downloader)) (int64, error) { +// GetObject mocks base method +func (m *MockS3Client) GetObject(arg0 *s3.GetObjectInput) (*s3.GetObjectOutput, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2} - for _, a := range arg3 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "DownloadWithContext", varargs...) - ret0, _ := ret[0].(int64) + ret := m.ctrl.Call(m, "GetObject", arg0) + ret0, _ := ret[0].(*s3.GetObjectOutput) ret1, _ := ret[1].(error) return ret0, ret1 } -// DownloadWithContext indicates an expected call of DownloadWithContext -func (mr *MockS3ClientMockRecorder) DownloadWithContext(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { +// GetObject indicates an expected call of GetObject +func (mr *MockS3ClientMockRecorder) GetObject(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DownloadWithContext", reflect.TypeOf((*MockS3Client)(nil).DownloadWithContext), varargs...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetObject", reflect.TypeOf((*MockS3Client)(nil).GetObject), arg0) } diff --git a/agent/s3/mocks/s3manager/s3_mocks.go b/agent/s3/mocks/s3manager/s3_mocks.go new file mode 100644 index 00000000000..3ca9b4cd8f6 --- /dev/null +++ b/agent/s3/mocks/s3manager/s3_mocks.go @@ -0,0 +1,72 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. +// + +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/aws/amazon-ecs-agent/agent/s3 (interfaces: S3ManagerClient) + +// Package mock_s3 is a generated GoMock package. +package mock_s3 + +import ( + context "context" + io "io" + reflect "reflect" + + s3 "github.com/aws/aws-sdk-go/service/s3" + s3manager "github.com/aws/aws-sdk-go/service/s3/s3manager" + gomock "github.com/golang/mock/gomock" +) + +// MockS3ManagerClient is a mock of S3ManagerClient interface +type MockS3ManagerClient struct { + ctrl *gomock.Controller + recorder *MockS3ManagerClientMockRecorder +} + +// MockS3ManagerClientMockRecorder is the mock recorder for MockS3ManagerClient +type MockS3ManagerClientMockRecorder struct { + mock *MockS3ManagerClient +} + +// NewMockS3ManagerClient creates a new mock instance +func NewMockS3ManagerClient(ctrl *gomock.Controller) *MockS3ManagerClient { + mock := &MockS3ManagerClient{ctrl: ctrl} + mock.recorder = &MockS3ManagerClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockS3ManagerClient) EXPECT() *MockS3ManagerClientMockRecorder { + return m.recorder +} + +// DownloadWithContext mocks base method +func (m *MockS3ManagerClient) DownloadWithContext(arg0 context.Context, arg1 io.WriterAt, arg2 *s3.GetObjectInput, arg3 ...func(*s3manager.Downloader)) (int64, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1, arg2} + for _, a := range arg3 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "DownloadWithContext", varargs...) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DownloadWithContext indicates an expected call of DownloadWithContext +func (mr *MockS3ManagerClientMockRecorder) DownloadWithContext(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DownloadWithContext", reflect.TypeOf((*MockS3ManagerClient)(nil).DownloadWithContext), varargs...) +} diff --git a/agent/s3/s3.go b/agent/s3/s3.go index 71c86b2579f..edd58639b3e 100644 --- a/agent/s3/s3.go +++ b/agent/s3/s3.go @@ -29,7 +29,7 @@ const ( ) // DownloadFile downloads a file from s3 and writes it with the writer. -func DownloadFile(bucket, key string, timeout time.Duration, w io.WriterAt, client S3Client) error { +func DownloadFile(bucket, key string, timeout time.Duration, w io.WriterAt, client S3ManagerClient) error { input := &s3.GetObjectInput{ Bucket: aws.String(bucket), Key: aws.String(key), @@ -50,3 +50,24 @@ func ParseS3ARN(s3ARN string) (bucket string, key string, err error) { } return match[2], match[3], nil } + +func GetObject(bucket string, key string, client S3Client) (string, error) { + requestInput := &s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + } + + result, err := client.GetObject(requestInput) + if err != nil { + return "", err + } + + defer result.Body.Close() + resultBody, err := io.ReadAll(result.Body) + if err != nil { + return "", err + } + credSpecData := string(resultBody) + + return credSpecData, nil +} diff --git a/agent/s3/s3_test.go b/agent/s3/s3_test.go index b44a221651c..6186255c788 100644 --- a/agent/s3/s3_test.go +++ b/agent/s3/s3_test.go @@ -19,6 +19,7 @@ package s3 import ( "errors" "io" + "strings" "testing" "time" @@ -28,6 +29,7 @@ import ( "github.com/stretchr/testify/assert" mock_s3 "github.com/aws/amazon-ecs-agent/agent/s3/mocks" + mock_s3manager "github.com/aws/amazon-ecs-agent/agent/s3/mocks/s3manager" mock_oswrapper "github.com/aws/amazon-ecs-agent/agent/utils/oswrapper/mocks" ) @@ -42,15 +44,15 @@ func TestDownloadFile(t *testing.T) { defer ctrl.Finish() mockFile := mock_oswrapper.NewMockFile() - mockS3Client := mock_s3.NewMockS3Client(ctrl) + mockS3ManagerClient := mock_s3manager.NewMockS3ManagerClient(ctrl) - mockS3Client.EXPECT().DownloadWithContext(gomock.Any(), mockFile, gomock.Any()).Do(func(ctx aws.Context, + mockS3ManagerClient.EXPECT().DownloadWithContext(gomock.Any(), mockFile, gomock.Any()).Do(func(ctx aws.Context, w io.WriterAt, input *s3sdk.GetObjectInput) { assert.Equal(t, testBucket, aws.StringValue(input.Bucket)) assert.Equal(t, testKey, aws.StringValue(input.Key)) }) - err := DownloadFile(testBucket, testKey, testTimeout, mockFile, mockS3Client) + err := DownloadFile(testBucket, testKey, testTimeout, mockFile, mockS3ManagerClient) assert.NoError(t, err) } @@ -58,12 +60,12 @@ func TestDownloadFileError(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockS3Client := mock_s3.NewMockS3Client(ctrl) + mockS3ManagerClient := mock_s3manager.NewMockS3ManagerClient(ctrl) mockFile := mock_oswrapper.NewMockFile() - mockS3Client.EXPECT().DownloadWithContext(gomock.Any(), mockFile, gomock.Any()).Return(int64(0), errors.New("test error")) + mockS3ManagerClient.EXPECT().DownloadWithContext(gomock.Any(), mockFile, gomock.Any()).Return(int64(0), errors.New("test error")) - err := DownloadFile(testBucket, testKey, testTimeout, mockFile, mockS3Client) + err := DownloadFile(testBucket, testKey, testTimeout, mockFile, mockS3ManagerClient) assert.Error(t, err) } @@ -78,3 +80,31 @@ func TestParseS3ARNInvalid(t *testing.T) { _, _, err := ParseS3ARN("arn:aws:xxx:::xxx") assert.Error(t, err) } + +func TestGetObject(t *testing.T) { + expectedValue := "testdata" + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockGetObjectResponse := &s3sdk.GetObjectOutput{ + Body: io.NopCloser(strings.NewReader(expectedValue)), + } + mockS3Client := mock_s3.NewMockS3Client(ctrl) + mockS3Client.EXPECT().GetObject(gomock.Any()).Return(mockGetObjectResponse, nil) + + actualValue, err := GetObject(testBucket, testKey, mockS3Client) + assert.NoError(t, err) + assert.Equal(t, actualValue, expectedValue) +} + +func TestGetObjectErr(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockGetObjectResponse := &s3sdk.GetObjectOutput{} + mockS3Client := mock_s3.NewMockS3Client(ctrl) + mockS3Client.EXPECT().GetObject(gomock.Any()).Return(mockGetObjectResponse, errors.New("test error")) + + _, err := GetObject(testBucket, testKey, mockS3Client) + assert.Error(t, err) +} diff --git a/agent/taskresource/credentialspec/credentialspec_windows.go b/agent/taskresource/credentialspec/credentialspec_windows.go index 41874c579c9..ac606b8a83c 100644 --- a/agent/taskresource/credentialspec/credentialspec_windows.go +++ b/agent/taskresource/credentialspec/credentialspec_windows.go @@ -398,7 +398,7 @@ func (cs *CredentialSpecResource) handleS3CredentialspecFile(originalCredentials return err } - s3Client, err := cs.s3ClientCreator.NewS3ClientForBucket(bucket, cs.region, iamCredentials) + s3Client, err := cs.s3ClientCreator.NewS3ManagerClient(bucket, cs.region, iamCredentials) if err != nil { cs.setTerminalReason(err.Error()) return err diff --git a/agent/taskresource/credentialspec/credentialspec_windows_test.go b/agent/taskresource/credentialspec/credentialspec_windows_test.go index b67f6509784..15dfc60b7d6 100644 --- a/agent/taskresource/credentialspec/credentialspec_windows_test.go +++ b/agent/taskresource/credentialspec/credentialspec_windows_test.go @@ -28,7 +28,7 @@ import ( "github.com/aws/amazon-ecs-agent/agent/credentials" mock_credentials "github.com/aws/amazon-ecs-agent/agent/credentials/mocks" mock_s3_factory "github.com/aws/amazon-ecs-agent/agent/s3/factory/mocks" - mock_s3 "github.com/aws/amazon-ecs-agent/agent/s3/mocks" + mock_s3 "github.com/aws/amazon-ecs-agent/agent/s3/mocks/s3manager" mock_factory "github.com/aws/amazon-ecs-agent/agent/ssm/factory/mocks" mock_ssmiface "github.com/aws/amazon-ecs-agent/agent/ssm/mocks" "github.com/aws/amazon-ecs-agent/agent/taskresource" @@ -464,7 +464,7 @@ func TestHandleS3CredentialspecFile(t *testing.T) { s3ClientCreator := mock_s3_factory.NewMockS3ClientCreator(ctrl) mockIO := mock_ioutilwrapper.NewMockIOUtil(ctrl) mockFile := mock_oswrapper.NewMockFile() - mockS3Client := mock_s3.NewMockS3Client(ctrl) + mockS3Client := mock_s3.NewMockS3ManagerClient(ctrl) iamCredentials := credentials.IAMRoleCredentials{ CredentialsID: "test-cred-id", } @@ -495,7 +495,7 @@ func TestHandleS3CredentialspecFile(t *testing.T) { return testTempFile } gomock.InOrder( - s3ClientCreator.EXPECT().NewS3ClientForBucket(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockS3Client, nil), + s3ClientCreator.EXPECT().NewS3ManagerClient(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockS3Client, nil), mockIO.EXPECT().TempFile(gomock.Any(), gomock.Any()).Return(mockFile, nil), mockS3Client.EXPECT().DownloadWithContext(gomock.Any(), mockFile, gomock.Any()).Return(int64(0), nil), ) @@ -531,7 +531,7 @@ func TestHandleS3CredentialspecFileS3ClientErr(t *testing.T) { credentialsManager := mock_credentials.NewMockManager(ctrl) ssmClientCreator := mock_factory.NewMockSSMClientCreator(ctrl) s3ClientCreator := mock_s3_factory.NewMockS3ClientCreator(ctrl) - mockS3Client := mock_s3.NewMockS3Client(ctrl) + mockS3Client := mock_s3.NewMockS3ManagerClient(ctrl) iamCredentials := credentials.IAMRoleCredentials{ CredentialsID: "test-cred-id", } @@ -551,7 +551,7 @@ func TestHandleS3CredentialspecFileS3ClientErr(t *testing.T) { }, apitaskstatus.TaskStatusNone, apitaskstatus.TaskRunning) gomock.InOrder( - s3ClientCreator.EXPECT().NewS3ClientForBucket(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockS3Client, errors.New("test-error")), + s3ClientCreator.EXPECT().NewS3ManagerClient(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockS3Client, errors.New("test-error")), ) err := cs.handleS3CredentialspecFile(s3CredentialSpec, credentialSpecS3ARN, iamCredentials) @@ -567,7 +567,7 @@ func TestHandleS3CredentialspecFileWriteErr(t *testing.T) { s3ClientCreator := mock_s3_factory.NewMockS3ClientCreator(ctrl) mockIO := mock_ioutilwrapper.NewMockIOUtil(ctrl) mockFile := mock_oswrapper.NewMockFile() - mockS3Client := mock_s3.NewMockS3Client(ctrl) + mockS3Client := mock_s3.NewMockS3ManagerClient(ctrl) iamCredentials := credentials.IAMRoleCredentials{ CredentialsID: "test-cred-id", @@ -605,7 +605,7 @@ func TestHandleS3CredentialspecFileWriteErr(t *testing.T) { }() gomock.InOrder( - s3ClientCreator.EXPECT().NewS3ClientForBucket(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockS3Client, nil), + s3ClientCreator.EXPECT().NewS3ManagerClient(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockS3Client, nil), mockIO.EXPECT().TempFile(gomock.Any(), gomock.Any()).Return(mockFile, nil), mockS3Client.EXPECT().DownloadWithContext(gomock.Any(), mockFile, gomock.Any()).Return(int64(0), nil), ) @@ -692,7 +692,7 @@ func TestCreateS3(t *testing.T) { s3ClientCreator := mock_s3_factory.NewMockS3ClientCreator(ctrl) mockIO := mock_ioutilwrapper.NewMockIOUtil(ctrl) mockFile := mock_oswrapper.NewMockFile() - mockS3Client := mock_s3.NewMockS3Client(ctrl) + mockS3Client := mock_s3.NewMockS3ManagerClient(ctrl) s3CredentialSpec := "credentialspec:arn:aws:s3:::bucket_name/test" @@ -725,7 +725,7 @@ func TestCreateS3(t *testing.T) { defer mockRename()() gomock.InOrder( credentialsManager.EXPECT().GetTaskCredentials(gomock.Any()).Return(creds, true), - s3ClientCreator.EXPECT().NewS3ClientForBucket(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockS3Client, nil), + s3ClientCreator.EXPECT().NewS3ManagerClient(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockS3Client, nil), mockIO.EXPECT().TempFile(gomock.Any(), gomock.Any()).Return(mockFile, nil), mockS3Client.EXPECT().DownloadWithContext(gomock.Any(), mockFile, gomock.Any()).Return(int64(0), nil), ) diff --git a/agent/taskresource/envFiles/envfile.go b/agent/taskresource/envFiles/envfile.go index fe56ad67410..546844d02b8 100644 --- a/agent/taskresource/envFiles/envfile.go +++ b/agent/taskresource/envFiles/envfile.go @@ -354,7 +354,7 @@ func (envfile *EnvironmentFileResource) downloadEnvfileFromS3(envFilePath string return } - s3Client, err := envfile.s3ClientCreator.NewS3ClientForBucket(bucket, envfile.region, iamCredentials) + s3Client, err := envfile.s3ClientCreator.NewS3ManagerClient(bucket, envfile.region, iamCredentials) if err != nil { errorEvents <- fmt.Errorf("unable to initialize s3 client for bucket %s, error: %v", bucket, err) return diff --git a/agent/taskresource/envFiles/envfile_test.go b/agent/taskresource/envFiles/envfile_test.go index 6e15702de62..ea94c32da11 100644 --- a/agent/taskresource/envFiles/envfile_test.go +++ b/agent/taskresource/envFiles/envfile_test.go @@ -29,7 +29,7 @@ import ( "github.com/aws/amazon-ecs-agent/agent/credentials" mock_credentials "github.com/aws/amazon-ecs-agent/agent/credentials/mocks" mock_factory "github.com/aws/amazon-ecs-agent/agent/s3/factory/mocks" - mock_s3 "github.com/aws/amazon-ecs-agent/agent/s3/mocks" + mock_s3 "github.com/aws/amazon-ecs-agent/agent/s3/mocks/s3manager" "github.com/aws/amazon-ecs-agent/agent/taskresource" mock_bufio "github.com/aws/amazon-ecs-agent/agent/utils/bufiowrapper/mocks" mock_ioutilwrapper "github.com/aws/amazon-ecs-agent/agent/utils/ioutilwrapper/mocks" @@ -58,14 +58,14 @@ const ( ) func setup(t *testing.T) (oswrapper.File, *mock_ioutilwrapper.MockIOUtil, - *mock_credentials.MockManager, *mock_factory.MockS3ClientCreator, *mock_s3.MockS3Client, func()) { + *mock_credentials.MockManager, *mock_factory.MockS3ClientCreator, *mock_s3.MockS3ManagerClient, func()) { ctrl := gomock.NewController(t) mockFile := mock_oswrapper.NewMockFile() mockIOUtil := mock_ioutilwrapper.NewMockIOUtil(ctrl) mockCredentialsManager := mock_credentials.NewMockManager(ctrl) mockS3ClientCreator := mock_factory.NewMockS3ClientCreator(ctrl) - mockS3Client := mock_s3.NewMockS3Client(ctrl) + mockS3Client := mock_s3.NewMockS3ManagerClient(ctrl) return mockFile, mockIOUtil, mockCredentialsManager, mockS3ClientCreator, mockS3Client, ctrl.Finish } @@ -139,7 +139,7 @@ func TestCreateWithEnvVarFile(t *testing.T) { gomock.InOrder( mockCredentialsManager.EXPECT().GetTaskCredentials(executionCredentialsID).Return(creds, true), - mockS3ClientCreator.EXPECT().NewS3ClientForBucket(s3Bucket, region, creds.IAMRoleCredentials).Return(mockS3Client, nil), + mockS3ClientCreator.EXPECT().NewS3ManagerClient(s3Bucket, region, creds.IAMRoleCredentials).Return(mockS3Client, nil), mockIOUtil.EXPECT().TempFile(resourceDir, gomock.Any()).Return(mockFile, nil), mockS3Client.EXPECT().DownloadWithContext(gomock.Any(), mockFile, gomock.Any()).Do( func(ctx aws.Context, w io.WriterAt, input *s3.GetObjectInput) { @@ -193,7 +193,7 @@ func TestCreateUnableToRetrieveDataFromS3(t *testing.T) { gomock.InOrder( mockCredentialsManager.EXPECT().GetTaskCredentials(executionCredentialsID).Return(creds, true), - mockS3ClientCreator.EXPECT().NewS3ClientForBucket(s3Bucket, region, creds.IAMRoleCredentials).Return(mockS3Client, nil), + mockS3ClientCreator.EXPECT().NewS3ManagerClient(s3Bucket, region, creds.IAMRoleCredentials).Return(mockS3Client, nil), mockIOUtil.EXPECT().TempFile(resourceDir, gomock.Any()).Return(mockFile, nil), mockS3Client.EXPECT().DownloadWithContext(gomock.Any(), mockFile, gomock.Any()).Return(int64(0), errors.New("error response")), ) @@ -221,7 +221,7 @@ func TestCreateUnableToCreateTmpFile(t *testing.T) { gomock.InOrder( mockCredentialsManager.EXPECT().GetTaskCredentials(executionCredentialsID).Return(creds, true), - mockS3ClientCreator.EXPECT().NewS3ClientForBucket(s3Bucket, region, creds.IAMRoleCredentials).Return(mockS3Client, nil), + mockS3ClientCreator.EXPECT().NewS3ManagerClient(s3Bucket, region, creds.IAMRoleCredentials).Return(mockS3Client, nil), mockIOUtil.EXPECT().TempFile(resourceDir, gomock.Any()).Return(nil, errors.New("error response")), ) @@ -256,7 +256,7 @@ func TestCreateRenameFileError(t *testing.T) { gomock.InOrder( mockCredentialsManager.EXPECT().GetTaskCredentials(executionCredentialsID).Return(creds, true), - mockS3ClientCreator.EXPECT().NewS3ClientForBucket(s3Bucket, region, creds.IAMRoleCredentials).Return(mockS3Client, nil), + mockS3ClientCreator.EXPECT().NewS3ManagerClient(s3Bucket, region, creds.IAMRoleCredentials).Return(mockS3Client, nil), mockIOUtil.EXPECT().TempFile(resourceDir, gomock.Any()).Return(mockFile, nil), mockS3Client.EXPECT().DownloadWithContext(gomock.Any(), mockFile, gomock.Any()).Return(int64(0), nil), ) diff --git a/agent/taskresource/firelens/firelens_unix.go b/agent/taskresource/firelens/firelens_unix.go index 062001f377c..e76f8433903 100644 --- a/agent/taskresource/firelens/firelens_unix.go +++ b/agent/taskresource/firelens/firelens_unix.go @@ -494,7 +494,7 @@ func (firelens *FirelensResource) downloadConfigFromS3() error { return errors.Wrap(err, "unable to parse bucket and key from s3 arn") } - s3Client, err := firelens.s3ClientCreator.NewS3ClientForBucket(bucket, firelens.region, creds.GetIAMRoleCredentials()) + s3Client, err := firelens.s3ClientCreator.NewS3ManagerClient(bucket, firelens.region, creds.GetIAMRoleCredentials()) if err != nil { return errors.Wrapf(err, "unable to initialize s3 client for bucket %s", bucket) } diff --git a/agent/taskresource/firelens/firelens_unix_test.go b/agent/taskresource/firelens/firelens_unix_test.go index d7cf8624897..0dc550417c5 100644 --- a/agent/taskresource/firelens/firelens_unix_test.go +++ b/agent/taskresource/firelens/firelens_unix_test.go @@ -32,7 +32,7 @@ import ( "github.com/aws/amazon-ecs-agent/agent/credentials" mock_credentials "github.com/aws/amazon-ecs-agent/agent/credentials/mocks" mock_factory "github.com/aws/amazon-ecs-agent/agent/s3/factory/mocks" - mock_s3 "github.com/aws/amazon-ecs-agent/agent/s3/mocks" + mock_s3 "github.com/aws/amazon-ecs-agent/agent/s3/mocks/s3manager" "github.com/aws/amazon-ecs-agent/agent/taskresource" resourcestatus "github.com/aws/amazon-ecs-agent/agent/taskresource/status" mock_ioutilwrapper "github.com/aws/amazon-ecs-agent/agent/utils/ioutilwrapper/mocks" @@ -70,14 +70,14 @@ var ( ) func setup(t *testing.T) (oswrapper.File, *mock_ioutilwrapper.MockIOUtil, - *mock_credentials.MockManager, *mock_factory.MockS3ClientCreator, *mock_s3.MockS3Client, func()) { + *mock_credentials.MockManager, *mock_factory.MockS3ClientCreator, *mock_s3.MockS3ManagerClient, func()) { ctrl := gomock.NewController(t) mockFile := mock_oswrapper.NewMockFile() mockIOUtil := mock_ioutilwrapper.NewMockIOUtil(ctrl) mockCredentialsManager := mock_credentials.NewMockManager(ctrl) mockS3ClientCreator := mock_factory.NewMockS3ClientCreator(ctrl) - mockS3Client := mock_s3.NewMockS3Client(ctrl) + mockS3Client := mock_s3.NewMockS3ManagerClient(ctrl) return mockFile, mockIOUtil, mockCredentialsManager, mockS3ClientCreator, mockS3Client, ctrl.Finish } @@ -363,7 +363,7 @@ func TestCreateFirelensResourceWithS3Config(t *testing.T) { gomock.InOrder( mockCredentialsManager.EXPECT().GetTaskCredentials(testExecutionCredentialsID).Return(creds, true), - mockS3ClientCreator.EXPECT().NewS3ClientForBucket("bucket", testRegion, creds.IAMRoleCredentials).Return(mockS3Client, nil), + mockS3ClientCreator.EXPECT().NewS3ManagerClient("bucket", testRegion, creds.IAMRoleCredentials).Return(mockS3Client, nil), // write external config file downloaded from s3 mockIOUtil.EXPECT().TempFile(testResourceDir, tempFile).Return(mockFile, nil), mockS3Client.EXPECT().DownloadWithContext(gomock.Any(), mockFile, gomock.Any()).Do( @@ -435,7 +435,7 @@ func TestCreateFirelensResourceWithS3ConfigDownloadFailure(t *testing.T) { } gomock.InOrder( mockCredentialsManager.EXPECT().GetTaskCredentials(testExecutionCredentialsID).Return(creds, true), - mockS3ClientCreator.EXPECT().NewS3ClientForBucket("bucket", testRegion, creds.IAMRoleCredentials).Return(mockS3Client, nil), + mockS3ClientCreator.EXPECT().NewS3ManagerClient("bucket", testRegion, creds.IAMRoleCredentials).Return(mockS3Client, nil), mockIOUtil.EXPECT().TempFile(testResourceDir, tempFile).Return(mockFile, nil), mockS3Client.EXPECT().DownloadWithContext(gomock.Any(), mockFile, gomock.Any()).Return(int64(0), errors.New("test error")), )