Skip to content

Commit

Permalink
Merge pull request #3041 from le-vlad/pg_cdc_types
Browse files Browse the repository at this point in the history
Pg cdc types
  • Loading branch information
rockwotj authored Nov 28, 2024
2 parents 4faad33 + d7cac5f commit 0e56401
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 10 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ All notable changes to this project will be documented in this file.
### Added

- Add support for `spanner` driver to SQL plugins. (@yufeng-deng)
- Add support for complex database types (JSONB, TEXT[], INET, TSVECTOR, TSRANGE, POINT, INTEGER[]) for `pg_stream` input. (@le-vlad)

### Fixed

- Fixed `pg_stream` issue with discrepancies between replication and snapshot streaming for `UUID` type (@le-vlad)

### Changed

Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ require (
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgproto3/v2 v2.3.3 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgtype v1.14.3 // indirect
github.com/jackc/pgtype v1.14.3
github.com/jackc/puddle v1.3.0 // indirect
github.com/jcmturner/aescts/v2 v2.0.0 // indirect
github.com/jcmturner/dnsutils/v2 v2.0.0 // indirect
Expand Down
130 changes: 130 additions & 0 deletions internal/impl/postgresql/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,22 @@ func ResourceWithPostgreSQLVersion(t *testing.T, pool *dockertest.Pool, version
return err
}

// Creating table with complex PG types
_, err = db.Exec(`CREATE TABLE complex_types_example (
id SERIAL PRIMARY KEY,
json_data JSONB,
tags TEXT[],
ip_addr INET,
search_text TSVECTOR,
time_range TSRANGE,
location POINT,
uuid_col UUID,
int_array INTEGER[]
);`)
if err != nil {
return err
}

_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS flights_composite_pks (
id serial, seq integer, name VARCHAR(50), created_at TIMESTAMP,
Expand Down Expand Up @@ -467,6 +483,120 @@ file:
require.NoError(t, streamOut.StopWithin(time.Second*10))
}

func TestIntegrationPgCDCForPgOutputStreamComplexTypesPlugin(t *testing.T) {
integration.CheckSkip(t)
tmpDir := t.TempDir()
pool, err := dockertest.NewPool("")
require.NoError(t, err)

var (
resource *dockertest.Resource
db *sql.DB
)

resource, db, err = ResourceWithPostgreSQLVersion(t, pool, "16")
require.NoError(t, err)
require.NoError(t, resource.Expire(120))

hostAndPort := resource.GetHostPort("5432/tcp")
hostAndPortSplited := strings.Split(hostAndPort, ":")
password := "l]YLSc|4[i56%{gY"

// inserting data
_, err = db.Exec(`INSERT INTO complex_types_example (
json_data,
tags,
ip_addr,
search_text,
time_range,
location,
uuid_col,
int_array
) VALUES (
'{"name": "test", "value": 42}'::jsonb,
ARRAY['tag1', 'tag2', 'tag3'],
'192.168.1.1',
to_tsvector('english', 'The quick brown fox jumps over the lazy dog'),
tsrange('2024-01-01', '2024-12-31'),
point(45.5, -122.6),
'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11',
ARRAY[1, 2, 3, 4, 5]
);`)
require.NoError(t, err)

databaseURL := fmt.Sprintf("user=user_name password=%s dbname=dbname sslmode=disable host=%s port=%s", password, hostAndPortSplited[0], hostAndPortSplited[1])
template := fmt.Sprintf(`
pg_stream:
dsn: %s
slot_name: test_slot_native_decoder
snapshot_batch_size: 100
stream_snapshot: true
include_transaction_markers: false
schema: public
tables:
- complex_types_example
`, databaseURL)

cacheConf := fmt.Sprintf(`
label: pg_stream_cache
file:
directory: %v
`, tmpDir)

streamOutBuilder := service.NewStreamBuilder()
require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: INFO`))
require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf))
require.NoError(t, streamOutBuilder.AddInputYAML(template))

var outBatches []string
var outBatchMut sync.Mutex
require.NoError(t, streamOutBuilder.AddBatchConsumerFunc(func(c context.Context, mb service.MessageBatch) error {
msgBytes, err := mb[0].AsBytes()
require.NoError(t, err)
outBatchMut.Lock()
outBatches = append(outBatches, string(msgBytes))
outBatchMut.Unlock()
return nil
}))

streamOut, err := streamOutBuilder.Build()
require.NoError(t, err)

go func() {
err = streamOut.Run(context.Background())
require.NoError(t, err)
}()

assert.Eventually(t, func() bool {
outBatchMut.Lock()
defer outBatchMut.Unlock()
return len(outBatches) == 1
}, time.Second*25, time.Millisecond*100)

messageWithComplexTypes := outBatches[0]

// producing change to non-complex type to trigger replication and receive updated row so we can check the complex types again
// but after they have been produced by replication to ensure the consistency
_, err = db.Exec("UPDATE complex_types_example SET id = 2 WHERE id = 1")
require.NoError(t, err)

assert.Eventually(t, func() bool {
outBatchMut.Lock()
defer outBatchMut.Unlock()
return len(outBatches) == 2
}, time.Second*25, time.Millisecond*100)

// replacing update with insert to remove replication messages type differences
// so we will be checking only the data
lastMessage := outBatches[len(outBatches)-1]
lastMessage = strings.Replace(lastMessage, "update", "insert", 1)
messageWithComplexTypes = strings.Replace(messageWithComplexTypes, "\"table_snapshot_progress\":0,", "", 1)

require.Equal(t, messageWithComplexTypes, strings.Replace(lastMessage, ":2", ":1", 1))

require.NoError(t, streamOut.StopWithin(time.Second*10))
}

func TestIntegrationMultiplePostgresVersions(t *testing.T) {
integration.CheckSkip(t)
// running tests in the look to test different PostgreSQL versions
Expand Down
9 changes: 7 additions & 2 deletions internal/impl/postgresql/pglogicalstream/logical_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -500,9 +500,14 @@ func (s *Stream) processSnapshot() error {

var data = make(map[string]any)
for i, getter := range valueGetters {
data[columnNames[i]] = getter(scanArgs[i])
if data[columnNames[i]], err = getter(scanArgs[i]); err != nil {
return err
}

if _, ok := lastPrimaryKey[columnNames[i]]; ok {
lastPkVals[columnNames[i]] = getter(scanArgs[i])
if lastPkVals[columnNames[i]], err = getter(scanArgs[i]); err != nil {
return err
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
package pglogicalstream

import (
"errors"
"fmt"

"github.com/google/uuid"
pgtypes "github.com/jackc/pgtype"
"github.com/jackc/pgx/v5/pgtype"
)

Expand Down Expand Up @@ -148,7 +151,32 @@ func decodePgOutput(WALData []byte, relations map[uint32]*RelationMessage, typeM

func decodeTextColumnData(mi *pgtype.Map, data []byte, dataType uint32) (interface{}, error) {
if dt, ok := mi.TypeForOID(dataType); ok {
return dt.Codec.DecodeValue(mi, dataType, pgtype.TextFormatCode, data)
val, err := dt.Codec.DecodeValue(mi, dataType, pgtype.TextFormatCode, data)
if err != nil {
return val, err
}

if dt.Name == "uuid" {
typesValueForUUID, ok := val.([16]uint8)
if !ok {
return nil, errors.New("unable to convert uuid to string. type casting failed")
}

return uuid.UUID(typesValueForUUID).String(), nil
}

if dt.Name == "tsrange" {
newArray := pgtypes.Tsrange{}
if err := newArray.Scan(data); err != nil {
return nil, err
}

vv, _ := newArray.Value()
return vv, err
}

return val, err
}

return string(data), nil
}
74 changes: 68 additions & 6 deletions internal/impl/postgresql/pglogicalstream/snapshotter.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@ package pglogicalstream
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"

"github.com/jackc/pgtype"

"errors"

_ "github.com/lib/pq"
Expand Down Expand Up @@ -141,24 +144,83 @@ func (s *Snapshotter) findAvgRowSize(ctx context.Context, table string) (sql.Nul
return avgRowSize, nil
}

func (s *Snapshotter) prepareScannersAndGetters(columnTypes []*sql.ColumnType) ([]interface{}, []func(interface{}) interface{}) {
func (s *Snapshotter) prepareScannersAndGetters(columnTypes []*sql.ColumnType) ([]interface{}, []func(interface{}) (interface{}, error)) {
scanArgs := make([]interface{}, len(columnTypes))
valueGetters := make([]func(interface{}) interface{}, len(columnTypes))
valueGetters := make([]func(interface{}) (interface{}, error), len(columnTypes))

for i, v := range columnTypes {
switch v.DatabaseTypeName() {
case "VARCHAR", "TEXT", "UUID", "TIMESTAMP":
scanArgs[i] = new(sql.NullString)
valueGetters[i] = func(v interface{}) interface{} { return v.(*sql.NullString).String }
valueGetters[i] = func(v interface{}) (interface{}, error) { return v.(*sql.NullString).String, nil }
case "BOOL":
scanArgs[i] = new(sql.NullBool)
valueGetters[i] = func(v interface{}) interface{} { return v.(*sql.NullBool).Bool }
valueGetters[i] = func(v interface{}) (interface{}, error) { return v.(*sql.NullBool).Bool, nil }
case "INT4":
scanArgs[i] = new(sql.NullInt64)
valueGetters[i] = func(v interface{}) interface{} { return v.(*sql.NullInt64).Int64 }
valueGetters[i] = func(v interface{}) (interface{}, error) { return v.(*sql.NullInt64).Int64, nil }
case "JSONB":
scanArgs[i] = new(sql.NullString)
valueGetters[i] = func(v interface{}) (interface{}, error) {
payload := v.(*sql.NullString).String
if payload == "" {
return payload, nil
}
var dst any
if err := json.Unmarshal([]byte(v.(*sql.NullString).String), &dst); err != nil {
return nil, err
}

return dst, nil
}
case "INET":
scanArgs[i] = new(sql.NullString)
valueGetters[i] = func(v interface{}) (interface{}, error) {
inet := pgtype.Inet{}
val := v.(*sql.NullString).String
if err := inet.Scan(val); err != nil {
return nil, err
}

return inet.IPNet.String(), nil
}
case "TSRANGE":
scanArgs[i] = new(sql.NullString)
valueGetters[i] = func(v interface{}) (interface{}, error) {
newArray := pgtype.Tsrange{}
val := v.(*sql.NullString).String
if err := newArray.Scan(val); err != nil {
return nil, err
}

vv, _ := newArray.Value()
return vv, nil
}
case "_INT4":
scanArgs[i] = new(sql.NullString)
valueGetters[i] = func(v interface{}) (interface{}, error) {
newArray := pgtype.Int4Array{}
val := v.(*sql.NullString).String
if err := newArray.Scan(val); err != nil {
return nil, err
}

return newArray.Elements, nil
}
case "_TEXT":
scanArgs[i] = new(sql.NullString)
valueGetters[i] = func(v interface{}) (interface{}, error) {
newArray := pgtype.TextArray{}
val := v.(*sql.NullString).String
if err := newArray.Scan(val); err != nil {
return nil, err
}

return newArray.Elements, nil
}
default:
scanArgs[i] = new(sql.NullString)
valueGetters[i] = func(v interface{}) interface{} { return v.(*sql.NullString).String }
valueGetters[i] = func(v interface{}) (interface{}, error) { return v.(*sql.NullString).String, nil }
}
}

Expand Down

0 comments on commit 0e56401

Please sign in to comment.