diff --git a/agent/app/agent.go b/agent/app/agent.go index fa6d73c1c63..6b7cbd7ff1b 100644 --- a/agent/app/agent.go +++ b/agent/app/agent.go @@ -18,6 +18,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/aws/amazon-ecs-agent/agent/stats/reporter" "time" acshandler "github.com/aws/amazon-ecs-agent/agent/acs/handler" @@ -48,7 +49,6 @@ import ( "github.com/aws/amazon-ecs-agent/agent/statemanager" "github.com/aws/amazon-ecs-agent/agent/stats" "github.com/aws/amazon-ecs-agent/agent/taskresource" - tcshandler "github.com/aws/amazon-ecs-agent/agent/tcs/handler" "github.com/aws/amazon-ecs-agent/agent/utils" "github.com/aws/amazon-ecs-agent/agent/utils/loader" "github.com/aws/amazon-ecs-agent/agent/utils/mobypkgwrapper" @@ -871,21 +871,13 @@ func (agent *ecsAgent) startAsyncRoutines( } go statsEngine.StartMetricsPublish() - telemetrySessionParams := tcshandler.TelemetrySessionParams{ - Ctx: agent.ctx, - CredentialProvider: agent.credentialProvider, - Cfg: agent.cfg, - ContainerInstanceArn: agent.containerInstanceARN, - DeregisterInstanceEventStream: deregisterInstanceEventStream, - ECSClient: client, - TaskEngine: taskEngine, - StatsEngine: statsEngine, - MetricsChannel: telemetryMessages, - HealthChannel: healthMessages, - Doctor: doctor, - } - // Start metrics session in a go routine - go tcshandler.StartMetricsSession(&telemetrySessionParams) + session := reporter.NewDockerTelemetrySession(agent.containerInstanceARN, agent.credentialProvider, agent.cfg, deregisterInstanceEventStream, + client, taskEngine, telemetryMessages, healthMessages, doctor) + if session == nil { + seelog. + } + + go session.Start(agent.ctx) } func (agent *ecsAgent) startSpotInstanceDrainingPoller(ctx context.Context, client api.ECSClient) { diff --git a/agent/stats/reporter/reporter.go b/agent/stats/reporter/reporter.go index 8cee4359d85..fe503f08992 100644 --- a/agent/stats/reporter/reporter.go +++ b/agent/stats/reporter/reporter.go @@ -48,22 +48,18 @@ func NewDockerTelemetrySession( taskEngine engine.TaskEngine, metricsChannel <-chan ecstcs.TelemetryMessage, healthChannel <-chan ecstcs.HealthMessage, - doctor *doctor.Doctor) *DockerTelemetrySession { + doctor *doctor.Doctor) (*DockerTelemetrySession, error) { ok, cfgParseErr := isContainerHealthMetricsDisabled(cfg) if cfgParseErr != nil { seelog.Warnf("Error starting metrics session: %v", cfgParseErr) - return nil + return nil, cfgParseErr } if ok { seelog.Warnf("Metrics were disabled, not starting the telemetry session") - return nil + return nil, nil } agentVersion, agentHash, containerRuntimeVersion := generateVersionInfo(taskEngine) - if cfg == nil { - logger.Error("Config is empty in the tcs session parameter") - return nil - } session := tcshandler.NewTelemetrySession( containerInstanceArn, @@ -90,7 +86,7 @@ func NewDockerTelemetrySession( healthChannel, doctor, ) - return &DockerTelemetrySession{session, ecsClient, containerInstanceArn} + return &DockerTelemetrySession{session, ecsClient, containerInstanceArn}, nil } // Start "overloads" tcshandler.TelemetrySession's Start with extra handling of discoverTelemetryEndpoint result. diff --git a/agent/stats/reporter/reporter_test.go b/agent/stats/reporter/reporter_test.go new file mode 100644 index 00000000000..e28dfc6555a --- /dev/null +++ b/agent/stats/reporter/reporter_test.go @@ -0,0 +1,125 @@ +package reporter + +import ( + "context" + "errors" + "testing" + + "github.com/aws/amazon-ecs-agent/agent/config" + mock_engine "github.com/aws/amazon-ecs-agent/agent/engine/mocks" + "github.com/aws/amazon-ecs-agent/agent/version" + "github.com/aws/amazon-ecs-agent/ecs-agent/doctor" + "github.com/aws/amazon-ecs-agent/ecs-agent/eventstream" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" +) + +const ( + testContainerInstanceArn = "testContainerInstanceArn" + testCluster = "testCluster" + testRegion = "us-west-2" + testDockerEndpoint = "testDockerEndpoint" + testDockerVersion = "testDockerVersion" +) + +func TestNewDockerTelemetrySession(t *testing.T) { + emptyDoctor, _ := doctor.NewDoctor([]doctor.Healthcheck{}, testCluster, testContainerInstanceArn) + testCredentials := credentials.NewStaticCredentials("test-id", "test-secret", "test-token") + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockEngine := mock_engine.NewMockTaskEngine(ctrl) + mockEngine.EXPECT().Version().Return(testDockerVersion, nil) + testCases := []struct { + name string + cfg *config.Config + expectedSession bool + expectedError bool + }{ + { + name: "happy case", + cfg: &config.Config{ + DisableMetrics: config.BooleanDefaultFalse{}, + DisableDockerHealthCheck: config.BooleanDefaultFalse{}, + Cluster: testCluster, + AWSRegion: testRegion, + AcceptInsecureCert: false, + DockerEndpoint: testDockerEndpoint, + }, + expectedSession: true, + expectedError: false, + }, + { + name: "cfg parsing error", + cfg: nil, + expectedSession: false, + expectedError: true, + }, + { + name: "metrics disabled", + cfg: &config.Config{ + DisableMetrics: config.BooleanDefaultFalse{ + Value: config.ExplicitlyEnabled, + }, + DisableDockerHealthCheck: config.BooleanDefaultFalse{ + Value: config.ExplicitlyEnabled, + }, + Cluster: testCluster, + AWSRegion: testRegion, + AcceptInsecureCert: false, + DockerEndpoint: testDockerEndpoint, + }, + expectedSession: false, + expectedError: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + dockerTelemetrySession, err := NewDockerTelemetrySession( + testContainerInstanceArn, + testCredentials, + tc.cfg, + eventstream.NewEventStream("Deregister_Instance", context.Background()), + nil, + mockEngine, + nil, + nil, + emptyDoctor, + ) + if tc.expectedSession { + assert.NotNil(t, dockerTelemetrySession) + } else { + assert.Nil(t, dockerTelemetrySession) + } + + if tc.expectedError { + assert.NotNil(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestGenerateVersionInfo_GetVersionError(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockEngine := mock_engine.NewMockTaskEngine(ctrl) + mockEngine.EXPECT().Version().Times(1).Return(nil, errors.New("error")) + agentVersion, agentHash, containerRuntimeVersion := generateVersionInfo(mockEngine) + assert.Equal(t, version.Version, agentVersion) + assert.Equal(t, version.GitShortHash, agentHash) + assert.Equal(t, "", containerRuntimeVersion) +} + +func TestGenerateVersionInfo_NoError(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockEngine := mock_engine.NewMockTaskEngine(ctrl) + mockEngine.EXPECT().Version().Times(1).Return(testDockerVersion, nil) + agentVersion, agentHash, containerRuntimeVersion := generateVersionInfo(mockEngine) + assert.Equal(t, version.Version, agentVersion) + assert.Equal(t, version.GitShortHash, agentHash) + assert.Equal(t, testDockerVersion, containerRuntimeVersion) +}