Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for pg_get_serial_sequence(text,text) #969

Merged
merged 2 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions server/functions/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ func Init() {
initPgIsInRecovery()
initPgPostmasterStartTime()
initPgRelationSize()
initPgGetSerialSequence()
initPgStatGetNumscans()
initPgTableIsVisible()
initPgTableSize()
Expand Down
117 changes: 117 additions & 0 deletions server/functions/pg_get_serial_sequence.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package functions

import (
"fmt"
"strings"

"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/resolve"
"github.com/dolthub/go-mysql-server/sql"

"github.com/dolthub/doltgresql/core"
"github.com/dolthub/doltgresql/server/functions/framework"
pgtypes "github.com/dolthub/doltgresql/server/types"
)

// initPgGetSerialSequence registers the functions to the catalog.
func initPgGetSerialSequence() {
framework.RegisterFunction(pg_get_serial_sequence_text_text)
}

// pg_get_serial_sequence_text_text represents the PostgreSQL function of the same name, taking the same parameters.
var pg_get_serial_sequence_text_text = framework.Function2{
Name: "pg_get_serial_sequence",
Return: pgtypes.Text,
Parameters: [2]*pgtypes.DoltgresType{pgtypes.Text, pgtypes.Text},
Variadic: false,
IsNonDeterministic: false,
Strict: true,
Callable: func(ctx *sql.Context, paramsAndReturn [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) {
tableName := val1.(string)
columnName := val2.(string)

// Parse out the schema if one was supplied
var err error
schemaName := ""
if strings.Contains(tableName, ".") {
// TODO: parseRelationName() will return the first schema from the search_path if one is not included
// in the relation name, but that doesn't mean it's the correct schema. It should be updated to
// not return any schema name if one wasn't explicitly specified, then we should search for the
// table on the search_path and find the first schema that contains a table with that name.
schemaName, tableName, err = parseRelationName(ctx, tableName)
if err != nil {
return nil, err
}
}

// Resolve the table's schema if it wasn't specified
if schemaName == "" {
doltSession := dsess.DSessFromSess(ctx.Session)
roots, ok := doltSession.GetRoots(ctx, ctx.GetCurrentDatabase())
if !ok {
return nil, fmt.Errorf("unable to get roots")
}
foundTableName, _, ok, err := resolve.TableWithSearchPath(ctx, roots.Working, tableName)
if err != nil {
return nil, err
}
if !ok {
return nil, fmt.Errorf(`relation "%s" does not exist`, tableName)
}
schemaName = foundTableName.Schema
}

// Validate the full schema + table name and grab the columns
table, err := core.GetSqlTableFromContext(ctx, "", doltdb.TableName{
Schema: schemaName,
Name: tableName,
})
if err != nil {
return nil, err
}
if table == nil {
return nil, fmt.Errorf(`relation "%s" does not exist`, tableName)
}
tableSchema := table.Schema()

// Find the column in the table's schema
columnIndex := tableSchema.IndexOfColName(columnName)
if columnIndex < 0 {
return nil, fmt.Errorf(`column "%s" of relation "%s" does not exist`, columnName, tableName)
}
column := tableSchema[columnIndex]

// Find any sequence associated with the column
sequenceCollection, err := core.GetSequencesCollectionFromContext(ctx)
if err != nil {
return nil, err
}
sequences := sequenceCollection.GetSequencesWithTable(doltdb.TableName{
Name: tableName,
Schema: schemaName,
})
for _, sequence := range sequences {
if sequence.OwnerColumn == column.Name {
// pg_get_serial_sequence() always includes the schema name in its output
return schemaName + "." + sequence.Name, nil
}
}

return nil, nil
},
}
52 changes: 52 additions & 0 deletions testing/go/functions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,58 @@ func TestSystemInformationFunctions(t *testing.T) {
},
},
},
{
Name: "pg_get_serial_sequence",
SetUpScript: []string{
`create table t0 (id INTEGER NOT NULL PRIMARY KEY);`,
`create table t1 (id SERIAL PRIMARY KEY);`,
`create sequence t2_id_seq START 1 INCREMENT 3;`,
`create table t2 (id INTEGER NOT NULL DEFAULT nextval('t2_id_seq'));`,
// TODO: ALTER SEQUENCE OWNED BY is not supported yet. When the sequence is created
fulghum marked this conversation as resolved.
Show resolved Hide resolved
// explicitly, separate from the column, the owner must be udpated before
// pg_get_serial_sequence() will identify it.
//`ALTER SEQUENCE t2_id_seq OWNED BY t2.id;`,
},
Assertions: []ScriptTestAssertion{
{
Query: `SELECT pg_get_serial_sequence('doesnotexist.t1', 'id');`,
ExpectedErr: "does not exist",
},
{
Query: `SELECT pg_get_serial_sequence('doesnotexist', 'id');`,
ExpectedErr: "does not exist",
},
{
Query: `SELECT pg_get_serial_sequence('t0', 'doesnotexist');`,
ExpectedErr: "does not exist",
},
{
// No sequence for column returns null
Query: `SELECT pg_get_serial_sequence('t0', 'id');`,
Cols: []string{"pg_get_serial_sequence"},
Expected: []sql.Row{{nil}},
},
{
Query: `SELECT pg_get_serial_sequence('public.t1', 'id');`,
Cols: []string{"pg_get_serial_sequence"},
Expected: []sql.Row{{"public.t1_id_seq"}},
},
{
// Test with no schema specified
Query: `SELECT pg_get_serial_sequence('t1', 'id');`,
Cols: []string{"pg_get_serial_sequence"},
Expected: []sql.Row{{"public.t1_id_seq"}},
},
{
// TODO: This test shouldn't pass until we're able to use
// ALTER SEQUENCE OWNED BY to set the owning column.
Skip: true,
Query: `SELECT pg_get_serial_sequence('t2', 'id');`,
Cols: []string{"pg_get_serial_sequence"},
Expected: []sql.Row{{"public.t2_id_seq"}},
},
},
},
})
}

Expand Down
Loading