Skip to content

Commit

Permalink
Fixing broke tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sparrc committed Dec 21, 2019
1 parent 1a0c755 commit 2c8f46c
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 53 deletions.
96 changes: 53 additions & 43 deletions agent/acs/handler/payload_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ package handler

import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
"runtime"
"sync"
"testing"

Expand All @@ -33,11 +34,11 @@ import (
"github.com/aws/amazon-ecs-agent/agent/eventhandler"
"github.com/aws/amazon-ecs-agent/agent/statemanager"
mock_statemanager "github.com/aws/amazon-ecs-agent/agent/statemanager/mocks"
"github.com/aws/amazon-ecs-agent/agent/taskresource"
mock_wsclient "github.com/aws/amazon-ecs-agent/agent/wsclient/mock"
"github.com/aws/aws-sdk-go/aws"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

const (
Expand Down Expand Up @@ -147,13 +148,7 @@ func TestHandlePayloadMessageStateSaveError(t *testing.T) {
})
assert.Error(t, err, "Expected error while adding a task from statemanager")

// We expect task to be added to the engine even though it hasn't been saved
expectedTask := &apitask.Task{
Arn: "t1",
ResourcesMapUnsafe: make(map[string][]taskresource.TaskResource),
}

assert.Equal(t, addedTask, expectedTask, "added task is not expected")
validateTask(t, addedTask, "t1")
}

// TestHandlePayloadMessageAckedWhenTaskAdded tests if the handler generates an ack
Expand Down Expand Up @@ -194,12 +189,7 @@ func TestHandlePayloadMessageAckedWhenTaskAdded(t *testing.T) {
// Verify the message id acked
assert.Equal(t, aws.StringValue(ackRequested.MessageId), payloadMessageId, "received message is not expected")

// Verify if task added == expected task
expectedTask := &apitask.Task{
Arn: "t1",
ResourcesMapUnsafe: make(map[string][]taskresource.TaskResource),
}
assert.Equal(t, addedTask, expectedTask, "received task is not expected")
validateTask(t, addedTask, "t1")
}

// TestHandlePayloadMessageCredentialsAckedWhenTaskAdded tests if the handler generates
Expand Down Expand Up @@ -290,8 +280,7 @@ func TestHandlePayloadMessageCredentialsAckedWhenTaskAdded(t *testing.T) {
SessionToken: credentialsSessionToken,
CredentialsID: credentialsId,
}
err = validateTaskAndCredentials(taskCredentialsAckRequested, expectedCredentialsAck, addedTask, taskArn, expectedCredentials)
assert.NoError(t, err, "error validating added task or credentials ack for the same")
validateTaskAndCredentials(t, taskCredentialsAckRequested, expectedCredentialsAck, addedTask, taskArn, expectedCredentials, "t1")
}

// TestAddPayloadTaskAddsNonStoppedTasksAfterStoppedTasks tests if tasks with desired status
Expand Down Expand Up @@ -371,12 +360,7 @@ func TestPayloadBufferHandler(t *testing.T) {
// Verify if payloadMessageId read from the ack buffer is correct
assert.Equal(t, aws.StringValue(ackRequested.MessageId), payloadMessageId, "received task is not expected")

// Verify if the task added to the engine is correct
expectedTask := &apitask.Task{
Arn: taskArn,
ResourcesMapUnsafe: make(map[string][]taskresource.TaskResource),
}
assert.Equal(t, addedTask, expectedTask, "received task is not expected")
validateTask(t, addedTask, "t1")
}

// TestPayloadBufferHandlerWithCredentials tests if the async payloadBufferHandler routine
Expand Down Expand Up @@ -495,8 +479,7 @@ func TestPayloadBufferHandlerWithCredentials(t *testing.T) {
SessionToken: firstTaskCredentialsSessionToken,
CredentialsID: firstTaskCredentialsId,
}
err := validateTaskAndCredentials(firstTaskCredentialsAckRequested, expectedCredentialsAckForFirstTask, firstAddedTask, firstTaskArn, expectedCredentialsForFirstTask)
assert.NoError(t, err, "error validating added task or credentials ack for the same")
validateTaskAndCredentials(t, firstTaskCredentialsAckRequested, expectedCredentialsAckForFirstTask, firstAddedTask, firstTaskArn, expectedCredentialsForFirstTask, "t1")

// Verify the correctness of the second task added to the engine and the
// credentials ack generated for it
Expand All @@ -513,8 +496,7 @@ func TestPayloadBufferHandlerWithCredentials(t *testing.T) {
SessionToken: secondTaskCredentialsSessionToken,
CredentialsID: secondTaskCredentialsId,
}
err = validateTaskAndCredentials(secondTaskCredentialsAckRequested, expectedCredentialsAckForSecondTask, secondAddedTask, secondTaskArn, expectedCredentialsForSecondTask)
assert.NoError(t, err, "error validating added task or credentials ack for the same")
validateTaskAndCredentials(t, secondTaskCredentialsAckRequested, expectedCredentialsAckForSecondTask, secondAddedTask, secondTaskArn, expectedCredentialsForSecondTask, "t2")
}

// TestAddPayloadTaskAddsExecutionRoles tests the payload handler will add
Expand Down Expand Up @@ -596,24 +578,18 @@ func TestAddPayloadTaskAddsExecutionRoles(t *testing.T) {
// validateTaskAndCredentials compares a task and a credentials ack object
// against expected values. It returns an error if either of the the
// comparisons fail
func validateTaskAndCredentials(taskCredentialsAck, expectedCredentialsAckForTask *ecsacs.IAMRoleCredentialsAckRequest,
func validateTaskAndCredentials(
t *testing.T,
taskCredentialsAck *ecsacs.IAMRoleCredentialsAckRequest,
expectedCredentialsAckForTask *ecsacs.IAMRoleCredentialsAckRequest,
addedTask *apitask.Task,
expectedTaskArn string,
expectedTaskCredentials credentials.IAMRoleCredentials) error {
if !reflect.DeepEqual(taskCredentialsAck, expectedCredentialsAckForTask) {
return fmt.Errorf("Mismatch between expected and received credentials ack requests, expected: %s, got: %s", expectedCredentialsAckForTask.String(), taskCredentialsAck.String())
}

expectedTask := &apitask.Task{
Arn: expectedTaskArn,
ResourcesMapUnsafe: make(map[string][]taskresource.TaskResource),
}
expectedTask.SetCredentialsID(expectedTaskCredentials.CredentialsID)

if !reflect.DeepEqual(addedTask, expectedTask) {
return fmt.Errorf("Mismatch between expected and added tasks, expected: %v, added: %v", expectedTask, addedTask)
}
return nil
expectedTaskCredentials credentials.IAMRoleCredentials,
taskName string,
) {
require.Equal(t, expectedCredentialsAckForTask, taskCredentialsAck)
require.Equal(t, expectedTaskCredentials.CredentialsID, addedTask.GetCredentialsID())
validateTask(t, addedTask, taskName)
}

func TestPayloadHandlerAddedENIToTask(t *testing.T) {
Expand Down Expand Up @@ -949,3 +925,37 @@ func TestPayloadHandlerAddedFirelensData(t *testing.T) {
assert.NotNil(t, actual.Options)
assert.Equal(t, aws.StringValue(expected.Options["enable-ecs-log-metadata"]), actual.Options["enable-ecs-log-metadata"])
}

func validateTask(t *testing.T, addedTask *apitask.Task, expectedTaskName string) {
// We expect task to be added to the engine even though it hasn't been saved
addedTaskJSON, err := json.Marshal(addedTask)
require.NoError(t, err)
platformFields := "{}"
if runtime.GOOS == "windows" {
platformFields = `{"cpuUnbounded": false, "memoryUnbounded": false}`
}
expectedTaskJSON := fmt.Sprintf(`
{
"Arn": "%s",
"Family": "",
"Version": "",
"Containers": null,
"associations": null,
"resources": {},
"volumes": null,
"DesiredStatus": "NONE",
"KnownStatus": "NONE",
"KnownTime": "0001-01-01T00:00:00Z",
"PullStartedAt": "0001-01-01T00:00:00Z",
"PullStoppedAt": "0001-01-01T00:00:00Z",
"ExecutionStoppedAt": "0001-01-01T00:00:00Z",
"SentStatus": "NONE",
"StartSequenceNumber": 0,
"StopSequenceNumber": 0,
"executionCredentialsID": "",
"ENI": null,
"AppMesh": null,
"PlatformFields": %s
}`, expectedTaskName, platformFields)
require.JSONEq(t, expectedTaskJSON, string(addedTaskJSON))
}
1 change: 1 addition & 0 deletions agent/api/task/task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1277,6 +1277,7 @@ func TestTaskFromACS(t *testing.T) {

seqNum := int64(42)
task, err := TaskFromACS(&taskFromAcs, &ecsacs.PayloadMessage{SeqNum: &seqNum})
expectedTask.log = task.log

assert.NoError(t, err)
assert.EqualValues(t, expectedTask, task)
Expand Down
7 changes: 4 additions & 3 deletions agent/logger/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ var Config *logConfig
func logfmtFormatter(params string) seelog.FormatterFunc {
return func(message string, level seelog.LogLevel, context seelog.LogContextInterface) interface{} {
cc, ok := context.CustomContext().(map[string]string)
if !ok {
if !ok || len(cc) == 0 {
cc = map[string]string{}
}
if _, ok = cc["module"]; !ok {
Expand All @@ -77,7 +77,7 @@ func logfmtFormatter(params string) seelog.FormatterFunc {
func jsonFormatter(params string) seelog.FormatterFunc {
return func(message string, level seelog.LogLevel, context seelog.LogContextInterface) interface{} {
cc, ok := context.CustomContext().(map[string]string)
if !ok {
if !ok || len(cc) == 0 {
cc = map[string]string{}
}
if _, ok = cc["module"]; !ok {
Expand All @@ -94,7 +94,7 @@ func jsonFormatter(params string) seelog.FormatterFunc {

func seelogConfig() string {
c := `
<seelog type="asyncloop" minlevel="` + Config.level + `">
<seelog type="sync" minlevel="` + Config.level + `">
<outputs formatid="` + Config.outputFormat + `">
<console />`
c += platformLogConfig()
Expand All @@ -114,6 +114,7 @@ func seelogConfig() string {
<formats>
<format id="logfmt" format="%EcsAgentLogfmt" />
<format id="json" format="%EcsAgentJson" />
<format id="windows" format="%Msg" />
</formats>
</seelog>`
return c
Expand Down
20 changes: 13 additions & 7 deletions agent/logger/log_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// +build !windows
// +build !windows,unit

// Copyright 2014-2015 Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
Expand Down Expand Up @@ -115,7 +115,7 @@ func TestSeelogConfig_Default(t *testing.T) {
}
c := seelogConfig()
require.Equal(t, `
<seelog type="asyncloop" minlevel="info">
<seelog type="sync" minlevel="info">
<outputs formatid="logfmt">
<console />
<rollingfile filename="foo.log" type="date"
Expand All @@ -124,6 +124,7 @@ func TestSeelogConfig_Default(t *testing.T) {
<formats>
<format id="logfmt" format="%EcsAgentLogfmt" />
<format id="json" format="%EcsAgentJson" />
<format id="windows" format="%Msg" />
</formats>
</seelog>`, c)
}
Expand All @@ -139,7 +140,7 @@ func TestSeelogConfig_DebugLevel(t *testing.T) {
}
c := seelogConfig()
require.Equal(t, `
<seelog type="asyncloop" minlevel="debug">
<seelog type="sync" minlevel="debug">
<outputs formatid="logfmt">
<console />
<rollingfile filename="foo.log" type="date"
Expand All @@ -148,6 +149,7 @@ func TestSeelogConfig_DebugLevel(t *testing.T) {
<formats>
<format id="logfmt" format="%EcsAgentLogfmt" />
<format id="json" format="%EcsAgentJson" />
<format id="windows" format="%Msg" />
</formats>
</seelog>`, c)
}
Expand All @@ -163,7 +165,7 @@ func TestSeelogConfig_SizeRollover(t *testing.T) {
}
c := seelogConfig()
require.Equal(t, `
<seelog type="asyncloop" minlevel="info">
<seelog type="sync" minlevel="info">
<outputs formatid="logfmt">
<console />
<rollingfile filename="foo.log" type="size"
Expand All @@ -172,6 +174,7 @@ func TestSeelogConfig_SizeRollover(t *testing.T) {
<formats>
<format id="logfmt" format="%EcsAgentLogfmt" />
<format id="json" format="%EcsAgentJson" />
<format id="windows" format="%Msg" />
</formats>
</seelog>`, c)
}
Expand All @@ -187,7 +190,7 @@ func TestSeelogConfig_SizeRolloverFileSizeChange(t *testing.T) {
}
c := seelogConfig()
require.Equal(t, `
<seelog type="asyncloop" minlevel="info">
<seelog type="sync" minlevel="info">
<outputs formatid="logfmt">
<console />
<rollingfile filename="foo.log" type="size"
Expand All @@ -196,6 +199,7 @@ func TestSeelogConfig_SizeRolloverFileSizeChange(t *testing.T) {
<formats>
<format id="logfmt" format="%EcsAgentLogfmt" />
<format id="json" format="%EcsAgentJson" />
<format id="windows" format="%Msg" />
</formats>
</seelog>`, c)
}
Expand All @@ -211,7 +215,7 @@ func TestSeelogConfig_SizeRolloverRollCountChange(t *testing.T) {
}
c := seelogConfig()
require.Equal(t, `
<seelog type="asyncloop" minlevel="info">
<seelog type="sync" minlevel="info">
<outputs formatid="logfmt">
<console />
<rollingfile filename="foo.log" type="size"
Expand All @@ -220,6 +224,7 @@ func TestSeelogConfig_SizeRolloverRollCountChange(t *testing.T) {
<formats>
<format id="logfmt" format="%EcsAgentLogfmt" />
<format id="json" format="%EcsAgentJson" />
<format id="windows" format="%Msg" />
</formats>
</seelog>`, c)
}
Expand All @@ -235,7 +240,7 @@ func TestSeelogConfig_JSONOutput(t *testing.T) {
}
c := seelogConfig()
require.Equal(t, `
<seelog type="asyncloop" minlevel="info">
<seelog type="sync" minlevel="info">
<outputs formatid="json">
<console />
<rollingfile filename="foo.log" type="date"
Expand All @@ -244,6 +249,7 @@ func TestSeelogConfig_JSONOutput(t *testing.T) {
<formats>
<format id="logfmt" format="%EcsAgentLogfmt" />
<format id="json" format="%EcsAgentJson" />
<format id="windows" format="%Msg" />
</formats>
</seelog>`, c)
}
Expand Down

0 comments on commit 2c8f46c

Please sign in to comment.