Skip to content

Commit

Permalink
Support SET/RESET SESSION and PREPARE
Browse files Browse the repository at this point in the history
  • Loading branch information
nineinchnick authored and losipiuk committed Jul 11, 2022
1 parent d4b8686 commit a4d1a6d
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 10 deletions.
8 changes: 0 additions & 8 deletions trino/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,10 +467,6 @@ func TestIntegrationUnsupportedHeader(t *testing.T) {
query string
err error
}{
{
query: "SET SESSION optimize_hash_generation=true",
err: ErrUnsupportedHeader,
},
{
query: "SET ROLE dummy",
err: errors.New(`trino: query failed (200 OK): "io.trino.spi.TrinoException: line 1:1: Role 'dummy' does not exist"`),
Expand All @@ -479,10 +475,6 @@ func TestIntegrationUnsupportedHeader(t *testing.T) {
query: "SET PATH dummy",
err: errors.New(`trino: query failed (200 OK): "io.trino.spi.TrinoException: SET PATH not supported by client"`),
},
{
query: "RESET SESSION optimize_hash_generation",
err: ErrUnsupportedHeader,
},
}
for _, c := range cases {
_, err := db.Query(c.query)
Expand Down
32 changes: 30 additions & 2 deletions trino/trino.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ const (
trinoSetRoleHeader = trinoHeaderPrefix + `Set-Role`
trinoExtraCredentialHeader = trinoHeaderPrefix + `Extra-Credential`

trinoAddedPrepareHeader = trinoHeaderPrefix + `Added-Prepare`
trinoDeallocatedPrepareHeader = trinoHeaderPrefix + `Deallocated-Prepare`

KerberosEnabledConfig = "KerberosEnabled"
kerberosKeytabPathConfig = "KerberosKeytabPath"
kerberosPrincipalConfig = "KerberosPrincipal"
Expand All @@ -133,8 +136,6 @@ var (
}
unsupportedResponseHeaders = []string{
trinoSetPathHeader,
trinoSetSessionHeader,
trinoClearSessionHeader,
trinoSetRoleHeader,
}
)
Expand Down Expand Up @@ -462,6 +463,30 @@ func (c *Conn) roundTrip(ctx context.Context, req *http.Request) (*http.Response
c.httpHeaders.Set(dst, v)
}
}
if v := resp.Header.Get(trinoAddedPrepareHeader); v != "" {
c.httpHeaders.Add(preparedStatementHeader, v)
}
if v := resp.Header.Get(trinoDeallocatedPrepareHeader); v != "" {
values := c.httpHeaders.Values(preparedStatementHeader)
c.httpHeaders.Del(preparedStatementHeader)
for _, v2 := range values {
if !strings.HasPrefix(v2, v+"=") {
c.httpHeaders.Add(preparedStatementHeader, v2)
}
}
}
if v := resp.Header.Get(trinoSetSessionHeader); v != "" {
c.httpHeaders.Add(trinoSessionHeader, v)
}
if v := resp.Header.Get(trinoClearSessionHeader); v != "" {
values := c.httpHeaders.Values(trinoSessionHeader)
c.httpHeaders.Del(trinoSessionHeader)
for _, v2 := range values {
if !strings.HasPrefix(v2, v+"=") {
c.httpHeaders.Add(trinoSessionHeader, v2)
}
}
}
for _, name := range unsupportedResponseHeaders {
if v := resp.Header.Get(name); v != "" {
return nil, ErrUnsupportedHeader
Expand Down Expand Up @@ -668,6 +693,9 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
hs.Add(arg.Name, headerValue)
} else {
if hs.Get(preparedStatementHeader) == "" {
for _, v := range st.conn.httpHeaders.Values(preparedStatementHeader) {
hs.Add(preparedStatementHeader, v)
}
hs.Add(preparedStatementHeader, preparedStatementName+"="+url.QueryEscape(st.query))
}
ss = append(ss, s)
Expand Down
44 changes: 44 additions & 0 deletions trino/trino_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,50 @@ func TestQueryFailure(t *testing.T) {
assert.IsTypef(t, new(ErrQueryFailed), err, "unexpected error: %w", err)
}

func TestSession(t *testing.T) {
err := RegisterCustomClient("uncompressed", &http.Client{Transport: &http.Transport{DisableCompression: true}})
if err != nil {
t.Fatal(err)
}
c := &Config{
ServerURI: "http://foobar@localhost:8080?custom_client=uncompressed",
SessionProperties: map[string]string{"query_priority": "1"},
}

dsn, err := c.FormatDSN()
require.NoError(t, err)

db, err := sql.Open("trino", dsn)
require.NoError(t, err)

t.Cleanup(func() {
assert.NoError(t, db.Close())
})

_, err = db.Exec("SET SESSION join_distribution_type='BROADCAST'")
require.NoError(t, err, "Failed executing query")

row := db.QueryRow("SHOW SESSION LIKE 'join_distribution_type'")
var name string
var value string
var defaultValue string
var typeName string
var description string
err = row.Scan(&name, &value, &defaultValue, &typeName, &description)
require.NoError(t, err, "Failed executing query")

assert.Equal(t, "BROADCAST", value)

_, err = db.Exec("RESET SESSION join_distribution_type")
require.NoError(t, err, "Failed executing query")

row = db.QueryRow("SHOW SESSION LIKE 'join_distribution_type'")
err = row.Scan(&name, &value, &defaultValue, &typeName, &description)
require.NoError(t, err, "Failed executing query")

assert.Equal(t, "AUTOMATIC", value)
}

func TestUnsupportedHeader(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(trinoSetRoleHeader, "foo")
Expand Down

0 comments on commit a4d1a6d

Please sign in to comment.