diff --git a/trino/etc/config.properties b/trino/etc/config.properties index 6221b02..3eac336 100644 --- a/trino/etc/config.properties +++ b/trino/etc/config.properties @@ -4,3 +4,11 @@ node-scheduler.include-coordinator=true http-server.http.port=8080 discovery-server.enabled=true discovery.uri=http://localhost:8080 + +http-server.authentication.type=JWT +http-server.authentication.jwt.key-file=/etc/trino/secrets/public_key.pem +http-server.https.enabled=true +http-server.https.port=8443 +http-server.authentication.allow-insecure-over-http=true +http-server.https.keystore.path=/etc/trino/secrets/certificate_with_key.pem +internal-communication.shared-secret=gotrino diff --git a/trino/etc/secrets/.gitignore b/trino/etc/secrets/.gitignore new file mode 100644 index 0000000..cfaad76 --- /dev/null +++ b/trino/etc/secrets/.gitignore @@ -0,0 +1 @@ +*.pem diff --git a/trino/integration_test.go b/trino/integration_test.go index b7dbdf3..4d1df97 100644 --- a/trino/integration_test.go +++ b/trino/integration_test.go @@ -16,12 +16,20 @@ package trino import ( "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" "database/sql" "database/sql/driver" + "encoding/pem" "errors" "flag" + "fmt" "io" "log" + "math/big" "net/http" "os" "strings" @@ -55,6 +63,7 @@ var ( false, "do not delete containers on exit", ) + tlsServer = "" ) func TestMain(m *testing.M) { @@ -79,6 +88,10 @@ func TestMain(m *testing.M) { resource, ok = pool.ContainerByName(name) if !ok { + err = generateCerts(wd + "/etc/secrets") + if err != nil { + log.Fatalf("Could not generate TLS certificates: %s", err) + } if *trinoImageTagFlag == "" { *trinoImageTagFlag = "latest" } @@ -87,6 +100,10 @@ func TestMain(m *testing.M) { Repository: "trinodb/trino", Tag: *trinoImageTagFlag, Mounts: []string{wd + "/etc:/etc/trino"}, + ExposedPorts: []string{ + "8080/tcp", + "8443/tcp", + }, }) if err != nil { log.Fatalf("Could not start resource: %s", err) @@ -106,6 +123,12 @@ func TestMain(m *testing.M) { log.Fatalf("Timed out waiting for container to get ready: %s", err) } *integrationServerFlag = "http://test@localhost:" + resource.GetPort("8080/tcp") + tlsServer = "https://test@localhost:" + resource.GetPort("8443/tcp") + + http.DefaultTransport.(*http.Transport).TLSClientConfig, err = getTLSConfig(wd + "/etc/secrets") + if err != nil { + log.Fatalf("Failed to set the default TLS config: %s", err) + } } code := m.Run() @@ -120,6 +143,104 @@ func TestMain(m *testing.M) { os.Exit(code) } +func generateCerts(dir string) error { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return fmt.Errorf("failed to generate private key: %w", err) + } + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return fmt.Errorf("failed to generate serial number: %w", err) + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Trino Software Foundation"}, + }, + DNSNames: []string{"localhost"}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(1 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + privBytes, err := x509.MarshalPKCS8PrivateKey(priv) + if err != nil { + return fmt.Errorf("unable to marshal private key: %w", err) + } + privBlock := &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes} + err = writePEM(dir+"/private_key.pem", privBlock) + if err != nil { + return err + } + + pubBytes, err := x509.MarshalPKIXPublicKey(&priv.PublicKey) + if err != nil { + return fmt.Errorf("unable to marshal public key: %w", err) + } + pubBlock := &pem.Block{Type: "PUBLIC KEY", Bytes: pubBytes} + err = writePEM(dir+"/public_key.pem", pubBlock) + if err != nil { + return err + } + + certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return fmt.Errorf("failed to create certificate: %w", err) + } + certBlock := &pem.Block{Type: "CERTIFICATE", Bytes: certBytes} + err = writePEM(dir+"/certificate.pem", certBlock) + if err != nil { + return err + } + + err = writePEM(dir+"/certificate_with_key.pem", certBlock, privBlock, pubBlock) + if err != nil { + return err + } + + return nil +} + +func writePEM(filename string, blocks ...*pem.Block) error { + // all files are world-readable, so they can be read inside the Trino container + out, err := os.Create(filename) + if err != nil { + return fmt.Errorf("failed to open %s for writing: %w", filename, err) + } + for _, block := range blocks { + if err := pem.Encode(out, block); err != nil { + return fmt.Errorf("failed to write %s data to %s: %w", block.Type, filename, err) + } + } + if err := out.Close(); err != nil { + return fmt.Errorf("error closing %s: %w", filename, err) + } + return nil +} + +func getTLSConfig(dir string) (*tls.Config, error) { + certPool, err := x509.SystemCertPool() + if err != nil { + return nil, fmt.Errorf("failed to read the system cert pool: %s", err) + } + caCertPEM, err := os.ReadFile(dir + "/certificate.pem") + if err != nil { + return nil, fmt.Errorf("failed to read the certificate: %s", err) + } + ok := certPool.AppendCertsFromPEM(caCertPEM) + if !ok { + return nil, fmt.Errorf("failed to parse the certificate: %s", err) + } + return &tls.Config{ + RootCAs: certPool, + }, nil +} + // integrationOpen opens a connection to the integration test server. func integrationOpen(t *testing.T, dsn ...string) *sql.DB { if testing.Short() {