diff --git a/internal/session/session.go b/internal/session/session.go index 91f95334..1e2ca45c 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -19,6 +19,7 @@ import ( "fmt" "github.com/argoproj-labs/argocd-agent/pkg/types" + "k8s.io/apimachinery/pkg/api/validation" ) /* @@ -26,11 +27,31 @@ Package session contains various functions to access and manipulate session data. */ +// ClientIdFromContext returns the client ID stored in context ctx. If there +// is no client ID in the context, or the client ID is invalid, returns an +// error. func ClientIdFromContext(ctx context.Context) (string, error) { val := ctx.Value(types.ContextAgentIdentifier) - if clientId, ok := val.(string); !ok { + clientId, ok := val.(string) + if !ok { return "", fmt.Errorf("no client identifier found in context") - } else { - return clientId, nil } + if !IsValidClientId(clientId) { + return "", fmt.Errorf("invalid client identifier: %s", clientId) + } + return clientId, nil +} + +// ClientIdToContext returns a copy of context ctx with the clientId stored +func ClientIdToContext(ctx context.Context, clientId string) context.Context { + return context.WithValue(ctx, types.ContextAgentIdentifier, clientId) +} + +// IsValidClientId returns true if the string s is considered a valid client +// identifier. +func IsValidClientId(s string) bool { + if errs := validation.NameIsDNSSubdomain(s, false); len(errs) > 0 { + return false + } + return true } diff --git a/internal/session/session_test.go b/internal/session/session_test.go new file mode 100644 index 00000000..a5846f51 --- /dev/null +++ b/internal/session/session_test.go @@ -0,0 +1,28 @@ +package session + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_ClientIdFromContext(t *testing.T) { + t.Run("Successfully extract client ID", func(t *testing.T) { + ctx := ClientIdToContext(context.Background(), "agent") + a, err := ClientIdFromContext(ctx) + assert.NoError(t, err) + assert.Equal(t, "agent", a) + }) + t.Run("No client ID in context", func(t *testing.T) { + a, err := ClientIdFromContext(context.Background()) + assert.ErrorContains(t, err, "no client identifier") + assert.Empty(t, a) + }) + t.Run("Invalid client ID in context", func(t *testing.T) { + ctx := ClientIdToContext(context.Background(), "ag_ent") + a, err := ClientIdFromContext(ctx) + assert.ErrorContains(t, err, "invalid client identifier") + assert.Empty(t, a) + }) +} diff --git a/principal/apis/eventstream/eventstream.go b/principal/apis/eventstream/eventstream.go index ccc47cb6..2c02c3d8 100644 --- a/principal/apis/eventstream/eventstream.go +++ b/principal/apis/eventstream/eventstream.go @@ -23,8 +23,8 @@ import ( "github.com/argoproj-labs/argocd-agent/internal/event" "github.com/argoproj-labs/argocd-agent/internal/queue" + "github.com/argoproj-labs/argocd-agent/internal/session" "github.com/argoproj-labs/argocd-agent/pkg/api/grpc/eventstreamapi" - "github.com/argoproj-labs/argocd-agent/pkg/types" "github.com/argoproj/argo-cd/v2/pkg/apis/application/v1alpha1" "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" @@ -136,7 +136,7 @@ func (s *Server) newClientConnection(ctx context.Context, timeout time.Duration) c := &client{} c.wg = &sync.WaitGroup{} - agentName, err := agentName(ctx) + agentName, err := session.ClientIdFromContext(ctx) if err != nil { return nil, status.Error(codes.InvalidArgument, err.Error()) } @@ -160,17 +160,6 @@ func (s *Server) newClientConnection(ctx context.Context, timeout time.Duration) return c, nil } -// agentName gets the agent name from the context ctx. If no agent identifier -// could be found in the context, returns an error. -func agentName(ctx context.Context) (string, error) { - agentName, ok := ctx.Value(types.ContextAgentIdentifier).(string) - if !ok { - return "", fmt.Errorf("invalid context: no agent name") - } - // TODO: check agentName for validity - return agentName, nil -} - // onDisconnect must be called whenever client c disconnects from the stream func (s *Server) onDisconnect(c *client) { c.lock.Lock() @@ -379,7 +368,7 @@ func (s *Server) Push(pushs eventstreamapi.EventStream_PushServer) error { } defer cancel() - agentName, err := agentName(ctx) + agentName, err := session.ClientIdFromContext(ctx) if err != nil { return status.Error(codes.InvalidArgument, err.Error()) } diff --git a/principal/auth.go b/principal/auth.go index 4634ce68..8e6db3d6 100644 --- a/principal/auth.go +++ b/principal/auth.go @@ -21,6 +21,7 @@ import ( "github.com/argoproj-labs/argocd-agent/internal/auth" "github.com/argoproj-labs/argocd-agent/internal/grpcutil" + "github.com/argoproj-labs/argocd-agent/internal/session" "github.com/argoproj-labs/argocd-agent/pkg/types" middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2" "google.golang.org/grpc" @@ -117,7 +118,7 @@ func (s *Server) authenticate(ctx context.Context) (context.Context, error) { // claims at this point is validated and we can propagate values to the // context. - authCtx := context.WithValue(ctx, types.ContextAgentIdentifier, agentInfo.ClientID) + authCtx := session.ClientIdToContext(ctx, agentInfo.ClientID) if !s.queues.HasQueuePair(agentInfo.ClientID) { logCtx.Tracef("Creating a new queue pair for client %s", agentInfo.ClientID) if err := s.queues.Create(agentInfo.ClientID); err != nil {