From 6c296cc3e89d6e9ddac5c737f7946ccebce828ed Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Fri, 27 Sep 2024 15:20:36 -0700 Subject: [PATCH] Added encoder --- lib/pgvector.dart | 1 + lib/src/postgres.dart | 32 ++++++++++++++++++++++++++++++++ pubspec.yaml | 2 +- test/postgres_test.dart | 24 +++++++++++++----------- 4 files changed, 47 insertions(+), 12 deletions(-) create mode 100644 lib/src/postgres.dart diff --git a/lib/pgvector.dart b/lib/pgvector.dart index 4929ab7..748e282 100644 --- a/lib/pgvector.dart +++ b/lib/pgvector.dart @@ -3,3 +3,4 @@ export 'src/halfvec.dart' show HalfVector; export 'src/pgvector.dart' show pgvector; export 'src/sparsevec.dart' show SparseVector; export 'src/vector.dart' show Vector; +export 'src/postgres.dart' show PgvectorEncoder; diff --git a/lib/src/postgres.dart b/lib/src/postgres.dart new file mode 100644 index 0000000..e6b7b64 --- /dev/null +++ b/lib/src/postgres.dart @@ -0,0 +1,32 @@ +import 'dart:convert'; +import 'package:postgres/postgres.dart'; +import 'bit.dart'; +import 'halfvec.dart'; +import 'sparsevec.dart'; +import 'vector.dart'; + +EncodedValue? PgvectorEncoder(TypedValue input, CodecContext context) { + final value = input.value; + + if (value is Vector) { + final v = value as Vector; + return EncodedValue.binary(v.toBinary()); + } + + if (value is HalfVector) { + final v = value as HalfVector; + return EncodedValue.text(utf8.encode(v.toString())); + } + + if (value is Bit) { + final v = value as Bit; + return EncodedValue.binary(v.toBinary()); + } + + if (value is SparseVector) { + final v = value as SparseVector; + return EncodedValue.binary(v.toBinary()); + } + + return null; +} diff --git a/pubspec.yaml b/pubspec.yaml index a8b62ee..9e18610 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -7,8 +7,8 @@ environment: sdk: ^3.0.0 dependencies: + postgres: ^3.4.0 dev_dependencies: lints: ^2.0.0 - postgres: ^3.0.0 test: ^1.21.0 diff --git a/test/postgres_test.dart b/test/postgres_test.dart index 3f10b91..8a8a837 100644 --- a/test/postgres_test.dart +++ b/test/postgres_test.dart @@ -11,7 +11,9 @@ void main() { port: 5432, database: 'pgvector_dart_test', username: Platform.environment['USER']), - settings: ConnectionSettings(sslMode: SslMode.disable)); + settings: ConnectionSettings( + sslMode: SslMode.disable, + typeRegistry: TypeRegistry(encoders: [PgvectorEncoder]))); await connection.execute('CREATE EXTENSION IF NOT EXISTS vector'); await connection.execute('DROP TABLE IF EXISTS items'); @@ -23,25 +25,25 @@ void main() { Sql.named( 'INSERT INTO items (embedding, half_embedding, binary_embedding, sparse_embedding) VALUES (@embedding1, @half_embedding1, @binary_embedding1, @sparse_embedding1), (@embedding2, @half_embedding2, @binary_embedding2, @sparse_embedding2), (@embedding3, @half_embedding3, @binary_embedding3, @sparse_embedding3)'), parameters: { - 'embedding1': Vector([1, 1, 1]).toString(), - 'embedding2': Vector([2, 2, 2]).toString(), - 'embedding3': Vector([1, 1, 2]).toString(), - 'half_embedding1': HalfVector([1, 1, 1]).toString(), - 'half_embedding2': HalfVector([2, 2, 2]).toString(), - 'half_embedding3': HalfVector([1, 1, 2]).toString(), + 'embedding1': Vector([1, 1, 1]), + 'embedding2': Vector([2, 2, 2]), + 'embedding3': Vector([1, 1, 2]), + 'half_embedding1': HalfVector([1, 1, 1]), + 'half_embedding2': HalfVector([2, 2, 2]), + 'half_embedding3': HalfVector([1, 1, 2]), 'binary_embedding1': '000', 'binary_embedding2': '101', 'binary_embedding3': '111', - 'sparse_embedding1': SparseVector([1, 1, 1]).toString(), - 'sparse_embedding2': SparseVector([2, 2, 2]).toString(), - 'sparse_embedding3': SparseVector([1, 1, 2]).toString() + 'sparse_embedding1': SparseVector([1, 1, 1]), + 'sparse_embedding2': SparseVector([2, 2, 2]), + 'sparse_embedding3': SparseVector([1, 1, 2]) }); List> results = await connection.execute( Sql.named( 'SELECT id, embedding, binary_embedding, sparse_embedding FROM items ORDER BY embedding <-> @embedding LIMIT 5'), parameters: { - 'embedding': Vector([1, 1, 1]).toString() + 'embedding': Vector([1, 1, 1]) }); expect(results.map((r) => r[0]), equals([1, 3, 2])); expect(Vector.fromBinary(results[1][1].bytes), equals(Vector([1, 1, 2])));