-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathent_test.go
89 lines (79 loc) · 2.16 KB
/
ent_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
package pgvector_test
import (
"context"
"reflect"
"testing"
"entgo.io/ent/dialect/sql"
_ "github.com/lib/pq"
"github.com/pgvector/pgvector-go"
"github.com/pgvector/pgvector-go/ent"
)
func TestEnt(t *testing.T) {
ctx := context.Background()
client, err := ent.Open("postgres", "postgres://localhost/pgvector_go_test?sslmode=disable")
if err != nil {
panic(err)
}
defer client.Close()
_, err = client.ExecContext(ctx, "CREATE EXTENSION IF NOT EXISTS vector")
if err != nil {
panic(err)
}
_, err = client.ExecContext(ctx, "DROP TABLE IF EXISTS items")
if err != nil {
panic(err)
}
err = client.Schema.Create(ctx)
if err != nil {
panic(err)
}
embedding := pgvector.NewVector([]float32{1, 1, 1})
halfEmbedding := pgvector.NewHalfVector([]float32{1, 1, 1})
binaryEmbedding := "000"
sparseEmbedding := pgvector.NewSparseVector([]float32{1, 1, 1})
_, err = client.Item.Create().
SetEmbedding(embedding).
SetHalfEmbedding(halfEmbedding).
SetBinaryEmbedding(binaryEmbedding).
SetSparseEmbedding(sparseEmbedding).Save(ctx)
if err != nil {
panic(err)
}
_, err = client.Item.CreateBulk(
client.Item.Create().
SetEmbedding(pgvector.NewVector([]float32{2, 2, 2})).
SetHalfEmbedding(pgvector.NewHalfVector([]float32{2, 2, 2})).
SetBinaryEmbedding("101").
SetSparseEmbedding(pgvector.NewSparseVector([]float32{2, 2, 2})),
client.Item.Create().
SetEmbedding(pgvector.NewVector([]float32{1, 1, 2})).
SetHalfEmbedding(pgvector.NewHalfVector([]float32{1, 1, 2})).
SetBinaryEmbedding("111").
SetSparseEmbedding(pgvector.NewSparseVector([]float32{1, 1, 2})),
).Save(ctx)
if err != nil {
panic(err)
}
items, err := client.Item.
Query().
Order(func(s *sql.Selector) {
s.OrderExpr(sql.ExprP("embedding <-> $1", embedding))
}).
Limit(5).
All(ctx)
if err != nil {
panic(err)
}
if items[0].ID != 1 || items[1].ID != 3 || items[2].ID != 2 {
t.Error()
}
if !reflect.DeepEqual(items[1].Embedding.Slice(), []float32{1, 1, 2}) {
t.Error()
}
if !reflect.DeepEqual(items[1].HalfEmbedding.Slice(), []float32{1, 1, 2}) {
t.Error()
}
if !reflect.DeepEqual(items[1].SparseEmbedding.Slice(), []float32{1, 1, 2}) {
t.Error()
}
}