diff --git a/ecs-agent/api/container/restart/restart_tracker.go b/ecs-agent/api/container/restart/restart_tracker.go new file mode 100644 index 00000000000..ba6d9ed7697 --- /dev/null +++ b/ecs-agent/api/container/restart/restart_tracker.go @@ -0,0 +1,71 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package restart + +import ( + "fmt" + "time" + + api "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" + apicontainerstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/container/status" + + "github.com/aws/aws-sdk-go/aws" +) + +type RestartTracker struct { + RestartCount int `json:"restartCount,omitempty"` + restartPolicy *api.ContainerRestartPolicy +} + +func NewRestartTracker(restartPolicy *api.ContainerRestartPolicy) *RestartTracker { + return &RestartTracker{ + restartPolicy: restartPolicy, + } +} + +func (rt *RestartTracker) GetRestartCount() int { + return rt.RestartCount +} + +// RecordRestart updates the restart tracker's metadata after a restart has occurred. +// This metadata is used to calculate when restarts should occur and track how many +// have occurred. It is not the job of this method to determine if a restart should +// occur or restart the container. It is expected to receive a startedAt time from the container runtime. +func (rt *RestartTracker) RecordRestart() { + rt.RestartCount += 1 +} + +// ShouldRestart returns whether the container should restart and a reason string +// explaining why not. +func (rt *RestartTracker) ShouldRestart(exitCode *int64, startedAt time.Time, + desiredStatus apicontainerstatus.ContainerStatus) (bool, string) { + if !*rt.restartPolicy.Enabled { + return false, "restart policy is not enabled" + } + if desiredStatus == apicontainerstatus.ContainerStopped { + return false, "container's desired status is stopped" + } + if exitCode == nil { + return false, "exit code is nil" + } + for _, ignoredCode := range rt.restartPolicy.IgnoredExitCodes { + if aws.Int64Value(ignoredCode) == aws.Int64Value(exitCode) { + return false, fmt.Sprintf("exit code %d should be ignored", *exitCode) + } + } + if time.Since(startedAt) < time.Duration(*rt.restartPolicy.RestartAttemptPeriod*int64(time.Second)) { + return false, "attempt reset period has not elapsed" + } + return true, "" +} diff --git a/ecs-agent/api/container/restart/restart_tracker_test.go b/ecs-agent/api/container/restart/restart_tracker_test.go new file mode 100644 index 00000000000..d7ba13337a0 --- /dev/null +++ b/ecs-agent/api/container/restart/restart_tracker_test.go @@ -0,0 +1,158 @@ +//go:build unit +// +build unit + +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package restart + +import ( + "testing" + "time" + + api "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" + apicontainerstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/container/status" + + "github.com/aws/aws-sdk-go/aws" + "github.com/stretchr/testify/assert" +) + +func TestShouldRestart(t *testing.T) { + ignoredCode := aws.Int64(0) + rt := NewRestartTracker(&api.ContainerRestartPolicy{ + Enabled: aws.Bool(false), + IgnoredExitCodes: []*int64{ignoredCode}, + RestartAttemptPeriod: aws.Int64(60), + }) + testCases := []struct { + name string + rp api.ContainerRestartPolicy + exitCode int64 + startedAt time.Time + desiredStatus apicontainerstatus.ContainerStatus + expected bool + expectedReason string + }{ + { + name: "restart policy disabled", + rp: api.ContainerRestartPolicy{ + Enabled: aws.Bool(false), + IgnoredExitCodes: []*int64{ignoredCode}, + RestartAttemptPeriod: aws.Int64(60), + }, + exitCode: 1, + startedAt: time.Now().Add(2 * time.Minute), + desiredStatus: apicontainerstatus.ContainerRunning, + expected: false, + expectedReason: "restart policy is not enabled", + }, + { + name: "ignored exit code", + rp: api.ContainerRestartPolicy{ + Enabled: aws.Bool(true), + IgnoredExitCodes: []*int64{ignoredCode}, + RestartAttemptPeriod: aws.Int64(60), + }, + exitCode: 0, + startedAt: time.Now().Add(-2 * time.Minute), + desiredStatus: apicontainerstatus.ContainerRunning, + expected: false, + expectedReason: "exit code 0 should be ignored", + }, + { + name: "non ignored exit code", + rp: api.ContainerRestartPolicy{Enabled: aws.Bool(true), IgnoredExitCodes: []*int64{ignoredCode}, RestartAttemptPeriod: aws.Int64(60)}, + exitCode: 1, + startedAt: time.Now().Add(-2 * time.Minute), + desiredStatus: apicontainerstatus.ContainerRunning, + expected: true, + expectedReason: "", + }, + { + name: "nil exit code", + rp: api.ContainerRestartPolicy{Enabled: aws.Bool(true), IgnoredExitCodes: []*int64{ignoredCode}, RestartAttemptPeriod: aws.Int64(60)}, + exitCode: -1, + startedAt: time.Now().Add(-2 * time.Minute), + desiredStatus: apicontainerstatus.ContainerRunning, + expected: false, + expectedReason: "exit code is nil", + }, + { + name: "desired status stopped", + rp: api.ContainerRestartPolicy{Enabled: aws.Bool(true), IgnoredExitCodes: []*int64{ignoredCode}, RestartAttemptPeriod: aws.Int64(60)}, + exitCode: 1, + startedAt: time.Now().Add(2 * time.Minute), + desiredStatus: apicontainerstatus.ContainerStopped, + expected: false, + expectedReason: "container's desired status is stopped", + }, + { + name: "attempt reset period not elapsed", + rp: api.ContainerRestartPolicy{Enabled: aws.Bool(true), IgnoredExitCodes: []*int64{ignoredCode}, RestartAttemptPeriod: aws.Int64(60)}, + exitCode: 1, + startedAt: time.Now(), + desiredStatus: apicontainerstatus.ContainerRunning, + expected: false, + expectedReason: "attempt reset period has not elapsed", + }, + { + name: "attempt reset period not elapsed within one second", + rp: api.ContainerRestartPolicy{Enabled: aws.Bool(true), IgnoredExitCodes: []*int64{ignoredCode}, RestartAttemptPeriod: aws.Int64(60)}, + exitCode: 1, + startedAt: time.Now().Add(-time.Second * 59), + desiredStatus: apicontainerstatus.ContainerRunning, + expected: false, + expectedReason: "attempt reset period has not elapsed", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rt.restartPolicy = &tc.rp + + // Because we cannot instantiate int pointers directly, + // check for the exit code and leave this int pointer as nil + // if there is no value to override it. + var exitCodeAdjusted *int64 + if tc.exitCode != -1 { + exitCodeAdjusted = &tc.exitCode + } + + shouldRestart, reason := rt.ShouldRestart(exitCodeAdjusted, tc.startedAt, tc.desiredStatus) + assert.Equal(t, tc.expected, shouldRestart) + assert.Equal(t, tc.expectedReason, reason) + }) + } +} + +func TestRecordRestart(t *testing.T) { + rt := NewRestartTracker(&api.ContainerRestartPolicy{ + Enabled: aws.Bool(false), + RestartAttemptPeriod: aws.Int64(60), + }) + assert.Equal(t, 0, rt.RestartCount) + for i := 1; i < 1000; i++ { + rt.RecordRestart() + assert.Equal(t, i, rt.RestartCount) + } +} + +func TestRecordRestartPolicy(t *testing.T) { + rt := NewRestartTracker(&api.ContainerRestartPolicy{ + Enabled: aws.Bool(false), + RestartAttemptPeriod: aws.Int64(60), + }) + assert.Equal(t, 0, rt.RestartCount) + assert.Equal(t, 0, len(rt.restartPolicy.IgnoredExitCodes)) + assert.NotNil(t, rt.restartPolicy) +}