diff --git a/agent/app/agent.go b/agent/app/agent.go index f5b837acf1a..fa6d73c1c63 100644 --- a/agent/app/agent.go +++ b/agent/app/agent.go @@ -20,36 +20,29 @@ import ( "fmt" "time" - "github.com/aws/amazon-ecs-agent/ecs-agent/logger" - "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" - - dockerdoctor "github.com/aws/amazon-ecs-agent/agent/doctor" // for Docker specific container instance health checks - "github.com/aws/amazon-ecs-agent/agent/eni/watcher" - "github.com/aws/amazon-ecs-agent/ecs-agent/doctor" - "github.com/aws/aws-sdk-go/aws/awserr" - - "github.com/aws/amazon-ecs-agent/agent/credentials/instancecreds" - "github.com/aws/amazon-ecs-agent/agent/engine/execcmd" - "github.com/aws/amazon-ecs-agent/agent/metrics" - acshandler "github.com/aws/amazon-ecs-agent/agent/acs/handler" "github.com/aws/amazon-ecs-agent/agent/api" "github.com/aws/amazon-ecs-agent/agent/api/ecsclient" "github.com/aws/amazon-ecs-agent/agent/app/factory" "github.com/aws/amazon-ecs-agent/agent/config" "github.com/aws/amazon-ecs-agent/agent/containermetadata" + "github.com/aws/amazon-ecs-agent/agent/credentials/instancecreds" "github.com/aws/amazon-ecs-agent/agent/data" "github.com/aws/amazon-ecs-agent/agent/dockerclient" "github.com/aws/amazon-ecs-agent/agent/dockerclient/dockerapi" "github.com/aws/amazon-ecs-agent/agent/dockerclient/sdkclientfactory" + dockerdoctor "github.com/aws/amazon-ecs-agent/agent/doctor" // for Docker specific container instance health checks "github.com/aws/amazon-ecs-agent/agent/ec2" "github.com/aws/amazon-ecs-agent/agent/ecscni" "github.com/aws/amazon-ecs-agent/agent/engine" "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate" + "github.com/aws/amazon-ecs-agent/agent/engine/execcmd" engineserviceconnect "github.com/aws/amazon-ecs-agent/agent/engine/serviceconnect" "github.com/aws/amazon-ecs-agent/agent/eni/pause" + "github.com/aws/amazon-ecs-agent/agent/eni/watcher" "github.com/aws/amazon-ecs-agent/agent/eventhandler" "github.com/aws/amazon-ecs-agent/agent/handlers" + "github.com/aws/amazon-ecs-agent/agent/metrics" "github.com/aws/amazon-ecs-agent/agent/sighandlers" "github.com/aws/amazon-ecs-agent/agent/sighandlers/exitcodes" "github.com/aws/amazon-ecs-agent/agent/statemanager" @@ -63,12 +56,17 @@ import ( acsclient "github.com/aws/amazon-ecs-agent/ecs-agent/acs/client" apierrors "github.com/aws/amazon-ecs-agent/ecs-agent/api/errors" "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" + "github.com/aws/amazon-ecs-agent/ecs-agent/doctor" "github.com/aws/amazon-ecs-agent/ecs-agent/ecs_client/model/ecs" "github.com/aws/amazon-ecs-agent/ecs-agent/eventstream" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" "github.com/aws/amazon-ecs-agent/ecs-agent/tcs/model/ecstcs" "github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" aws_credentials "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/cihub/seelog" "github.com/pborman/uuid" ) @@ -866,6 +864,13 @@ func (agent *ecsAgent) startAsyncRoutines( // Start sending events to the backend go eventhandler.HandleEngineEvents(agent.ctx, taskEngine, client, taskHandler, attachmentEventHandler) + err := statsEngine.MustInit(agent.ctx, taskEngine, agent.cfg.Cluster, agent.containerInstanceARN) + if err != nil { + seelog.Warnf("Error initializing metrics engine: %v", err) + return + } + go statsEngine.StartMetricsPublish() + telemetrySessionParams := tcshandler.TelemetrySessionParams{ Ctx: agent.ctx, CredentialProvider: agent.credentialProvider, @@ -879,14 +884,6 @@ func (agent *ecsAgent) startAsyncRoutines( HealthChannel: healthMessages, Doctor: doctor, } - - err := statsEngine.MustInit(agent.ctx, taskEngine, agent.cfg.Cluster, agent.containerInstanceARN) - if err != nil { - seelog.Warnf("Error initializing metrics engine: %v", err) - return - } - go statsEngine.StartMetricsPublish() - // Start metrics session in a go routine go tcshandler.StartMetricsSession(&telemetrySessionParams) } diff --git a/agent/stats/reporter/reporter.go b/agent/stats/reporter/reporter.go new file mode 100644 index 00000000000..11118ec6143 --- /dev/null +++ b/agent/stats/reporter/reporter.go @@ -0,0 +1,147 @@ +package reporter + +import ( + "context" + "errors" + "io" + "time" + + "github.com/aws/amazon-ecs-agent/agent/api" + "github.com/aws/amazon-ecs-agent/agent/config" + "github.com/aws/amazon-ecs-agent/agent/engine" + "github.com/aws/amazon-ecs-agent/agent/version" + "github.com/aws/amazon-ecs-agent/ecs-agent/doctor" + "github.com/aws/amazon-ecs-agent/ecs-agent/eventstream" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" + tcshandler "github.com/aws/amazon-ecs-agent/ecs-agent/tcs/handler" + "github.com/aws/amazon-ecs-agent/ecs-agent/tcs/model/ecstcs" + "github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry" + "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/cihub/seelog" +) + +const ( + // The maximum time to wait between heartbeats without disconnecting + defaultHeartbeatTimeout = 1 * time.Minute + defaultHeartbeatJitter = 1 * time.Minute + // Default websocket client disconnection timeout initiated by agent + defaultDisconnectionTimeout = 15 * time.Minute + defaultDisconnectionJitter = 30 * time.Minute +) + +type DockerTelemetrySession struct { + s tcshandler.TelemetrySession + ecsClient api.ECSClient + containerInstanceArn string +} + +// NewDockerTelemetrySession returns creates a DockerTelemetrySession, which has a tcshandler.TelemetrySession embedded. +// tcshandler.TelemetrySession contains the logic to manage the TCSClient and corresponding websocket connection +func NewDockerTelemetrySession( + containerInstanceArn string, + credentialProvider *credentials.Credentials, + cfg *config.Config, + deregisterInstanceEventStream *eventstream.EventStream, + ecsClient api.ECSClient, + taskEngine engine.TaskEngine, + metricsChannel <-chan ecstcs.TelemetryMessage, + healthChannel <-chan ecstcs.HealthMessage, + doctor *doctor.Doctor) *DockerTelemetrySession { + ok, cfgParseErr := isContainerHealthMetricsDisabled(cfg) + if cfgParseErr != nil { + seelog.Warnf("Error starting metrics session: %v", cfgParseErr) + return nil + } + if ok { + seelog.Warnf("Metrics were disabled, not starting the telemetry session") + return nil + } + + agentVersion, agentHash, containerRuntimeVersion := generateVersionInfo(taskEngine) + if cfg == nil { + logger.Error("Config is empty in the tcs session parameter") + return nil + } + + session := tcshandler.NewTelemetrySession( + containerInstanceArn, + cfg.Cluster, + agentVersion, + agentHash, + containerRuntimeVersion, + "", // this will be overridden by DockerTelemetrySession.Start() + cfg.DisableMetrics.Enabled(), + credentialProvider, + &wsclient.WSClientMinAgentConfig{ + AWSRegion: cfg.AWSRegion, + AcceptInsecureCert: cfg.AcceptInsecureCert, + DockerEndpoint: cfg.DockerEndpoint, + IsDocker: true, + }, + deregisterInstanceEventStream, + defaultHeartbeatTimeout, + defaultHeartbeatJitter, + defaultDisconnectionTimeout, + defaultDisconnectionJitter, + nil, + metricsChannel, + healthChannel, + doctor, + ) + return &DockerTelemetrySession{session, ecsClient, containerInstanceArn} +} + +// Start "overloads" tcshandler.TelemetrySession's Start with extra handling of discoverTelemetryEndpoint result. +// discoverTelemetryEndpoint and tcshandler.TelemetrySession's StartTelemetrySession errors are handled +// (retryWithBackoff or return) in a combined manner +func (session *DockerTelemetrySession) Start(ctx context.Context) error { + backoff := retry.NewExponentialBackoff(time.Second, 1*time.Minute, 0.2, 2) + for { + endpoint, tcsError := discoverPollEndpoint(session.containerInstanceArn, session.ecsClient) + if tcsError == nil { + tcsError = session.s.StartTelemetrySession(ctx, endpoint) + } + switch tcsError { + case context.Canceled, context.DeadlineExceeded: + return tcsError + case io.EOF, nil: + logger.Info("TCS Websocket connection closed for a valid reason") + backoff.Reset() + default: + seelog.Errorf("Error: lost websocket connection with ECS Telemetry service (TCS): %v", tcsError) + time.Sleep(backoff.Duration()) + } + } +} + +// generateVersionInfo generates the agentVersion, agentHash and containerRuntimeVersion from dockerTaskEngine state +func generateVersionInfo(taskEngine engine.TaskEngine) (string, string, string) { + agentVersion := version.Version + agentHash := version.GitHashString() + var containerRuntimeVersion string + if dockerVersion, getVersionErr := taskEngine.Version(); getVersionErr == nil { + containerRuntimeVersion = dockerVersion + } + + return agentVersion, agentHash, containerRuntimeVersion +} + +// discoverPollEndpoint calls DiscoverTelemetryEndpoint to get the TCS endpoint url for TCS client to connect +func discoverPollEndpoint(containerInstanceArn string, ecsClient api.ECSClient) (string, error) { + tcsEndpoint, err := ecsClient.DiscoverTelemetryEndpoint(containerInstanceArn) + if err != nil { + logger.Error("tcs: unable to discover poll endpoint", logger.Fields{ + field.Error: err, + }) + } + return tcsEndpoint, err +} + +func isContainerHealthMetricsDisabled(cfg *config.Config) (bool, error) { + if cfg != nil { + return cfg.DisableMetrics.Enabled() && cfg.DisableDockerHealthCheck.Enabled(), nil + } + return false, errors.New("config is empty in the tcs session parameter") +} diff --git a/agent/tcs/handler/types.go b/agent/tcs/handler/types.go index 2666f945611..940267fba25 100644 --- a/agent/tcs/handler/types.go +++ b/agent/tcs/handler/types.go @@ -37,7 +37,6 @@ type TelemetrySessionParams struct { CredentialProvider *credentials.Credentials Cfg *config.Config DeregisterInstanceEventStream *eventstream.EventStream - AcceptInvalidCert bool ECSClient api.ECSClient TaskEngine engine.TaskEngine StatsEngine *stats.DockerStatsEngine diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tcs/client/client.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tcs/client/client.go index 669e140f92f..81cc754624b 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tcs/client/client.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tcs/client/client.go @@ -30,6 +30,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/private/protocol/json/jsonutil" + "github.com/cihub/seelog" "github.com/pborman/uuid" ) @@ -100,7 +101,7 @@ func New(url string, // AddRequestHandler). All request handlers should be added prior to making this // call as unhandled requests will be discarded. func (cs *tcsClientServer) Serve(ctx context.Context) error { - seelog.Debug("TCS client starting websocket poll loop") + logger.Debug("TCS client starting websocket poll loop") if !cs.IsReady() { return fmt.Errorf("tcs client: websocket not ready for connections") } @@ -124,7 +125,7 @@ func (cs *tcsClientServer) publishMessages(ctx context.Context) { case <-ctx.Done(): return case metric := <-cs.metrics: - seelog.Debugf("received telemetry message in metricsChannel") + logger.Debug("received telemetry message in metricsChannel") err := cs.publishMetricsOnce(metric) if err != nil { logger.Warn("Error publishing metrics", logger.Fields{ @@ -132,10 +133,12 @@ func (cs *tcsClientServer) publishMessages(ctx context.Context) { }) } case health := <-cs.health: - seelog.Debugf("received health message in healthChannel") + logger.Debug("received health message in healthChannel") err := cs.publishHealthOnce(health) if err != nil { - seelog.Warnf("Error publishing metrics: %v", err) + logger.Warn("Error publishing metrics", logger.Fields{ + field.Error: err, + }) } } } @@ -151,7 +154,7 @@ func (cs *tcsClientServer) publishMetricsOnce(message ecstcs.TelemetryMessage) e // Make the publish metrics request to the backend. for _, request := range requests { - seelog.Debugf("making publish metrics request") + logger.Debug("making publish metrics request") err = cs.MakeRequest(request) if err != nil { return err @@ -169,7 +172,7 @@ func (cs *tcsClientServer) publishHealthOnce(health ecstcs.HealthMessage) error } // Make the publish metrics request to the backend. for _, request := range requests { - seelog.Debugf("making publish health metrics request") + logger.Debug("making publish health metrics request") err = cs.MakeRequest(request) if err != nil { return err @@ -283,7 +286,7 @@ func (cs *tcsClientServer) healthToPublishHealthRequests(health ecstcs.HealthMes metadata, taskHealthMetrics := health.Metadata, health.HealthMetrics if metadata == nil || taskHealthMetrics == nil { - seelog.Debug("No container health metrics to report") + logger.Debug("No container health metrics to report") return nil, nil } @@ -377,7 +380,7 @@ func (cs *tcsClientServer) publishInstanceStatus(ctx context.Context) { // handles, pertain to the health of the tasks that are running on this // container instance. if cs.pullInstanceStatusTicker == nil { - seelog.Debug("Skipping publishing container instance statuses. Publish ticker is uninitialized") + logger.Debug("Skipping publishing container instance statuses. Publish ticker is uninitialized") return } @@ -387,12 +390,14 @@ func (cs *tcsClientServer) publishInstanceStatus(ctx context.Context) { if !cs.doctor.HasStatusBeenReported() { err := cs.publishInstanceStatusOnce() if err != nil { - seelog.Warnf("Unable to publish instance status: %v", err) + logger.Warn("Unable to publish instance status", logger.Fields{ + field.Error: err, + }) } else { cs.doctor.SetStatusReported(true) } } else { - seelog.Debug("Skipping publishing container instance status message that was already sent") + logger.Debug("Skipping publishing container instance status message that was already sent") } case <-ctx.Done(): return diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tcs/handler/handler.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tcs/handler/handler.go new file mode 100644 index 00000000000..7abfb2700c7 --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tcs/handler/handler.go @@ -0,0 +1,257 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +//lint:file-ignore U1000 Ignore unused metricsFactory field as it is only used by Fargate + +package tcshandler + +import ( + "context" + "io" + "net/url" + "strings" + "time" + + "github.com/aws/amazon-ecs-agent/ecs-agent/doctor" + "github.com/aws/amazon-ecs-agent/ecs-agent/eventstream" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" + "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" + tcsclient "github.com/aws/amazon-ecs-agent/ecs-agent/tcs/client" + "github.com/aws/amazon-ecs-agent/ecs-agent/tcs/model/ecstcs" + "github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry" + "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/cihub/seelog" +) + +const ( + deregisterContainerInstanceHandler = "TCSDeregisterContainerInstanceHandler" + ContainerRuntimeDocker = "Docker" +) + +// TelemetrySession defines an interface for handler's long-lived connection with TCS. +type TelemetrySession interface { + StartTelemetrySession(context.Context, string) error + Start(context.Context) error +} + +// telemetrySession is the base session params type which contains all the parameters required to start a tcs session +type telemetrySession struct { + containerInstanceArn string + cluster string + agentVersion string + agentHash string + containerRuntimeVersion string + endpoint string + disableContainerHealthMetrics bool + credentialsProvider *credentials.Credentials + cfg *wsclient.WSClientMinAgentConfig + deregisterInstanceEventStream *eventstream.EventStream + heartbeatTimeout time.Duration + heartbeatJitterMax time.Duration + disconnectTimeout time.Duration + disconnectJitterMax time.Duration + metricsFactory metrics.EntryFactory + metricsChannel <-chan ecstcs.TelemetryMessage + healthChannel <-chan ecstcs.HealthMessage + doctor *doctor.Doctor +} + +func NewTelemetrySession( + containerInstanceArn string, + cluster string, + agentVersion string, + agentHash string, + containerRuntimeVersion string, + endpoint string, + disableContainerHealthMetrics bool, + credentialsProvider *credentials.Credentials, + cfg *wsclient.WSClientMinAgentConfig, + deregisterInstanceEventStream *eventstream.EventStream, + heartbeatTimeout time.Duration, + heartbeatJitterMax time.Duration, + disconnectTimeout time.Duration, + disconnectJitterMax time.Duration, + metricsFactory metrics.EntryFactory, + metricsChannel <-chan ecstcs.TelemetryMessage, + healthChannel <-chan ecstcs.HealthMessage, + doctor *doctor.Doctor, +) TelemetrySession { + return &telemetrySession{ + containerInstanceArn: containerInstanceArn, + cluster: cluster, + agentVersion: agentVersion, + agentHash: agentHash, + containerRuntimeVersion: containerRuntimeVersion, + endpoint: endpoint, + disableContainerHealthMetrics: disableContainerHealthMetrics, + credentialsProvider: credentialsProvider, + cfg: cfg, + deregisterInstanceEventStream: deregisterInstanceEventStream, + metricsChannel: metricsChannel, + healthChannel: healthChannel, + heartbeatTimeout: heartbeatTimeout, + heartbeatJitterMax: heartbeatJitterMax, + disconnectTimeout: disconnectTimeout, + disconnectJitterMax: disconnectJitterMax, + metricsFactory: metricsFactory, + doctor: doctor, + } +} + +// Start runs in for loop to start telemetry session with exponential backoff +func (session *telemetrySession) Start(ctx context.Context) error { + backoff := retry.NewExponentialBackoff(time.Second, 1*time.Minute, 0.2, 2) + for { + tcsError := session.StartTelemetrySession(ctx, session.endpoint) + switch tcsError { + case context.Canceled, context.DeadlineExceeded: + return tcsError + case io.EOF, nil: + logger.Info("TCS Websocket connection closed for a valid reason") + backoff.Reset() + default: + seelog.Errorf("Error: lost websocket connection with ECS Telemetry service (TCS): %v", tcsError) + time.Sleep(backoff.Duration()) + } + } +} + +// StartTelemetrySession creates a session with the backend and handles requests. +func (session *telemetrySession) StartTelemetrySession(ctx context.Context, endpoint string) error { + if session.disableContainerHealthMetrics { + logger.Warn("Metrics were disabled, not starting the telemetry session") + return nil + } + + wsRWTimeout := 2*session.heartbeatTimeout + session.heartbeatJitterMax + + var containerRuntime string + if session.cfg.IsDocker { + containerRuntime = ContainerRuntimeDocker + } + + tcsEndpointUrl := formatURL(endpoint, session.cluster, session.containerInstanceArn, session.agentVersion, + session.agentHash, containerRuntime, session.containerRuntimeVersion) + client := tcsclient.New(tcsEndpointUrl, session.cfg, session.doctor, session.disableContainerHealthMetrics, tcsclient.DefaultContainerMetricsPublishInterval, + session.credentialsProvider, wsRWTimeout, session.metricsChannel, session.healthChannel) + defer client.Close() + + if session.deregisterInstanceEventStream != nil { + err := session.deregisterInstanceEventStream.Subscribe(deregisterContainerInstanceHandler, client.Disconnect) + if err != nil { + return err + } + defer session.deregisterInstanceEventStream.Unsubscribe(deregisterContainerInstanceHandler) + } + err := client.Connect() + if err != nil { + logger.Error("Error connecting to TCS", logger.Fields{ + field.Error: err, + }) + return err + } + logger.Info("Connected to TCS endpoint") + // start a timer and listens for tcs heartbeats/acks. The timer is reset when + // we receive a heartbeat from the server or when a published metrics message + // is acked. + timer := time.NewTimer(retry.AddJitter(session.heartbeatTimeout, session.heartbeatJitterMax)) + defer timer.Stop() + client.AddRequestHandler(heartbeatHandler(timer, session.heartbeatTimeout, session.heartbeatJitterMax)) + client.AddRequestHandler(ackPublishMetricHandler(timer, session.heartbeatTimeout, session.heartbeatJitterMax)) + client.AddRequestHandler(ackPublishHealthMetricHandler(timer, session.heartbeatTimeout, session.heartbeatJitterMax)) + client.AddRequestHandler(ackPublishInstanceStatusHandler(timer, session.heartbeatTimeout, session.heartbeatJitterMax)) + client.SetAnyRequestHandler(anyMessageHandler(client, wsRWTimeout)) + serveC := make(chan error, 1) + go func() { + serveC <- client.Serve(ctx) + }() + select { + case <-ctx.Done(): + // outer context done, agent is exiting + client.Disconnect() + case <-timer.C: + seelog.Info("TCS Connection hasn't had any activity for too long; disconnecting") + client.Disconnect() + case err := <-serveC: + return err + } + return nil +} + +// heartbeatHandler resets the heartbeat timer when HeartbeatMessage message is received from tcs. +func heartbeatHandler(timer *time.Timer, heartbeatTimeout, heartbeatJitter time.Duration) func(*ecstcs.HeartbeatMessage) { + return func(*ecstcs.HeartbeatMessage) { + logger.Debug("Received HeartbeatMessage from tcs") + timer.Reset(retry.AddJitter(heartbeatTimeout, heartbeatJitter)) + } +} + +// ackPublishMetricHandler consumes the ack message from the backend. THe backend sends +// the ack each time it processes a metric message. +func ackPublishMetricHandler(timer *time.Timer, heartbeatTimeout, heartbeatJitter time.Duration) func(*ecstcs.AckPublishMetric) { + return func(*ecstcs.AckPublishMetric) { + logger.Debug("Received AckPublishMetric from tcs") + timer.Reset(retry.AddJitter(heartbeatTimeout, heartbeatJitter)) + } +} + +// ackPublishHealthMetricHandler consumes the ack message from backend. The backend sends +// the ack each time it processes a health message +func ackPublishHealthMetricHandler(timer *time.Timer, heartbeatTimeout, heartbeatJitter time.Duration) func(*ecstcs.AckPublishHealth) { + return func(*ecstcs.AckPublishHealth) { + logger.Debug("Received ACKPublishHealth from tcs") + timer.Reset(retry.AddJitter(heartbeatTimeout, heartbeatJitter)) + } +} + +// ackPublishInstanceStatusHandler consumes the ack message from backend. The backend sends +// the ack each time it processes a health message +func ackPublishInstanceStatusHandler(timer *time.Timer, heartbeatTimeout, heartbeatJitter time.Duration) func(*ecstcs.AckPublishInstanceStatus) { + return func(*ecstcs.AckPublishInstanceStatus) { + logger.Debug("Received AckPublishInstanceStatus from tcs") + timer.Reset(retry.AddJitter(heartbeatTimeout, heartbeatJitter)) + } +} + +// anyMessageHandler handles any server message. Any server message means the +// connection is active +func anyMessageHandler(client wsclient.ClientServer, wsRWTimeout time.Duration) func(interface{}) { + return func(interface{}) { + logger.Trace("TCS activity occurred") + // Reset read deadline as there's activity on the channel + if err := client.SetReadDeadline(time.Now().Add(wsRWTimeout)); err != nil { + logger.Warn("Unable to extend read deadline for TCS connection", logger.Fields{ + field.Error: err, + }) + } + } +} + +// formatURL returns formatted url for tcs endpoint. +func formatURL(endpoint, cluster, containerInstance, agentVersion, agentHash, containerRuntime, containerRuntimeVersion string) string { + tcsURL := endpoint + if !strings.HasSuffix(tcsURL, "/") { + tcsURL += "/" + } + query := url.Values{} + query.Set("cluster", cluster) + query.Set("containerInstance", containerInstance) + query.Set("agentVersion", agentVersion) + query.Set("agentHash", agentHash) + if containerRuntime == ContainerRuntimeDocker && containerRuntimeVersion != "" { + query.Set("dockerVersion", containerRuntimeVersion) + } + return tcsURL + "ws?" + query.Encode() +} diff --git a/agent/vendor/modules.txt b/agent/vendor/modules.txt index 04f120fda98..3515b41a843 100644 --- a/agent/vendor/modules.txt +++ b/agent/vendor/modules.txt @@ -27,6 +27,7 @@ github.com/aws/amazon-ecs-agent/ecs-agent/logger/audit/request github.com/aws/amazon-ecs-agent/ecs-agent/logger/field github.com/aws/amazon-ecs-agent/ecs-agent/metrics github.com/aws/amazon-ecs-agent/ecs-agent/tcs/client +github.com/aws/amazon-ecs-agent/ecs-agent/tcs/handler github.com/aws/amazon-ecs-agent/ecs-agent/tcs/model/ecstcs github.com/aws/amazon-ecs-agent/ecs-agent/tmds github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/response diff --git a/ecs-agent/tcs/client/client.go b/ecs-agent/tcs/client/client.go index 669e140f92f..81cc754624b 100644 --- a/ecs-agent/tcs/client/client.go +++ b/ecs-agent/tcs/client/client.go @@ -30,6 +30,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/private/protocol/json/jsonutil" + "github.com/cihub/seelog" "github.com/pborman/uuid" ) @@ -100,7 +101,7 @@ func New(url string, // AddRequestHandler). All request handlers should be added prior to making this // call as unhandled requests will be discarded. func (cs *tcsClientServer) Serve(ctx context.Context) error { - seelog.Debug("TCS client starting websocket poll loop") + logger.Debug("TCS client starting websocket poll loop") if !cs.IsReady() { return fmt.Errorf("tcs client: websocket not ready for connections") } @@ -124,7 +125,7 @@ func (cs *tcsClientServer) publishMessages(ctx context.Context) { case <-ctx.Done(): return case metric := <-cs.metrics: - seelog.Debugf("received telemetry message in metricsChannel") + logger.Debug("received telemetry message in metricsChannel") err := cs.publishMetricsOnce(metric) if err != nil { logger.Warn("Error publishing metrics", logger.Fields{ @@ -132,10 +133,12 @@ func (cs *tcsClientServer) publishMessages(ctx context.Context) { }) } case health := <-cs.health: - seelog.Debugf("received health message in healthChannel") + logger.Debug("received health message in healthChannel") err := cs.publishHealthOnce(health) if err != nil { - seelog.Warnf("Error publishing metrics: %v", err) + logger.Warn("Error publishing metrics", logger.Fields{ + field.Error: err, + }) } } } @@ -151,7 +154,7 @@ func (cs *tcsClientServer) publishMetricsOnce(message ecstcs.TelemetryMessage) e // Make the publish metrics request to the backend. for _, request := range requests { - seelog.Debugf("making publish metrics request") + logger.Debug("making publish metrics request") err = cs.MakeRequest(request) if err != nil { return err @@ -169,7 +172,7 @@ func (cs *tcsClientServer) publishHealthOnce(health ecstcs.HealthMessage) error } // Make the publish metrics request to the backend. for _, request := range requests { - seelog.Debugf("making publish health metrics request") + logger.Debug("making publish health metrics request") err = cs.MakeRequest(request) if err != nil { return err @@ -283,7 +286,7 @@ func (cs *tcsClientServer) healthToPublishHealthRequests(health ecstcs.HealthMes metadata, taskHealthMetrics := health.Metadata, health.HealthMetrics if metadata == nil || taskHealthMetrics == nil { - seelog.Debug("No container health metrics to report") + logger.Debug("No container health metrics to report") return nil, nil } @@ -377,7 +380,7 @@ func (cs *tcsClientServer) publishInstanceStatus(ctx context.Context) { // handles, pertain to the health of the tasks that are running on this // container instance. if cs.pullInstanceStatusTicker == nil { - seelog.Debug("Skipping publishing container instance statuses. Publish ticker is uninitialized") + logger.Debug("Skipping publishing container instance statuses. Publish ticker is uninitialized") return } @@ -387,12 +390,14 @@ func (cs *tcsClientServer) publishInstanceStatus(ctx context.Context) { if !cs.doctor.HasStatusBeenReported() { err := cs.publishInstanceStatusOnce() if err != nil { - seelog.Warnf("Unable to publish instance status: %v", err) + logger.Warn("Unable to publish instance status", logger.Fields{ + field.Error: err, + }) } else { cs.doctor.SetStatusReported(true) } } else { - seelog.Debug("Skipping publishing container instance status message that was already sent") + logger.Debug("Skipping publishing container instance status message that was already sent") } case <-ctx.Done(): return diff --git a/ecs-agent/tcs/client/client_test.go b/ecs-agent/tcs/client/client_test.go index 68c6f9d47e7..0611f05d844 100644 --- a/ecs-agent/tcs/client/client_test.go +++ b/ecs-agent/tcs/client/client_test.go @@ -34,7 +34,6 @@ import ( "github.com/aws/amazon-ecs-agent/ecs-agent/tcs/model/ecstcs" "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" mock_wsconn "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/wsconn/mock" - "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/golang/mock/gomock" @@ -42,20 +41,16 @@ import ( ) const ( - testPublishMetricsInterval = 1 * time.Second - testMessageId = "testMessageId" - testCluster = "default" - testContainerInstance = "containerInstance" - rwTimeout = time.Second - testPublishMetricRequestSizeLimit = 1024 - testTelemetryChannelDefaultBufferSize = 10 - testIncludeScStats = true - testNotIncludeScStats = false -) - -const ( - TEST_CLUSTER = "test-cluster" - TEST_INSTANCE_ARN = "test-instance-arn" + testPublishMetricsInterval = 1 * time.Second + testMessageId = "testMessageId" + testCluster = "default" + testContainerInstance = "containerInstance" + rwTimeout = time.Second + testPublishMetricRequestSizeLimitSC = 1024 + testPublishMetricRequestSizeLimitNonSC = 220 + testTelemetryChannelDefaultBufferSize = 10 + testIncludeScStats = true + testNotIncludeScStats = false ) type trueHealthcheck struct{} @@ -106,53 +101,53 @@ var testCreds = credentials.NewStaticCredentials("test-id", "test-secret", "test var emptyDoctor, _ = doctor.NewDoctor([]doctor.Healthcheck{}, "test-cluster", "this:is:an:instance:arn") -type mockStatsEngine struct{} +type mockStatsSource struct{} -func (*mockStatsEngine) GetInstanceMetrics(includeServiceConnectStats bool) (*ecstcs.MetricsMetadata, []*ecstcs.TaskMetric, error) { +func (*mockStatsSource) GetInstanceMetrics(includeServiceConnectStats bool) (*ecstcs.MetricsMetadata, []*ecstcs.TaskMetric, error) { return nil, nil, fmt.Errorf("uninitialized") } -func (*mockStatsEngine) GetTaskHealthMetrics() (*ecstcs.HealthMetadata, []*ecstcs.TaskHealth, error) { +func (*mockStatsSource) GetTaskHealthMetrics() (*ecstcs.HealthMetadata, []*ecstcs.TaskHealth, error) { return nil, nil, nil } -func (*mockStatsEngine) GetPublishServiceConnectTickerInterval() int32 { +func (*mockStatsSource) GetPublishServiceConnectTickerInterval() int32 { return 0 } -func (*mockStatsEngine) SetPublishServiceConnectTickerInterval(counter int32) { +func (*mockStatsSource) SetPublishServiceConnectTickerInterval(counter int32) { return } -func (*mockStatsEngine) GetPublishMetricsTicker() *time.Ticker { +func (*mockStatsSource) GetPublishMetricsTicker() *time.Ticker { return time.NewTicker(DefaultContainerMetricsPublishInterval) } -type emptyStatsEngine struct{} +type emptyStatsSource struct{} -func (*emptyStatsEngine) GetInstanceMetrics(includeServiceConnectStats bool) (*ecstcs.MetricsMetadata, []*ecstcs.TaskMetric, error) { +func (*emptyStatsSource) GetInstanceMetrics(includeServiceConnectStats bool) (*ecstcs.MetricsMetadata, []*ecstcs.TaskMetric, error) { return nil, nil, fmt.Errorf("empty stats") } -func (*emptyStatsEngine) GetTaskHealthMetrics() (*ecstcs.HealthMetadata, []*ecstcs.TaskHealth, error) { +func (*emptyStatsSource) GetTaskHealthMetrics() (*ecstcs.HealthMetadata, []*ecstcs.TaskHealth, error) { return nil, nil, nil } -func (*emptyStatsEngine) GetPublishServiceConnectTickerInterval() int32 { +func (*emptyStatsSource) GetPublishServiceConnectTickerInterval() int32 { return 0 } -func (*emptyStatsEngine) SetPublishServiceConnectTickerInterval(counter int32) { +func (*emptyStatsSource) SetPublishServiceConnectTickerInterval(counter int32) { return } -func (*emptyStatsEngine) GetPublishMetricsTicker() *time.Ticker { +func (*emptyStatsSource) GetPublishMetricsTicker() *time.Ticker { return time.NewTicker(DefaultContainerMetricsPublishInterval) } -type idleStatsEngine struct{} +type idleStatsSource struct{} -func (*idleStatsEngine) GetInstanceMetrics(includeServiceConnectStats bool) (*ecstcs.MetricsMetadata, []*ecstcs.TaskMetric, error) { +func (*idleStatsSource) GetInstanceMetrics(includeServiceConnectStats bool) (*ecstcs.MetricsMetadata, []*ecstcs.TaskMetric, error) { metadata := &ecstcs.MetricsMetadata{ Cluster: aws.String(testCluster), ContainerInstance: aws.String(testContainerInstance), @@ -162,27 +157,27 @@ func (*idleStatsEngine) GetInstanceMetrics(includeServiceConnectStats bool) (*ec return metadata, []*ecstcs.TaskMetric{}, nil } -func (*idleStatsEngine) GetTaskHealthMetrics() (*ecstcs.HealthMetadata, []*ecstcs.TaskHealth, error) { +func (*idleStatsSource) GetTaskHealthMetrics() (*ecstcs.HealthMetadata, []*ecstcs.TaskHealth, error) { return nil, nil, nil } -func (*idleStatsEngine) GetPublishServiceConnectTickerInterval() int32 { +func (*idleStatsSource) GetPublishServiceConnectTickerInterval() int32 { return 0 } -func (*idleStatsEngine) SetPublishServiceConnectTickerInterval(counter int32) { +func (*idleStatsSource) SetPublishServiceConnectTickerInterval(counter int32) { return } -func (*idleStatsEngine) GetPublishMetricsTicker() *time.Ticker { +func (*idleStatsSource) GetPublishMetricsTicker() *time.Ticker { return time.NewTicker(DefaultContainerMetricsPublishInterval) } -type nonIdleStatsEngine struct { +type nonIdleStatsSource struct { numTasks int } -func (engine *nonIdleStatsEngine) GetInstanceMetrics(includeServiceConnectStats bool) (*ecstcs.MetricsMetadata, []*ecstcs.TaskMetric, error) { +func (engine *nonIdleStatsSource) GetInstanceMetrics(includeServiceConnectStats bool) (*ecstcs.MetricsMetadata, []*ecstcs.TaskMetric, error) { metadata := &ecstcs.MetricsMetadata{ Cluster: aws.String(testCluster), ContainerInstance: aws.String(testContainerInstance), @@ -198,31 +193,31 @@ func (engine *nonIdleStatsEngine) GetInstanceMetrics(includeServiceConnectStats return metadata, taskMetrics, nil } -func (*nonIdleStatsEngine) GetTaskHealthMetrics() (*ecstcs.HealthMetadata, []*ecstcs.TaskHealth, error) { +func (*nonIdleStatsSource) GetTaskHealthMetrics() (*ecstcs.HealthMetadata, []*ecstcs.TaskHealth, error) { return nil, nil, nil } -func (*nonIdleStatsEngine) GetPublishServiceConnectTickerInterval() int32 { +func (*nonIdleStatsSource) GetPublishServiceConnectTickerInterval() int32 { return 0 } -func (*nonIdleStatsEngine) SetPublishServiceConnectTickerInterval(counter int32) { +func (*nonIdleStatsSource) SetPublishServiceConnectTickerInterval(counter int32) { return } -func (*nonIdleStatsEngine) GetPublishMetricsTicker() *time.Ticker { +func (*nonIdleStatsSource) GetPublishMetricsTicker() *time.Ticker { return time.NewTicker(DefaultContainerMetricsPublishInterval) } -func newNonIdleStatsEngine(numTasks int) *nonIdleStatsEngine { - return &nonIdleStatsEngine{numTasks: numTasks} +func newNonIdleStatsSource(numTasks int) *nonIdleStatsSource { + return &nonIdleStatsSource{numTasks: numTasks} } -type serviceConnectStatsEngine struct { +type serviceConnectStatsSource struct { numTasks int } -func (engine *serviceConnectStatsEngine) GetInstanceMetrics(includeServiceConnectStats bool) (*ecstcs.MetricsMetadata, []*ecstcs.TaskMetric, error) { +func (engine *serviceConnectStatsSource) GetInstanceMetrics(includeServiceConnectStats bool) (*ecstcs.MetricsMetadata, []*ecstcs.TaskMetric, error) { metadata := &ecstcs.MetricsMetadata{ Cluster: aws.String(testCluster), ContainerInstance: aws.String(testContainerInstance), @@ -266,7 +261,7 @@ func (engine *serviceConnectStatsEngine) GetInstanceMetrics(includeServiceConnec metricValue := 3.0 var metricCount int64 = 1 - // generate a task metric with size more than testPublishMetricRequestSizeLimit i.e 1kB + // generate a task metric with size more than testPublishMetricRequestSizeLimitSC i.e 1kB generalMetric := ecstcs.GeneralMetric{ MetricName: &metricName, MetricValues: []*float64{&metricValue}, @@ -295,24 +290,24 @@ func (engine *serviceConnectStatsEngine) GetInstanceMetrics(includeServiceConnec return metadata, taskMetrics, nil } -func (*serviceConnectStatsEngine) GetTaskHealthMetrics() (*ecstcs.HealthMetadata, []*ecstcs.TaskHealth, error) { +func (*serviceConnectStatsSource) GetTaskHealthMetrics() (*ecstcs.HealthMetadata, []*ecstcs.TaskHealth, error) { return nil, nil, nil } -func (*serviceConnectStatsEngine) GetPublishServiceConnectTickerInterval() int32 { +func (*serviceConnectStatsSource) GetPublishServiceConnectTickerInterval() int32 { return 0 } -func (*serviceConnectStatsEngine) SetPublishServiceConnectTickerInterval(counter int32) { +func (*serviceConnectStatsSource) SetPublishServiceConnectTickerInterval(counter int32) { return } -func (*serviceConnectStatsEngine) GetPublishMetricsTicker() *time.Ticker { +func (*serviceConnectStatsSource) GetPublishMetricsTicker() *time.Ticker { return time.NewTicker(DefaultContainerMetricsPublishInterval) } -func newServiceConnectStatsEngine(numTasks int) *serviceConnectStatsEngine { - return &serviceConnectStatsEngine{numTasks: numTasks} +func newServiceConnectStatsSource(numTasks int) *serviceConnectStatsSource { + return &serviceConnectStatsSource{numTasks: numTasks} } func TestPayloadHandlerCalled(t *testing.T) { @@ -320,7 +315,7 @@ func TestPayloadHandlerCalled(t *testing.T) { defer ctrl.Finish() conn := mock_wsconn.NewMockWebsocketConn(ctrl) - cs := testCS(conn) + cs := testCS(conn, nil, nil) ctx, _ := context.WithCancel(context.TODO()) @@ -357,7 +352,7 @@ func TestPublishMetricsRequest(t *testing.T) { // TODO: should use explicit values conn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()) - cs := testCS(conn) + cs := testCS(conn, nil, nil) defer cs.Close() err := cs.MakeRequest(&ecstcs.PublishMetricsRequest{}) if err != nil { @@ -365,10 +360,10 @@ func TestPublishMetricsRequest(t *testing.T) { } } -func TestPublishOnceIdleStatsEngine(t *testing.T) { +func TestMetricsToPublishMetricRequestsIdleStatsSource(t *testing.T) { cs := tcsClientServer{} - mockEngine := idleStatsEngine{} - metadata, taskMetrics, _ := mockEngine.GetInstanceMetrics(testNotIncludeScStats) + statsSource := idleStatsSource{} + metadata, taskMetrics, _ := statsSource.GetInstanceMetrics(testNotIncludeScStats) requests, err := cs.metricsToPublishMetricRequests(ecstcs.TelemetryMessage{ Metadata: metadata, TaskMetrics: taskMetrics, @@ -385,16 +380,68 @@ func TestPublishOnceIdleStatsEngine(t *testing.T) { } } -func TestPublishOnceNonIdleStatsEngine(t *testing.T) { +// TestMetricsToPublishMetricRequestsNonIdleStatsSourcePaginationWithMetricsSize checks the correct pagination behavior +// due to number of tasks +func TestMetricsToPublishMetricRequestsNonIdleStatsSourcePaginationWithTaskNumber(t *testing.T) { expectedRequests := 3 // Creates 21 task metrics, which translate to 3 batches, // {[Task1, Task2, ...Task10], [Task11, Task12, ...Task20], [Task21]} numTasks := (tasksInMetricMessage * (expectedRequests - 1)) + 1 cs := tcsClientServer{} - mockEngine := nonIdleStatsEngine{ + statsSource := nonIdleStatsSource{ + numTasks: numTasks, + } + metadata, taskMetrics, err := statsSource.GetInstanceMetrics(testNotIncludeScStats) + requests, err := cs.metricsToPublishMetricRequests(ecstcs.TelemetryMessage{ + Metadata: metadata, + TaskMetrics: taskMetrics, + }) + if err != nil { + t.Fatal("Error creating publishMetricRequests: ", err) + } + taskArns := make(map[string]bool) + for _, request := range requests { + for _, taskMetric := range request.TaskMetrics { + _, exists := taskArns[*taskMetric.TaskArn] + if exists { + t.Fatal("Duplicate task arn in requests: ", *taskMetric.TaskArn) + } + taskArns[*taskMetric.TaskArn] = true + } + } + if len(requests) != expectedRequests { + t.Errorf("Expected %d requests, got %d", expectedRequests, len(requests)) + } + lastRequest := requests[expectedRequests-1] + if !*lastRequest.Metadata.Fin { + t.Error("Fin not set to true in last request") + } + requests = requests[:(expectedRequests - 1)] + for i, request := range requests { + if *request.Metadata.Fin { + t.Errorf("Fin set to true in request %d/%d", i, (expectedRequests - 1)) + } + } +} + +// TestMetricsToPublishMetricRequestsNonIdleStatsSourcePaginationWithMetricsSize checks the correct pagination behavior +// due to metric size limit +func TestMetricsToPublishMetricRequestsNonIdleStatsSourcePaginationWithMetricsSize(t *testing.T) { + tempLimit := publishMetricRequestSizeLimit + publishMetricRequestSizeLimit = testPublishMetricRequestSizeLimitNonSC + defer func() { + publishMetricRequestSizeLimit = tempLimit + }() + + expectedRequests := 2 + // Creates 3 task metrics, which translate to 2 batches, + // {[Task1, Task2], [Task3]} + numTasks := 3 + cs := tcsClientServer{} + statsSource := nonIdleStatsSource{ numTasks: numTasks, } - metadata, taskMetrics, err := mockEngine.GetInstanceMetrics(testNotIncludeScStats) + metadata, taskMetrics, err := statsSource.GetInstanceMetrics(testNotIncludeScStats) requests, err := cs.metricsToPublishMetricRequests(ecstcs.TelemetryMessage{ Metadata: metadata, TaskMetrics: taskMetrics, @@ -427,9 +474,9 @@ func TestPublishOnceNonIdleStatsEngine(t *testing.T) { } } -func TestPublishServiceConnectStatsEngine(t *testing.T) { +func TestMetricsToPublishMetricRequestsServiceConnectStatsSource(t *testing.T) { tempLimit := publishMetricRequestSizeLimit - publishMetricRequestSizeLimit = testPublishMetricRequestSizeLimit + publishMetricRequestSizeLimit = testPublishMetricRequestSizeLimitSC defer func() { publishMetricRequestSizeLimit = tempLimit }() @@ -454,8 +501,8 @@ func TestPublishServiceConnectStatsEngine(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { cs := tcsClientServer{} - mockEngine := newServiceConnectStatsEngine(tc.numTasks) - metadata, taskMetrics, _ := mockEngine.GetInstanceMetrics(testIncludeScStats) + statsSource := newServiceConnectStatsSource(tc.numTasks) + metadata, taskMetrics, _ := statsSource.GetInstanceMetrics(testIncludeScStats) requests, err := cs.metricsToPublishMetricRequests(ecstcs.TelemetryMessage{ Metadata: metadata, TaskMetrics: taskMetrics, @@ -491,13 +538,13 @@ func TestPublishServiceConnectStatsEngine(t *testing.T) { } } -func testCS(conn *mock_wsconn.MockWebsocketConn) wsclient.ClientServer { +func testCS(conn *mock_wsconn.MockWebsocketConn, metricsMessages <-chan ecstcs.TelemetryMessage, healthMessages <-chan ecstcs.HealthMessage) wsclient.ClientServer { cfg := &wsclient.WSClientMinAgentConfig{ AWSRegion: "us-east-1", AcceptInsecureCert: true, } cs := New("https://aws.amazon.com/ecs", cfg, emptyDoctor, false, testPublishMetricsInterval, - testCreds, rwTimeout, nil, nil).(*tcsClientServer) + testCreds, rwTimeout, metricsMessages, healthMessages).(*tcsClientServer) cs.SetConnection(conn) return cs } @@ -509,7 +556,7 @@ func TestCloseClientServer(t *testing.T) { defer ctrl.Finish() conn := mock_wsconn.NewMockWebsocketConn(ctrl) - cs := testCS(conn) + cs := testCS(conn, nil, nil) gomock.InOrder( conn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil), @@ -530,7 +577,7 @@ func TestAckPublishHealthHandlerCalled(t *testing.T) { defer ctrl.Finish() conn := mock_wsconn.NewMockWebsocketConn(ctrl) - cs := testCS(conn) + cs := testCS(conn, nil, nil) ctx, _ := context.WithCancel(context.TODO()) @@ -623,7 +670,7 @@ func TestSessionClosed(t *testing.T) { defer ctrl.Finish() conn := mock_wsconn.NewMockWebsocketConn(ctrl) - cs := testCS(conn) + cs := testCS(conn, nil, nil) ctx, _ := context.WithCancel(context.TODO()) @@ -692,7 +739,7 @@ func TestGetInstanceStatuses(t *testing.T) { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - newDoctor, _ := doctor.NewDoctor(tc.checks, TEST_CLUSTER, TEST_INSTANCE_ARN) + newDoctor, _ := doctor.NewDoctor(tc.checks, testCluster, testContainerInstance) cs := tcsClientServer{ doctor: newDoctor, } @@ -749,7 +796,7 @@ func TestGetPublishInstanceStatusRequest(t *testing.T) { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - newDoctor, _ := doctor.NewDoctor(tc.checks, TEST_CLUSTER, TEST_INSTANCE_ARN) + newDoctor, _ := doctor.NewDoctor(tc.checks, testCluster, testContainerInstance) cs := tcsClientServer{ doctor: newDoctor, } @@ -757,8 +804,8 @@ func TestGetPublishInstanceStatusRequest(t *testing.T) { // note: setting RequestId and Timestamp to nil so I can make the comparison metadata := &ecstcs.InstanceStatusMetadata{ - Cluster: aws.String(TEST_CLUSTER), - ContainerInstance: aws.String(TEST_INSTANCE_ARN), + Cluster: aws.String(testCluster), + ContainerInstance: aws.String(testContainerInstance), RequestId: nil, } @@ -786,7 +833,7 @@ func TestAckPublishInstanceStatusHandlerCalled(t *testing.T) { defer ctrl.Finish() conn := mock_wsconn.NewMockWebsocketConn(ctrl) - cs := testCS(conn) + cs := testCS(conn, nil, nil) ctx, _ := context.WithCancel(context.TODO()) @@ -811,3 +858,53 @@ func TestAckPublishInstanceStatusHandlerCalled(t *testing.T) { t.Log("Waiting for handler to return payload.") <-handledPayload } + +func TestEmptyChannelNonBlocking(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + ctx, cancel := context.WithCancel(context.TODO()) + + telemetryMessages := make(chan ecstcs.TelemetryMessage, 10) + healthMessages := make(chan ecstcs.HealthMessage, 10) + + conn := mock_wsconn.NewMockWebsocketConn(ctrl) + cs := testCS(conn, telemetryMessages, healthMessages).(*tcsClientServer) + go cancelAfterWait(cancel) + + // verify publishMessages returns (empty channels) after context cancels + cs.publishMessages(ctx) + + // verify message is polled out + assert.Len(t, telemetryMessages, 0) + assert.Len(t, healthMessages, 0) +} + +func cancelAfterWait(cancel context.CancelFunc) { + time.Sleep(5 * time.Second) + cancel() +} + +func TestInvalidFormatMessageOnChannel(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + ctx, _ := context.WithCancel(context.TODO()) + + telemetryMessages := make(chan ecstcs.TelemetryMessage, 10) + healthMessages := make(chan ecstcs.HealthMessage, 10) + + // channel will do type check when sending message. We can only test nil attribute case. + telemetryMessages <- ecstcs.TelemetryMessage{} + healthMessages <- ecstcs.HealthMessage{} + + conn := mock_wsconn.NewMockWebsocketConn(ctrl) + cs := testCS(conn, telemetryMessages, healthMessages).(*tcsClientServer) + go cs.publishMessages(ctx) + time.Sleep(1 * time.Second) // wait for message polled + + // verify message is polled out + assert.Len(t, telemetryMessages, 0) + assert.Len(t, healthMessages, 0) + + // verify no request was made from the two ill-formed message + conn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).Times(0) +} diff --git a/ecs-agent/tcs/handler/handler.go b/ecs-agent/tcs/handler/handler.go new file mode 100644 index 00000000000..529c8d791b8 --- /dev/null +++ b/ecs-agent/tcs/handler/handler.go @@ -0,0 +1,252 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +//lint:file-ignore U1000 Ignore unused metricsFactory field as it is only used by Fargate + +package tcshandler + +import ( + "context" + "io" + "net/url" + "strings" + "time" + + "github.com/aws/amazon-ecs-agent/ecs-agent/doctor" + "github.com/aws/amazon-ecs-agent/ecs-agent/eventstream" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" + "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" + tcsclient "github.com/aws/amazon-ecs-agent/ecs-agent/tcs/client" + "github.com/aws/amazon-ecs-agent/ecs-agent/tcs/model/ecstcs" + "github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry" + "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/cihub/seelog" +) + +const ( + deregisterContainerInstanceHandler = "TCSDeregisterContainerInstanceHandler" + ContainerRuntimeDocker = "Docker" +) + +// TelemetrySession defines an interface for handler's long-lived connection with TCS. +type TelemetrySession interface { + StartTelemetrySession(context.Context, string) error + Start(context.Context) error +} + +// telemetrySession is the base session params type which contains all the parameters required to start a tcs session +type telemetrySession struct { + containerInstanceArn string + cluster string + agentVersion string + agentHash string + containerRuntimeVersion string + endpoint string + disableMetrics bool + credentialsProvider *credentials.Credentials + cfg *wsclient.WSClientMinAgentConfig + deregisterInstanceEventStream *eventstream.EventStream + heartbeatTimeout time.Duration + heartbeatJitterMax time.Duration + disconnectTimeout time.Duration + disconnectJitterMax time.Duration + metricsFactory metrics.EntryFactory + metricsChannel <-chan ecstcs.TelemetryMessage + healthChannel <-chan ecstcs.HealthMessage + doctor *doctor.Doctor +} + +func NewTelemetrySession( + containerInstanceArn string, + cluster string, + agentVersion string, + agentHash string, + containerRuntimeVersion string, + endpoint string, + disableMetrics bool, + credentialsProvider *credentials.Credentials, + cfg *wsclient.WSClientMinAgentConfig, + deregisterInstanceEventStream *eventstream.EventStream, + heartbeatTimeout time.Duration, + heartbeatJitterMax time.Duration, + disconnectTimeout time.Duration, + disconnectJitterMax time.Duration, + metricsFactory metrics.EntryFactory, + metricsChannel <-chan ecstcs.TelemetryMessage, + healthChannel <-chan ecstcs.HealthMessage, + doctor *doctor.Doctor, +) TelemetrySession { + return &telemetrySession{ + containerInstanceArn: containerInstanceArn, + cluster: cluster, + agentVersion: agentVersion, + agentHash: agentHash, + containerRuntimeVersion: containerRuntimeVersion, + endpoint: endpoint, + disableMetrics: disableMetrics, + credentialsProvider: credentialsProvider, + cfg: cfg, + deregisterInstanceEventStream: deregisterInstanceEventStream, + metricsChannel: metricsChannel, + healthChannel: healthChannel, + heartbeatTimeout: heartbeatTimeout, + heartbeatJitterMax: heartbeatJitterMax, + disconnectTimeout: disconnectTimeout, + disconnectJitterMax: disconnectJitterMax, + metricsFactory: metricsFactory, + doctor: doctor, + } +} + +// Start runs in for loop to start telemetry session with exponential backoff +func (session *telemetrySession) Start(ctx context.Context) error { + backoff := retry.NewExponentialBackoff(time.Second, 1*time.Minute, 0.2, 2) + for { + tcsError := session.StartTelemetrySession(ctx, session.endpoint) + switch tcsError { + case context.Canceled, context.DeadlineExceeded: + return tcsError + case io.EOF, nil: + logger.Info("TCS Websocket connection closed for a valid reason") + backoff.Reset() + default: + seelog.Errorf("Error: lost websocket connection with ECS Telemetry service (TCS): %v", tcsError) + time.Sleep(backoff.Duration()) + } + } +} + +// StartTelemetrySession creates a session with the backend and handles requests. +func (session *telemetrySession) StartTelemetrySession(ctx context.Context, endpoint string) error { + wsRWTimeout := 2*session.heartbeatTimeout + session.heartbeatJitterMax + + var containerRuntime string + if session.cfg.IsDocker { + containerRuntime = ContainerRuntimeDocker + } + + tcsEndpointUrl := formatURL(endpoint, session.cluster, session.containerInstanceArn, session.agentVersion, + session.agentHash, containerRuntime, session.containerRuntimeVersion) + client := tcsclient.New(tcsEndpointUrl, session.cfg, session.doctor, session.disableMetrics, tcsclient.DefaultContainerMetricsPublishInterval, + session.credentialsProvider, wsRWTimeout, session.metricsChannel, session.healthChannel) + defer client.Close() + + if session.deregisterInstanceEventStream != nil { + err := session.deregisterInstanceEventStream.Subscribe(deregisterContainerInstanceHandler, client.Disconnect) + if err != nil { + return err + } + defer session.deregisterInstanceEventStream.Unsubscribe(deregisterContainerInstanceHandler) + } + err := client.Connect() + if err != nil { + logger.Error("Error connecting to TCS", logger.Fields{ + field.Error: err, + }) + return err + } + logger.Info("Connected to TCS endpoint") + // start a timer and listens for tcs heartbeats/acks. The timer is reset when + // we receive a heartbeat from the server or when a published metrics message + // is acked. + timer := time.NewTimer(retry.AddJitter(session.heartbeatTimeout, session.heartbeatJitterMax)) + defer timer.Stop() + client.AddRequestHandler(heartbeatHandler(timer, session.heartbeatTimeout, session.heartbeatJitterMax)) + client.AddRequestHandler(ackPublishMetricHandler(timer, session.heartbeatTimeout, session.heartbeatJitterMax)) + client.AddRequestHandler(ackPublishHealthMetricHandler(timer, session.heartbeatTimeout, session.heartbeatJitterMax)) + client.AddRequestHandler(ackPublishInstanceStatusHandler(timer, session.heartbeatTimeout, session.heartbeatJitterMax)) + client.SetAnyRequestHandler(anyMessageHandler(client, wsRWTimeout)) + serveC := make(chan error, 1) + go func() { + serveC <- client.Serve(ctx) + }() + select { + case <-ctx.Done(): + // outer context done, agent is exiting + client.Disconnect() + case <-timer.C: + seelog.Info("TCS Connection hasn't had any activity for too long; disconnecting") + client.Disconnect() + case err := <-serveC: + return err + } + return nil +} + +// heartbeatHandler resets the heartbeat timer when HeartbeatMessage message is received from tcs. +func heartbeatHandler(timer *time.Timer, heartbeatTimeout, heartbeatJitter time.Duration) func(*ecstcs.HeartbeatMessage) { + return func(*ecstcs.HeartbeatMessage) { + logger.Debug("Received HeartbeatMessage from tcs") + timer.Reset(retry.AddJitter(heartbeatTimeout, heartbeatJitter)) + } +} + +// ackPublishMetricHandler consumes the ack message from the backend. THe backend sends +// the ack each time it processes a metric message. +func ackPublishMetricHandler(timer *time.Timer, heartbeatTimeout, heartbeatJitter time.Duration) func(*ecstcs.AckPublishMetric) { + return func(*ecstcs.AckPublishMetric) { + logger.Debug("Received AckPublishMetric from tcs") + timer.Reset(retry.AddJitter(heartbeatTimeout, heartbeatJitter)) + } +} + +// ackPublishHealthMetricHandler consumes the ack message from backend. The backend sends +// the ack each time it processes a health message +func ackPublishHealthMetricHandler(timer *time.Timer, heartbeatTimeout, heartbeatJitter time.Duration) func(*ecstcs.AckPublishHealth) { + return func(*ecstcs.AckPublishHealth) { + logger.Debug("Received ACKPublishHealth from tcs") + timer.Reset(retry.AddJitter(heartbeatTimeout, heartbeatJitter)) + } +} + +// ackPublishInstanceStatusHandler consumes the ack message from backend. The backend sends +// the ack each time it processes a health message +func ackPublishInstanceStatusHandler(timer *time.Timer, heartbeatTimeout, heartbeatJitter time.Duration) func(*ecstcs.AckPublishInstanceStatus) { + return func(*ecstcs.AckPublishInstanceStatus) { + logger.Debug("Received AckPublishInstanceStatus from tcs") + timer.Reset(retry.AddJitter(heartbeatTimeout, heartbeatJitter)) + } +} + +// anyMessageHandler handles any server message. Any server message means the +// connection is active +func anyMessageHandler(client wsclient.ClientServer, wsRWTimeout time.Duration) func(interface{}) { + return func(interface{}) { + logger.Trace("TCS activity occurred") + // Reset read deadline as there's activity on the channel + if err := client.SetReadDeadline(time.Now().Add(wsRWTimeout)); err != nil { + logger.Warn("Unable to extend read deadline for TCS connection", logger.Fields{ + field.Error: err, + }) + } + } +} + +// formatURL returns formatted url for tcs endpoint. +func formatURL(endpoint, cluster, containerInstance, agentVersion, agentHash, containerRuntime, containerRuntimeVersion string) string { + tcsURL := endpoint + if !strings.HasSuffix(tcsURL, "/") { + tcsURL += "/" + } + query := url.Values{} + query.Set("cluster", cluster) + query.Set("containerInstance", containerInstance) + query.Set("agentVersion", agentVersion) + query.Set("agentHash", agentHash) + if containerRuntime == ContainerRuntimeDocker && containerRuntimeVersion != "" { + query.Set("dockerVersion", containerRuntimeVersion) + } + return tcsURL + "ws?" + query.Encode() +} diff --git a/ecs-agent/tcs/handler/handler_test.go b/ecs-agent/tcs/handler/handler_test.go new file mode 100644 index 00000000000..571ec93bdc5 --- /dev/null +++ b/ecs-agent/tcs/handler/handler_test.go @@ -0,0 +1,490 @@ +//go:build unit +// +build unit + +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package tcshandler + +import ( + "errors" + "io" + "math/rand" + "net/url" + "strings" + "sync" + "testing" + "time" + + "context" + + "github.com/aws/amazon-ecs-agent/ecs-agent/doctor" + "github.com/aws/amazon-ecs-agent/ecs-agent/eventstream" + tcsclient "github.com/aws/amazon-ecs-agent/ecs-agent/tcs/client" + "github.com/aws/amazon-ecs-agent/ecs-agent/tcs/model/ecstcs" + "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" + wsmock "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/mock/utils" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/golang/mock/gomock" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" +) + +const ( + testTaskArn = "arn:aws:ecs:us-east-1:123:task/def" + testTaskDefinitionFamily = "task-def" + testClusterArn = "arn:aws:ecs:us-east-1:123:cluster/default" + testInstanceArn = "arn:aws:ecs:us-east-1:123:container-instance/abc" + testMessageId = "testMessageId" + testPublishMetricsInterval = 1 * time.Second + testSendMetricsToChannelWaitTime = 100 * time.Millisecond + testTelemetryChannelDefaultBufferSize = 10 + testDockerEndpoint = "testEndpoint" + testAgentVersion = "testAgentVersion" + testAgentHash = "testAgentHash" + testContainerRuntimeVersion = "testContainerRuntimeVersion" + testHeartbeatTimeout = 1 * time.Minute + testHeartbeatJitter = 1 * time.Minute + testDisconnectionTimeout = 15 * time.Minute + testDisconnectionJitter = 30 * time.Minute +) + +type mockStatsSource struct { + metricsChannel chan<- ecstcs.TelemetryMessage + healthChannel chan<- ecstcs.HealthMessage + publishMetricsTicker *time.Ticker +} + +var testCreds = credentials.NewStaticCredentials("test-id", "test-secret", "test-token") + +var testCfg = &wsclient.WSClientMinAgentConfig{ + AWSRegion: "us-east-1", + AcceptInsecureCert: true, + DockerEndpoint: testDockerEndpoint, + IsDocker: true, +} + +var emptyDoctor, _ = doctor.NewDoctor([]doctor.Healthcheck{}, "test-cluster", "this:is:an:instance:arn") + +func (*mockStatsSource) GetInstanceMetrics(includeServiceConnectStats bool) (*ecstcs.MetricsMetadata, []*ecstcs.TaskMetric, error) { + req := createPublishMetricsRequest() + return req.Metadata, req.TaskMetrics, nil +} + +func (*mockStatsSource) GetTaskHealthMetrics() (*ecstcs.HealthMetadata, []*ecstcs.TaskHealth, error) { + return nil, nil, nil +} + +func (*mockStatsSource) GetPublishServiceConnectTickerInterval() int32 { + return 0 +} + +func (*mockStatsSource) SetPublishServiceConnectTickerInterval(counter int32) { + return +} + +func (Source *mockStatsSource) GetPublishMetricsTicker() *time.Ticker { + return Source.publishMetricsTicker +} + +// SimulateMetricsPublishToChannel simulates the behavior of `StartMetricsPublish` in DockerStatsSource, which feeds metrics +// to channel to TCS Client. There has to be at least one valid metrics sent, otherwise no request will be made to mockServer +// in TestStartTelemetrySession, specifically blocking `request := <-requestChan` +func (Source *mockStatsSource) SimulateMetricsPublishToChannel(ctx context.Context) { + Source.publishMetricsTicker = time.NewTicker(testPublishMetricsInterval) + for { + select { + case <-Source.publishMetricsTicker.C: + Source.metricsChannel <- ecstcs.TelemetryMessage{ + Metadata: &ecstcs.MetricsMetadata{ + Cluster: aws.String(testClusterArn), + ContainerInstance: aws.String(testInstanceArn), + Fin: aws.Bool(false), + Idle: aws.Bool(false), + MessageId: aws.String(testMessageId), + }, + TaskMetrics: []*ecstcs.TaskMetric{ + &ecstcs.TaskMetric{}, + }, + } + + Source.healthChannel <- ecstcs.HealthMessage{ + Metadata: &ecstcs.HealthMetadata{}, + HealthMetrics: []*ecstcs.TaskHealth{}, + } + + case <-ctx.Done(): + defer close(Source.metricsChannel) + defer close(Source.healthChannel) + return + } + } +} + +func TestFormatURL(t *testing.T) { + endpoint := "http://127.0.0.0.1/" + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + wsurl := formatURL(endpoint, testClusterArn, testInstanceArn, testAgentVersion, testAgentHash, + ContainerRuntimeDocker, testContainerRuntimeVersion) + parsed, err := url.Parse(wsurl) + assert.NoError(t, err, "should be able to parse URL") + assert.Equal(t, "/ws", parsed.Path, "wrong path") + assert.Equal(t, testClusterArn, parsed.Query().Get("cluster"), "wrong cluster") + assert.Equal(t, testInstanceArn, parsed.Query().Get("containerInstance"), "wrong container instance") + assert.Equal(t, testAgentVersion, parsed.Query().Get("agentVersion"), "wrong agent version") + assert.Equal(t, testAgentHash, parsed.Query().Get("agentHash"), "wrong agent hash") + assert.Equal(t, testContainerRuntimeVersion, parsed.Query().Get("dockerVersion"), "wrong docker version") +} + +func TestStartTelemetrySession(t *testing.T) { + // Start test server. + closeWS := make(chan []byte) + server, serverChan, requestChan, serverErr, err := wsmock.GetMockServer(closeWS) + server.StartTLS() + defer server.Close() + if err != nil { + t.Fatal(err) + } + + telemetryMessages := make(chan ecstcs.TelemetryMessage, testTelemetryChannelDefaultBufferSize) + healthMessages := make(chan ecstcs.HealthMessage, testTelemetryChannelDefaultBufferSize) + + wait := &sync.WaitGroup{} + ctx, cancel := context.WithCancel(context.Background()) + wait.Add(1) + go func() { + select { + case sErr := <-serverErr: + t.Error(sErr) + case <-ctx.Done(): + } + wait.Done() + }() + defer func() { + closeSocket(closeWS) + close(serverChan) + }() + + deregisterInstanceEventStream := eventstream.NewEventStream("Deregister_Instance", context.Background()) + + mockSource := &mockStatsSource{ + metricsChannel: telemetryMessages, + healthChannel: healthMessages, + } + + session := NewTelemetrySession( + testInstanceArn, + testClusterArn, + testAgentVersion, + testAgentHash, + testContainerRuntimeVersion, + server.URL, + false, + testCreds, + testCfg, + deregisterInstanceEventStream, + testHeartbeatTimeout, + testHeartbeatJitter, + testDisconnectionTimeout, + testDisconnectionJitter, + nil, + telemetryMessages, + healthMessages, + emptyDoctor, + ) + + // Start a session with the test server. + go session.StartTelemetrySession(ctx, server.URL) + + // Wait for 100 ms to make sure the session is ready to receive message from channel + time.Sleep(testSendMetricsToChannelWaitTime) + go mockSource.SimulateMetricsPublishToChannel(ctx) + + // startTelemetrySession internally starts publishing metrics from the mockStatsSource object (poll msg out of channel). + time.Sleep(testPublishMetricsInterval * 2) + + // Read request channel to get the metric data published to the server. + request := <-requestChan + cancel() + wait.Wait() + go func() { + for { + select { + case <-requestChan: + } + } + }() + + // Decode and verify the metric data. + payload, err := getPayloadFromRequest(request) + if err != nil { + t.Fatal("Error decoding payload: ", err) + } + + // Decode and verify the metric data. + _, responseType, err := wsclient.DecodeData([]byte(payload), tcsclient.NewTCSDecoder()) + if err != nil { + t.Fatal("error decoding data: ", err) + } + if responseType != "PublishMetricsRequest" { + t.Fatal("Unexpected responseType: ", responseType) + } +} + +func TestSessionConnectionClosedByRemote(t *testing.T) { + // Start test server. + closeWS := make(chan []byte) + server, serverChan, _, serverErr, err := wsmock.GetMockServer(closeWS) + server.StartTLS() + defer server.Close() + if err != nil { + t.Fatal(err) + } + go func() { + serr := <-serverErr + if !websocket.IsCloseError(serr, websocket.CloseNormalClosure) { + t.Error(serr) + } + }() + sleepBeforeClose := 10 * time.Millisecond + go func() { + time.Sleep(sleepBeforeClose) + closeSocket(closeWS) + close(serverChan) + }() + + ctx, cancel := context.WithCancel(context.Background()) + deregisterInstanceEventStream := eventstream.NewEventStream("Deregister_Instance", ctx) + deregisterInstanceEventStream.StartListening() + defer cancel() + + telemetryMessages := make(chan ecstcs.TelemetryMessage, testTelemetryChannelDefaultBufferSize) + healthMessages := make(chan ecstcs.HealthMessage, testTelemetryChannelDefaultBufferSize) + + session := NewTelemetrySession( + testInstanceArn, + testClusterArn, + testAgentVersion, + testAgentHash, + testContainerRuntimeVersion, + server.URL, + false, + testCreds, + testCfg, + deregisterInstanceEventStream, + testHeartbeatTimeout, + testHeartbeatJitter, + testDisconnectionTimeout, + testDisconnectionJitter, + nil, + telemetryMessages, + healthMessages, + emptyDoctor, + ) + + // Start a session with the test server. + err = session.StartTelemetrySession(ctx, server.URL) + + if err == nil { + t.Error("Expected io.EOF on closed connection") + } + if err != io.EOF { + t.Error("Expected io.EOF on closed connection, got: ", err) + } +} + +// TestConnectionInactiveTimeout tests the tcs client reconnect when it loses network +// connection, or it's inactive for too long +func TestConnectionInactiveTimeout(t *testing.T) { + // Start test server. + closeWS := make(chan []byte) + server, _, requestChan, serverErr, err := wsmock.GetMockServer(closeWS) + server.StartTLS() + defer server.Close() + if err != nil { + t.Fatal(err) + } + + go func() { + for { + select { + case <-requestChan: + } + } + }() + + ctx, cancel := context.WithCancel(context.Background()) + deregisterInstanceEventStream := eventstream.NewEventStream("Deregister_Instance", ctx) + deregisterInstanceEventStream.StartListening() + defer cancel() + + telemetryMessages := make(chan ecstcs.TelemetryMessage, testTelemetryChannelDefaultBufferSize) + healthMessages := make(chan ecstcs.HealthMessage, testTelemetryChannelDefaultBufferSize) + + session := NewTelemetrySession( + testInstanceArn, + testClusterArn, + testAgentVersion, + testAgentHash, + testContainerRuntimeVersion, + server.URL, + false, + testCreds, + testCfg, + deregisterInstanceEventStream, + 50*time.Millisecond, + 100*time.Millisecond, + testDisconnectionTimeout, + testDisconnectionJitter, + nil, + telemetryMessages, + healthMessages, + emptyDoctor, + ) + + // Start a session with the test server. + err = session.StartTelemetrySession(ctx, server.URL) + + assert.NoError(t, err, "Close the connection should cause the tcs client return error") + + assert.True(t, websocket.IsCloseError(<-serverErr, websocket.CloseAbnormalClosure), + "Read from closed connection should produce an io.EOF error") + + closeSocket(closeWS) +} + +func getPayloadFromRequest(request string) (string, error) { + lines := strings.Split(request, "\r\n") + if len(lines) > 0 { + return lines[len(lines)-1], nil + } + + return "", errors.New("Could not get payload") +} + +// closeSocket tells the server to send a close frame. This lets us test +// what happens if the connection is closed by the remote server. +func closeSocket(ws chan<- []byte) { + ws <- websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") + close(ws) +} + +func createPublishMetricsRequest() *ecstcs.PublishMetricsRequest { + cluster := testClusterArn + ci := testInstanceArn + taskArn := testTaskArn + taskDefinitionFamily := testTaskDefinitionFamily + var fval float64 + fval = rand.Float64() + var ival int64 + ival = rand.Int63n(10) + ts := time.Now() + idle := false + messageId := testMessageId + return &ecstcs.PublishMetricsRequest{ + Metadata: &ecstcs.MetricsMetadata{ + Cluster: &cluster, + ContainerInstance: &ci, + Idle: &idle, + MessageId: &messageId, + }, + TaskMetrics: []*ecstcs.TaskMetric{ + { + ContainerMetrics: []*ecstcs.ContainerMetric{ + { + CpuStatsSet: &ecstcs.CWStatsSet{ + Max: &fval, + Min: &fval, + SampleCount: &ival, + Sum: &fval, + }, + MemoryStatsSet: &ecstcs.CWStatsSet{ + Max: &fval, + Min: &fval, + SampleCount: &ival, + Sum: &fval, + }, + }, + }, + TaskArn: &taskArn, + TaskDefinitionFamily: &taskDefinitionFamily, + }, + }, + Timestamp: &ts, + } +} + +func TestStartTelemetrySessionMetricsChannelPauseWhenClientClosed(t *testing.T) { + telemetryMessages := make(chan ecstcs.TelemetryMessage, testTelemetryChannelDefaultBufferSize) + healthMessages := make(chan ecstcs.HealthMessage, testTelemetryChannelDefaultBufferSize) + + // Start test server. + closeWS := make(chan []byte) + server, _, _, _, _ := wsmock.GetMockServer(closeWS) + server.StartTLS() + defer server.Close() + + ctx, cancel := context.WithCancel(context.Background()) + + deregisterInstanceEventStream := eventstream.NewEventStream("Deregister_Instance", context.Background()) + deregisterInstanceEventStream.StartListening() + + session := NewTelemetrySession( + testInstanceArn, + testClusterArn, + testAgentVersion, + testAgentHash, + testContainerRuntimeVersion, + server.URL, + false, + testCreds, + testCfg, + deregisterInstanceEventStream, + testHeartbeatTimeout, + testHeartbeatJitter, + testDisconnectionTimeout, + testDisconnectionJitter, + nil, + telemetryMessages, + healthMessages, + emptyDoctor, + ) + + go session.StartTelemetrySession(ctx, server.URL) + telemetryMessages <- ecstcs.TelemetryMessage{} + for len(telemetryMessages) != 0 { + time.Sleep(1 * time.Second) + } // wait till tcs client is up and is polling message + + cancel() + time.Sleep(5 * time.Second) // wait till tcs client is stopped (returned from StartTelemetrySession) + + // Send message while TCS Client is closed and verify the message is not polled but stays in the channel + for it := 0; it < testTelemetryChannelDefaultBufferSize; it++ { + telemetryMessages <- ecstcs.TelemetryMessage{} + } + + // check messages filled the channel + assert.Len(t, telemetryMessages, testTelemetryChannelDefaultBufferSize) + // check after channel is full, message will be dropped + + // simulating retry after backoff + newCtx, _ := context.WithCancel(context.Background()) + go session.StartTelemetrySession(newCtx, server.URL) + for len(telemetryMessages) != 0 { + time.Sleep(1 * time.Second) + } // test will time out if after tcs client startup, message does not resume flowing +}