From 4e3e3ec37692c105fb3943c215d2c340622d81e9 Mon Sep 17 00:00:00 2001 From: happy-qiao <159568575+happy-qiao@users.noreply.github.com> Date: Mon, 16 Sep 2024 18:46:23 -0700 Subject: [PATCH] vertexai(test): Run corpora test in go coroutine to reduce test runtime (#10841) --- vertexai/genai/tokenizer/corpora_test.go | 80 ++++++++++++++---------- 1 file changed, 46 insertions(+), 34 deletions(-) diff --git a/vertexai/genai/tokenizer/corpora_test.go b/vertexai/genai/tokenizer/corpora_test.go index ca41d2818782..42c79bf8d8b0 100644 --- a/vertexai/genai/tokenizer/corpora_test.go +++ b/vertexai/genai/tokenizer/corpora_test.go @@ -23,9 +23,11 @@ import ( "net/http" "os" "strings" + "sync" "testing" "cloud.google.com/go/vertexai/genai" + "golang.org/x/text/encoding" "golang.org/x/text/encoding/charmap" "golang.org/x/text/encoding/japanese" @@ -241,47 +243,57 @@ func TestCountTokensWithCorpora(t *testing.T) { model := client.GenerativeModel(defaultModel) ucr := newUdhrCorpus() + tok, err := New(defaultModel) + if err != nil { + log.Fatal(err) + } + corporaURL := "https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/packages/corpora/udhr.zip" - files, err := corporaGenerator(corporaURL) + corporaFiles, err := corporaGenerator(corporaURL) if err != nil { t.Fatalf("Failed to generate corpora: %v", err) } - // Iterate over files generated by the generator function - for _, fileInfo := range files { - if ucr.shouldSkip(fileInfo.Name) { - fmt.Printf("Skipping file: %s\n", fileInfo.Name) - continue - } - - enc, found := ucr.getEncoding(fileInfo.Name) - if !found { - fmt.Printf("No encoding found for file: %s\n", fileInfo.Name) - continue - } - - decodedContent, err := decodeBytes(enc, fileInfo.Content) - if err != nil { - log.Fatalf("Failed to decode bytes: %v", err) - } + // Manage up to 10 corpora run simultaneously + workLimiter := make(chan struct{}, 10) + defer close(workLimiter) + var wg sync.WaitGroup + for _, corpora := range corporaFiles { + wg.Add(1) + go func(corpora corporaInfo) { + workLimiter <- struct{}{} + defer func() { + <-workLimiter + wg.Done() + }() + if ucr.shouldSkip(corpora.Name) { + log.Printf("Skipping file: %s\n", corpora.Name) + return + } - tok, err := New(defaultModel) - if err != nil { - log.Fatal(err) - } + enc, found := ucr.getEncoding(corpora.Name) + if !found { + log.Printf("No encoding found for file: %s\n", corpora.Name) + return + } - localNtoks, err := tok.CountTokens(genai.Text(decodedContent)) - if err != nil { - log.Fatal(err) - } - remoteNtoks, err := model.CountTokens(ctx, genai.Text(decodedContent)) - if err != nil { - log.Fatal(fileInfo.Name, err) - } - if localNtoks.TotalTokens != remoteNtoks.TotalTokens { - t.Errorf("expected %d(remote count-token results), but got %d(local count-token results)", remoteNtoks, localNtoks) - } + decodedContent, err := decodeBytes(enc, corpora.Content) + if err != nil { + log.Fatalf("Failed to decode bytes: %v", err) + } + localNtoks, err := tok.CountTokens(genai.Text(decodedContent)) + if err != nil { + log.Fatal(err) + } + remoteNtoks, err := model.CountTokens(ctx, genai.Text(decodedContent)) + if err != nil { + log.Fatal(corpora.Name, err) + } + if localNtoks.TotalTokens != remoteNtoks.TotalTokens { + t.Errorf("expected %d(remote count-token results), but got %d(local count-token results)", remoteNtoks, localNtoks) + } + }(corpora) } - + wg.Wait() }