Skip to content

Commit

Permalink
Adding support for pg_get_serial_sequence(text,text)
Browse files Browse the repository at this point in the history
  • Loading branch information
fulghum committed Nov 15, 2024
1 parent b704225 commit c438685
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 0 deletions.
1 change: 1 addition & 0 deletions server/functions/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ func Init() {
initPgIsInRecovery()
initPgPostmasterStartTime()
initPgRelationSize()
initPgGetSerialSequence()
initPgStatGetNumscans()
initPgTableIsVisible()
initPgTableSize()
Expand Down
117 changes: 117 additions & 0 deletions server/functions/pg_get_serial_sequence.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package functions

import (
"fmt"
"strings"

"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/resolve"
"github.com/dolthub/go-mysql-server/sql"

"github.com/dolthub/doltgresql/core"
"github.com/dolthub/doltgresql/server/functions/framework"
pgtypes "github.com/dolthub/doltgresql/server/types"
)

// initPgGetSerialSequence registers the functions to the catalog.
func initPgGetSerialSequence() {
framework.RegisterFunction(pg_get_serial_sequence_text_text)
}

// pg_get_serial_sequence_text_text represents the PostgreSQL function of the same name, taking the same parameters.
var pg_get_serial_sequence_text_text = framework.Function2{
Name: "pg_get_serial_sequence",
Return: pgtypes.Text,
Parameters: [2]pgtypes.DoltgresType{pgtypes.Text, pgtypes.Text},
Variadic: false,
IsNonDeterministic: false,
Strict: true,
Callable: func(ctx *sql.Context, paramsAndReturn [3]pgtypes.DoltgresType, val1 any, val2 any) (any, error) {
tableName := val1.(string)
columnName := val2.(string)

// Parse out the schema if one was supplied
var err error
schemaName := ""
if strings.Contains(tableName, ".") {
// TODO: parseRelationName() will return the first schema from the search_path if one is not included
// in the relation name, but that doesn't mean it's the correct schema. It should be updated to
// not return any schema name if one wasn't explicitly specified, then we should search for the
// table on the search_path and find the first schema that contains a table with that name.
schemaName, tableName, err = parseRelationName(ctx, tableName)
if err != nil {
return nil, err
}
}

// Resolve the table's schema if it wasn't specified
if schemaName == "" {
doltSession := dsess.DSessFromSess(ctx.Session)
roots, ok := doltSession.GetRoots(ctx, ctx.GetCurrentDatabase())
if !ok {
return nil, fmt.Errorf("unable to get roots")
}
foundTableName, _, ok, err := resolve.TableWithSearchPath(ctx, roots.Working, tableName)
if err != nil {
return nil, err
}
if !ok {
return nil, fmt.Errorf(`relation "%s" does not exist`, tableName)
}
schemaName = foundTableName.Schema
}

// Validate the full schema + table name and grab the columns
table, err := core.GetSqlTableFromContext(ctx, "", doltdb.TableName{
Schema: schemaName,
Name: tableName,
})
if err != nil {
return nil, err
}
if table == nil {
return nil, fmt.Errorf(`relation "%s" does not exist`, tableName)
}
tableSchema := table.Schema()

// Find the column in the table's schema
columnIndex := tableSchema.IndexOfColName(columnName)
if columnIndex < 0 {
return nil, fmt.Errorf(`column "%s" of relation "%s" does not exist`, columnName, tableName)
}
column := tableSchema[columnIndex]

// Find any sequence associated with the column
sequenceCollection, err := core.GetSequencesCollectionFromContext(ctx)
if err != nil {
return nil, err
}
sequences := sequenceCollection.GetSequencesWithTable(doltdb.TableName{
Name: tableName,
Schema: schemaName,
})
for _, sequence := range sequences {
if sequence.OwnerColumn == column.Name {
// pg_get_serial_sequence() always includes the schema name in its output
return schemaName + "." + sequence.Name, nil
}
}

return nil, nil
},
}
52 changes: 52 additions & 0 deletions testing/go/functions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,58 @@ func TestSystemInformationFunctions(t *testing.T) {
},
},
},
{
Name: "pg_get_serial_sequence",
SetUpScript: []string{
`create table t0 (id INTEGER NOT NULL PRIMARY KEY);`,
`create table t1 (id SERIAL PRIMARY KEY);`,
`create sequence t2_id_seq START 1 INCREMENT 3;`,
`create table t2 (id INTEGER NOT NULL DEFAULT nextval('t2_id_seq'));`,
// TODO: ALTER SEQUENCE OWNED BY is not supported yet. When the sequence is created
// explicitly, separate from the column, the owner must be udpated before
// pg_get_serial_sequence() will identify it.
//`ALTER SEQUENCE t2_id_seq OWNED BY t2.id;`,
},
Assertions: []ScriptTestAssertion{
{
Query: `SELECT pg_get_serial_sequence('doesnotexist.t1', 'id');`,
ExpectedErr: "does not exist",
},
{
Query: `SELECT pg_get_serial_sequence('doesnotexist', 'id');`,
ExpectedErr: "does not exist",
},
{
Query: `SELECT pg_get_serial_sequence('t0', 'doesnotexist');`,
ExpectedErr: "does not exist",
},
{
// No sequence for column returns null
Query: `SELECT pg_get_serial_sequence('t0', 'id');`,
Cols: []string{"pg_get_serial_sequence"},
Expected: []sql.Row{{nil}},
},
{
Query: `SELECT pg_get_serial_sequence('public.t1', 'id');`,
Cols: []string{"pg_get_serial_sequence"},
Expected: []sql.Row{{"public.t1_id_seq"}},
},
{
// Test with no schema specified
Query: `SELECT pg_get_serial_sequence('t1', 'id');`,
Cols: []string{"pg_get_serial_sequence"},
Expected: []sql.Row{{"public.t1_id_seq"}},
},
{
// TODO: This test shouldn't pass until we're able to use
// ALTER SEQUENCE OWNED BY to set the owning column.
Skip: true,
Query: `SELECT pg_get_serial_sequence('t2', 'id');`,
Cols: []string{"pg_get_serial_sequence"},
Expected: []sql.Row{{"public.t2_id_seq"}},
},
},
},
})
}

Expand Down

0 comments on commit c438685

Please sign in to comment.