From 7891bad7e2bbcf88d79cb8ecccfcb8a3495f500d Mon Sep 17 00:00:00 2001 From: wangjian Date: Tue, 18 Jun 2024 16:47:30 +0800 Subject: [PATCH 1/2] fix: vertex do not return usage info --- llms/googleai/vertex/vertex.go | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/llms/googleai/vertex/vertex.go b/llms/googleai/vertex/vertex.go index 29e8ab107..afb50551b 100644 --- a/llms/googleai/vertex/vertex.go +++ b/llms/googleai/vertex/vertex.go @@ -113,7 +113,7 @@ func (g *Vertex) GenerateContent( } // convertCandidates converts a sequence of genai.Candidate to a response. -func convertCandidates(candidates []*genai.Candidate) (*llms.ContentResponse, error) { +func convertCandidates(candidates []*genai.Candidate, usage *genai.UsageMetadata) (*llms.ContentResponse, error) { var contentResponse llms.ContentResponse var toolCalls []llms.ToolCall @@ -149,6 +149,11 @@ func convertCandidates(candidates []*genai.Candidate) (*llms.ContentResponse, er metadata := make(map[string]any) metadata[CITATIONS] = candidate.CitationMetadata metadata[SAFETY] = candidate.SafetyRatings + if usage != nil { + metadata["input_tokens"] = usage.PromptTokenCount + metadata["output_tokens"] = usage.CandidatesTokenCount + metadata["total_tokens"] = usage.TotalTokenCount + } contentResponse.Choices = append(contentResponse.Choices, &llms.ContentChoice{ @@ -257,7 +262,7 @@ func generateFromSingleMessage( if len(resp.Candidates) == 0 { return nil, ErrNoContentInResponse } - return convertCandidates(resp.Candidates) + return convertCandidates(resp.Candidates, resp.UsageMetadata) } iter := model.GenerateContentStream(ctx, convertedParts...) return convertAndStreamFromIterator(ctx, iter, opts) @@ -296,7 +301,7 @@ func generateFromMessages( if len(resp.Candidates) == 0 { return nil, ErrNoContentInResponse } - return convertCandidates(resp.Candidates) + return convertCandidates(resp.Candidates, resp.UsageMetadata) } iter := session.SendMessageStream(ctx, reqContent.Parts...) return convertAndStreamFromIterator(ctx, iter, opts) @@ -315,6 +320,7 @@ func convertAndStreamFromIterator( candidate := &genai.Candidate{ Content: &genai.Content{}, } + var usage *genai.UsageMetadata DoStream: for { resp, err := iter.Next() @@ -339,6 +345,10 @@ DoStream: candidate.SafetyRatings = respCandidate.SafetyRatings candidate.CitationMetadata = respCandidate.CitationMetadata + if resp.UsageMetadata != nil { + usage = resp.UsageMetadata + } + for _, part := range respCandidate.Content.Parts { if text, ok := part.(genai.Text); ok { if opts.StreamingFunc(ctx, []byte(text)) != nil { @@ -348,7 +358,7 @@ DoStream: } } - return convertCandidates([]*genai.Candidate{candidate}) + return convertCandidates([]*genai.Candidate{candidate}, usage) } // convertTools converts from a list of langchaingo tools to a list of genai From 691e99ed3abc9de9c2859dcebbaf6342f8165135 Mon Sep 17 00:00:00 2001 From: wangjian Date: Tue, 18 Jun 2024 16:53:13 +0800 Subject: [PATCH 2/2] test: add testcase for gemini return usage info --- llms/googleai/shared_test/shared_test.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/llms/googleai/shared_test/shared_test.go b/llms/googleai/shared_test/shared_test.go index 4bdaa8bb0..2495e9810 100644 --- a/llms/googleai/shared_test/shared_test.go +++ b/llms/googleai/shared_test/shared_test.go @@ -134,6 +134,8 @@ func testMultiContentText(t *testing.T, llm llms.Model) { assert.NotEmpty(t, rsp.Choices) c1 := rsp.Choices[0] assert.Regexp(t, "(?i)dog|carnivo|canid|canine", c1.Content) + assert.Contains(t, c1.GenerationInfo, "output_tokens") + assert.NotZero(t, c1.GenerationInfo["output_tokens"]) } func testMultiContentTextUsingTextParts(t *testing.T, llm llms.Model) { @@ -317,6 +319,8 @@ func testWithStreaming(t *testing.T, llm llms.Model) { c1 := rsp.Choices[0] assert.Regexp(t, "dog|canid", strings.ToLower(c1.Content)) assert.Regexp(t, "dog|canid", strings.ToLower(sb.String())) + assert.Contains(t, c1.GenerationInfo, "output_tokens") + assert.NotZero(t, c1.GenerationInfo["output_tokens"]) } func testTools(t *testing.T, llm llms.Model) { @@ -354,6 +358,8 @@ func testTools(t *testing.T, llm llms.Model) { assert.NotEmpty(t, resp.Choices) c1 := resp.Choices[0] + assert.Contains(t, c1.GenerationInfo, "output_tokens") + assert.NotZero(t, c1.GenerationInfo["output_tokens"]) // Update chat history with assistant's response, with its tool calls. assistantResp := llms.MessageContent{