Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cosmos] Adds global endpoint manager policy and links GEM to client #22223

Merged
merged 13 commits into from
Jan 17, 2024
44 changes: 40 additions & 4 deletions sdk/data/azcosmos/cosmos_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,15 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
)

const (
apiVersion = "2020-11-05"
)

// Client is used to interact with the Azure Cosmos DB database service.
type Client struct {
endpoint string
pipeline azruntime.Pipeline
gem *globalEndpointManager
}

// Endpoint used to create the client.
Expand All @@ -36,7 +41,15 @@ func (c *Client) Endpoint() string {
// cred - The credential used to authenticate with the cosmos service.
// options - Optional Cosmos client options. Pass nil to accept default values.
func NewClientWithKey(endpoint string, cred KeyCredential, o *ClientOptions) (*Client, error) {
return &Client{endpoint: endpoint, pipeline: newPipeline(newSharedKeyCredPolicy(cred), o)}, nil
preferredRegions := []string{}
if o != nil {
preferredRegions = o.PreferredRegions
}
gem, err := newGlobalEndpointManager(endpoint, newInternalPipeline(newSharedKeyCredPolicy(cred), o), preferredRegions, 0)
if err != nil {
return nil, err
}
return &Client{endpoint: endpoint, pipeline: newPipeline(newSharedKeyCredPolicy(cred), gem, o), gem: gem}, nil
}

// NewClient creates a new instance of Cosmos client with Azure AD access token authentication. It uses the default pipeline configuration.
Expand All @@ -48,7 +61,16 @@ func NewClient(endpoint string, cred azcore.TokenCredential, o *ClientOptions) (
if err != nil {
return nil, err
}
return &Client{endpoint: endpoint, pipeline: newPipeline(newCosmosBearerTokenPolicy(cred, scope, nil), o)}, nil
preferredRegions := []string{}
if o != nil {
preferredRegions = o.PreferredRegions
}
gem, err := newGlobalEndpointManager(endpoint, newInternalPipeline(newCosmosBearerTokenPolicy(cred, scope, nil), o), preferredRegions, 0)
if err != nil {
return nil, err
}

return &Client{endpoint: endpoint, pipeline: newPipeline(newCosmosBearerTokenPolicy(cred, scope, nil), gem, o), gem: gem}, nil
}

// NewClientFromConnectionString creates a new instance of Cosmos client from connection string. It uses the default pipeline configuration.
Expand Down Expand Up @@ -87,7 +109,7 @@ func NewClientFromConnectionString(connectionString string, o *ClientOptions) (*
return NewClientWithKey(endpoint, cred, o)
}

func newPipeline(authPolicy policy.Policy, options *ClientOptions) azruntime.Pipeline {
func newPipeline(authPolicy policy.Policy, gem *globalEndpointManager, options *ClientOptions) azruntime.Pipeline {
if options == nil {
options = &ClientOptions{}
}
Expand All @@ -98,6 +120,7 @@ func newPipeline(authPolicy policy.Policy, options *ClientOptions) azruntime.Pip
&headerPolicies{
enableContentResponseOnWrite: options.EnableContentResponseOnWrite,
},
&globalEndpointManagerPolicy{gem: gem},
},
PerRetry: []policy.Policy{
authPolicy,
Expand All @@ -106,6 +129,19 @@ func newPipeline(authPolicy policy.Policy, options *ClientOptions) azruntime.Pip
&options.ClientOptions)
}

func newInternalPipeline(authPolicy policy.Policy, options *ClientOptions) azruntime.Pipeline {
if options == nil {
options = &ClientOptions{}
}
return azruntime.NewPipeline("azcosmos", serviceLibVersion,
azruntime.PipelineOptions{
PerRetry: []policy.Policy{
authPolicy,
},
},
&options.ClientOptions)
}

func createScopeFromEndpoint(endpoint string) ([]string, error) {
u, err := url.Parse(endpoint)
if err != nil {
Expand Down Expand Up @@ -394,7 +430,7 @@ func (c *Client) createRequest(
}

req.Raw().Header.Set(headerXmsDate, time.Now().UTC().Format(http.TimeFormat))
req.Raw().Header.Set(headerXmsVersion, "2020-11-05")
req.Raw().Header.Set(headerXmsVersion, apiVersion)
req.Raw().Header.Set(cosmosHeaderSDKSupportedCapabilities, supportedCapabilitiesHeaderValue)

req.SetOperationValue(operationContext)
Expand Down
4 changes: 2 additions & 2 deletions sdk/data/azcosmos/cosmos_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,8 @@ func TestCreateRequest(t *testing.T) {
t.Errorf("Expected %v, but got %v", "", req.Raw().Header.Get(headerXmsDate))
}

if req.Raw().Header.Get(headerXmsVersion) != "2020-11-05" {
t.Errorf("Expected %v, but got %v", "2020-11-05", req.Raw().Header.Get(headerXmsVersion))
if req.Raw().Header.Get(headerXmsVersion) != apiVersion {
t.Errorf("Expected %v, but got %v", apiVersion, req.Raw().Header.Get(headerXmsVersion))
}

if req.Raw().Header.Get(cosmosHeaderSDKSupportedCapabilities) != supportedCapabilitiesHeaderValue {
Expand Down
40 changes: 26 additions & 14 deletions sdk/data/azcosmos/cosmos_global_endpoint_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@ import (
const defaultUnavailableLocationRefreshInterval = 5 * time.Minute

type globalEndpointManager struct {
client *Client
clientEndpoint string
pipeline azruntime.Pipeline
preferredLocations []string
locationCache *locationCache
refreshTimeInterval time.Duration
gemMutex sync.Mutex
lastUpdateTime time.Time
}

func newGlobalEndpointManager(client *Client, preferredLocations []string, refreshTimeInterval time.Duration) (*globalEndpointManager, error) {
endpoint, err := url.Parse(client.endpoint)
func newGlobalEndpointManager(clientEndpoint string, pipeline azruntime.Pipeline, preferredLocations []string, refreshTimeInterval time.Duration) (*globalEndpointManager, error) {
endpoint, err := url.Parse(clientEndpoint)
if err != nil {
return &globalEndpointManager{}, err
}
Expand All @@ -36,7 +37,8 @@ func newGlobalEndpointManager(client *Client, preferredLocations []string, refre
}

gem := &globalEndpointManager{
client: client,
clientEndpoint: clientEndpoint,
pipeline: pipeline,
preferredLocations: preferredLocations,
locationCache: newLocationCache(preferredLocations, *endpoint),
refreshTimeInterval: refreshTimeInterval,
Expand Down Expand Up @@ -110,24 +112,34 @@ func (gem *globalEndpointManager) GetAccountProperties(ctx context.Context) (acc
resourceAddress: "",
}

path, err := generatePathForNameBased(resourceTypeDatabaseAccount, "", false)
ctxt, cancel := context.WithTimeout(ctx, 60*time.Second)
defer cancel()
req, err := azruntime.NewRequest(ctxt, http.MethodGet, gem.clientEndpoint)
if err != nil {
return accountProperties{}, fmt.Errorf("failed to generate path for name-based request: %v", err)
return accountProperties{}, err
}

ctx, cancel := context.WithTimeout(ctx, 60*time.Second)
azResponse, err := gem.client.sendGetRequest(path, ctx, operationContext, nil, nil)
cancel()
req.Raw().Header.Set(headerXmsDate, time.Now().UTC().Format(http.TimeFormat))
req.Raw().Header.Set(headerXmsVersion, apiVersion)
req.Raw().Header.Set(cosmosHeaderSDKSupportedCapabilities, supportedCapabilitiesHeaderValue)

req.SetOperationValue(operationContext)

azResponse, err := gem.pipeline.Do(req)
if err != nil {
return accountProperties{}, fmt.Errorf("failed to retrieve account properties: %v", err)
return accountProperties{}, err
}

properties, err := newAccountProperties(azResponse)
if err != nil {
return accountProperties{}, fmt.Errorf("failed to parse account properties: %v", err)
successResponse := (azResponse.StatusCode >= 200 && azResponse.StatusCode < 300)
if successResponse {
properties, err := newAccountProperties(azResponse)
if err != nil {
return accountProperties{}, fmt.Errorf("failed to parse account properties: %v", err)
}
return properties, nil
}

return properties, nil
return accountProperties{}, newCosmosError(azResponse)
}

func newAccountProperties(azResponse *http.Response) (accountProperties, error) {
Expand Down
25 changes: 25 additions & 0 deletions sdk/data/azcosmos/cosmos_global_endpoint_manager_policy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package azcosmos

import (
"context"
"net/http"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
)

type globalEndpointManagerPolicy struct {
gem *globalEndpointManager
}

func (p *globalEndpointManagerPolicy) Do(req *policy.Request) (*http.Response, error) {
simorenoh marked this conversation as resolved.
Show resolved Hide resolved
shouldRefresh := p.gem.ShouldRefresh()
if shouldRefresh {
go func() {
_ = p.gem.Update(context.Background())
}()
}
return req.Next()
}
51 changes: 16 additions & 35 deletions sdk/data/azcosmos/cosmos_global_endpoint_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,7 @@ func TestGlobalEndpointManagerGetWriteEndpoints(t *testing.T) {

pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: srv})

client := &Client{endpoint: srv.URL(), pipeline: pl}

preferredRegions := []string{"West US", "Central US"}

gem, err := newGlobalEndpointManager(client, preferredRegions, 5*time.Minute)
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute)
assert.NoError(t, err)

writeEndpoints, err := gem.GetWriteEndpoints()
Expand All @@ -50,6 +46,7 @@ func TestGlobalEndpointManagerGetWriteEndpoints(t *testing.T) {
expectedWriteEndpoints := []url.URL{
*serverEndpoint,
}

assert.Equal(t, expectedWriteEndpoints, writeEndpoints)
}

Expand All @@ -60,11 +57,7 @@ func TestGlobalEndpointManagerGetReadEndpoints(t *testing.T) {

pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: srv})

client := &Client{endpoint: srv.URL(), pipeline: pl}

preferredRegions := []string{"West US", "Central US"}

gem, err := newGlobalEndpointManager(client, preferredRegions, 5*time.Minute)
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute)
assert.NoError(t, err)

readEndpoints, err := gem.GetReadEndpoints()
Expand All @@ -88,12 +81,10 @@ func TestGlobalEndpointManagerMarkEndpointUnavailableForRead(t *testing.T) {

client := &Client{endpoint: srv.URL(), pipeline: pl}

preferredRegions := []string{"West US", "Central US"}

gem, err := newGlobalEndpointManager(client, preferredRegions, 5*time.Minute)
endpoint, err := url.Parse(client.endpoint)
assert.NoError(t, err)

endpoint, err := url.Parse(client.endpoint)
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute)
assert.NoError(t, err)

err = gem.MarkEndpointUnavailableForRead(*endpoint)
Expand All @@ -112,12 +103,10 @@ func TestGlobalEndpointManagerMarkEndpointUnavailableForWrite(t *testing.T) {

client := &Client{endpoint: srv.URL(), pipeline: pl}

preferredRegions := []string{"West US", "Central US"}

gem, err := newGlobalEndpointManager(client, preferredRegions, 5*time.Minute)
endpoint, err := url.Parse(client.endpoint)
assert.NoError(t, err)

endpoint, err := url.Parse(client.endpoint)
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute)
assert.NoError(t, err)

err = gem.MarkEndpointUnavailableForWrite(*endpoint)
Expand All @@ -130,7 +119,6 @@ func TestGlobalEndpointManagerMarkEndpointUnavailableForWrite(t *testing.T) {
func TestGlobalEndpointManagerGetEndpointLocation(t *testing.T) {
srv, close := mock.NewTLSServer()
defer close()
srv.SetResponse(mock.WithStatusCode(http.StatusOK))

westRegion := accountRegion{
Name: "West US",
Expand All @@ -144,19 +132,17 @@ func TestGlobalEndpointManagerGetEndpointLocation(t *testing.T) {
}

jsonString, err := json.Marshal(properties)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)

srv.SetResponse(mock.WithStatusCode(200))
srv.SetResponse(mock.WithBody(jsonString))

pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: srv})

client := &Client{endpoint: srv.URL(), pipeline: pl}

gem, err := newGlobalEndpointManager(client, []string{}, 5*time.Minute)
serverEndpoint, err := url.Parse(srv.URL())
assert.NoError(t, err)

serverEndpoint, err := url.Parse(srv.URL())
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{}, 5*time.Minute)
assert.NoError(t, err)

err = gem.Update(context.Background())
Expand All @@ -175,11 +161,7 @@ func TestGlobalEndpointManagerGetAccountProperties(t *testing.T) {

pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: srv})

client := &Client{endpoint: srv.URL(), pipeline: pl}

preferredRegions := []string{"West US", "Central US"}

gem, err := newGlobalEndpointManager(client, preferredRegions, 5*time.Minute)
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute)
assert.NoError(t, err)

accountProps, err := gem.GetAccountProperties(context.Background())
Expand Down Expand Up @@ -212,13 +194,13 @@ func TestGlobalEndpointManagerCanUseMultipleWriteLocations(t *testing.T) {
mockLc.useMultipleWriteLocations = true

mockGem := globalEndpointManager{
client: client,
clientEndpoint: client.endpoint,
preferredLocations: preferredRegions,
locationCache: mockLc,
refreshTimeInterval: 5 * time.Minute,
}

gem, err := newGlobalEndpointManager(client, preferredRegions, 5*time.Minute)
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{}, 5*time.Minute)
assert.NoError(t, err)

// Multiple locations should be false for default GEM
Expand Down Expand Up @@ -254,9 +236,8 @@ func TestGlobalEndpointManagerConcurrentUpdate(t *testing.T) {
srv.SetResponse(mock.WithBody(jsonString))

pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerCall: []policy.Policy{countPolicy}}, &policy.ClientOptions{Transport: srv})
client := &Client{endpoint: srv.URL(), pipeline: pl}

gem, err := newGlobalEndpointManager(client, []string{}, 5*time.Second)
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{}, 5*time.Second)
assert.NoError(t, err)

// Call update concurrently and see how many times the policy gets called
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func TestGlobalEndpointManagerEmulator(t *testing.T) {
preferredRegions := []string{}
emulatorRegion := accountRegion{Name: emulatorRegionName, Endpoint: "https://127.0.0.1:8081/"}

gem, err := newGlobalEndpointManager(client, preferredRegions, 5*time.Minute)
gem, err := newGlobalEndpointManager(client.endpoint, client.pipeline, preferredRegions, 5*time.Minute)
assert.NoError(t, err)

accountProps, err := gem.GetAccountProperties(context.Background())
Expand Down Expand Up @@ -61,7 +61,7 @@ func TestGlobalEndpointManagerEmulator(t *testing.T) {
assert.Equal(t, locationInfo.availReadEndpointsByLocation, availableEndpointsByLocation)
assert.Equal(t, locationInfo.availWriteEndpointsByLocation, availableEndpointsByLocation)

//update and assert available locations are now populated in location cache
// Run Update() and assert available locations are now populated in location cache
err = gem.Update(context.Background())
assert.NoError(t, err)
locationInfo = gem.locationCache.locationInfo
Expand All @@ -73,3 +73,34 @@ func TestGlobalEndpointManagerEmulator(t *testing.T) {
assert.Equal(t, len(locationInfo.availReadEndpointsByLocation), len(availableEndpointsByLocation)+1)
assert.Equal(t, len(locationInfo.availWriteEndpointsByLocation), len(availableEndpointsByLocation)+1)
}

func TestGlobalEndpointManagerPolicyEmulator(t *testing.T) {
ealsur marked this conversation as resolved.
Show resolved Hide resolved
emulatorTests := newEmulatorTests(t)
client := emulatorTests.getClient(t)
emulatorRegionName := "South Central US"

// Assert location cache is not populated until update() is called within the policy
locationInfo := client.gem.locationCache.locationInfo
availableLocation := []string{}
availableEndpointsByLocation := map[string]url.URL{}

assert.Equal(t, locationInfo.availReadLocations, availableLocation)
assert.Equal(t, locationInfo.availWriteLocations, availableLocation)
assert.Equal(t, locationInfo.availReadEndpointsByLocation, availableEndpointsByLocation)
assert.Equal(t, locationInfo.availWriteEndpointsByLocation, availableEndpointsByLocation)

// Assert that information gets populated by the gem policy after running an http request (read item)
db, _ := client.NewDatabase("database_id")
container, _ := db.NewContainer("container_id")
_, err := container.ReadItem(context.TODO(), NewPartitionKeyString("1"), "doc1", nil)
assert.Error(t, err)

locationInfo = client.gem.locationCache.locationInfo

assert.Equal(t, len(locationInfo.availReadLocations), len(availableLocation)+1)
assert.Equal(t, len(locationInfo.availWriteLocations), len(availableLocation)+1)
assert.Equal(t, locationInfo.availWriteLocations[0], emulatorRegionName)
assert.Equal(t, locationInfo.availReadLocations[0], emulatorRegionName)
assert.Equal(t, len(locationInfo.availReadEndpointsByLocation), len(availableEndpointsByLocation)+1)
assert.Equal(t, len(locationInfo.availWriteEndpointsByLocation), len(availableEndpointsByLocation)+1)
}
Loading