Skip to content

Commit

Permalink
Added encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Sep 27, 2024
1 parent c523ae9 commit 6c296cc
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 12 deletions.
1 change: 1 addition & 0 deletions lib/pgvector.dart
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ export 'src/halfvec.dart' show HalfVector;
export 'src/pgvector.dart' show pgvector;
export 'src/sparsevec.dart' show SparseVector;
export 'src/vector.dart' show Vector;
export 'src/postgres.dart' show PgvectorEncoder;
32 changes: 32 additions & 0 deletions lib/src/postgres.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import 'dart:convert';
import 'package:postgres/postgres.dart';
import 'bit.dart';
import 'halfvec.dart';
import 'sparsevec.dart';
import 'vector.dart';

EncodedValue? PgvectorEncoder(TypedValue input, CodecContext context) {
final value = input.value;

if (value is Vector) {
final v = value as Vector;
return EncodedValue.binary(v.toBinary());
}

if (value is HalfVector) {
final v = value as HalfVector;
return EncodedValue.text(utf8.encode(v.toString()));
}

if (value is Bit) {
final v = value as Bit;
return EncodedValue.binary(v.toBinary());
}

if (value is SparseVector) {
final v = value as SparseVector;
return EncodedValue.binary(v.toBinary());
}

return null;
}
2 changes: 1 addition & 1 deletion pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ environment:
sdk: ^3.0.0

dependencies:
postgres: ^3.4.0

dev_dependencies:
lints: ^2.0.0
postgres: ^3.0.0
test: ^1.21.0
24 changes: 13 additions & 11 deletions test/postgres_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ void main() {
port: 5432,
database: 'pgvector_dart_test',
username: Platform.environment['USER']),
settings: ConnectionSettings(sslMode: SslMode.disable));
settings: ConnectionSettings(
sslMode: SslMode.disable,
typeRegistry: TypeRegistry(encoders: [PgvectorEncoder])));

await connection.execute('CREATE EXTENSION IF NOT EXISTS vector');
await connection.execute('DROP TABLE IF EXISTS items');
Expand All @@ -23,25 +25,25 @@ void main() {
Sql.named(
'INSERT INTO items (embedding, half_embedding, binary_embedding, sparse_embedding) VALUES (@embedding1, @half_embedding1, @binary_embedding1, @sparse_embedding1), (@embedding2, @half_embedding2, @binary_embedding2, @sparse_embedding2), (@embedding3, @half_embedding3, @binary_embedding3, @sparse_embedding3)'),
parameters: {
'embedding1': Vector([1, 1, 1]).toString(),
'embedding2': Vector([2, 2, 2]).toString(),
'embedding3': Vector([1, 1, 2]).toString(),
'half_embedding1': HalfVector([1, 1, 1]).toString(),
'half_embedding2': HalfVector([2, 2, 2]).toString(),
'half_embedding3': HalfVector([1, 1, 2]).toString(),
'embedding1': Vector([1, 1, 1]),
'embedding2': Vector([2, 2, 2]),
'embedding3': Vector([1, 1, 2]),
'half_embedding1': HalfVector([1, 1, 1]),
'half_embedding2': HalfVector([2, 2, 2]),
'half_embedding3': HalfVector([1, 1, 2]),
'binary_embedding1': '000',
'binary_embedding2': '101',
'binary_embedding3': '111',
'sparse_embedding1': SparseVector([1, 1, 1]).toString(),
'sparse_embedding2': SparseVector([2, 2, 2]).toString(),
'sparse_embedding3': SparseVector([1, 1, 2]).toString()
'sparse_embedding1': SparseVector([1, 1, 1]),
'sparse_embedding2': SparseVector([2, 2, 2]),
'sparse_embedding3': SparseVector([1, 1, 2])
});

List<List<dynamic>> results = await connection.execute(
Sql.named(
'SELECT id, embedding, binary_embedding, sparse_embedding FROM items ORDER BY embedding <-> @embedding LIMIT 5'),
parameters: {
'embedding': Vector([1, 1, 1]).toString()
'embedding': Vector([1, 1, 1])
});
expect(results.map((r) => r[0]), equals([1, 3, 2]));
expect(Vector.fromBinary(results[1][1].bytes), equals(Vector([1, 1, 2])));
Expand Down

0 comments on commit 6c296cc

Please sign in to comment.