From e7d99838689c4d26f91fd014ffbfc9ee4deaf31f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Wa=C5=9B?= Date: Tue, 30 Apr 2024 10:59:02 +0200 Subject: [PATCH] Add integration test for JWT auth --- go.mod | 1 + go.sum | 2 ++ trino/integration_test.go | 58 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 61 insertions(+) diff --git a/go.mod b/go.mod index 32017a6..ce579b8 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/trinodb/trino-go-client go 1.21 require ( + github.com/golang-jwt/jwt/v4 v4.5.0 github.com/ory/dockertest/v3 v3.10.0 github.com/stretchr/testify v1.9.0 gopkg.in/jcmturner/gokrb5.v6 v6.1.1 diff --git a/go.sum b/go.sum index d4e74a4..e385be0 100644 --- a/go.sum +++ b/go.sum @@ -27,6 +27,8 @@ github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfC github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= +github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= diff --git a/trino/integration_test.go b/trino/integration_test.go index 4d1df97..aa41c50 100644 --- a/trino/integration_test.go +++ b/trino/integration_test.go @@ -36,6 +36,7 @@ import ( "testing" "time" + "github.com/golang-jwt/jwt/v4" dt "github.com/ory/dockertest/v3" ) @@ -875,6 +876,63 @@ func TestIntegrationQueryContextCancellation(t *testing.T) { } } +func TestIntegrationAccessToken(t *testing.T) { + if tlsServer == "" { + t.Skip("Skipping access token test when using a custom integration server.") + } + + accessToken, err := generateToken() + if err != nil { + t.Fatal(err) + } + + dsn := tlsServer + "?accessToken=" + accessToken + + db := integrationOpen(t, dsn) + + defer db.Close() + rows, err := db.Query("SHOW CATALOGS") + if err != nil { + t.Fatal(err) + } + defer rows.Close() + count := 0 + for rows.Next() { + count++ + } + if count < 1 { + t.Fatal("not enough rows returned:", count) + } +} + +func generateToken() (string, error) { + privateKeyPEM, err := os.ReadFile("etc/secrets/private_key.pem") + if err != nil { + return "", fmt.Errorf("error reading private key file: %w", err) + } + + privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(privateKeyPEM) + if err != nil { + return "", fmt.Errorf("error parsing private key: %w", err) + } + + // Subject must be 'test' + claims := jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * 365 * time.Hour)), + Issuer: "gotrino", + Subject: "test", + } + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + signedToken, err := token.SignedString(privateKey) + + if err != nil { + return "", fmt.Errorf("error generating token: %w", err) + } + + return signedToken, nil +} + func contextSleep(ctx context.Context, d time.Duration) error { timer := time.NewTimer(100 * time.Millisecond) select {