diff --git a/trino/etc/config.properties b/trino/etc/config.properties index d4d09aa..4dbf12d 100644 --- a/trino/etc/config.properties +++ b/trino/etc/config.properties @@ -12,3 +12,5 @@ http-server.https.port=8443 http-server.authentication.allow-insecure-over-http=true http-server.https.keystore.path=/etc/trino/secrets/certificate_with_key.pem internal-communication.shared-secret=gotrino + +query.max-length=5000043 diff --git a/trino/etc/jvm.config b/trino/etc/jvm.config index bb4dca9..cdf7bde 100644 --- a/trino/etc/jvm.config +++ b/trino/etc/jvm.config @@ -1,4 +1,4 @@ --Xmx1G +-Xmx4G -XX:+UseG1GC -XX:G1HeapRegionSize=32M -XX:+UseGCOverheadLimit diff --git a/trino/integration_test.go b/trino/integration_test.go index d78c1ae..0442229 100644 --- a/trino/integration_test.go +++ b/trino/integration_test.go @@ -34,6 +34,7 @@ import ( "math/big" "net/http" "os" + "strconv" "strings" "testing" "time" @@ -75,6 +76,9 @@ func TestMain(m *testing.M) { flag.Parse() DefaultQueryTimeout = *integrationServerQueryTimeout DefaultCancelQueryTimeout = *integrationServerQueryTimeout + if *trinoImageTagFlag == "" { + *trinoImageTagFlag = "latest" + } var err error if *integrationServerFlag == "" && !testing.Short() { @@ -97,9 +101,6 @@ func TestMain(m *testing.M) { if err != nil { log.Fatalf("Could not generate TLS certificates: %s", err) } - if *trinoImageTagFlag == "" { - *trinoImageTagFlag = "latest" - } resource, err = pool.RunWithOptions(&dt.RunOptions{ Name: name, Repository: "trinodb/trino", @@ -1112,3 +1113,29 @@ func TestIntegrationDayToHourIntervalMilliPrecision(t *testing.T) { }) } } + +func TestIntegrationLargeQuery(t *testing.T) { + version, err := strconv.Atoi(*trinoImageTagFlag) + if (err != nil && *trinoImageTagFlag != "latest") || (err == nil && version < 418) { + t.Skip("Skipping test when not using Trino 418 or later.") + } + dsn := *integrationServerFlag + dsn += "?explicitPrepare=false" + db := integrationOpen(t, dsn) + defer db.Close() + rows, err := db.Query("SELECT ?, '"+strings.Repeat("a", 5000000)+"'", 42) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + count := 0 + for rows.Next() { + count++ + } + if rows.Err() != nil { + t.Fatal(err) + } + if count != 1 { + t.Fatal("not enough rows returned:", count) + } +} diff --git a/trino/trino.go b/trino/trino.go index 5204717..528381b 100644 --- a/trino/trino.go +++ b/trino/trino.go @@ -141,6 +141,7 @@ const ( sslCertPathConfig = "SSLCertPath" sslCertConfig = "SSLCert" accessTokenConfig = "accessToken" + explicitPrepareConfig = "explicitPrepare" ) var ( @@ -282,6 +283,7 @@ type Conn struct { kerberosRemoteServiceName string progressUpdater ProgressUpdater progressUpdaterPeriod queryProgressCallbackPeriod + useExplicitPrepare bool } var ( @@ -298,6 +300,10 @@ func newConn(dsn string) (*Conn, error) { query := serverURL.Query() kerberosEnabled, _ := strconv.ParseBool(query.Get(kerberosEnabledConfig)) + useExplicitPrepare := true + if query.Get(explicitPrepareConfig) != "" { + useExplicitPrepare, _ = strconv.ParseBool(query.Get(explicitPrepareConfig)) + } var kerberosClient *client.Client @@ -356,6 +362,7 @@ func newConn(dsn string) (*Conn, error) { kerberosClient: kerberosClient, kerberosEnabled: kerberosEnabled, kerberosRemoteServiceName: query.Get(kerberosRemoteServiceNameConfig), + useExplicitPrepare: useExplicitPrepare, } var user string @@ -867,7 +874,7 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt hs.Add(arg.Name, headerValue) } else { - if hs.Get(preparedStatementHeader) == "" { + if st.conn.useExplicitPrepare && hs.Get(preparedStatementHeader) == "" { for _, v := range st.conn.httpHeaders.Values(preparedStatementHeader) { hs.Add(preparedStatementHeader, v) } @@ -880,7 +887,11 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt return nil, ErrInvalidProgressCallbackHeader } if len(ss) > 0 { - query = "EXECUTE " + preparedStatementName + " USING " + strings.Join(ss, ", ") + if st.conn.useExplicitPrepare { + query = "EXECUTE " + preparedStatementName + " USING " + strings.Join(ss, ", ") + } else { + query = "EXECUTE IMMEDIATE " + formatStringLiteral(st.query) + " USING " + strings.Join(ss, ", ") + } } } @@ -1028,6 +1039,10 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt return &sr, handleResponseError(resp.StatusCode, sr.Error) } +func formatStringLiteral(query string) string { + return "'" + strings.ReplaceAll(query, "'", "''") + "'" +} + type driverRows struct { ctx context.Context stmt *driverStmt