-
Notifications
You must be signed in to change notification settings - Fork 22
/
embed_vertex.go
157 lines (130 loc) · 4.19 KB
/
embed_vertex.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
package chromem
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"sync"
)
type EmbeddingModelVertex string
const (
EmbeddingModelVertexEnglishV1 EmbeddingModelVertex = "textembedding-gecko@001"
EmbeddingModelVertexEnglishV2 EmbeddingModelVertex = "textembedding-gecko@002"
EmbeddingModelVertexEnglishV3 EmbeddingModelVertex = "textembedding-gecko@003"
EmbeddingModelVertexEnglishV4 EmbeddingModelVertex = "text-embedding-004"
EmbeddingModelVertexMultilingualV1 EmbeddingModelVertex = "textembedding-gecko-multilingual@001"
EmbeddingModelVertexMultilingualV2 EmbeddingModelVertex = "text-multilingual-embedding-002"
)
const baseURLVertex = "https://us-central1-aiplatform.googleapis.com/v1"
type vertexOptions struct {
apiEndpoint string
autoTruncate bool
}
func defaultVertexOptions() *vertexOptions {
return &vertexOptions{
apiEndpoint: baseURLVertex,
autoTruncate: false,
}
}
type VertexOption func(*vertexOptions)
func WithVertexAPIEndpoint(apiEndpoint string) VertexOption {
return func(o *vertexOptions) {
o.apiEndpoint = apiEndpoint
}
}
func WithVertexAutoTruncate(autoTruncate bool) VertexOption {
return func(o *vertexOptions) {
o.autoTruncate = autoTruncate
}
}
type vertexResponse struct {
Predictions []vertexPrediction `json:"predictions"`
}
type vertexPrediction struct {
Embeddings vertexEmbeddings `json:"embeddings"`
}
type vertexEmbeddings struct {
Values []float32 `json:"values"`
// there's more here, but we only care about the embeddings
}
func NewEmbeddingFuncVertex(apiKey, project string, model EmbeddingModelVertex, opts ...VertexOption) EmbeddingFunc {
cfg := defaultVertexOptions()
for _, opt := range opts {
opt(cfg)
}
if cfg.apiEndpoint == "" {
cfg.apiEndpoint = baseURLVertex
}
// We don't set a default timeout here, although it's usually a good idea.
// In our case though, the library user can set the timeout on the context,
// and it might have to be a long timeout, depending on the text length.
client := &http.Client{}
var checkedNormalized bool
checkNormalized := sync.Once{}
return func(ctx context.Context, text string) ([]float32, error) {
b := map[string]any{
"instances": []map[string]any{
{
"content": text,
},
},
"parameters": map[string]any{
"autoTruncate": cfg.autoTruncate,
},
}
// Prepare the request body.
reqBody, err := json.Marshal(b)
if err != nil {
return nil, fmt.Errorf("couldn't marshal request body: %w", err)
}
fullURL := fmt.Sprintf("%s/projects/%s/locations/us-central1/publishers/google/models/%s:predict", cfg.apiEndpoint, project, model)
// Create the request. Creating it with context is important for a timeout
// to be possible, because the client is configured without a timeout.
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewBuffer(reqBody))
if err != nil {
return nil, fmt.Errorf("couldn't create request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
// Send the request.
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("couldn't send request: %w", err)
}
defer resp.Body.Close()
// Check the response status.
if resp.StatusCode != http.StatusOK {
return nil, errors.New("error response from the embedding API: " + resp.Status)
}
// Read and decode the response body.
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("couldn't read response body: %w", err)
}
var embeddingResponse vertexResponse
err = json.Unmarshal(body, &embeddingResponse)
if err != nil {
return nil, fmt.Errorf("couldn't unmarshal response body: %w", err)
}
// Check if the response contains embeddings.
if len(embeddingResponse.Predictions) == 0 || len(embeddingResponse.Predictions[0].Embeddings.Values) == 0 {
return nil, errors.New("no embeddings found in the response")
}
v := embeddingResponse.Predictions[0].Embeddings.Values
checkNormalized.Do(func() {
if isNormalized(v) {
checkedNormalized = true
} else {
checkedNormalized = false
}
})
if !checkedNormalized {
v = normalizeVector(v)
}
return v, nil
}
}