From 0d4abb902a64b49179c2720ab95d6bc495d4f0a8 Mon Sep 17 00:00:00 2001 From: Amogh Rathore Date: Fri, 28 Apr 2023 16:49:11 +0000 Subject: [PATCH] Consume TMDS init function from ecs-agent module --- agent/config/config.go | 3 - agent/config/config_unix.go | 3 +- agent/config/config_windows.go | 3 +- agent/config/config_windows_test.go | 3 +- agent/go.mod | 2 +- agent/handlers/introspection_server_setup.go | 3 +- agent/handlers/task_server_setup.go | 42 ++--- agent/handlers/task_server_setup_test.go | 104 ++++++++----- .../tmds/logging}/logging_handler.go | 2 +- .../amazon-ecs-agent/ecs-agent/tmds/server.go | 144 ++++++++++++++++++ .../ecs-agent/tmds/utils/mux/mux.go | 34 +++++ agent/vendor/modules.txt | 4 + 12 files changed, 277 insertions(+), 70 deletions(-) rename agent/{handlers => vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/logging}/logging_handler.go (98%) create mode 100644 agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/server.go create mode 100644 agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/mux/mux.go diff --git a/agent/config/config.go b/agent/config/config.go index 38f0d0f4b5b..bdd4381dec7 100644 --- a/agent/config/config.go +++ b/agent/config/config.go @@ -44,9 +44,6 @@ const ( // AgentIntrospectionPort is used to serve the metadata about the agent and to query the tasks being managed by the agent. AgentIntrospectionPort = 51678 - // AgentCredentialsPort is used to serve the credentials for tasks. - AgentCredentialsPort = 51679 - // AgentPrometheusExpositionPort is used to expose Prometheus metrics that can be scraped by a Prometheus server AgentPrometheusExpositionPort = 51680 diff --git a/agent/config/config_unix.go b/agent/config/config_unix.go index 6c1254d35b2..6e068238e58 100644 --- a/agent/config/config_unix.go +++ b/agent/config/config_unix.go @@ -23,6 +23,7 @@ import ( "github.com/aws/amazon-ecs-agent/agent/dockerclient" "github.com/aws/amazon-ecs-agent/agent/utils" + "github.com/aws/amazon-ecs-agent/ecs-agent/tmds" ) const ( @@ -61,7 +62,7 @@ const ( func DefaultConfig() Config { return Config{ DockerEndpoint: "unix:///var/run/docker.sock", - ReservedPorts: []uint16{SSHPort, DockerReservedPort, DockerReservedSSLPort, AgentIntrospectionPort, AgentCredentialsPort}, + ReservedPorts: []uint16{SSHPort, DockerReservedPort, DockerReservedSSLPort, AgentIntrospectionPort, tmds.Port}, ReservedPortsUDP: []uint16{}, DataDir: "/data/", DataDirOnHost: "/var/lib/ecs", diff --git a/agent/config/config_windows.go b/agent/config/config_windows.go index 8f73140a6e7..fc12be02cf4 100644 --- a/agent/config/config_windows.go +++ b/agent/config/config_windows.go @@ -25,6 +25,7 @@ import ( "github.com/aws/amazon-ecs-agent/agent/dockerclient" "github.com/aws/amazon-ecs-agent/agent/utils" + "github.com/aws/amazon-ecs-agent/ecs-agent/tmds" "github.com/cihub/seelog" "github.com/hectane/go-acl/api" @@ -103,7 +104,7 @@ func DefaultConfig() Config { DockerReservedPort, DockerReservedSSLPort, AgentIntrospectionPort, - AgentCredentialsPort, + tmds.Port, rdpPort, rpcPort, smbPort, diff --git a/agent/config/config_windows_test.go b/agent/config/config_windows_test.go index f2c4a323678..5a16237ef53 100644 --- a/agent/config/config_windows_test.go +++ b/agent/config/config_windows_test.go @@ -26,6 +26,7 @@ import ( "github.com/aws/amazon-ecs-agent/agent/dockerclient" "github.com/aws/amazon-ecs-agent/agent/ec2" + "github.com/aws/amazon-ecs-agent/ecs-agent/tmds" "github.com/hectane/go-acl/api" "github.com/stretchr/testify/assert" @@ -83,7 +84,7 @@ func TestConfigIAMTaskRolesReserves80(t *testing.T) { DockerReservedPort, DockerReservedSSLPort, AgentIntrospectionPort, - AgentCredentialsPort, + tmds.Port, rdpPort, rpcPort, smbPort, diff --git a/agent/go.mod b/agent/go.mod index d9975b23afe..86f678d9809 100644 --- a/agent/go.mod +++ b/agent/go.mod @@ -11,7 +11,6 @@ require ( github.com/containernetworking/cni v0.8.1 github.com/containernetworking/plugins v0.9.1 github.com/deniswernert/udev v0.0.0-20170418162847-a12666f7b5a1 - github.com/didip/tollbooth v4.0.2+incompatible github.com/docker/docker v20.10.23+incompatible github.com/docker/go-connections v0.4.0 github.com/docker/go-units v0.4.0 @@ -46,6 +45,7 @@ require ( github.com/containerd/continuity v0.3.0 // indirect github.com/coreos/go-systemd/v22 v22.3.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/didip/tollbooth v4.0.2+incompatible // indirect github.com/docker/distribution v2.8.1+incompatible // indirect github.com/godbus/dbus/v5 v5.0.6 // indirect github.com/gogo/protobuf v1.3.2 // indirect diff --git a/agent/handlers/introspection_server_setup.go b/agent/handlers/introspection_server_setup.go index 3ca0ecc7ede..d0646182127 100644 --- a/agent/handlers/introspection_server_setup.go +++ b/agent/handlers/introspection_server_setup.go @@ -27,6 +27,7 @@ import ( handlersutils "github.com/aws/amazon-ecs-agent/agent/handlers/utils" v1 "github.com/aws/amazon-ecs-agent/agent/handlers/v1" "github.com/aws/amazon-ecs-agent/agent/utils/retry" + logginghandler "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/logging" "github.com/cihub/seelog" ) @@ -81,7 +82,7 @@ func introspectionServerSetup(containerInstanceArn *string, taskEngine handlersu // Log all requests and then pass through to serverMux loggingServeMux := http.NewServeMux() - loggingServeMux.Handle("/", LoggingHandler{serverMux}) + loggingServeMux.Handle("/", logginghandler.NewLoggingHandler(serverMux)) wTimeout := writeTimeout if cfg.EnableRuntimeStats.Enabled() { diff --git a/agent/handlers/task_server_setup.go b/agent/handlers/task_server_setup.go index ca846817ba3..3ad6d3f8d70 100644 --- a/agent/handlers/task_server_setup.go +++ b/agent/handlers/task_server_setup.go @@ -16,7 +16,6 @@ package handlers import ( "context" "net/http" - "strconv" "time" "github.com/aws/amazon-ecs-agent/agent/api" @@ -24,7 +23,6 @@ import ( "github.com/aws/amazon-ecs-agent/agent/credentials" "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate" agentAPITaskProtectionV1 "github.com/aws/amazon-ecs-agent/agent/handlers/agentapi/taskprotection/v1/handlers" - handlersutils "github.com/aws/amazon-ecs-agent/agent/handlers/utils" v1 "github.com/aws/amazon-ecs-agent/agent/handlers/v1" v2 "github.com/aws/amazon-ecs-agent/agent/handlers/v2" v3 "github.com/aws/amazon-ecs-agent/agent/handlers/v3" @@ -33,8 +31,8 @@ import ( "github.com/aws/amazon-ecs-agent/agent/stats" "github.com/aws/amazon-ecs-agent/agent/utils/retry" auditinterface "github.com/aws/amazon-ecs-agent/ecs-agent/logger/audit" + "github.com/aws/amazon-ecs-agent/ecs-agent/tmds" "github.com/cihub/seelog" - "github.com/didip/tollbooth" "github.com/gorilla/mux" ) @@ -61,7 +59,8 @@ func taskServerSetup(credentialsManager credentials.Manager, vpcID string, containerInstanceArn string, apiEndpoint string, - acceptInsecureCert bool) *http.Server { + acceptInsecureCert bool) (*http.Server, error) { + muxRouter := mux.NewRouter() // Set this to false so that for request like "//v3//metadata/task" @@ -79,28 +78,13 @@ func taskServerSetup(credentialsManager credentials.Manager, agentAPIV1HandlersSetup(muxRouter, state, credentialsManager, cluster, region, apiEndpoint, acceptInsecureCert) - limiter := tollbooth.NewLimiter(float64(steadyStateRate), nil) - limiter.SetOnLimitReached(handlersutils.LimitReachedHandler(auditLogger)) - limiter.SetBurst(burstRate) - - // Log all requests and then pass through to muxRouter. - loggingMuxRouter := mux.NewRouter() - - // rootPath is a path for any traffic to this endpoint, "root" mux name will not be used. - rootPath := "/" + handlersutils.ConstructMuxVar("root", handlersutils.AnythingRegEx) - loggingMuxRouter.Handle(rootPath, tollbooth.LimitHandler( - limiter, NewLoggingHandler(muxRouter))) - - loggingMuxRouter.SkipClean(false) - - server := http.Server{ - Addr: "127.0.0.1:" + strconv.Itoa(config.AgentCredentialsPort), - Handler: loggingMuxRouter, - ReadTimeout: readTimeout, - WriteTimeout: writeTimeout, - } - - return &server + return tmds.NewServer(auditLogger, + tmds.WithRouter(muxRouter), + tmds.WithListenAddress(tmds.AddressIPv4()), + tmds.WithReadTimeout(readTimeout), + tmds.WithWriteTimeout(writeTimeout), + tmds.WithSteadyStateRate(float64(steadyStateRate)), + tmds.WithBurstRate(burstRate)) } // v2HandlersSetup adds all handlers in v2 package to the mux router. @@ -200,9 +184,13 @@ func ServeTaskHTTPEndpoint( auditLogger := audit.NewAuditLog(containerInstanceArn, cfg, logger) - server := taskServerSetup(credentialsManager, auditLogger, state, ecsClient, cfg.Cluster, cfg.AWSRegion, statsEngine, + server, err := taskServerSetup(credentialsManager, auditLogger, state, ecsClient, cfg.Cluster, cfg.AWSRegion, statsEngine, cfg.TaskMetadataSteadyStateRate, cfg.TaskMetadataBurstRate, availabilityZone, vpcID, containerInstanceArn, cfg.APIEndpoint, cfg.AcceptInsecureCert) + if err != nil { + seelog.Criticalf("Failed to set up Task Metadata Server: %v", err) + return + } go func() { <-ctx.Done() diff --git a/agent/handlers/task_server_setup_test.go b/agent/handlers/task_server_setup_test.go index 2a3cd5942eb..7b6c83d78ce 100644 --- a/agent/handlers/task_server_setup_test.go +++ b/agent/handlers/task_server_setup_test.go @@ -664,9 +664,10 @@ func testErrorResponsesFromServer(t *testing.T, path string, expectedErrorMessag credentialsManager := mock_credentials.NewMockManager(ctrl) auditLog := mock_audit.NewMockAuditLogger(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) - server := taskServerSetup(credentialsManager, auditLog, nil, ecsClient, "", "", nil, + server, err := taskServerSetup(credentialsManager, auditLog, nil, ecsClient, "", "", nil, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, containerInstanceArn, "", true) + require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", path, nil) @@ -700,9 +701,11 @@ func getResponseForCredentialsRequest(t *testing.T, expectedStatus int, credentialsManager := mock_credentials.NewMockManager(ctrl) auditLog := mock_audit.NewMockAuditLogger(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) - server := taskServerSetup(credentialsManager, auditLog, nil, ecsClient, "", "", nil, + server, err := taskServerSetup(credentialsManager, auditLog, nil, ecsClient, "", "", nil, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, containerInstanceArn, "", true) + require.NoError(t, err) + recorder := httptest.NewRecorder() creds, ok := getCredentials() @@ -769,9 +772,10 @@ func TestV2TaskMetadata(t *testing.T) { state.EXPECT().TaskByArn(taskARN).Return(task, true), state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToDockerContainer, true), ) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, availabilityzone, vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", tc.path, nil) req.RemoteAddr = remoteIP + ":" + remotePort @@ -855,9 +859,10 @@ func TestV2TaskWithTagsMetadata(t *testing.T) { }, }, nil), ) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, availabilityzone, vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v2BaseMetadataWithTagsPath, nil) req.RemoteAddr = remoteIP + ":" + remotePort @@ -887,9 +892,11 @@ func TestV2ContainerMetadata(t *testing.T) { state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true), state.EXPECT().TaskByID(containerID).Return(task, true), ) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) + recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v2BaseMetadataPath+"/"+containerID, nil) req.RemoteAddr = remoteIP + ":" + remotePort @@ -918,9 +925,10 @@ func TestV2ContainerStats(t *testing.T) { state.EXPECT().GetTaskByIPAddress(remoteIP).Return(taskARN, true), statsEngine.EXPECT().ContainerDockerStats(taskARN, containerID).Return(dockerStats, &stats.NetworkStatsPerSec{}, nil), ) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v2BaseStatsPath+"/"+containerID, nil) req.RemoteAddr = remoteIP + ":" + remotePort @@ -968,9 +976,10 @@ func TestV2TaskStats(t *testing.T) { state.EXPECT().ContainerMapByArn(taskARN).Return(containerMap, true), statsEngine.EXPECT().ContainerDockerStats(taskARN, containerID).Return(dockerStats, &stats.NetworkStatsPerSec{}, nil), ) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", tc.path, nil) req.RemoteAddr = remoteIP + ":" + remotePort @@ -1003,9 +1012,10 @@ func TestV3TaskMetadata(t *testing.T) { state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToDockerContainer, true), state.EXPECT().TaskByArn(taskARN).Return(task, true), ) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, availabilityzone, vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v3BasePath+v3EndpointID+"/task", nil) server.Handler.ServeHTTP(recorder, req) @@ -1034,9 +1044,10 @@ func TestV3BridgeTaskMetadata(t *testing.T) { state.EXPECT().TaskByArn(taskARN).Return(bridgeTask, true), state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true), ) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, availabilityzone, vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v3BasePath+v3EndpointID+"/task", nil) server.Handler.ServeHTTP(recorder, req) @@ -1064,9 +1075,10 @@ func TestV3BridgeContainerMetadata(t *testing.T) { state.EXPECT().TaskByID(containerID).Return(bridgeTask, true), state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true), ) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v3BasePath+v3EndpointID, nil) server.Handler.ServeHTTP(recorder, req) @@ -1136,9 +1148,10 @@ func TestV3TaskMetadataWithTags(t *testing.T) { }, nil), state.EXPECT().TaskByArn(taskARN).Return(task, true), ) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, availabilityzone, vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v3BasePath+v3EndpointID+"/taskWithTags", nil) server.Handler.ServeHTTP(recorder, req) @@ -1165,9 +1178,10 @@ func TestV3ContainerMetadata(t *testing.T) { state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true), state.EXPECT().TaskByID(containerID).Return(task, true), ) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v3BasePath+v3EndpointID, nil) server.Handler.ServeHTTP(recorder, req) @@ -1203,9 +1217,10 @@ func TestV3TaskStats(t *testing.T) { state.EXPECT().ContainerMapByArn(taskARN).Return(containerMap, true), statsEngine.EXPECT().ContainerDockerStats(taskARN, containerID).Return(dockerStats, &stats.NetworkStatsPerSec{}, nil), ) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v3BasePath+v3EndpointID+"/task/stats", nil) server.Handler.ServeHTTP(recorder, req) @@ -1237,9 +1252,10 @@ func TestV3ContainerStats(t *testing.T) { state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), statsEngine.EXPECT().ContainerDockerStats(taskARN, containerID).Return(dockerStats, &stats.NetworkStatsPerSec{}, nil), ) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v3BasePath+v3EndpointID+"/stats", nil) server.Handler.ServeHTTP(recorder, req) @@ -1267,9 +1283,10 @@ func TestV3ContainerAssociations(t *testing.T) { state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true), state.EXPECT().TaskByArn(taskARN).Return(task, true), ) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v3BasePath+v3EndpointID+"/associations/"+associationType, nil) server.Handler.ServeHTTP(recorder, req) @@ -1296,9 +1313,10 @@ func TestV3ContainerAssociation(t *testing.T) { state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true), state.EXPECT().TaskByArn(taskARN).Return(task, true), ) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v3BasePath+v3EndpointID+"/associations/"+associationType+"/"+associationName, nil) server.Handler.ServeHTTP(recorder, req) @@ -1325,9 +1343,10 @@ func TestV4TaskMetadata(t *testing.T) { state.EXPECT().TaskByArn(taskARN).Return(task, true).AnyTimes(), state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true), ) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, availabilityzone, vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v4BasePath+v3EndpointID+"/task", nil) server.Handler.ServeHTTP(recorder, req) @@ -1359,9 +1378,10 @@ func TestV4TaskMetadataWithPulledContainers(t *testing.T) { state.EXPECT().TaskByArn(taskARN).Return(pulledTask, true).AnyTimes(), state.EXPECT().PulledContainerMapByArn(taskARN).Return(pulledContainerNameToDockerContainer, true), ) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, availabilityzone, vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v4BasePath+v3EndpointID+"/task", nil) server.Handler.ServeHTTP(recorder, req) @@ -1390,9 +1410,10 @@ func TestV4ContainerMetadata(t *testing.T) { state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true), state.EXPECT().TaskByID(containerID).Return(task, true).Times(2), ) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, availabilityzone, vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v4BasePath+v3EndpointID, nil) server.Handler.ServeHTTP(recorder, req) @@ -1470,9 +1491,10 @@ func TestV4TaskMetadataWithTags(t *testing.T) { state.EXPECT().TaskByArn(taskARN).Return(task, true).AnyTimes(), state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true), ) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, availabilityzone, vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v4BasePath+v3EndpointID+"/taskWithTags", nil) server.Handler.ServeHTTP(recorder, req) @@ -1506,9 +1528,10 @@ func TestV4BridgeTaskMetadata(t *testing.T) { state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true), ) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, availabilityzone, vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v4BasePath+v3EndpointID+"/task", nil) server.Handler.ServeHTTP(recorder, req) @@ -1542,9 +1565,10 @@ func TestV4BridgeTaskMetadataAllowMissingContainerNetwork(t *testing.T) { state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true), ) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, availabilityzone, vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v4BasePath+v3EndpointID+"/task", nil) server.Handler.ServeHTTP(recorder, req) @@ -1572,9 +1596,10 @@ func TestV4BridgeContainerMetadata(t *testing.T) { state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true), ) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v4BasePath+v3EndpointID, nil) server.Handler.ServeHTTP(recorder, req) @@ -1612,9 +1637,10 @@ func TestV4TaskStats(t *testing.T) { state.EXPECT().ContainerMapByArn(taskARN).Return(containerMap, true), statsEngine.EXPECT().ContainerDockerStats(taskARN, containerID).Return(dockerStats, &stats.NetworkStatsPerSec{}, nil), ) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v4BasePath+v3EndpointID+"/task/stats", nil) server.Handler.ServeHTTP(recorder, req) @@ -1646,9 +1672,10 @@ func TestV4ContainerStats(t *testing.T) { state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), statsEngine.EXPECT().ContainerDockerStats(taskARN, containerID).Return(dockerStats, &stats.NetworkStatsPerSec{}, nil), ) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v4BasePath+v3EndpointID+"/stats", nil) server.Handler.ServeHTTP(recorder, req) @@ -1676,9 +1703,10 @@ func TestV4ContainerAssociations(t *testing.T) { state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true), state.EXPECT().TaskByArn(taskARN).Return(task, true), ) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v4BasePath+v3EndpointID+"/associations/"+associationType, nil) server.Handler.ServeHTTP(recorder, req) @@ -1705,9 +1733,10 @@ func TestV4ContainerAssociation(t *testing.T) { state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true), state.EXPECT().TaskByArn(taskARN).Return(task, true), ) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v4BasePath+v3EndpointID+"/associations/"+associationType+"/"+associationName, nil) server.Handler.ServeHTTP(recorder, req) @@ -1731,9 +1760,10 @@ func TestTaskHTTPEndpoint301Redirect(t *testing.T) { statsEngine := mock_stats.NewMockEngine(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) for testPath, expectedPath := range testPathsMap { t.Run(fmt.Sprintf("Test path: %s", testPath), func(t *testing.T) { @@ -1773,9 +1803,10 @@ func TestTaskHTTPEndpointErrorCode404(t *testing.T) { statsEngine := mock_stats.NewMockEngine(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) for _, testPath := range testPaths { t.Run(fmt.Sprintf("Test path: %s", testPath), func(t *testing.T) { @@ -1812,9 +1843,10 @@ func TestTaskHTTPEndpointErrorCode400(t *testing.T) { statsEngine := mock_stats.NewMockEngine(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) for _, testPath := range testPaths { t.Run(fmt.Sprintf("Test path: %s", testPath), func(t *testing.T) { @@ -1850,9 +1882,10 @@ func TestTaskHTTPEndpointErrorCode500(t *testing.T) { statsEngine := mock_stats.NewMockEngine(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) for _, testPath := range testPaths { t.Run(fmt.Sprintf("Test path: %s", testPath), func(t *testing.T) { @@ -1919,9 +1952,10 @@ func TestV4TaskNotFoundError404(t *testing.T) { statsEngine := mock_stats.NewMockEngine(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) state.EXPECT().TaskARNByV3EndpointID(gomock.Any()).Return("", tc.taskFound).AnyTimes() state.EXPECT().DockerIDByV3EndpointID(gomock.Any()).Return("", false).AnyTimes() @@ -1974,9 +2008,10 @@ func TestV4Unexpected500Error(t *testing.T) { statsEngine := mock_stats.NewMockEngine(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) // Initial lookups succeed state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true).AnyTimes() @@ -2019,9 +2054,10 @@ func testAgentAPITaskProtectionV1Handler(t *testing.T, requestBody interface{}, ) // Set up the server - server := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) // Prepare the request var requestReader io.Reader = nil diff --git a/agent/handlers/logging_handler.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/logging/logging_handler.go similarity index 98% rename from agent/handlers/logging_handler.go rename to agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/logging/logging_handler.go index 6db8697b444..033b13779ae 100644 --- a/agent/handlers/logging_handler.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/logging/logging_handler.go @@ -11,7 +11,7 @@ // express or implied. See the License for the specific language governing // permissions and limitations under the License. -package handlers +package logging import ( "net/http" diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/server.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/server.go new file mode 100644 index 00000000000..fa3de539a04 --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/server.go @@ -0,0 +1,144 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. +package tmds + +import ( + "errors" + "fmt" + "net/http" + "time" + + "github.com/aws/amazon-ecs-agent/ecs-agent/logger/audit" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger/audit/request" + "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/logging" + muxutils "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/mux" + + "github.com/didip/tollbooth" + "github.com/gorilla/mux" +) + +const ( + // TMDS IP and port + IPv4 = "127.0.0.1" + Port = 51679 +) + +// IPv4 address for TMDS +func AddressIPv4() string { + return fmt.Sprintf("%s:%d", IPv4, Port) +} + +// Configuration for TMDS +type Config struct { + listenAddress string // http server listen address + readTimeout time.Duration // http server read timeout + writeTimeout time.Duration // http server write timeout + steadyStateRate float64 // steady request rate limit + burstRate int // burst request rate limit + router *mux.Router // router with routes configured +} + +// Function type for updating TMDS config +type ConfigOpt func(*Config) + +// Set TMDS listen address +func WithListenAddress(listenAddr string) ConfigOpt { + return func(c *Config) { + c.listenAddress = listenAddr + } +} + +// Set TMDS read timeout +func WithReadTimeout(readTimeout time.Duration) ConfigOpt { + return func(c *Config) { + c.readTimeout = readTimeout + } +} + +// Set TMDS write timeout +func WithWriteTimeout(writeTimeout time.Duration) ConfigOpt { + return func(c *Config) { + c.writeTimeout = writeTimeout + } +} + +// Set TMDS steady request rate limit +func WithSteadyStateRate(steadyStateRate float64) ConfigOpt { + return func(c *Config) { + c.steadyStateRate = steadyStateRate + } +} + +// Set TMDS burst request rate limit +func WithBurstRate(burstRate int) ConfigOpt { + return func(c *Config) { + c.burstRate = burstRate + } +} + +// Set TMDS router +func WithRouter(router *mux.Router) ConfigOpt { + return func(c *Config) { + c.router = router + } +} + +// Create a new HTTP Task Metadata Server (TMDS) +func NewServer(auditLogger audit.AuditLogger, options ...ConfigOpt) (*http.Server, error) { + config := new(Config) + for _, opt := range options { + opt(config) + } + + return setup(auditLogger, config) +} + +func setup(auditLogger audit.AuditLogger, config *Config) (*http.Server, error) { + if config.router == nil { + return nil, errors.New("router cannot be nil") + } + + // Define a reqeuest rate limiter + limiter := tollbooth. + NewLimiter(config.steadyStateRate, nil). + SetOnLimitReached(limitReachedHandler(auditLogger)). + SetBurst(config.burstRate) + + // Log all requests and then pass through to muxRouter. + loggingMuxRouter := mux.NewRouter() + + // rootPath is a path for any traffic to this endpoint + rootPath := "/" + muxutils.ConstructMuxVar("root", muxutils.AnythingRegEx) + loggingMuxRouter.Handle(rootPath, tollbooth.LimitHandler( + limiter, logging.NewLoggingHandler(config.router))) + + // explicitly enable path cleaning + loggingMuxRouter.SkipClean(false) + + return &http.Server{ + Addr: config.listenAddress, + Handler: loggingMuxRouter, + ReadTimeout: config.readTimeout, + WriteTimeout: config.writeTimeout, + }, nil +} + +// LimitReachedHandler logs the throttled request in the credentials audit log +func limitReachedHandler(auditLogger audit.AuditLogger) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + logRequest := request.LogRequest{ + Request: r, + } + auditLogger.Log(logRequest, http.StatusTooManyRequests, "") + } +} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/mux/mux.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/mux/mux.go new file mode 100644 index 00000000000..fd83d39e441 --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/mux/mux.go @@ -0,0 +1,34 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. +package mux + +const ( + // AnythingRegEx is a regex pattern that matches anything. + AnythingRegEx = ".*" + + // AnythingButSlashRegEx is a regex pattern that matches any string without slash. + AnythingButSlashRegEx = "[^/]*" + + // AnythingButEmptyRegEx is a regex pattern that matches anything but an empty string. + AnythingButEmptyRegEx = ".+" +) + +// ConstructMuxVar constructs the mux var that is used in the gorilla/mux styled +// path, example: {id}, {id:[0-9]+}. +func ConstructMuxVar(name string, pattern string) string { + if pattern == "" { + return "{" + name + "}" + } + + return "{" + name + ":" + pattern + "}" +} diff --git a/agent/vendor/modules.txt b/agent/vendor/modules.txt index 8239737f285..afb6eec44b0 100644 --- a/agent/vendor/modules.txt +++ b/agent/vendor/modules.txt @@ -7,9 +7,13 @@ github.com/Microsoft/go-winio/pkg/guid github.com/Microsoft/hcsshim/osversion # github.com/aws/amazon-ecs-agent/ecs-agent v0.0.0 => ../ecs-agent ## explicit; go 1.19 +github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs github.com/aws/amazon-ecs-agent/ecs-agent/logger/audit github.com/aws/amazon-ecs-agent/ecs-agent/logger/audit/mocks github.com/aws/amazon-ecs-agent/ecs-agent/logger/audit/request +github.com/aws/amazon-ecs-agent/ecs-agent/tmds +github.com/aws/amazon-ecs-agent/ecs-agent/tmds/logging +github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/mux # github.com/aws/aws-sdk-go v1.36.0 ## explicit; go 1.11 github.com/aws/aws-sdk-go/aws