From d541d8e4f6699db6216a0aeee5a7a5c23ed8a3b7 Mon Sep 17 00:00:00 2001 From: Eytan Kidron Date: Thu, 14 Sep 2023 15:30:36 -0400 Subject: [PATCH] feat(idtoken): add ParsePayload returning unvalidated token payload (#2136) --- idtoken/validate.go | 19 ++++++++++++++++++- idtoken/validate_test.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/idtoken/validate.go b/idtoken/validate.go index 5f89b9b630e..47d314f50e5 100644 --- a/idtoken/validate.go +++ b/idtoken/validate.go @@ -122,6 +122,23 @@ func Validate(ctx context.Context, idToken string, audience string) (*Payload, e return defaultValidator.validate(ctx, idToken, audience) } +// ParsePayload parses the given token and returns its payload. +// +// Warning: This function does not validate the token prior to parsing it. +// +// ParsePayload is primarily meant to be used to inspect a token's payload. This is +// useful when validation fails and the payload needs to be inspected. +// +// Note: A successful Validate() invocation with the same token will return an +// identical payload. +func ParsePayload(idToken string) (*Payload, error) { + jwt, err := parseJWT(idToken) + if err != nil { + return nil, err + } + return jwt.parsedPayload() +} + func (v *Validator) validate(ctx context.Context, idToken string, audience string) (*Payload, error) { jwt, err := parseJWT(idToken) if err != nil { @@ -145,7 +162,7 @@ func (v *Validator) validate(ctx context.Context, idToken string, audience strin } if now().Unix() > payload.Expires { - return nil, fmt.Errorf("idtoken: token expired") + return nil, fmt.Errorf("idtoken: token expired: now=%v, expires=%v", now().Unix(), payload.Expires) } switch header.Algorithm { diff --git a/idtoken/validate_test.go b/idtoken/validate_test.go index 46528db2747..6c254c7c783 100644 --- a/idtoken/validate_test.go +++ b/idtoken/validate_test.go @@ -231,6 +231,39 @@ func TestValidateES256(t *testing.T) { } } +func TestParsePayload(t *testing.T) { + idToken, _ := createRS256JWT(t) + tests := []struct { + name string + token string + wantPayloadAudience string + wantErr bool + }{{ + name: "valid token", + token: idToken, + wantPayloadAudience: testAudience, + }, { + name: "unparseable token", + token: "aaa.bbb.ccc", + wantErr: true, + }} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + payload, err := ParsePayload(tt.token) + gotErr := err != nil + if gotErr != tt.wantErr { + t.Errorf("ParsePayload(%q) got error %v, wantErr = %v", tt.token, err, tt.wantErr) + } + if tt.wantPayloadAudience != "" { + if payload == nil || payload.Audience != tt.wantPayloadAudience { + t.Errorf("ParsePayload(%q) got payload %+v, want payload with audience = %q", tt.token, payload, tt.wantPayloadAudience) + } + } + }) + } +} + func createES256JWT(t *testing.T) (string, ecdsa.PublicKey) { t.Helper() token := commonToken(t, "ES256")