From 4ef136edf144591417dfddfb0dbed31e4b476f17 Mon Sep 17 00:00:00 2001 From: Lydia Filipe Date: Wed, 8 Sep 2021 11:34:04 -0700 Subject: [PATCH] prevent instances going into an EC2 Auto Scaling group warm pool from being registered with the cluster --- README.md | 3 +- agent/app/agent.go | 88 ++++++++++++- agent/app/agent_test.go | 141 +++++++++++++++++++-- agent/config/config.go | 1 + agent/config/config_test.go | 2 + agent/config/types.go | 4 + agent/ec2/blackhole_ec2_metadata_client.go | 4 + agent/ec2/ec2_metadata_client.go | 6 + agent/ec2/mocks/ec2_mocks.go | 15 +++ agent/utils/utils.go | 10 ++ agent/utils/utils_test.go | 25 ++++ 11 files changed, 287 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 71e4f9c2f5b..b7d68e4ec95 100644 --- a/README.md +++ b/README.md @@ -195,7 +195,8 @@ additional details on each available environment variable. | `ECS_FSX_WINDOWS_FILE_SERVER_SUPPORTED` | `true` | Whether FSx for Windows File Server volume type is supported on the container instance. This variable is only supported on agent versions 1.47.0 and later. | `false` | `true` | | `ECS_ENABLE_RUNTIME_STATS` | `true` | Determines if [pprof](https://pkg.go.dev/net/http/pprof) is enabled for the agent. If enabled, the different profiles can be accessed through the agent's introspection port (e.g. `curl http://localhost:51678/debug/pprof/heap > heap.pprof`). In addition, agent's [runtime stats](https://pkg.go.dev/runtime#ReadMemStats) are logged to `/var/log/ecs/runtime-stats.log` file. | `false` | `false` | | `ECS_EXCLUDE_IPV6_PORTBINDING` | `true` | Determines if agent should exclude IPv6 port binding using default network mode. If enabled, IPv6 port binding will be filtered out, and the response of DescribeTasks API call will not show tasks' IPv6 port bindings, but it is still included in Task metadata endpoint. | `true` | `true` | - +| `ECS_WARM_POOLS_CHECK` | `true` | Whether to ensure instances going into an [EC2 Auto Scaling group warm pool](https://docs.aws.amazon.com/autoscaling/ec2/userguide/ec2-auto-scaling-warm-pools.html) are prevented from being registered with the cluster. Set to true only if using EC2 Autoscaling | `false` | `false` | + ### Persistence When you run the Amazon ECS Container Agent in production, its `datadir` should be persisted between runs of the Docker diff --git a/agent/app/agent.go b/agent/app/agent.go index 4505c0ab917..bfe71daa52a 100644 --- a/agent/app/agent.go +++ b/agent/app/agent.go @@ -83,6 +83,15 @@ const ( instanceIdBackoffJitter = 0.2 instanceIdBackoffMultiple = 1.3 instanceIdMaxRetryCount = 3 + + targetLifecycleBackoffMin = time.Second + targetLifecycleBackoffMax = time.Second * 5 + targetLifecycleBackoffJitter = 0.2 + targetLifecycleBackoffMultiple = 1.3 + targetLifecycleMaxRetryCount = 3 + inServiceState = "InService" + asgLifecyclePollWait = time.Minute + asgLifecyclePollMax = 120 // given each poll cycle waits for about a minute, this gives 2-3 hours before timing out ) var ( @@ -291,6 +300,19 @@ func (agent *ecsAgent) doStart(containerChangeEventStream *eventstream.EventStre seelog.Criticalf("Unable to initialize new task engine: %v", err) return exitcodes.ExitTerminal } + + // Start termination handler in goroutine + go agent.terminationHandler(state, agent.dataClient, taskEngine, agent.cancel) + + // If part of ASG, wait until instance is being set up to go in service before registering with cluster + if agent.cfg.WarmPoolsSupport.Enabled() { + err := agent.waitUntilInstanceInService(asgLifecyclePollWait, asgLifecyclePollMax, targetLifecycleMaxRetryCount) + if err != nil && err.Error() != blackholed { + seelog.Criticalf("Could not determine target lifecycle of instance: %v", err) + return exitcodes.ExitTerminal + } + } + agent.initMetricsEngine() loadPauseErr := agent.loadPauseContainer() @@ -387,6 +409,70 @@ func (agent *ecsAgent) doStart(containerChangeEventStream *eventstream.EventStre deregisterInstanceEventStream, client, state, taskHandler, doctor) } +// waitUntilInstanceInService Polls IMDS until the target lifecycle state indicates that the instance is going in +// service. This is to avoid instances going to a warm pool being registered as container instances with the cluster +func (agent *ecsAgent) waitUntilInstanceInService(pollWaitDuration time.Duration, pollMaxTimes int, maxRetries int) error { + seelog.Info("Waiting for instance to go InService") + var err error + var targetState string + // Poll until a target lifecycle state is obtained from IMDS, or an unexpected error occurs + targetState, err = agent.pollUntilTargetLifecyclePresent(pollWaitDuration, pollMaxTimes, maxRetries) + if err != nil { + return err + } + // Poll while the instance is in a warmed state until it is going to go into service + for targetState != inServiceState { + time.Sleep(pollWaitDuration) + targetState, err = agent.getTargetLifecycle(maxRetries) + if err != nil { + // Do not exit if error is due to throttling or temporary server errors + // These are likely transient, as at this point IMDS has been successfully queried for state + switch utils.GetRequestFailureStatusCode(err) { + case 429, 500, 502, 503, 504: + seelog.Warnf("Encountered error while waiting for warmed instance to go in service: %v", err) + default: + return err + } + } + } + return err +} + +// pollUntilTargetLifecyclePresent polls until obtains a target state or receives an unexpected error +func (agent *ecsAgent) pollUntilTargetLifecyclePresent(pollWaitDuration time.Duration, pollMaxTimes int, maxRetries int) (string, error) { + var err error + var targetState string + for i := 0; i < pollMaxTimes; i++ { + targetState, err = agent.getTargetLifecycle(maxRetries) + if targetState != "" || + (err != nil && utils.GetRequestFailureStatusCode(err) != 404) { + break + } + time.Sleep(pollWaitDuration) + } + return targetState, err +} + +// getTargetLifecycle obtains the target lifecycle state for the instance from IMDS. This is populated for instances +// associated with an ASG +func (agent *ecsAgent) getTargetLifecycle(maxRetries int) (string, error) { + var targetState string + var err error + backoff := retry.NewExponentialBackoff(targetLifecycleBackoffMin, targetLifecycleBackoffMax, targetLifecycleBackoffJitter, targetLifecycleBackoffMultiple) + for i := 0; i < maxRetries; i++ { + targetState, err = agent.ec2MetadataClient.TargetLifecycleState() + if err == nil { + break + } + seelog.Debugf("Error when getting intended lifecycle state: %v", err) + if i < maxRetries { + time.Sleep(backoff.Duration()) + } + } + seelog.Debugf("Target lifecycle state of instance: %v", targetState) + return targetState, err +} + // newTaskEngine creates a new docker task engine object. It tries to load the // local state if needed, else initializes a new one func (agent *ecsAgent) newTaskEngine(containerChangeEventStream *eventstream.EventStream, @@ -687,8 +773,6 @@ func (agent *ecsAgent) startAsyncRoutines( go agent.startSpotInstanceDrainingPoller(agent.ctx, client) } - go agent.terminationHandler(state, agent.dataClient, taskEngine, agent.cancel) - // Agent introspection api go handlers.ServeIntrospectionHTTPEndpoint(agent.ctx, &agent.containerInstanceARN, taskEngine, agent.cfg) diff --git a/agent/app/agent_test.go b/agent/app/agent_test.go index e12cecad3b7..79210cb8569 100644 --- a/agent/app/agent_test.go +++ b/agent/app/agent_test.go @@ -24,6 +24,7 @@ import ( "sort" "sync" "testing" + "time" apierrors "github.com/aws/amazon-ecs-agent/agent/api/errors" mock_api "github.com/aws/amazon-ecs-agent/agent/api/mocks" @@ -50,7 +51,6 @@ import ( mock_statemanager "github.com/aws/amazon-ecs-agent/agent/statemanager/mocks" mock_mobypkgwrapper "github.com/aws/amazon-ecs-agent/agent/utils/mobypkgwrapper/mocks" "github.com/aws/amazon-ecs-agent/agent/version" - "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" aws_credentials "github.com/aws/aws-sdk-go/aws/credentials" @@ -60,14 +60,19 @@ import ( ) const ( - clusterName = "some-cluster" - containerInstanceARN = "container-instance1" - availabilityZone = "us-west-2b" - hostPrivateIPv4Address = "127.0.0.1" - hostPublicIPv4Address = "127.0.0.1" - instanceID = "i-123" + clusterName = "some-cluster" + containerInstanceARN = "container-instance1" + availabilityZone = "us-west-2b" + hostPrivateIPv4Address = "127.0.0.1" + hostPublicIPv4Address = "127.0.0.1" + instanceID = "i-123" + warmedState = "Warmed:Running" + testTargetLifecycleMaxRetryCount = 1 ) +var notFoundErr = awserr.NewRequestFailure(awserr.Error(awserr.New("NotFound", "", errors.New(""))), 404, "") +var badReqErr = awserr.NewRequestFailure(awserr.Error(awserr.New("BadRequest", "", errors.New(""))), 400, "") +var serverErr = awserr.NewRequestFailure(awserr.Error(awserr.New("InternalServerError", "", errors.New(""))), 500, "") var apiVersions = []dockerclient.DockerVersion{ dockerclient.Version_1_21, dockerclient.Version_1_22, @@ -235,6 +240,8 @@ func TestDoStartRegisterContainerInstanceErrorTerminal(t *testing.T) { dockerClient: dockerClient, mobyPlugins: mockMobyPlugins, ec2MetadataClient: mockEC2Metadata, + terminationHandler: func(taskEngineState dockerstate.TaskEngineState, dataClient data.Client, taskEngine engine.TaskEngine, cancel context.CancelFunc) { + }, } exitCode := agent.doStart(eventstream.NewEventStream("events", ctx), @@ -279,6 +286,8 @@ func TestDoStartRegisterContainerInstanceErrorNonTerminal(t *testing.T) { credentialProvider: aws_credentials.NewCredentials(mockCredentialsProvider), mobyPlugins: mockMobyPlugins, ec2MetadataClient: mockEC2Metadata, + terminationHandler: func(taskEngineState dockerstate.TaskEngineState, dataClient data.Client, taskEngine engine.TaskEngine, cancel context.CancelFunc) { + }, } exitCode := agent.doStart(eventstream.NewEventStream("events", ctx), @@ -286,7 +295,60 @@ func TestDoStartRegisterContainerInstanceErrorNonTerminal(t *testing.T) { assert.Equal(t, exitcodes.ExitError, exitCode) } +func TestDoStartWarmPoolsError(t *testing.T) { + ctrl, credentialsManager, state, imageManager, client, + dockerClient, _, _, execCmdMgr := setup(t) + defer ctrl.Finish() + mockEC2Metadata := mock_ec2.NewMockEC2MetadataClient(ctrl) + gomock.InOrder( + dockerClient.EXPECT().SupportedVersions().Return(apiVersions), + ) + + cfg := getTestConfig() + cfg.WarmPoolsSupport = config.BooleanDefaultFalse{Value: config.ExplicitlyEnabled} + ctx, cancel := context.WithCancel(context.TODO()) + // Cancel the context to cancel async routines + defer cancel() + terminationHandlerChan := make(chan bool) + terminationHandlerInvoked := false + agent := &ecsAgent{ + ctx: ctx, + cfg: &cfg, + dockerClient: dockerClient, + ec2MetadataClient: mockEC2Metadata, + terminationHandler: func(taskEngineState dockerstate.TaskEngineState, dataClient data.Client, taskEngine engine.TaskEngine, cancel context.CancelFunc) { + terminationHandlerChan <- true + }, + } + + err := errors.New("error") + mockEC2Metadata.EXPECT().TargetLifecycleState().Return("", err).Times(targetLifecycleMaxRetryCount) + + exitCode := agent.doStart(eventstream.NewEventStream("events", ctx), + credentialsManager, state, imageManager, client, execCmdMgr) + + select { + case terminationHandlerInvoked = <-terminationHandlerChan: + case <-time.After(10 * time.Second): + } + assert.Equal(t, exitcodes.ExitTerminal, exitCode) + // verify that termination handler had been started before pollling + assert.True(t, terminationHandlerInvoked) +} + func TestDoStartHappyPath(t *testing.T) { + testDoStartHappyPathWithConditions(t, false, false) +} + +func TestDoStartWarmPoolsEnabled(t *testing.T) { + testDoStartHappyPathWithConditions(t, false, true) +} + +func TestDoStartWarmPoolsBlackholed(t *testing.T) { + testDoStartHappyPathWithConditions(t, true, true) +} + +func testDoStartHappyPathWithConditions(t *testing.T, blackholed bool, warmPoolsEnv bool) { ctrl, credentialsManager, _, imageManager, client, dockerClient, stateManagerFactory, saveableOptionFactory, execCmdMgr := setup(t) defer ctrl.Finish() @@ -299,7 +361,19 @@ func TestDoStartHappyPath(t *testing.T) { ec2MetadataClient.EXPECT().PrivateIPv4Address().Return(hostPrivateIPv4Address, nil) ec2MetadataClient.EXPECT().PublicIPv4Address().Return(hostPublicIPv4Address, nil) ec2MetadataClient.EXPECT().OutpostARN().Return("", nil) - ec2MetadataClient.EXPECT().InstanceID().Return(instanceID, nil) + + if blackholed { + if warmPoolsEnv { + ec2MetadataClient.EXPECT().TargetLifecycleState().Return("", errors.New("blackholed")).Times(targetLifecycleMaxRetryCount) + } + ec2MetadataClient.EXPECT().InstanceID().Return("", errors.New("blackholed")) + } else { + if warmPoolsEnv { + ec2MetadataClient.EXPECT().TargetLifecycleState().Return("", errors.New("error")) + ec2MetadataClient.EXPECT().TargetLifecycleState().Return(inServiceState, nil) + } + ec2MetadataClient.EXPECT().InstanceID().Return(instanceID, nil) + } var discoverEndpointsInvoked sync.WaitGroup discoverEndpointsInvoked.Add(2) @@ -347,6 +421,9 @@ func TestDoStartHappyPath(t *testing.T) { cfg := getTestConfig() cfg.ContainerMetadataEnabled = config.BooleanDefaultFalse{Value: config.ExplicitlyEnabled} cfg.Checkpoint = config.BooleanDefaultFalse{Value: config.ExplicitlyEnabled} + if warmPoolsEnv { + cfg.WarmPoolsSupport = config.BooleanDefaultFalse{Value: config.ExplicitlyEnabled} + } cfg.Cluster = clusterName ctx, cancel := context.WithCancel(context.TODO()) @@ -386,7 +463,9 @@ func TestDoStartHappyPath(t *testing.T) { assertMetadata(t, data.AvailabilityZoneKey, availabilityZone, dataClient) assertMetadata(t, data.ClusterNameKey, clusterName, dataClient) assertMetadata(t, data.ContainerInstanceARNKey, containerInstanceARN, dataClient) - assertMetadata(t, data.EC2InstanceIDKey, instanceID, dataClient) + if !blackholed { + assertMetadata(t, data.EC2InstanceIDKey, instanceID, dataClient) + } } func assertMetadata(t *testing.T, key, expectedVal string, dataClient data.Client) { @@ -1195,6 +1274,8 @@ func TestRegisterContainerInstanceInvalidParameterTerminalError(t *testing.T) { credentialProvider: aws_credentials.NewCredentials(mockCredentialsProvider), dockerClient: dockerClient, mobyPlugins: mockMobyPlugins, + terminationHandler: func(taskEngineState dockerstate.TaskEngineState, dataClient data.Client, taskEngine engine.TaskEngine, cancel context.CancelFunc) { + }, } exitCode := agent.doStart(eventstream.NewEventStream("events", ctx), @@ -1473,3 +1554,45 @@ func newTestDataClient(t *testing.T) (data.Client, func()) { } return testClient, cleanup } + +type targetLifecycleFuncDetail struct { + val string + err error + returnTimes int +} + +func TestWaitUntilInstanceInServicePolling(t *testing.T) { + warmedResult := targetLifecycleFuncDetail{warmedState, nil, 1} + inServiceResult := targetLifecycleFuncDetail{inServiceState, nil, 1} + notFoundErrResult := targetLifecycleFuncDetail{"", notFoundErr, testTargetLifecycleMaxRetryCount} + unexpectedErrResult := targetLifecycleFuncDetail{"", badReqErr, testTargetLifecycleMaxRetryCount} + serverErrResult := targetLifecycleFuncDetail{"", serverErr, testTargetLifecycleMaxRetryCount} + testCases := []struct { + name string + funcTestDetails []targetLifecycleFuncDetail + result error + maxPolls int + }{ + {"TestWaitUntilInServicePollWarmed", []targetLifecycleFuncDetail{warmedResult, warmedResult, inServiceResult}, nil, asgLifecyclePollMax}, + {"TestWaitUntilInServicePollMissing", []targetLifecycleFuncDetail{notFoundErrResult, inServiceResult}, nil, asgLifecyclePollMax}, + {"TestWaitUntilInServiceErrPollMaxReached", []targetLifecycleFuncDetail{notFoundErrResult}, notFoundErr, 1}, + {"TestWaitUntilInServiceNoStateUnexpectedErr", []targetLifecycleFuncDetail{unexpectedErrResult}, badReqErr, asgLifecyclePollMax}, + {"TestWaitUntilInServiceUnexpectedErr", []targetLifecycleFuncDetail{warmedResult, unexpectedErrResult}, badReqErr, asgLifecyclePollMax}, + {"TestWaitUntilInServiceServerErrContinue", []targetLifecycleFuncDetail{warmedResult, serverErrResult, inServiceResult}, nil, asgLifecyclePollMax}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + cfg := getTestConfig() + cfg.WarmPoolsSupport = config.BooleanDefaultFalse{Value: config.ExplicitlyEnabled} + ec2MetadataClient := mock_ec2.NewMockEC2MetadataClient(ctrl) + agent := &ecsAgent{ec2MetadataClient: ec2MetadataClient, cfg: &cfg} + for _, detail := range tc.funcTestDetails { + ec2MetadataClient.EXPECT().TargetLifecycleState().Return(detail.val, detail.err).Times(detail.returnTimes) + } + assert.Equal(t, tc.result, agent.waitUntilInstanceInService(1*time.Millisecond, tc.maxPolls, testTargetLifecycleMaxRetryCount)) + }) + } +} diff --git a/agent/config/config.go b/agent/config/config.go index d9307b8cad3..b7e296df33a 100644 --- a/agent/config/config.go +++ b/agent/config/config.go @@ -592,6 +592,7 @@ func environmentConfig() (Config, error) { External: parseBooleanDefaultFalseConfig("ECS_EXTERNAL"), EnableRuntimeStats: parseBooleanDefaultFalseConfig("ECS_ENABLE_RUNTIME_STATS"), ShouldExcludeIPv6PortBinding: parseBooleanDefaultTrueConfig("ECS_EXCLUDE_IPV6_PORTBINDING"), + WarmPoolsSupport: parseBooleanDefaultFalseConfig("ECS_WARM_POOLS_CHECK"), }, err } diff --git a/agent/config/config_test.go b/agent/config/config_test.go index dd4aac1bb62..c2d14738364 100644 --- a/agent/config/config_test.go +++ b/agent/config/config_test.go @@ -158,6 +158,7 @@ func TestEnvironmentConfig(t *testing.T) { defer setTestEnv("ECS_PULL_DEPENDENT_CONTAINERS_UPFRONT", "true")() defer setTestEnv("ECS_ENABLE_RUNTIME_STATS", "true")() defer setTestEnv("ECS_EXCLUDE_IPV6_PORTBINDING", "true")() + defer setTestEnv("ECS_WARM_POOLS_CHECK", "false")() additionalLocalRoutesJSON := `["1.2.3.4/22","5.6.7.8/32"]` setTestEnv("ECS_AWSVPC_ADDITIONAL_LOCAL_ROUTES", additionalLocalRoutesJSON) setTestEnv("ECS_ENABLE_CONTAINER_METADATA", "true") @@ -216,6 +217,7 @@ func TestEnvironmentConfig(t *testing.T) { assert.True(t, conf.DependentContainersPullUpfront.Enabled(), "Wrong value for DependentContainersPullUpfront") assert.True(t, conf.EnableRuntimeStats.Enabled(), "Wrong value for EnableRuntimeStats") assert.True(t, conf.ShouldExcludeIPv6PortBinding.Enabled(), "Wrong value for ShouldExcludeIPv6PortBinding") + assert.False(t, conf.WarmPoolsSupport.Enabled(), "Wrong value for WarmPoolsSupport") } func TestTrimWhitespaceWhenCreating(t *testing.T) { diff --git a/agent/config/types.go b/agent/config/types.go index d0eafd2f9b1..2e533f41bf7 100644 --- a/agent/config/types.go +++ b/agent/config/types.go @@ -354,4 +354,8 @@ type Config struct { // is set to true by default, and can be overridden by the ECS_EXCLUDE_IPV6_PORTBINDING environment variable. This is a workaround // for docker's bug as detailed in https://github.com/aws/amazon-ecs-agent/issues/2870. ShouldExcludeIPv6PortBinding BooleanDefaultTrue + + // WarmPoolsSupport specifies whether the agent should poll IMDS to check the target lifecycle state for a starting + // instance + WarmPoolsSupport BooleanDefaultFalse } diff --git a/agent/ec2/blackhole_ec2_metadata_client.go b/agent/ec2/blackhole_ec2_metadata_client.go index eb7d33d8758..6e3ed5d5523 100644 --- a/agent/ec2/blackhole_ec2_metadata_client.go +++ b/agent/ec2/blackhole_ec2_metadata_client.go @@ -84,3 +84,7 @@ func (blackholeMetadataClient) SpotInstanceAction() (string, error) { func (blackholeMetadataClient) OutpostARN() (string, error) { return "", errors.New("blackholed") } + +func (blackholeMetadataClient) TargetLifecycleState() (string, error) { + return "", errors.New("blackholed") +} diff --git a/agent/ec2/ec2_metadata_client.go b/agent/ec2/ec2_metadata_client.go index 3ab16928e70..3486a4d8c8c 100644 --- a/agent/ec2/ec2_metadata_client.go +++ b/agent/ec2/ec2_metadata_client.go @@ -40,6 +40,7 @@ const ( PublicIPv4Resource = "public-ipv4" OutpostARN = "outpost-arn" PrimaryIPV4VPCCIDRResourceFormat = "network/interfaces/macs/%s/vpc-ipv4-cidr-block" + TargetLifecycleState = "autoscaling/target-lifecycle-state" ) const ( @@ -82,6 +83,7 @@ type EC2MetadataClient interface { PublicIPv4Address() (string, error) SpotInstanceAction() (string, error) OutpostARN() (string, error) + TargetLifecycleState() (string, error) } type ec2MetadataClientImpl struct { @@ -203,3 +205,7 @@ func (c *ec2MetadataClientImpl) SpotInstanceAction() (string, error) { func (c *ec2MetadataClientImpl) OutpostARN() (string, error) { return c.client.GetMetadata(OutpostARN) } + +func (c *ec2MetadataClientImpl) TargetLifecycleState() (string, error) { + return c.client.GetMetadata(TargetLifecycleState) +} diff --git a/agent/ec2/mocks/ec2_mocks.go b/agent/ec2/mocks/ec2_mocks.go index 8257b54cabf..bf25d6e2be5 100644 --- a/agent/ec2/mocks/ec2_mocks.go +++ b/agent/ec2/mocks/ec2_mocks.go @@ -261,6 +261,21 @@ func (mr *MockEC2MetadataClientMockRecorder) SubnetID(arg0 interface{}) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubnetID", reflect.TypeOf((*MockEC2MetadataClient)(nil).SubnetID), arg0) } +// TargetLifecycleState mocks base method +func (m *MockEC2MetadataClient) TargetLifecycleState() (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TargetLifecycleState") + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// TargetLifecycleState indicates an expected call of TargetLifecycleState +func (mr *MockEC2MetadataClientMockRecorder) TargetLifecycleState() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TargetLifecycleState", reflect.TypeOf((*MockEC2MetadataClient)(nil).TargetLifecycleState)) +} + // VPCID mocks base method func (m *MockEC2MetadataClient) VPCID(arg0 string) (string, error) { m.ctrl.T.Helper() diff --git a/agent/utils/utils.go b/agent/utils/utils.go index 78a51b2293f..341e3bc5dfb 100644 --- a/agent/utils/utils.go +++ b/agent/utils/utils.go @@ -159,6 +159,16 @@ func IsAWSErrorCodeEqual(err error, code string) bool { return ok && awsErr.Code() == code } +// GetRequestFailureStatusCode returns the status code from a +// RequestFailure error, or 0 if the error is not of that type +func GetRequestFailureStatusCode(err error) int { + var statusCode int + if reqErr, ok := err.(awserr.RequestFailure); ok { + statusCode = reqErr.StatusCode() + } + return statusCode +} + // MapToTags converts a map to a slice of tags. func MapToTags(tagsMap map[string]string) []*ecs.Tag { tags := make([]*ecs.Tag, 0) diff --git a/agent/utils/utils_test.go b/agent/utils/utils_test.go index 0b3fd70471f..1e477b6bfd9 100644 --- a/agent/utils/utils_test.go +++ b/agent/utils/utils_test.go @@ -162,6 +162,31 @@ func TestIsAWSErrorCodeEqual(t *testing.T) { } } +func TestGetRequestFailureStatusCode(t *testing.T) { + testcases := []struct { + name string + err error + res int + }{ + { + name: "TestGetRequestFailureStatusCodeSuccess", + err: awserr.NewRequestFailure(awserr.Error(awserr.New("BadRequest", "", errors.New(""))), 400, ""), + res: 400, + }, + { + name: "TestGetRequestFailureStatusCodeWrongErrType", + err: errors.New("err"), + res: 0, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.res, GetRequestFailureStatusCode(tc.err)) + }) + } +} + func TestMapToTags(t *testing.T) { tagKey1 := "tagKey1" tagKey2 := "tagKey2"