From 467db65cc0c143f32645ef352dbb6e97875173f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Wa=C5=9B?= Date: Sun, 6 Oct 2024 13:24:43 +0200 Subject: [PATCH] Support EXECUTE IMMEDIATE Use EXECUTE IMMEDIATE sent in the HTTP request body, instead of putting the query text in HTTP headers. This should allow sending large query text. It can be enabled by setting the `explicitPrepare` option to false in the connection string. --- trino/etc/config.properties | 2 ++ trino/etc/jvm.config | 2 +- trino/integration_test.go | 33 ++++++++++++++++++++++++++++++--- trino/trino.go | 19 +++++++++++++++++-- 4 files changed, 50 insertions(+), 6 deletions(-) diff --git a/trino/etc/config.properties b/trino/etc/config.properties index d4d09aa..4dbf12d 100644 --- a/trino/etc/config.properties +++ b/trino/etc/config.properties @@ -12,3 +12,5 @@ http-server.https.port=8443 http-server.authentication.allow-insecure-over-http=true http-server.https.keystore.path=/etc/trino/secrets/certificate_with_key.pem internal-communication.shared-secret=gotrino + +query.max-length=5000043 diff --git a/trino/etc/jvm.config b/trino/etc/jvm.config index bb4dca9..cdf7bde 100644 --- a/trino/etc/jvm.config +++ b/trino/etc/jvm.config @@ -1,4 +1,4 @@ --Xmx1G +-Xmx4G -XX:+UseG1GC -XX:G1HeapRegionSize=32M -XX:+UseGCOverheadLimit diff --git a/trino/integration_test.go b/trino/integration_test.go index d78c1ae..0442229 100644 --- a/trino/integration_test.go +++ b/trino/integration_test.go @@ -34,6 +34,7 @@ import ( "math/big" "net/http" "os" + "strconv" "strings" "testing" "time" @@ -75,6 +76,9 @@ func TestMain(m *testing.M) { flag.Parse() DefaultQueryTimeout = *integrationServerQueryTimeout DefaultCancelQueryTimeout = *integrationServerQueryTimeout + if *trinoImageTagFlag == "" { + *trinoImageTagFlag = "latest" + } var err error if *integrationServerFlag == "" && !testing.Short() { @@ -97,9 +101,6 @@ func TestMain(m *testing.M) { if err != nil { log.Fatalf("Could not generate TLS certificates: %s", err) } - if *trinoImageTagFlag == "" { - *trinoImageTagFlag = "latest" - } resource, err = pool.RunWithOptions(&dt.RunOptions{ Name: name, Repository: "trinodb/trino", @@ -1112,3 +1113,29 @@ func TestIntegrationDayToHourIntervalMilliPrecision(t *testing.T) { }) } } + +func TestIntegrationLargeQuery(t *testing.T) { + version, err := strconv.Atoi(*trinoImageTagFlag) + if (err != nil && *trinoImageTagFlag != "latest") || (err == nil && version < 418) { + t.Skip("Skipping test when not using Trino 418 or later.") + } + dsn := *integrationServerFlag + dsn += "?explicitPrepare=false" + db := integrationOpen(t, dsn) + defer db.Close() + rows, err := db.Query("SELECT ?, '"+strings.Repeat("a", 5000000)+"'", 42) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + count := 0 + for rows.Next() { + count++ + } + if rows.Err() != nil { + t.Fatal(err) + } + if count != 1 { + t.Fatal("not enough rows returned:", count) + } +} diff --git a/trino/trino.go b/trino/trino.go index 5204717..528381b 100644 --- a/trino/trino.go +++ b/trino/trino.go @@ -141,6 +141,7 @@ const ( sslCertPathConfig = "SSLCertPath" sslCertConfig = "SSLCert" accessTokenConfig = "accessToken" + explicitPrepareConfig = "explicitPrepare" ) var ( @@ -282,6 +283,7 @@ type Conn struct { kerberosRemoteServiceName string progressUpdater ProgressUpdater progressUpdaterPeriod queryProgressCallbackPeriod + useExplicitPrepare bool } var ( @@ -298,6 +300,10 @@ func newConn(dsn string) (*Conn, error) { query := serverURL.Query() kerberosEnabled, _ := strconv.ParseBool(query.Get(kerberosEnabledConfig)) + useExplicitPrepare := true + if query.Get(explicitPrepareConfig) != "" { + useExplicitPrepare, _ = strconv.ParseBool(query.Get(explicitPrepareConfig)) + } var kerberosClient *client.Client @@ -356,6 +362,7 @@ func newConn(dsn string) (*Conn, error) { kerberosClient: kerberosClient, kerberosEnabled: kerberosEnabled, kerberosRemoteServiceName: query.Get(kerberosRemoteServiceNameConfig), + useExplicitPrepare: useExplicitPrepare, } var user string @@ -867,7 +874,7 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt hs.Add(arg.Name, headerValue) } else { - if hs.Get(preparedStatementHeader) == "" { + if st.conn.useExplicitPrepare && hs.Get(preparedStatementHeader) == "" { for _, v := range st.conn.httpHeaders.Values(preparedStatementHeader) { hs.Add(preparedStatementHeader, v) } @@ -880,7 +887,11 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt return nil, ErrInvalidProgressCallbackHeader } if len(ss) > 0 { - query = "EXECUTE " + preparedStatementName + " USING " + strings.Join(ss, ", ") + if st.conn.useExplicitPrepare { + query = "EXECUTE " + preparedStatementName + " USING " + strings.Join(ss, ", ") + } else { + query = "EXECUTE IMMEDIATE " + formatStringLiteral(st.query) + " USING " + strings.Join(ss, ", ") + } } } @@ -1028,6 +1039,10 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt return &sr, handleResponseError(resp.StatusCode, sr.Error) } +func formatStringLiteral(query string) string { + return "'" + strings.ReplaceAll(query, "'", "''") + "'" +} + type driverRows struct { ctx context.Context stmt *driverStmt