Skip to content

Commit

Permalink
Added support for automatic spot instance draining. (#2205)
Browse files Browse the repository at this point in the history
* Add ECS_SPOT_INSTANCE_DRAINING_ENABLED configuration variable (#2180)

* Add ECS_SPOT_INSTANCE_DRAINING_ENABLED configuration variable

* _ENABLED->ENABLE_

* Added support for automatic spot instance draining. (#2182)

* Added Spot termination poller routine

* Added unit tests for ECS client: UpdateContainerInstancesState and GetResourceTags

* Added unit tests ec2 metadata client: SpotTerminationTime

* Added unit tests to agent: isSpotTerminationTimeSet

* code review comment updates

* use assert library for unit tests

* Change termination-time query to instance-action (#2199)

* Change termination-time query to instance-action

* code review fixups

* more code review fixups

* refactor tests to be table-driven
  • Loading branch information
sparrc authored Sep 16, 2019
1 parent e38b51d commit 25efaa6
Show file tree
Hide file tree
Showing 13 changed files with 306 additions and 0 deletions.
10 changes: 10 additions & 0 deletions agent/api/ecsclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -573,3 +573,13 @@ func (client *APIECSClient) GetResourceTags(resourceArn string) ([]*ecs.Tag, err
}
return output.Tags, nil
}

func (client *APIECSClient) UpdateContainerInstancesState(instanceARN string, status string) error {
seelog.Debugf("Invoking UpdateContainerInstancesState, status='%s' instanceARN='%s'", status, instanceARN)
_, err := client.standardClient.UpdateContainerInstancesState(&ecs.UpdateContainerInstancesStateInput{
ContainerInstances: []*string{aws.String(instanceARN)},
Status: aws.String(status),
Cluster: &client.config.Cluster,
})
return err
}
64 changes: 64 additions & 0 deletions agent/api/ecsclient/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,70 @@ func TestDiscoverNilTelemetryEndpoint(t *testing.T) {
}
}

func TestUpdateContainerInstancesState(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
client, mc, _ := NewMockClient(mockCtrl, ec2.NewBlackholeEC2MetadataClient(), nil)

instanceARN := "myInstanceARN"
status := "DRAINING"
mc.EXPECT().UpdateContainerInstancesState(&ecs.UpdateContainerInstancesStateInput{
ContainerInstances: []*string{aws.String(instanceARN)},
Status: aws.String(status),
Cluster: aws.String(configuredCluster),
}).Return(&ecs.UpdateContainerInstancesStateOutput{}, nil)

err := client.UpdateContainerInstancesState(instanceARN, status)
assert.NoError(t, err, fmt.Sprintf("Unexpected error calling UpdateContainerInstancesState: %s", err))
}

func TestUpdateContainerInstancesStateError(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
client, mc, _ := NewMockClient(mockCtrl, ec2.NewBlackholeEC2MetadataClient(), nil)

instanceARN := "myInstanceARN"
status := "DRAINING"
mc.EXPECT().UpdateContainerInstancesState(&ecs.UpdateContainerInstancesStateInput{
ContainerInstances: []*string{aws.String(instanceARN)},
Status: aws.String(status),
Cluster: aws.String(configuredCluster),
}).Return(nil, fmt.Errorf("ERROR"))

err := client.UpdateContainerInstancesState(instanceARN, status)
assert.Error(t, err, "Expected an error calling UpdateContainerInstancesState but got nil")
}

func TestGetResourceTags(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
client, mc, _ := NewMockClient(mockCtrl, ec2.NewBlackholeEC2MetadataClient(), nil)

instanceARN := "myInstanceARN"
mc.EXPECT().ListTagsForResource(&ecs.ListTagsForResourceInput{
ResourceArn: aws.String(instanceARN),
}).Return(&ecs.ListTagsForResourceOutput{
Tags: containerInstanceTags,
}, nil)

_, err := client.GetResourceTags(instanceARN)
assert.NoError(t, err, fmt.Sprintf("Unexpected error calling GetResourceTags: %s", err))
}

func TestGetResourceTagsError(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
client, mc, _ := NewMockClient(mockCtrl, ec2.NewBlackholeEC2MetadataClient(), nil)

instanceARN := "myInstanceARN"
mc.EXPECT().ListTagsForResource(&ecs.ListTagsForResourceInput{
ResourceArn: aws.String(instanceARN),
}).Return(nil, fmt.Errorf("ERROR"))

_, err := client.GetResourceTags(instanceARN)
assert.Error(t, err, "Expected an error calling GetResourceTags but got nil")
}

func TestDiscoverPollEndpointCacheHit(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
Expand Down
4 changes: 4 additions & 0 deletions agent/api/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ type ECSClient interface {
DiscoverTelemetryEndpoint(containerInstanceArn string) (string, error)
// GetResourceTags retrieves the Tags associated with a certain resource
GetResourceTags(resourceArn string) ([]*ecs.Tag, error)
// UpdateContainerInstancesState updates the given container Instance ID with
// the given status. Only valid statuses are ACTIVE and DRAINING.
UpdateContainerInstancesState(instanceARN, status string) error
}

// ECSSDK is an interface that specifies the subset of the AWS Go SDK's ECS
Expand All @@ -55,6 +58,7 @@ type ECSSDK interface {
RegisterContainerInstance(*ecs.RegisterContainerInstanceInput) (*ecs.RegisterContainerInstanceOutput, error)
DiscoverPollEndpoint(*ecs.DiscoverPollEndpointInput) (*ecs.DiscoverPollEndpointOutput, error)
ListTagsForResource(*ecs.ListTagsForResourceInput) (*ecs.ListTagsForResourceOutput, error)
UpdateContainerInstancesState(input *ecs.UpdateContainerInstancesStateInput) (*ecs.UpdateContainerInstancesStateOutput, error)
}

// ECSSubmitStateSDK is an interface with customized ecs client that
Expand Down
29 changes: 29 additions & 0 deletions agent/api/mocks/api_mocks.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

49 changes: 49 additions & 0 deletions agent/app/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ package app

import (
"context"
"encoding/json"
"errors"
"fmt"
"time"

"github.com/aws/amazon-ecs-agent/agent/metrics"

Expand Down Expand Up @@ -578,6 +580,11 @@ func (agent *ecsAgent) startAsyncRoutines(
go imageManager.StartImageCleanupProcess(agent.ctx)
}

// Start automatic spot instance draining poller routine
if agent.cfg.SpotInstanceDrainingEnabled {
go agent.startSpotInstanceDrainingPoller(client)
}

go agent.terminationHandler(stateManager, taskEngine)

// Agent introspection api
Expand Down Expand Up @@ -611,6 +618,48 @@ func (agent *ecsAgent) startAsyncRoutines(
go tcshandler.StartMetricsSession(&telemetrySessionParams)
}

func (agent *ecsAgent) startSpotInstanceDrainingPoller(client api.ECSClient) {
for !agent.spotInstanceDrainingPoller(client) {
time.Sleep(time.Second)
}
}

// spotInstanceDrainingPoller returns true if spot instance interruption has been
// set AND the container instance state is successfully updated to DRAINING.
func (agent *ecsAgent) spotInstanceDrainingPoller(client api.ECSClient) bool {
// this endpoint 404s unless a interruption has been set, so expect failure in most cases.
resp, err := agent.ec2MetadataClient.SpotInstanceAction()
if err == nil {
type InstanceAction struct {
Time string
Action string
}
ia := InstanceAction{}

err := json.Unmarshal([]byte(resp), &ia)
if err != nil {
seelog.Errorf("Invalid response from /spot/instance-action endpoint: %s Error: %s", resp, err)
return false
}

switch ia.Action {
case "hibernate", "terminate", "stop":
default:
seelog.Errorf("Invalid response from /spot/instance-action endpoint: %s, Error: unrecognized action (%s)", resp, ia.Action)
return false
}

seelog.Infof("Received a spot interruption (%s) scheduled for %s, setting state to DRAINING", ia.Action, ia.Time)
err = client.UpdateContainerInstancesState(agent.containerInstanceARN, "DRAINING")
if err != nil {
seelog.Errorf("Error setting instance [ARN: %s] state to DRAINING: %s", agent.containerInstanceARN, err)
} else {
return true
}
}
return false
}

// startACSSession starts a session with ECS's Agent Communication service. This
// is a blocking call and only returns when the handler returns
func (agent *ecsAgent) startACSSession(
Expand Down
81 changes: 81 additions & 0 deletions agent/app/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,87 @@ func TestGetHostPublicIPv4AddressFromEC2MetadataFailWithError(t *testing.T) {
assert.Empty(t, agent.getHostPublicIPv4AddressFromEC2Metadata())
}

func TestSpotInstanceActionCheck_Sunny(t *testing.T) {
tests := []struct {
jsonresp string
}{
{jsonresp: `{"action": "terminate", "time": "2017-09-18T08:22:00Z"}`},
{jsonresp: `{"action": "hibernate", "time": "2017-09-18T08:22:00Z"}`},
{jsonresp: `{"action": "stop", "time": "2017-09-18T08:22:00Z"}`},
}
ctrl := gomock.NewController(t)
defer ctrl.Finish()

ec2MetadataClient := mock_ec2.NewMockEC2MetadataClient(ctrl)
ec2Client := mock_ec2.NewMockClient(ctrl)
ecsClient := mock_api.NewMockECSClient(ctrl)

for _, test := range tests {
myARN := "myARN"
agent := &ecsAgent{
ec2MetadataClient: ec2MetadataClient,
ec2Client: ec2Client,
containerInstanceARN: myARN,
}
ec2MetadataClient.EXPECT().SpotInstanceAction().Return(test.jsonresp, nil)
ecsClient.EXPECT().UpdateContainerInstancesState(myARN, "DRAINING").Return(nil)

assert.True(t, agent.spotInstanceDrainingPoller(ecsClient))
}
}

func TestSpotInstanceActionCheck_Fail(t *testing.T) {
tests := []struct {
jsonresp string
}{
{jsonresp: `{"action": "terminate" "time": "2017-09-18T08:22:00Z"}`}, // invalid json
{jsonresp: ``}, // empty json
{jsonresp: `{"action": "flip!", "time": "2017-09-18T08:22:00Z"}`}, // invalid action
}
ctrl := gomock.NewController(t)
defer ctrl.Finish()

ec2MetadataClient := mock_ec2.NewMockEC2MetadataClient(ctrl)
ec2Client := mock_ec2.NewMockClient(ctrl)
ecsClient := mock_api.NewMockECSClient(ctrl)

for _, test := range tests {
myARN := "myARN"
agent := &ecsAgent{
ec2MetadataClient: ec2MetadataClient,
ec2Client: ec2Client,
containerInstanceARN: myARN,
}
ec2MetadataClient.EXPECT().SpotInstanceAction().Return(test.jsonresp, nil)
// Container state should NOT be updated because the termination time field is empty.
ecsClient.EXPECT().UpdateContainerInstancesState(gomock.Any(), gomock.Any()).Times(0)

assert.False(t, agent.spotInstanceDrainingPoller(ecsClient))
}
}

func TestSpotInstanceActionCheck_NoInstanceActionYet(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

ec2MetadataClient := mock_ec2.NewMockEC2MetadataClient(ctrl)
ec2Client := mock_ec2.NewMockClient(ctrl)
ecsClient := mock_api.NewMockECSClient(ctrl)

myARN := "myARN"
agent := &ecsAgent{
ec2MetadataClient: ec2MetadataClient,
ec2Client: ec2Client,
containerInstanceARN: myARN,
}
ec2MetadataClient.EXPECT().SpotInstanceAction().Return("", fmt.Errorf("404"))

// Container state should NOT be updated because there is no termination time.
ecsClient.EXPECT().UpdateContainerInstancesState(gomock.Any(), gomock.Any()).Times(0)

assert.False(t, agent.spotInstanceDrainingPoller(ecsClient))
}

func getTestConfig() config.Config {
cfg := config.DefaultConfig()
cfg.TaskCPUMemLimit = config.ExplicitlyDisabled
Expand Down
1 change: 1 addition & 0 deletions agent/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,7 @@ func environmentConfig() (Config, error) {
NvidiaRuntime: os.Getenv("ECS_NVIDIA_RUNTIME"),
TaskMetadataAZDisabled: utils.ParseBool(os.Getenv("ECS_DISABLE_TASK_METADATA_AZ"), false),
CgroupCPUPeriod: parseCgroupCPUPeriod(),
SpotInstanceDrainingEnabled: utils.ParseBool(os.Getenv("ECS_ENABLE_SPOT_INSTANCE_DRAINING"), false),
}, err
}

Expand Down
3 changes: 3 additions & 0 deletions agent/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ func TestEnvironmentConfig(t *testing.T) {
assert.Equal(t, "nvidia", conf.NvidiaRuntime)
assert.True(t, conf.TaskMetadataAZDisabled, "Wrong value for TaskMetadataAZDisabled")
assert.Equal(t, 10*time.Millisecond, conf.CgroupCPUPeriod)
assert.False(t, conf.SpotInstanceDrainingEnabled)
}

func TestTrimWhitespaceWhenCreating(t *testing.T) {
Expand Down Expand Up @@ -194,10 +195,12 @@ func TestConfigBoolean(t *testing.T) {
defer setTestRegion()()
defer setTestEnv("ECS_DISABLE_DOCKER_HEALTH_CHECK", "true")()
defer setTestEnv("ECS_DISABLE_METRICS", "true")()
defer setTestEnv("ECS_ENABLE_SPOT_INSTANCE_DRAINING", "true")()
cfg, err := NewConfig(ec2.NewBlackholeEC2MetadataClient())
assert.NoError(t, err)
assert.True(t, cfg.DisableMetrics)
assert.True(t, cfg.DisableDockerHealthCheck)
assert.True(t, cfg.SpotInstanceDrainingEnabled)
}

func TestBadLoggingDriverSerialization(t *testing.T) {
Expand Down
8 changes: 8 additions & 0 deletions agent/config/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,4 +294,12 @@ type Config struct {

// CgroupCPUPeriod is config option to set different CFS quota and period values in microsecond, defaults to 100 ms
CgroupCPUPeriod time.Duration

// SpotInstanceDrainingEnabled, if true, agent will poll the container instance's metadata endpoint for an ec2 spot
// instance termination notice. If EC2 sends a spot termination notice, then agent will set the instance's state
// to DRAINING, which gracefully shuts down all running tasks on the instance.
// If the instance is not spot then the poller will still run but it will never receive a termination notice.
// Defaults to false.
// see https://docs.aws.amazon.com/AmazonECS/latest/developerguide/container-instance-draining.html
SpotInstanceDrainingEnabled bool
}
4 changes: 4 additions & 0 deletions agent/ec2/blackhole_ec2_metadata_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,7 @@ func (blackholeMetadataClient) PrivateIPv4Address() (string, error) {
func (blackholeMetadataClient) PublicIPv4Address() (string, error) {
return "", errors.New("blackholed")
}

func (blackholeMetadataClient) SpotInstanceAction() (string, error) {
return "", errors.New("blackholed")
}
10 changes: 10 additions & 0 deletions agent/ec2/ec2_metadata_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ const (
AllMacResource = "network/interfaces/macs"
VPCIDResourceFormat = "network/interfaces/macs/%s/vpc-id"
SubnetIDResourceFormat = "network/interfaces/macs/%s/subnet-id"
SpotInstanceActionResource = "spot/instance-action"
InstanceIDResource = "instance-id"
PrivateIPv4Resource = "local-ipv4"
PublicIPv4Resource = "public-ipv4"
Expand Down Expand Up @@ -76,6 +77,7 @@ type EC2MetadataClient interface {
Region() (string, error)
PrivateIPv4Address() (string, error)
PublicIPv4Address() (string, error)
SpotInstanceAction() (string, error)
}

type ec2MetadataClientImpl struct {
Expand Down Expand Up @@ -184,3 +186,11 @@ func (c *ec2MetadataClientImpl) PublicIPv4Address() (string, error) {
func (c *ec2MetadataClientImpl) PrivateIPv4Address() (string, error) {
return c.client.GetMetadata(PrivateIPv4Resource)
}

// SpotInstanceAction returns the spot instance-action, if it has been set.
// If the time has not been set (ie, the instance is not scheduled for interruption)
// then this function returns an error.
// see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/spot-interruptions.html#using-spot-instances-managing-interruptions
func (c *ec2MetadataClientImpl) SpotInstanceAction() (string, error) {
return c.client.GetMetadata(SpotInstanceActionResource)
}
Loading

0 comments on commit 25efaa6

Please sign in to comment.