diff --git a/go/tasks/pluginmachinery/remote/mocks/remote_resource.go b/go/tasks/pluginmachinery/remote/mocks/remote_resource.go deleted file mode 100644 index 5770fab12..000000000 --- a/go/tasks/pluginmachinery/remote/mocks/remote_resource.go +++ /dev/null @@ -1,10 +0,0 @@ -// Code generated by mockery v1.0.1. DO NOT EDIT. - -package mocks - -import mock "github.com/stretchr/testify/mock" - -// RemoteResource is an autogenerated mock type for the RemoteResource type -type RemoteResource struct { - mock.Mock -} diff --git a/go/tasks/pluginmachinery/remote/mocks/resource.go b/go/tasks/pluginmachinery/remote/mocks/resource.go deleted file mode 100644 index 9d8d150f2..000000000 --- a/go/tasks/pluginmachinery/remote/mocks/resource.go +++ /dev/null @@ -1,84 +0,0 @@ -// Code generated by mockery v1.0.1. DO NOT EDIT. - -package mocks - -import ( - core "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" - mock "github.com/stretchr/testify/mock" -) - -// Resource is an autogenerated mock type for the Resource type -type Resource struct { - mock.Mock -} - -type Resource_ID struct { - *mock.Call -} - -func (_m Resource_ID) Return(_a0 string) *Resource_ID { - return &Resource_ID{Call: _m.Call.Return(_a0)} -} - -func (_m *Resource) OnID() *Resource_ID { - c := _m.On("ID") - return &Resource_ID{Call: c} -} - -func (_m *Resource) OnIDMatch(matchers ...interface{}) *Resource_ID { - c := _m.On("ID", matchers...) - return &Resource_ID{Call: c} -} - -// ID provides a mock function with given fields: -func (_m *Resource) ID() string { - ret := _m.Called() - - var r0 string - if rf, ok := ret.Get(0).(func() string); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(string) - } - - return r0 -} - -type Resource_Status struct { - *mock.Call -} - -func (_m Resource_Status) Return(phase core.PhaseInfo, err error) *Resource_Status { - return &Resource_Status{Call: _m.Call.Return(phase, err)} -} - -func (_m *Resource) OnStatus() *Resource_Status { - c := _m.On("Status") - return &Resource_Status{Call: c} -} - -func (_m *Resource) OnStatusMatch(matchers ...interface{}) *Resource_Status { - c := _m.On("Status", matchers...) - return &Resource_Status{Call: c} -} - -// Status provides a mock function with given fields: -func (_m *Resource) Status() (core.PhaseInfo, error) { - ret := _m.Called() - - var r0 core.PhaseInfo - if rf, ok := ret.Get(0).(func() core.PhaseInfo); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(core.PhaseInfo) - } - - var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} diff --git a/go/tasks/plugins/hive/config/config.go b/go/tasks/plugins/hive/config/config.go index 7d2b57277..47270b041 100644 --- a/go/tasks/plugins/hive/config/config.go +++ b/go/tasks/plugins/hive/config/config.go @@ -5,6 +5,9 @@ package config import ( "context" "net/url" + "time" + + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/remote" "github.com/lyft/flytestdlib/config" "github.com/lyft/flytestdlib/logger" @@ -45,11 +48,22 @@ var ( CommandAPIPath: MustParse("/api/v1.2/commands/"), AnalyzeLinkPath: MustParse("/v2/analyze"), TokenKey: "FLYTE_QUBOLE_CLIENT_TOKEN", - LruCacheSize: 2000, - Workers: 15, DefaultClusterLabel: "default", ClusterConfigs: []ClusterConfig{{PrimaryLabel: "default", Labels: []string{"default"}, Limit: 100, ProjectScopeQuotaProportionCap: 0.7, NamespaceScopeQuotaProportionCap: 0.7}}, DestinationClusterConfigs: []DestinationClusterConfig{}, + Caching: remote.CachingProperties{ + Size: 2000, + ResyncInterval: config.Duration{Duration: 20 * time.Second}, + Workers: 15, + }, + WriteRateLimiter: remote.RateLimiterProperties{ + QPS: 100, + Burst: 200, + }, + ReadRateLimiter: remote.RateLimiterProperties{ + QPS: 100, + Burst: 200, + }, } quboleConfigSection = pluginsConfig.MustRegisterSubSection(quboleConfigSectionKey, &defaultConfig) @@ -57,15 +71,16 @@ var ( // Qubole plugin configs type Config struct { - Endpoint config.URL `json:"endpoint" pflag:",Endpoint for qubole to use"` - CommandAPIPath config.URL `json:"commandApiPath" pflag:",API Path where commands can be launched on Qubole. Should be a valid url."` - AnalyzeLinkPath config.URL `json:"analyzeLinkPath" pflag:",URL path where queries can be visualized on qubole website. Should be a valid url."` - TokenKey string `json:"quboleTokenKey" pflag:",Name of the key where to find Qubole token in the secret manager."` - LruCacheSize int `json:"lruCacheSize" pflag:",Size of the AutoRefreshCache"` - Workers int `json:"workers" pflag:",Number of parallel workers to refresh the cache"` - DefaultClusterLabel string `json:"defaultClusterLabel" pflag:",The default cluster label. This will be used if label is not specified on the hive job."` - ClusterConfigs []ClusterConfig `json:"clusterConfigs" pflag:"-,A list of cluster configs. Each of the configs corresponds to a service cluster"` - DestinationClusterConfigs []DestinationClusterConfig `json:"destinationClusterConfigs" pflag:"-,A list configs specifying the destination service cluster for (project, domain)"` + Endpoint config.URL `json:"endpoint" pflag:",Endpoint for qubole to use"` + CommandAPIPath config.URL `json:"commandApiPath" pflag:",API Path where commands can be launched on Qubole. Should be a valid url."` + AnalyzeLinkPath config.URL `json:"analyzeLinkPath" pflag:",URL path where queries can be visualized on qubole website. Should be a valid url."` + TokenKey string `json:"quboleTokenKey" pflag:",Name of the key where to find Qubole token in the secret manager."` + ClusterConfigs []ClusterConfig `json:"clusterConfigs" pflag:"-,A list of cluster configs. Each of the configs corresponds to a service cluster"` + DefaultClusterLabel string `json:"defaultClusterLabel" pflag:",The default cluster label. This will be used if label is not specified on the hive job."` + DestinationClusterConfigs []DestinationClusterConfig `json:"destinationClusterConfigs" pflag:"-,A list configs specifying the destination service cluster for (project, domain)"` + ReadRateLimiter remote.RateLimiterProperties `json:"readRateLimiter" pflag:",Defines rate limiter properties for read actions (e.g. retrieve status)."` + WriteRateLimiter remote.RateLimiterProperties `json:"writeRateLimiter" pflag:",Defines rate limiter properties for write actions."` + Caching remote.CachingProperties `json:"caching" pflag:",Defines caching characteristics."` } // Retrieves the current config value or default. diff --git a/go/tasks/plugins/hive/config/config_flags.go b/go/tasks/plugins/hive/config/config_flags.go index 5b82f436d..f469db49a 100755 --- a/go/tasks/plugins/hive/config/config_flags.go +++ b/go/tasks/plugins/hive/config/config_flags.go @@ -45,8 +45,13 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.String(fmt.Sprintf("%v%v", prefix, "commandApiPath"), defaultConfig.CommandAPIPath.String(), "API Path where commands can be launched on Qubole. Should be a valid url.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "analyzeLinkPath"), defaultConfig.AnalyzeLinkPath.String(), "URL path where queries can be visualized on qubole website. Should be a valid url.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "quboleTokenKey"), defaultConfig.TokenKey, "Name of the key where to find Qubole token in the secret manager.") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "lruCacheSize"), defaultConfig.LruCacheSize, "Size of the AutoRefreshCache") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "workers"), defaultConfig.Workers, "Number of parallel workers to refresh the cache") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "defaultClusterLabel"), defaultConfig.DefaultClusterLabel, "The default cluster label. This will be used if label is not specified on the hive job.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "readRateLimiter.qps"), defaultConfig.ReadRateLimiter.QPS, "Defines the max rate of calls per second.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "readRateLimiter.burst"), defaultConfig.ReadRateLimiter.Burst, "Defines the maximum burst size.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "writeRateLimiter.qps"), defaultConfig.WriteRateLimiter.QPS, "Defines the max rate of calls per second.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "writeRateLimiter.burst"), defaultConfig.WriteRateLimiter.Burst, "Defines the maximum burst size.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "caching.size"), defaultConfig.Caching.Size, "Defines the maximum number of items to cache.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "caching.resyncInterval"), defaultConfig.Caching.ResyncInterval.String(), "Defines the sync interval.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "caching.workers"), defaultConfig.Caching.Workers, "Defines the number of workers to start up to process items.") return cmdFlags } diff --git a/go/tasks/plugins/hive/config/config_flags_test.go b/go/tasks/plugins/hive/config/config_flags_test.go index a7fdb6063..67be1f288 100755 --- a/go/tasks/plugins/hive/config/config_flags_test.go +++ b/go/tasks/plugins/hive/config/config_flags_test.go @@ -187,11 +187,11 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) - t.Run("Test_lruCacheSize", func(t *testing.T) { + t.Run("Test_defaultClusterLabel", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("lruCacheSize"); err == nil { - assert.Equal(t, int(defaultConfig.LruCacheSize), vInt) + if vString, err := cmdFlags.GetString("defaultClusterLabel"); err == nil { + assert.Equal(t, string(defaultConfig.DefaultClusterLabel), vString) } else { assert.FailNow(t, err.Error()) } @@ -200,20 +200,20 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := "1" - cmdFlags.Set("lruCacheSize", testValue) - if vInt, err := cmdFlags.GetInt("lruCacheSize"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.LruCacheSize) + cmdFlags.Set("defaultClusterLabel", testValue) + if vString, err := cmdFlags.GetString("defaultClusterLabel"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.DefaultClusterLabel) } else { assert.FailNow(t, err.Error()) } }) }) - t.Run("Test_workers", func(t *testing.T) { + t.Run("Test_readRateLimiter.qps", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("workers"); err == nil { - assert.Equal(t, int(defaultConfig.Workers), vInt) + if vInt, err := cmdFlags.GetInt("readRateLimiter.qps"); err == nil { + assert.Equal(t, int(defaultConfig.ReadRateLimiter.QPS), vInt) } else { assert.FailNow(t, err.Error()) } @@ -222,20 +222,20 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := "1" - cmdFlags.Set("workers", testValue) - if vInt, err := cmdFlags.GetInt("workers"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.Workers) + cmdFlags.Set("readRateLimiter.qps", testValue) + if vInt, err := cmdFlags.GetInt("readRateLimiter.qps"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.ReadRateLimiter.QPS) } else { assert.FailNow(t, err.Error()) } }) }) - t.Run("Test_defaultClusterLabel", func(t *testing.T) { + t.Run("Test_readRateLimiter.burst", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly - if vString, err := cmdFlags.GetString("defaultClusterLabel"); err == nil { - assert.Equal(t, string(defaultConfig.DefaultClusterLabel), vString) + if vInt, err := cmdFlags.GetInt("readRateLimiter.burst"); err == nil { + assert.Equal(t, int(defaultConfig.ReadRateLimiter.Burst), vInt) } else { assert.FailNow(t, err.Error()) } @@ -244,9 +244,119 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := "1" - cmdFlags.Set("defaultClusterLabel", testValue) - if vString, err := cmdFlags.GetString("defaultClusterLabel"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.DefaultClusterLabel) + cmdFlags.Set("readRateLimiter.burst", testValue) + if vInt, err := cmdFlags.GetInt("readRateLimiter.burst"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.ReadRateLimiter.Burst) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_writeRateLimiter.qps", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("writeRateLimiter.qps"); err == nil { + assert.Equal(t, int(defaultConfig.WriteRateLimiter.QPS), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("writeRateLimiter.qps", testValue) + if vInt, err := cmdFlags.GetInt("writeRateLimiter.qps"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.WriteRateLimiter.QPS) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_writeRateLimiter.burst", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("writeRateLimiter.burst"); err == nil { + assert.Equal(t, int(defaultConfig.WriteRateLimiter.Burst), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("writeRateLimiter.burst", testValue) + if vInt, err := cmdFlags.GetInt("writeRateLimiter.burst"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.WriteRateLimiter.Burst) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_caching.size", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("caching.size"); err == nil { + assert.Equal(t, int(defaultConfig.Caching.Size), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("caching.size", testValue) + if vInt, err := cmdFlags.GetInt("caching.size"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.Caching.Size) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_caching.resyncInterval", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("caching.resyncInterval"); err == nil { + assert.Equal(t, string(defaultConfig.Caching.ResyncInterval.String()), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := defaultConfig.Caching.ResyncInterval.String() + + cmdFlags.Set("caching.resyncInterval", testValue) + if vString, err := cmdFlags.GetString("caching.resyncInterval"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Caching.ResyncInterval) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_caching.workers", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("caching.workers"); err == nil { + assert.Equal(t, int(defaultConfig.Caching.Workers), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("caching.workers", testValue) + if vInt, err := cmdFlags.GetInt("caching.workers"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.Caching.Workers) } else { assert.FailNow(t, err.Error()) diff --git a/go/tasks/plugins/hive/execution_state.go b/go/tasks/plugins/hive/execution_state.go deleted file mode 100644 index e0c61f1d8..000000000 --- a/go/tasks/plugins/hive/execution_state.go +++ /dev/null @@ -1,458 +0,0 @@ -package hive - -import ( - "context" - "fmt" - "strconv" - "time" - - "github.com/lyft/flytestdlib/cache" - - idlCore "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" - - "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils" - "github.com/lyft/flyteplugins/go/tasks/plugins/hive/config" - - "github.com/lyft/flyteplugins/go/tasks/errors" - "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/lyft/flyteplugins/go/tasks/plugins/hive/client" - "github.com/lyft/flytestdlib/logger" -) - -type ExecutionPhase int - -const ( - PhaseNotStarted ExecutionPhase = iota - PhaseQueued // resource manager token gotten - PhaseSubmitted // Sent off to Qubole - - PhaseQuerySucceeded - PhaseQueryFailed -) - -func (p ExecutionPhase) String() string { - switch p { - case PhaseNotStarted: - return "PhaseNotStarted" - case PhaseQueued: - return "PhaseQueued" - case PhaseSubmitted: - return "PhaseSubmitted" - case PhaseQuerySucceeded: - return "PhaseQuerySucceeded" - case PhaseQueryFailed: - return "PhaseQueryFailed" - } - return "Bad Qubole execution phase" -} - -type ExecutionState struct { - Phase ExecutionPhase - - // This will store the command ID from Qubole - CommandID string `json:"command_id,omitempty"` - URI string `json:"uri,omitempty"` - - // This number keeps track of the number of failures within the sync function. Without this, what happens in - // the sync function is entirely opaque. Note that this field is completely orthogonal to Flyte system/node/task - // level retries, just errors from hitting the Qubole API, inside the sync loop - SyncFailureCount int `json:"sync_failure_count,omitempty"` - - // In kicking off the Qubole command, this is the number of failures - CreationFailureCount int `json:"creation_failure_count,omitempty"` - - // The time the execution first requests for an allocation token - AllocationTokenRequestStartTime time.Time `json:"allocation_token_request_start_time,omitempty"` -} - -// This is the main state iteration -func HandleExecutionState(ctx context.Context, tCtx core.TaskExecutionContext, currentState ExecutionState, quboleClient client.QuboleClient, - executionsCache cache.AutoRefresh, cfg *config.Config, metrics QuboleHiveExecutorMetrics) (ExecutionState, error) { - - var transformError error - var newState ExecutionState - - switch currentState.Phase { - case PhaseNotStarted: - newState, transformError = GetAllocationToken(ctx, tCtx, currentState, metrics) - - case PhaseQueued: - newState, transformError = KickOffQuery(ctx, tCtx, currentState, quboleClient, executionsCache, cfg) - - case PhaseSubmitted: - newState, transformError = MonitorQuery(ctx, tCtx, currentState, executionsCache) - - case PhaseQuerySucceeded: - newState = currentState - transformError = nil - - case PhaseQueryFailed: - newState = currentState - transformError = nil - } - - return newState, transformError -} - -func MapExecutionStateToPhaseInfo(state ExecutionState, quboleClient client.QuboleClient) core.PhaseInfo { - var phaseInfo core.PhaseInfo - t := time.Now() - - switch state.Phase { - case PhaseNotStarted: - phaseInfo = core.PhaseInfoNotReady(t, core.DefaultPhaseVersion, "Haven't received allocation token") - case PhaseQueued: - // TODO: Turn into config - if state.CreationFailureCount > 5 { - phaseInfo = core.PhaseInfoSystemRetryableFailure("QuboleFailure", "Too many creation attempts", nil) - } else { - phaseInfo = core.PhaseInfoQueued(t, uint32(state.CreationFailureCount), "Waiting for Qubole launch") - } - case PhaseSubmitted: - phaseInfo = core.PhaseInfoRunning(core.DefaultPhaseVersion, ConstructTaskInfo(state)) - - case PhaseQuerySucceeded: - phaseInfo = core.PhaseInfoSuccess(ConstructTaskInfo(state)) - - case PhaseQueryFailed: - phaseInfo = core.PhaseInfoRetryableFailure(errors.DownstreamSystemError, "Query failed", ConstructTaskInfo(state)) - } - - return phaseInfo -} - -func ConstructTaskLog(e ExecutionState) *idlCore.TaskLog { - return &idlCore.TaskLog{ - Name: fmt.Sprintf("Status: %s [%s]", e.Phase, e.CommandID), - MessageFormat: idlCore.TaskLog_UNKNOWN, - Uri: e.URI, - } -} - -func ConstructTaskInfo(e ExecutionState) *core.TaskInfo { - logs := make([]*idlCore.TaskLog, 0, 1) - t := time.Now() - if e.CommandID != "" { - logs = append(logs, ConstructTaskLog(e)) - return &core.TaskInfo{ - Logs: logs, - OccurredAt: &t, - } - } - - return nil -} - -func composeResourceNamespaceWithClusterPrimaryLabel(ctx context.Context, tCtx core.TaskExecutionContext) (core.ResourceNamespace, error) { - _, clusterLabelOverride, _, _, _, err := GetQueryInfo(ctx, tCtx) - if err != nil { - return "", err - } - clusterPrimaryLabel := getClusterPrimaryLabel(ctx, tCtx, clusterLabelOverride) - return core.ResourceNamespace(clusterPrimaryLabel), nil -} - -func createResourceConstraintsSpec(ctx context.Context, _ core.TaskExecutionContext, targetClusterPrimaryLabel core.ResourceNamespace) core.ResourceConstraintsSpec { - cfg := config.GetQuboleConfig() - constraintsSpec := core.ResourceConstraintsSpec{ - ProjectScopeResourceConstraint: nil, - NamespaceScopeResourceConstraint: nil, - } - if cfg.ClusterConfigs == nil { - logger.Infof(ctx, "No cluster config is found. Returning an empty resource constraints spec") - return constraintsSpec - } - for _, cluster := range cfg.ClusterConfigs { - if cluster.PrimaryLabel == string(targetClusterPrimaryLabel) { - constraintsSpec.ProjectScopeResourceConstraint = &core.ResourceConstraint{Value: int64(float64(cluster.Limit) * cluster.ProjectScopeQuotaProportionCap)} - constraintsSpec.NamespaceScopeResourceConstraint = &core.ResourceConstraint{Value: int64(float64(cluster.Limit) * cluster.NamespaceScopeQuotaProportionCap)} - break - } - } - logger.Infof(ctx, "Created a resource constraints spec: [%v]", constraintsSpec) - return constraintsSpec -} - -func GetAllocationToken(ctx context.Context, tCtx core.TaskExecutionContext, currentState ExecutionState, metric QuboleHiveExecutorMetrics) (ExecutionState, error) { - newState := ExecutionState{} - uniqueID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() - - clusterPrimaryLabel, err := composeResourceNamespaceWithClusterPrimaryLabel(ctx, tCtx) - if err != nil { - return newState, errors.Wrapf(errors.ResourceManagerFailure, err, "Error getting query info when requesting allocation token %s", uniqueID) - } - - resourceConstraintsSpec := createResourceConstraintsSpec(ctx, tCtx, clusterPrimaryLabel) - - allocationStatus, err := tCtx.ResourceManager().AllocateResource(ctx, clusterPrimaryLabel, uniqueID, resourceConstraintsSpec) - if err != nil { - logger.Errorf(ctx, "Resource manager failed for TaskExecId [%s] token [%s]. error %s", - tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID(), uniqueID, err) - return newState, errors.Wrapf(errors.ResourceManagerFailure, err, "Error requesting allocation token %s", uniqueID) - } - logger.Infof(ctx, "Allocation result for [%s] is [%s]", uniqueID, allocationStatus) - - // Emitting the duration this execution has been waiting for a token allocation - if currentState.AllocationTokenRequestStartTime.IsZero() { - newState.AllocationTokenRequestStartTime = time.Now() - } else { - newState.AllocationTokenRequestStartTime = currentState.AllocationTokenRequestStartTime - } - waitTime := time.Since(newState.AllocationTokenRequestStartTime) - metric.ResourceWaitTime.Observe(waitTime.Seconds()) - - if allocationStatus == core.AllocationStatusGranted { - metric.AllocationGranted.Inc(ctx) - newState.Phase = PhaseQueued - } else if allocationStatus == core.AllocationStatusExhausted { - metric.AllocationNotGranted.Inc(ctx) - newState.Phase = PhaseNotStarted - } else if allocationStatus == core.AllocationStatusNamespaceQuotaExceeded { - metric.AllocationNotGranted.Inc(ctx) - newState.Phase = PhaseNotStarted - } else { - return newState, errors.Errorf(errors.ResourceManagerFailure, "Got bad allocation result [%s] for token [%s]", - allocationStatus, uniqueID) - } - - return newState, nil -} - -func validateQuboleHiveJob(hiveJob plugins.QuboleHiveJob) error { - if hiveJob.Query == nil { - return errors.Errorf(errors.BadTaskSpecification, - "Query could not be found. Please ensure that you are at least on Flytekit version 0.3.0 or later.") - } - return nil -} - -// This function is the link between the output written by the SDK, and the execution side. It extracts the query -// out of the task template. -func GetQueryInfo(ctx context.Context, tCtx core.TaskExecutionContext) ( - query string, cluster string, tags []string, timeoutSec uint32, taskName string, err error) { - - taskTemplate, err := tCtx.TaskReader().Read(ctx) - if err != nil { - return "", "", []string{}, 0, "", err - } - - hiveJob := plugins.QuboleHiveJob{} - err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &hiveJob) - if err != nil { - return "", "", []string{}, 0, "", err - } - - if err := validateQuboleHiveJob(hiveJob); err != nil { - return "", "", []string{}, 0, "", err - } - - query = hiveJob.Query.GetQuery() - cluster = hiveJob.ClusterLabel - timeoutSec = hiveJob.Query.TimeoutSec - taskName = taskTemplate.Id.Name - tags = hiveJob.Tags - tags = append(tags, fmt.Sprintf("ns:%s", tCtx.TaskExecutionMetadata().GetNamespace())) - for k, v := range tCtx.TaskExecutionMetadata().GetLabels() { - tags = append(tags, fmt.Sprintf("%s:%s", k, v)) - } - logger.Debugf(ctx, "QueryInfo: query: [%v], cluster: [%v], timeoutSec: [%v], tags: [%v]", query, cluster, timeoutSec, tags) - return -} - -func mapLabelToPrimaryLabel(ctx context.Context, quboleCfg *config.Config, label string) (primaryLabel string, found bool) { - primaryLabel = quboleCfg.DefaultClusterLabel - found = false - - if label == "" { - logger.Debugf(ctx, "Input cluster label is an empty string; falling back to using the default primary label [%v]", label, primaryLabel) - return - } - - // Using a linear search because N is small and because of ClusterConfig's struct definition - // which is determined specifically for the readability of the corresponding configmap yaml file - for _, clusterCfg := range quboleCfg.ClusterConfigs { - for _, l := range clusterCfg.Labels { - if label != "" && l == label { - logger.Debugf(ctx, "Found the primary label [%v] for label [%v]", clusterCfg.PrimaryLabel, label) - primaryLabel, found = clusterCfg.PrimaryLabel, true - break - } - } - } - - if !found { - logger.Debugf(ctx, "Cannot find the primary cluster label for label [%v] in configmap; "+ - "falling back to using the default primary label [%v]", label, primaryLabel) - } - - return primaryLabel, found -} - -func mapProjectDomainToDestinationClusterLabel(ctx context.Context, tCtx core.TaskExecutionContext, quboleCfg *config.Config) (string, bool) { - tExecID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID() - project := tExecID.NodeExecutionId.GetExecutionId().GetProject() - domain := tExecID.NodeExecutionId.GetExecutionId().GetDomain() - logger.Debugf(ctx, "No clusterLabelOverride. Finding the pre-defined cluster label for (project: %v, domain: %v)", project, domain) - // Using a linear search because N is small - for _, m := range quboleCfg.DestinationClusterConfigs { - if project == m.Project && domain == m.Domain { - logger.Debugf(ctx, "Found the pre-defined cluster label [%v] for (project: %v, domain: %v)", m.ClusterLabel, project, domain) - return m.ClusterLabel, true - } - } - - // This function finds the label, not primary label, so in the case where no mapping is found, this function should return an empty string - return "", false -} - -func getClusterPrimaryLabel(ctx context.Context, tCtx core.TaskExecutionContext, clusterLabelOverride string) string { - cfg := config.GetQuboleConfig() - - // If override is not empty and if it has a mapping, we return the mapped primary label - if clusterLabelOverride != "" { - if primaryLabel, found := mapLabelToPrimaryLabel(ctx, cfg, clusterLabelOverride); found { - return primaryLabel - } - } - - // If override is empty or if the override does not have a mapping, we return the primary label mapped using (project, domain) - if clusterLabel, found := mapProjectDomainToDestinationClusterLabel(ctx, tCtx, cfg); found { - primaryLabel, _ := mapLabelToPrimaryLabel(ctx, cfg, clusterLabel) - return primaryLabel - } - - // Else we return the default primary label - return cfg.DefaultClusterLabel -} - -func KickOffQuery(ctx context.Context, tCtx core.TaskExecutionContext, currentState ExecutionState, quboleClient client.QuboleClient, - cache cache.AutoRefresh, cfg *config.Config) (ExecutionState, error) { - - uniqueID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() - apiKey, err := tCtx.SecretManager().Get(ctx, cfg.TokenKey) - if err != nil { - return currentState, errors.Wrapf(errors.RuntimeFailure, err, "Failed to read token from secrets manager") - } - - query, clusterLabelOverride, tags, timeoutSec, taskName, err := GetQueryInfo(ctx, tCtx) - if err != nil { - return currentState, err - } - - clusterPrimaryLabel := getClusterPrimaryLabel(ctx, tCtx, clusterLabelOverride) - - taskExecutionIdentifier := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID() - commandMetadata := client.CommandMetadata{TaskName: taskName, - Domain: taskExecutionIdentifier.GetTaskId().GetDomain(), - Project: taskExecutionIdentifier.GetNodeExecutionId().GetExecutionId().GetProject(), - Labels: tCtx.TaskExecutionMetadata().GetLabels(), - AttemptNumber: taskExecutionIdentifier.GetRetryAttempt(), - MaxAttempts: tCtx.TaskExecutionMetadata().GetMaxAttempts(), - } - - cmdDetails, err := quboleClient.ExecuteHiveCommand(ctx, query, timeoutSec, - clusterPrimaryLabel, apiKey, tags, commandMetadata) - if err != nil { - // If we failed, we'll keep the NotStarted state - currentState.CreationFailureCount = currentState.CreationFailureCount + 1 - logger.Warnf(ctx, "Error creating Qubole query for %s, failure counts %d. Error: %s", uniqueID, currentState.CreationFailureCount, err) - } else { - // If we succeed, then store the command id returned from Qubole, and update our state. Also, add to the - // AutoRefreshCache so we start getting updates. - commandID := strconv.FormatInt(cmdDetails.ID, 10) - logger.Infof(ctx, "Created Qubole ID [%s] for token %s", commandID, uniqueID) - currentState.CommandID = commandID - currentState.Phase = PhaseSubmitted - currentState.URI = cmdDetails.URI.String() - - executionStateCacheItem := ExecutionStateCacheItem{ - ExecutionState: currentState, - Identifier: uniqueID, - } - - // The first time we put it in the cache, we know it won't have succeeded so we don't need to look at it - _, err := cache.GetOrCreate(uniqueID, executionStateCacheItem) - if err != nil { - // This means that our cache has fundamentally broken... return a system error - logger.Errorf(ctx, "Cache failed to GetOrCreate for execution [%s] cache key [%s], owner [%s]. Error %s", - taskExecutionIdentifier, uniqueID, - tCtx.TaskExecutionMetadata().GetOwnerReference(), err) - return currentState, err - } - } - - return currentState, nil -} - -func MonitorQuery(ctx context.Context, tCtx core.TaskExecutionContext, currentState ExecutionState, cache cache.AutoRefresh) ( - ExecutionState, error) { - - uniqueID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() - executionStateCacheItem := ExecutionStateCacheItem{ - ExecutionState: currentState, - Identifier: uniqueID, - } - - cachedItem, err := cache.GetOrCreate(uniqueID, executionStateCacheItem) - if err != nil { - // This means that our cache has fundamentally broken... return a system error - logger.Errorf(ctx, "Cache is broken on execution [%s] cache key [%s], owner [%s]. Error %s", - tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID(), uniqueID, - tCtx.TaskExecutionMetadata().GetOwnerReference(), err) - return currentState, errors.Wrapf(errors.CacheFailed, err, "Error when GetOrCreate while monitoring") - } - - cachedExecutionState, ok := cachedItem.(ExecutionStateCacheItem) - if !ok { - logger.Errorf(ctx, "Error casting cache object into ExecutionState") - return currentState, errors.Errorf(errors.CacheFailed, "Failed to cast [%v]", cachedItem) - } - - // TODO: Add a couple of debug lines here - did it change or did it not? - - // If there were updates made to the state, we'll have picked them up automatically. Nothing more to do. - return cachedExecutionState.ExecutionState, nil -} - -func Abort(ctx context.Context, tCtx core.TaskExecutionContext, currentState ExecutionState, qubole client.QuboleClient, apiKey string) error { - // Cancel Qubole query if non-terminal state - if !InTerminalState(currentState) && currentState.CommandID != "" { - err := qubole.KillCommand(ctx, currentState.CommandID, apiKey) - if err != nil { - logger.Errorf(ctx, "Error terminating Qubole command in Finalize [%s]", err) - return err - } - } - return nil -} - -func Finalize(ctx context.Context, tCtx core.TaskExecutionContext, _ ExecutionState, metrics QuboleHiveExecutorMetrics) error { - // Release allocation token - uniqueID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() - clusterPrimaryLabel, err := composeResourceNamespaceWithClusterPrimaryLabel(ctx, tCtx) - if err != nil { - return errors.Wrapf(errors.ResourceManagerFailure, err, "Error getting query info when releasing allocation token %s", uniqueID) - } - - err = tCtx.ResourceManager().ReleaseResource(ctx, clusterPrimaryLabel, uniqueID) - - if err != nil { - metrics.ResourceReleaseFailed.Inc(ctx) - logger.Errorf(ctx, "Error releasing allocation token [%s] in Finalize [%s]", uniqueID, err) - return err - } - metrics.ResourceReleased.Inc(ctx) - return nil -} - -func InTerminalState(e ExecutionState) bool { - return e.Phase == PhaseQuerySucceeded || e.Phase == PhaseQueryFailed -} - -func IsNotYetSubmitted(e ExecutionState) bool { - if e.Phase == PhaseNotStarted || e.Phase == PhaseQueued { - return true - } - return false -} diff --git a/go/tasks/plugins/hive/execution_state_test.go b/go/tasks/plugins/hive/execution_state_test.go deleted file mode 100644 index e22e05473..000000000 --- a/go/tasks/plugins/hive/execution_state_test.go +++ /dev/null @@ -1,447 +0,0 @@ -package hive - -import ( - "context" - "net/url" - "testing" - "time" - - "github.com/lyft/flytestdlib/contextutils" - "github.com/lyft/flytestdlib/promutils/labeled" - - "github.com/lyft/flytestdlib/promutils" - - idlCore "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" - - mocks2 "github.com/lyft/flytestdlib/cache/mocks" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - - "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core/mocks" - pluginsCoreMocks "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core/mocks" - "github.com/lyft/flyteplugins/go/tasks/plugins/hive/client" - quboleMocks "github.com/lyft/flyteplugins/go/tasks/plugins/hive/client/mocks" - "github.com/lyft/flyteplugins/go/tasks/plugins/hive/config" -) - -func init() { - labeled.SetMetricKeys(contextutils.NamespaceKey) -} - -func TestInTerminalState(t *testing.T) { - var stateTests = []struct { - phase ExecutionPhase - isTerminal bool - }{ - {phase: PhaseNotStarted, isTerminal: false}, - {phase: PhaseQueued, isTerminal: false}, - {phase: PhaseSubmitted, isTerminal: false}, - {phase: PhaseQuerySucceeded, isTerminal: true}, - {phase: PhaseQueryFailed, isTerminal: true}, - } - - for _, tt := range stateTests { - t.Run(tt.phase.String(), func(t *testing.T) { - e := ExecutionState{Phase: tt.phase} - res := InTerminalState(e) - assert.Equal(t, tt.isTerminal, res) - }) - } -} - -func TestIsNotYetSubmitted(t *testing.T) { - var stateTests = []struct { - phase ExecutionPhase - isNotYetSubmitted bool - }{ - {phase: PhaseNotStarted, isNotYetSubmitted: true}, - {phase: PhaseQueued, isNotYetSubmitted: true}, - {phase: PhaseSubmitted, isNotYetSubmitted: false}, - {phase: PhaseQuerySucceeded, isNotYetSubmitted: false}, - {phase: PhaseQueryFailed, isNotYetSubmitted: false}, - } - - for _, tt := range stateTests { - t.Run(tt.phase.String(), func(t *testing.T) { - e := ExecutionState{Phase: tt.phase} - res := IsNotYetSubmitted(e) - assert.Equal(t, tt.isNotYetSubmitted, res) - }) - } -} - -func TestGetQueryInfo(t *testing.T) { - ctx := context.Background() - - taskTemplate := GetSingleHiveQueryTaskTemplate() - mockTaskReader := &mocks.TaskReader{} - mockTaskReader.On("Read", mock.Anything).Return(&taskTemplate, nil) - - mockTaskExecutionContext := mocks.TaskExecutionContext{} - mockTaskExecutionContext.On("TaskReader").Return(mockTaskReader) - - taskMetadata := &pluginsCoreMocks.TaskExecutionMetadata{} - taskMetadata.On("GetNamespace").Return("myproject-staging") - taskMetadata.On("GetLabels").Return(map[string]string{"sample": "label"}) - mockTaskExecutionContext.On("TaskExecutionMetadata").Return(taskMetadata) - - query, cluster, tags, timeout, taskName, err := GetQueryInfo(ctx, &mockTaskExecutionContext) - assert.NoError(t, err) - assert.Equal(t, "select 'one'", query) - assert.Equal(t, "default", cluster) - assert.Equal(t, []string{"flyte_plugin_test", "ns:myproject-staging", "sample:label"}, tags) - assert.Equal(t, 500, int(timeout)) - assert.Equal(t, "sample_hive_task_test_name", taskName) -} - -func TestValidateQuboleHiveJob(t *testing.T) { - hiveJob := plugins.QuboleHiveJob{ - ClusterLabel: "default", - Tags: []string{"flyte_plugin_test", "sample:label"}, - Query: nil, - } - err := validateQuboleHiveJob(hiveJob) - assert.Error(t, err) -} - -func TestConstructTaskLog(t *testing.T) { - expected := "https://wellness.qubole.com/v2/analyze?command_id=123" - u, err := url.Parse(expected) - assert.NoError(t, err) - taskLog := ConstructTaskLog(ExecutionState{CommandID: "123", URI: u.String()}) - assert.Equal(t, expected, taskLog.Uri) -} - -func TestConstructTaskInfo(t *testing.T) { - empty := ConstructTaskInfo(ExecutionState{}) - assert.Nil(t, empty) - - expected := "https://wellness.qubole.com/v2/analyze?command_id=123" - u, err := url.Parse(expected) - assert.NoError(t, err) - - e := ExecutionState{ - Phase: PhaseQuerySucceeded, - CommandID: "123", - SyncFailureCount: 0, - URI: u.String(), - } - - taskInfo := ConstructTaskInfo(e) - assert.Equal(t, "https://wellness.qubole.com/v2/analyze?command_id=123", taskInfo.Logs[0].Uri) -} - -func TestMapExecutionStateToPhaseInfo(t *testing.T) { - c := client.NewQuboleClient(config.GetQuboleConfig()) - t.Run("NotStarted", func(t *testing.T) { - e := ExecutionState{ - Phase: PhaseNotStarted, - } - phaseInfo := MapExecutionStateToPhaseInfo(e, c) - assert.Equal(t, core.PhaseNotReady, phaseInfo.Phase()) - }) - - t.Run("Queued", func(t *testing.T) { - e := ExecutionState{ - Phase: PhaseQueued, - CreationFailureCount: 0, - } - phaseInfo := MapExecutionStateToPhaseInfo(e, c) - assert.Equal(t, core.PhaseQueued, phaseInfo.Phase()) - - e = ExecutionState{ - Phase: PhaseQueued, - CreationFailureCount: 100, - } - phaseInfo = MapExecutionStateToPhaseInfo(e, c) - assert.Equal(t, core.PhaseRetryableFailure, phaseInfo.Phase()) - - }) - - t.Run("Submitted", func(t *testing.T) { - e := ExecutionState{ - Phase: PhaseSubmitted, - } - phaseInfo := MapExecutionStateToPhaseInfo(e, c) - assert.Equal(t, core.PhaseRunning, phaseInfo.Phase()) - }) -} - -func TestGetAllocationToken(t *testing.T) { - ctx := context.Background() - - t.Run("allocation granted", func(t *testing.T) { - tCtx := GetMockTaskExecutionContext() - mockResourceManager := tCtx.ResourceManager() - x := mockResourceManager.(*mocks.ResourceManager) - x.On("AllocateResource", mock.Anything, mock.Anything, mock.Anything, mock.Anything). - Return(core.AllocationStatusGranted, nil) - - mockCurrentState := ExecutionState{AllocationTokenRequestStartTime: time.Now()} - mockMetrics := getQuboleHiveExecutorMetrics(promutils.NewTestScope()) - state, err := GetAllocationToken(ctx, tCtx, mockCurrentState, mockMetrics) - assert.NoError(t, err) - assert.Equal(t, PhaseQueued, state.Phase) - }) - - t.Run("exhausted", func(t *testing.T) { - tCtx := GetMockTaskExecutionContext() - mockResourceManager := tCtx.ResourceManager() - x := mockResourceManager.(*mocks.ResourceManager) - x.On("AllocateResource", mock.Anything, mock.Anything, mock.Anything, mock.Anything). - Return(core.AllocationStatusExhausted, nil) - - mockCurrentState := ExecutionState{AllocationTokenRequestStartTime: time.Now()} - mockMetrics := getQuboleHiveExecutorMetrics(promutils.NewTestScope()) - state, err := GetAllocationToken(ctx, tCtx, mockCurrentState, mockMetrics) - assert.NoError(t, err) - assert.Equal(t, PhaseNotStarted, state.Phase) - }) - - t.Run("namespace exhausted", func(t *testing.T) { - tCtx := GetMockTaskExecutionContext() - mockResourceManager := tCtx.ResourceManager() - x := mockResourceManager.(*mocks.ResourceManager) - x.On("AllocateResource", mock.Anything, mock.Anything, mock.Anything, mock.Anything). - Return(core.AllocationStatusNamespaceQuotaExceeded, nil) - - mockCurrentState := ExecutionState{AllocationTokenRequestStartTime: time.Now()} - mockMetrics := getQuboleHiveExecutorMetrics(promutils.NewTestScope()) - state, err := GetAllocationToken(ctx, tCtx, mockCurrentState, mockMetrics) - assert.NoError(t, err) - assert.Equal(t, PhaseNotStarted, state.Phase) - }) - - t.Run("Request start time, if empty in current state, should be set", func(t *testing.T) { - tCtx := GetMockTaskExecutionContext() - mockResourceManager := tCtx.ResourceManager() - x := mockResourceManager.(*mocks.ResourceManager) - x.On("AllocateResource", mock.Anything, mock.Anything, mock.Anything, mock.Anything). - Return(core.AllocationStatusNamespaceQuotaExceeded, nil) - - mockCurrentState := ExecutionState{} - mockMetrics := getQuboleHiveExecutorMetrics(promutils.NewTestScope()) - state, err := GetAllocationToken(ctx, tCtx, mockCurrentState, mockMetrics) - assert.NoError(t, err) - assert.Equal(t, state.AllocationTokenRequestStartTime.IsZero(), false) - }) - - t.Run("Request start time, if already set in current state, should be maintained", func(t *testing.T) { - tCtx := GetMockTaskExecutionContext() - mockResourceManager := tCtx.ResourceManager() - x := mockResourceManager.(*mocks.ResourceManager) - x.On("AllocateResource", mock.Anything, mock.Anything, mock.Anything, mock.Anything). - Return(core.AllocationStatusGranted, nil) - - startTime := time.Now() - mockCurrentState := ExecutionState{AllocationTokenRequestStartTime: startTime} - mockMetrics := getQuboleHiveExecutorMetrics(promutils.NewTestScope()) - state, err := GetAllocationToken(ctx, tCtx, mockCurrentState, mockMetrics) - assert.NoError(t, err) - assert.Equal(t, state.AllocationTokenRequestStartTime.IsZero(), false) - assert.Equal(t, state.AllocationTokenRequestStartTime, startTime) - }) -} - -func TestAbort(t *testing.T) { - ctx := context.Background() - - t.Run("Terminate called when not in terminal state", func(t *testing.T) { - var x = false - mockQubole := &quboleMocks.QuboleClient{} - mockQubole.On("KillCommand", mock.Anything, mock.MatchedBy(func(commandId string) bool { - return commandId == "123456" - }), mock.Anything).Run(func(_ mock.Arguments) { - x = true - }).Return(nil) - - err := Abort(ctx, GetMockTaskExecutionContext(), ExecutionState{Phase: PhaseSubmitted, CommandID: "123456"}, mockQubole, "fake-key") - assert.NoError(t, err) - assert.True(t, x) - }) - - t.Run("Terminate not called when in terminal state", func(t *testing.T) { - var x = false - mockQubole := &quboleMocks.QuboleClient{} - mockQubole.On("KillCommand", mock.Anything, mock.Anything, mock.Anything).Run(func(_ mock.Arguments) { - x = true - }).Return(nil) - - err := Abort(ctx, GetMockTaskExecutionContext(), ExecutionState{ - Phase: PhaseQuerySucceeded, - CommandID: "123456", - }, mockQubole, "fake-key") - assert.NoError(t, err) - assert.False(t, x) - }) -} - -func TestFinalize(t *testing.T) { - // Test that Finalize releases resources - ctx := context.Background() - tCtx := GetMockTaskExecutionContext() - state := ExecutionState{} - var called = false - mockResourceManager := tCtx.ResourceManager() - x := mockResourceManager.(*mocks.ResourceManager) - x.On("ReleaseResource", mock.Anything, mock.Anything, mock.Anything).Run(func(_ mock.Arguments) { - called = true - }).Return(nil) - - err := Finalize(ctx, tCtx, state, getQuboleHiveExecutorMetrics(promutils.NewTestScope())) - assert.NoError(t, err) - assert.True(t, called) -} - -func TestMonitorQuery(t *testing.T) { - ctx := context.Background() - tCtx := GetMockTaskExecutionContext() - state := ExecutionState{ - Phase: PhaseSubmitted, - } - var getOrCreateCalled = false - mockCache := &mocks2.AutoRefresh{} - mockCache.OnGetOrCreateMatch("my_wf_exec_project:my_wf_exec_domain:my_wf_exec_name", mock.Anything).Return(ExecutionStateCacheItem{ - ExecutionState: ExecutionState{Phase: PhaseQuerySucceeded}, - Identifier: "my_wf_exec_project:my_wf_exec_domain:my_wf_exec_name", - }, nil).Run(func(_ mock.Arguments) { - getOrCreateCalled = true - }) - - newState, err := MonitorQuery(ctx, tCtx, state, mockCache) - assert.NoError(t, err) - assert.True(t, getOrCreateCalled) - assert.Equal(t, PhaseQuerySucceeded, newState.Phase) -} - -func TestKickOffQuery(t *testing.T) { - ctx := context.Background() - tCtx := GetMockTaskExecutionContext() - - var quboleCalled = false - quboleCommandDetails := &client.QuboleCommandDetails{ - ID: int64(453298043), - Status: client.QuboleStatusWaiting, - } - mockQubole := &quboleMocks.QuboleClient{} - mockQubole.OnExecuteHiveCommandMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, - mock.Anything, mock.Anything, mock.Anything).Run(func(_ mock.Arguments) { - quboleCalled = true - }).Return(quboleCommandDetails, nil) - - var getOrCreateCalled = false - mockCache := &mocks2.AutoRefresh{} - mockCache.OnGetOrCreate(mock.Anything, mock.Anything).Run(func(_ mock.Arguments) { - getOrCreateCalled = true - }).Return(ExecutionStateCacheItem{}, nil) - - state := ExecutionState{} - newState, err := KickOffQuery(ctx, tCtx, state, mockQubole, mockCache, config.GetQuboleConfig()) - assert.NoError(t, err) - assert.Equal(t, PhaseSubmitted, newState.Phase) - assert.Equal(t, "453298043", newState.CommandID) - assert.True(t, getOrCreateCalled) - assert.True(t, quboleCalled) -} - -func createMockQuboleCfg() *config.Config { - return &config.Config{ - DefaultClusterLabel: "default", - ClusterConfigs: []config.ClusterConfig{ - {PrimaryLabel: "primary A", Labels: []string{"primary A", "A", "label A", "A-prod"}, Limit: 10}, - {PrimaryLabel: "primary B", Labels: []string{"B"}, Limit: 10}, - {PrimaryLabel: "primary C", Labels: []string{"C-prod"}, Limit: 1}, - }, - DestinationClusterConfigs: []config.DestinationClusterConfig{ - {Project: "project A", Domain: "domain X", ClusterLabel: "A-prod"}, - {Project: "project A", Domain: "domain Y", ClusterLabel: "A"}, - {Project: "project A", Domain: "domain Z", ClusterLabel: "B"}, - {Project: "project C", Domain: "domain X", ClusterLabel: "C-prod"}, - }, - } -} - -func Test_mapLabelToPrimaryLabel(t *testing.T) { - ctx := context.TODO() - mockQuboleCfg := createMockQuboleCfg() - - type args struct { - ctx context.Context - quboleCfg *config.Config - label string - } - tests := []struct { - name string - args args - want string - wantFound bool - }{ - {name: "Label has a mapping", args: args{ctx: ctx, quboleCfg: mockQuboleCfg, label: "A-prod"}, want: "primary A", wantFound: true}, - {name: "Label has a typo", args: args{ctx: ctx, quboleCfg: mockQuboleCfg, label: "a"}, want: DefaultClusterPrimaryLabel, wantFound: false}, - {name: "Label has a mapping 2", args: args{ctx: ctx, quboleCfg: mockQuboleCfg, label: "C-prod"}, want: "primary C", wantFound: true}, - {name: "Label has a typo 2", args: args{ctx: ctx, quboleCfg: mockQuboleCfg, label: "C_prod"}, want: DefaultClusterPrimaryLabel, wantFound: false}, - {name: "Label has a mapping 3", args: args{ctx: ctx, quboleCfg: mockQuboleCfg, label: "primary A"}, want: "primary A", wantFound: true}, - {name: "Label has no mapping", args: args{ctx: ctx, quboleCfg: mockQuboleCfg, label: "D"}, want: DefaultClusterPrimaryLabel, wantFound: false}, - {name: "Label is an empty string", args: args{ctx: ctx, quboleCfg: mockQuboleCfg, label: ""}, want: DefaultClusterPrimaryLabel, wantFound: false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got, found := mapLabelToPrimaryLabel(tt.args.ctx, tt.args.quboleCfg, tt.args.label); got != tt.want || found != tt.wantFound { - t.Errorf("mapLabelToPrimaryLabel() = (%v, %v), want (%v, %v)", got, found, tt.want, tt.wantFound) - } - }) - } -} - -func createMockTaskExecutionContextWithProjectDomain(project string, domain string) *mocks.TaskExecutionContext { - mockTaskExecutionContext := mocks.TaskExecutionContext{} - taskExecID := &pluginsCoreMocks.TaskExecutionID{} - taskExecID.OnGetID().Return(idlCore.TaskExecutionIdentifier{ - NodeExecutionId: &idlCore.NodeExecutionIdentifier{ExecutionId: &idlCore.WorkflowExecutionIdentifier{ - Project: project, - Domain: domain, - Name: "random name", - }}, - }) - - taskMetadata := &pluginsCoreMocks.TaskExecutionMetadata{} - taskMetadata.OnGetTaskExecutionID().Return(taskExecID) - mockTaskExecutionContext.On("TaskExecutionMetadata").Return(taskMetadata) - return &mockTaskExecutionContext -} - -func Test_getClusterPrimaryLabel(t *testing.T) { - ctx := context.TODO() - err := config.SetQuboleConfig(createMockQuboleCfg()) - assert.Nil(t, err) - - type args struct { - ctx context.Context - tCtx core.TaskExecutionContext - clusterLabelOverride string - } - tests := []struct { - name string - args args - want string - }{ - {name: "Override is not empty + override has NO existing mapping + project-domain has an existing mapping", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project A", "domain Z"), clusterLabelOverride: "AAAA"}, want: "primary B"}, - {name: "Override is not empty + override has NO existing mapping + project-domain has NO existing mapping", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project A", "domain blah"), clusterLabelOverride: "blh"}, want: DefaultClusterPrimaryLabel}, - {name: "Override is not empty + override has an existing mapping + project-domain has NO existing mapping", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project blah", "domain blah"), clusterLabelOverride: "C-prod"}, want: "primary C"}, - {name: "Override is not empty + override has an existing mapping + project-domain has an existing mapping", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project A", "domain A"), clusterLabelOverride: "C-prod"}, want: "primary C"}, - {name: "Override is empty + project-domain has an existing mapping", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project A", "domain X"), clusterLabelOverride: ""}, want: "primary A"}, - {name: "Override is empty + project-domain has an existing mapping2", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project A", "domain Z"), clusterLabelOverride: ""}, want: "primary B"}, - {name: "Override is empty + project-domain has NO existing mapping", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project A", "domain blah"), clusterLabelOverride: ""}, want: DefaultClusterPrimaryLabel}, - {name: "Override is empty + project-domain has NO existing mapping2", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project blah", "domain X"), clusterLabelOverride: ""}, want: DefaultClusterPrimaryLabel}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := getClusterPrimaryLabel(tt.args.ctx, tt.args.tCtx, tt.args.clusterLabelOverride); got != tt.want { - t.Errorf("getClusterPrimaryLabel() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/go/tasks/plugins/hive/executions_cache.go b/go/tasks/plugins/hive/executions_cache.go deleted file mode 100644 index 3e7347ffa..000000000 --- a/go/tasks/plugins/hive/executions_cache.go +++ /dev/null @@ -1,172 +0,0 @@ -package hive - -import ( - "context" - "time" - - "k8s.io/client-go/util/workqueue" - - "github.com/lyft/flytestdlib/cache" - - "github.com/lyft/flyteplugins/go/tasks/errors" - stdErrors "github.com/lyft/flytestdlib/errors" - - "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/lyft/flyteplugins/go/tasks/plugins/hive/client" - "github.com/lyft/flyteplugins/go/tasks/plugins/hive/config" - - "github.com/lyft/flytestdlib/logger" - "github.com/lyft/flytestdlib/promutils" -) - -const ResyncDuration = 30 * time.Second - -const ( - BadQuboleReturnCodeError stdErrors.ErrorCode = "QUBOLE_RETURNED_UNKNOWN" -) - -type QuboleHiveExecutionsCache struct { - cache.AutoRefresh - quboleClient client.QuboleClient - secretManager core.SecretManager - scope promutils.Scope - cfg *config.Config -} - -func NewQuboleHiveExecutionsCache(ctx context.Context, quboleClient client.QuboleClient, - secretManager core.SecretManager, cfg *config.Config, scope promutils.Scope) (QuboleHiveExecutionsCache, error) { - - q := QuboleHiveExecutionsCache{ - quboleClient: quboleClient, - secretManager: secretManager, - scope: scope, - cfg: cfg, - } - autoRefreshCache, err := cache.NewAutoRefreshCache("qubole", q.SyncQuboleQuery, workqueue.DefaultControllerRateLimiter(), ResyncDuration, cfg.Workers, cfg.LruCacheSize, scope) - if err != nil { - logger.Errorf(ctx, "Could not create AutoRefreshCache in QuboleHiveExecutor. [%s]", err) - return q, errors.Wrapf(errors.CacheFailed, err, "Error creating AutoRefreshCache") - } - q.AutoRefresh = autoRefreshCache - return q, nil -} - -type ExecutionStateCacheItem struct { - ExecutionState - - // This ID is the cache key and so will need to be unique across all objects in the cache (it will probably be - // unique across all of Flyte) and needs to be deterministic. - // This will also be used as the allocation token for now. - Identifier string `json:"id"` -} - -func (e ExecutionStateCacheItem) ID() string { - return e.Identifier -} - -// This basically grab an updated status from the Qubole API and store it in the cache -// All other handling should be in the synchronous loop. -func (q *QuboleHiveExecutionsCache) SyncQuboleQuery(ctx context.Context, batch cache.Batch) ( - updatedBatch []cache.ItemSyncResponse, err error) { - - resp := make([]cache.ItemSyncResponse, 0, len(batch)) - for _, query := range batch { - // Cast the item back to the thing we want to work with. - executionStateCacheItem, ok := query.GetItem().(ExecutionStateCacheItem) - if !ok { - logger.Errorf(ctx, "Sync loop - Error casting cache object into ExecutionState") - return nil, errors.Errorf(errors.CacheFailed, "Failed to cast [%v]", batch[0].GetID()) - } - - if executionStateCacheItem.CommandID == "" { - logger.Warnf(ctx, "Sync loop - CommandID is blank for [%s] skipping", executionStateCacheItem.Identifier) - resp = append(resp, cache.ItemSyncResponse{ - ID: query.GetID(), - Item: query.GetItem(), - Action: cache.Unchanged, - }) - - continue - } - - logger.Debugf(ctx, "Sync loop - processing Hive job [%s] - cache key [%s]", - executionStateCacheItem.CommandID, executionStateCacheItem.Identifier) - - quboleAPIKey, err := q.secretManager.Get(ctx, q.cfg.TokenKey) - if err != nil { - return nil, err - } - - if InTerminalState(executionStateCacheItem.ExecutionState) { - logger.Debugf(ctx, "Sync loop - Qubole id [%s] in terminal state [%s]", - executionStateCacheItem.CommandID, executionStateCacheItem.Identifier) - - resp = append(resp, cache.ItemSyncResponse{ - ID: query.GetID(), - Item: query.GetItem(), - Action: cache.Unchanged, - }) - - continue - } - - // Get an updated status from Qubole - logger.Debugf(ctx, "Querying Qubole for %s - %s", executionStateCacheItem.CommandID, executionStateCacheItem.Identifier) - commandStatus, err := q.quboleClient.GetCommandStatus(ctx, executionStateCacheItem.CommandID, quboleAPIKey) - if err != nil { - logger.Errorf(ctx, "Error from Qubole command %s", executionStateCacheItem.CommandID) - executionStateCacheItem.SyncFailureCount++ - // Make sure we don't return nil for the first argument, because that deletes it from the cache. - resp = append(resp, cache.ItemSyncResponse{ - ID: query.GetID(), - Item: executionStateCacheItem, - Action: cache.Update, - }) - - continue - } - - newExecutionPhase, err := QuboleStatusToExecutionPhase(commandStatus) - if err != nil { - return nil, err - } - - if newExecutionPhase > executionStateCacheItem.Phase { - logger.Infof(ctx, "Moving ExecutionPhase for %s %s from %s to %s", executionStateCacheItem.CommandID, - executionStateCacheItem.Identifier, executionStateCacheItem.Phase, newExecutionPhase) - - executionStateCacheItem.Phase = newExecutionPhase - - resp = append(resp, cache.ItemSyncResponse{ - ID: query.GetID(), - Item: executionStateCacheItem, - Action: cache.Update, - }) - } - } - - return resp, nil -} - -// We need some way to translate results we get from Qubole, into a plugin phase -// NB: This function should only return plugin phases that are greater than (">") phases that represent states before -// the query was kicked off. That is, it will never make sense to go back to PhaseNotStarted, after we've -// submitted the query to Qubole. -func QuboleStatusToExecutionPhase(s client.QuboleStatus) (ExecutionPhase, error) { - switch s { - case client.QuboleStatusDone: - return PhaseQuerySucceeded, nil - case client.QuboleStatusCancelled: - return PhaseQueryFailed, nil - case client.QuboleStatusError: - return PhaseQueryFailed, nil - case client.QuboleStatusWaiting: - return PhaseSubmitted, nil - case client.QuboleStatusRunning: - return PhaseSubmitted, nil - case client.QuboleStatusUnknown: - return PhaseQueryFailed, errors.Errorf(BadQuboleReturnCodeError, "Qubole returned status Unknown") - default: - return PhaseQueryFailed, errors.Errorf(BadQuboleReturnCodeError, "default fallthrough case") - } -} diff --git a/go/tasks/plugins/hive/executions_cache_test.go b/go/tasks/plugins/hive/executions_cache_test.go deleted file mode 100644 index cc33365b3..000000000 --- a/go/tasks/plugins/hive/executions_cache_test.go +++ /dev/null @@ -1,91 +0,0 @@ -package hive - -import ( - "context" - "testing" - - "github.com/lyft/flytestdlib/cache" - cacheMocks "github.com/lyft/flytestdlib/cache/mocks" - - "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core/mocks" - "github.com/lyft/flyteplugins/go/tasks/plugins/hive/client" - quboleMocks "github.com/lyft/flyteplugins/go/tasks/plugins/hive/client/mocks" - "github.com/lyft/flyteplugins/go/tasks/plugins/hive/config" - - "github.com/lyft/flytestdlib/promutils" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestQuboleHiveExecutionsCache_SyncQuboleQuery(t *testing.T) { - ctx := context.Background() - - t.Run("terminal state return unchanged", func(t *testing.T) { - mockCache := &cacheMocks.AutoRefresh{} - mockQubole := &quboleMocks.QuboleClient{} - testScope := promutils.NewTestScope() - - q := QuboleHiveExecutionsCache{ - AutoRefresh: mockCache, - quboleClient: mockQubole, - scope: testScope, - cfg: config.GetQuboleConfig(), - } - - state := ExecutionState{ - Phase: PhaseQuerySucceeded, - } - cacheItem := ExecutionStateCacheItem{ - ExecutionState: state, - Identifier: "some-id", - } - - iw := &cacheMocks.ItemWrapper{} - iw.OnGetItem().Return(cacheItem) - iw.OnGetID().Return("some-id") - - newCacheItem, err := q.SyncQuboleQuery(ctx, []cache.ItemWrapper{iw}) - assert.NoError(t, err) - assert.Equal(t, cache.Unchanged, newCacheItem[0].Action) - assert.Equal(t, cacheItem, newCacheItem[0].Item) - }) - - t.Run("move to success", func(t *testing.T) { - mockCache := &cacheMocks.AutoRefresh{} - mockQubole := &quboleMocks.QuboleClient{} - mockSecretManager := &mocks.SecretManager{} - mockSecretManager.OnGetMatch(mock.Anything, mock.Anything).Return("fake key", nil) - - testScope := promutils.NewTestScope() - - q := QuboleHiveExecutionsCache{ - AutoRefresh: mockCache, - quboleClient: mockQubole, - scope: testScope, - secretManager: mockSecretManager, - cfg: config.GetQuboleConfig(), - } - - state := ExecutionState{ - CommandID: "123456", - Phase: PhaseSubmitted, - } - cacheItem := ExecutionStateCacheItem{ - ExecutionState: state, - Identifier: "some-id", - } - mockQubole.OnGetCommandStatusMatch(mock.Anything, mock.MatchedBy(func(commandId string) bool { - return commandId == state.CommandID - }), mock.Anything).Return(client.QuboleStatusDone, nil) - - iw := &cacheMocks.ItemWrapper{} - iw.OnGetItem().Return(cacheItem) - iw.OnGetID().Return("some-id") - - newCacheItem, err := q.SyncQuboleQuery(ctx, []cache.ItemWrapper{iw}) - newExecutionState := newCacheItem[0].Item.(ExecutionStateCacheItem) - assert.NoError(t, err) - assert.Equal(t, cache.Update, newCacheItem[0].Action) - assert.Equal(t, PhaseQuerySucceeded, newExecutionState.Phase) - }) -} diff --git a/go/tasks/plugins/hive/executor.go b/go/tasks/plugins/hive/executor.go deleted file mode 100644 index 87d540fc5..000000000 --- a/go/tasks/plugins/hive/executor.go +++ /dev/null @@ -1,168 +0,0 @@ -package hive - -import ( - "context" - - "github.com/lyft/flytestdlib/cache" - - "github.com/lyft/flyteplugins/go/tasks/errors" - pluginMachinery "github.com/lyft/flyteplugins/go/tasks/pluginmachinery" - "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/lyft/flyteplugins/go/tasks/plugins/hive/client" - "github.com/lyft/flyteplugins/go/tasks/plugins/hive/config" - "github.com/lyft/flytestdlib/logger" - "github.com/lyft/flytestdlib/promutils" -) - -// This is the name of this plugin effectively. In Flyte plugin configuration, use this string to enable this plugin. -const quboleHiveExecutorID = "qubole-hive-executor" - -// Version of the custom state this plugin stores. Useful for backwards compatibility if you one day need to update -// the structure of the stored state -const pluginStateVersion = 0 - -const hiveTaskType = "hive" // This needs to match the type defined in Flytekit constants.py - -const DefaultClusterPrimaryLabel = "default" - -type QuboleHiveExecutor struct { - id string - metrics QuboleHiveExecutorMetrics - quboleClient client.QuboleClient - executionsCache cache.AutoRefresh - cfg *config.Config -} - -func (q QuboleHiveExecutor) GetID() string { - return q.id -} - -func (q QuboleHiveExecutor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (core.Transition, error) { - incomingState := ExecutionState{} - - // We assume here that the first time this function is called, the custom state we get back is whatever we passed in, - // namely the zero-value of our struct. - if _, err := tCtx.PluginStateReader().Get(&incomingState); err != nil { - logger.Errorf(ctx, "Plugin %s failed to unmarshal custom state when handling [%s] [%s]", - q.id, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), err) - return core.UnknownTransition, errors.Wrapf(errors.CorruptedPluginState, err, - "Failed to unmarshal custom state in Handle") - } - - // Do what needs to be done, and give this function everything it needs to do its job properly - // TODO: Play around with making this return a transition directly. How will that pattern affect the multi-Qubole plugin - outgoingState, transformError := HandleExecutionState(ctx, tCtx, incomingState, q.quboleClient, q.executionsCache, q.cfg, q.metrics) - - // Return if there was an error - if transformError != nil { - return core.UnknownTransition, transformError - } - - // If no error, then infer the new Phase from the various states - phaseInfo := MapExecutionStateToPhaseInfo(outgoingState, q.quboleClient) - - if err := tCtx.PluginStateWriter().Put(pluginStateVersion, outgoingState); err != nil { - return core.UnknownTransition, err - } - - return core.DoTransitionType(core.TransitionTypeBarrier, phaseInfo), nil -} - -func (q QuboleHiveExecutor) Abort(ctx context.Context, tCtx core.TaskExecutionContext) error { - incomingState := ExecutionState{} - if _, err := tCtx.PluginStateReader().Get(&incomingState); err != nil { - logger.Errorf(ctx, "Plugin %s failed to unmarshal custom state in Finalize [%s] Err [%s]", - q.id, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), err) - return errors.Wrapf(errors.CorruptedPluginState, err, "Failed to unmarshal custom state in Finalize") - } - - key, err := tCtx.SecretManager().Get(ctx, q.cfg.TokenKey) - if err != nil { - logger.Errorf(ctx, "Error reading token in Finalize [%s]", err) - return err - } - - return Abort(ctx, tCtx, incomingState, q.quboleClient, key) -} - -func (q QuboleHiveExecutor) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) error { - incomingState := ExecutionState{} - if _, err := tCtx.PluginStateReader().Get(&incomingState); err != nil { - logger.Errorf(ctx, "Plugin %s failed to unmarshal custom state in Finalize [%s] Err [%s]", - q.id, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), err) - return errors.Wrapf(errors.CorruptedPluginState, err, "Failed to unmarshal custom state in Finalize") - } - - return Finalize(ctx, tCtx, incomingState, q.metrics) -} - -func (q QuboleHiveExecutor) GetProperties() core.PluginProperties { - return core.PluginProperties{} -} - -func QuboleHiveExecutorLoader(ctx context.Context, iCtx core.SetupContext) (core.Plugin, error) { - cfg := config.GetQuboleConfig() - return InitializeHiveExecutor(ctx, iCtx, cfg, BuildResourceConfig(cfg.ClusterConfigs), client.NewQuboleClient(cfg)) -} - -func BuildResourceConfig(cfg []config.ClusterConfig) map[core.ResourceNamespace]int { - resourceConfig := make(map[core.ResourceNamespace]int, len(cfg)) - - for _, clusterCfg := range cfg { - resourceConfig[core.ResourceNamespace(clusterCfg.PrimaryLabel)] = clusterCfg.Limit - } - return resourceConfig -} - -func InitializeHiveExecutor(ctx context.Context, iCtx core.SetupContext, cfg *config.Config, resourceConfig map[core.ResourceNamespace]int, - quboleClient client.QuboleClient) (core.Plugin, error) { - logger.Infof(ctx, "Initializing a Hive executor with a resource config [%v]", resourceConfig) - q, err := NewQuboleHiveExecutor(ctx, cfg, quboleClient, iCtx.SecretManager(), iCtx.MetricsScope()) - if err != nil { - logger.Errorf(ctx, "Failed to create a new QuboleHiveExecutor due to error: [%v]", err) - return nil, err - } - - for clusterPrimaryLabel, clusterLimit := range resourceConfig { - logger.Infof(ctx, "Registering resource quota ([%v]) and namespace quota cap ([%v]) for cluster [%v]", clusterPrimaryLabel) - if err := iCtx.ResourceRegistrar().RegisterResourceQuota(ctx, clusterPrimaryLabel, clusterLimit); err != nil { - logger.Errorf(ctx, "Resource quota registration for [%v] failed due to error [%v]", clusterPrimaryLabel, err) - return nil, err - } - } - - return q, nil -} - -// type PluginLoader func(ctx context.Context, iCtx SetupContext) (Plugin, error) -func NewQuboleHiveExecutor(ctx context.Context, cfg *config.Config, quboleClient client.QuboleClient, secretManager core.SecretManager, scope promutils.Scope) (QuboleHiveExecutor, error) { - executionsAutoRefreshCache, err := NewQuboleHiveExecutionsCache(ctx, quboleClient, secretManager, cfg, scope.NewSubScope(hiveTaskType)) - if err != nil { - logger.Errorf(ctx, "Failed to create AutoRefreshCache in QuboleHiveExecutor Setup. Error: %v", err) - return QuboleHiveExecutor{}, err - } - - err = executionsAutoRefreshCache.Start(ctx) - if err != nil { - logger.Errorf(ctx, "Failed to start AutoRefreshCache. Error: %v", err) - } - - return QuboleHiveExecutor{ - id: quboleHiveExecutorID, - cfg: cfg, - metrics: getQuboleHiveExecutorMetrics(scope.NewSubScope("hive")), - quboleClient: quboleClient, - executionsCache: executionsAutoRefreshCache, - }, nil -} - -func init() { - pluginMachinery.PluginRegistry().RegisterCorePlugin( - core.PluginEntry{ - ID: quboleHiveExecutorID, - RegisteredTaskTypes: []core.TaskType{hiveTaskType}, - LoadPlugin: QuboleHiveExecutorLoader, - IsDefault: false, - DefaultForTaskTypes: []core.TaskType{hiveTaskType}, - }) -} diff --git a/go/tasks/plugins/hive/executor_metrics.go b/go/tasks/plugins/hive/executor_metrics.go deleted file mode 100644 index 519a13602..000000000 --- a/go/tasks/plugins/hive/executor_metrics.go +++ /dev/null @@ -1,36 +0,0 @@ -package hive - -import ( - "github.com/lyft/flytestdlib/promutils" - "github.com/lyft/flytestdlib/promutils/labeled" - "github.com/prometheus/client_golang/prometheus" -) - -type QuboleHiveExecutorMetrics struct { - Scope promutils.Scope - ResourceReleased labeled.Counter - ResourceReleaseFailed labeled.Counter - AllocationGranted labeled.Counter - AllocationNotGranted labeled.Counter - ResourceWaitTime prometheus.Summary -} - -var ( - tokenAgeObjectives = map[float64]float64{0.5: 0.05, 0.9: 0.01, 0.99: 0.001, 1.0: 0.0} -) - -func getQuboleHiveExecutorMetrics(scope promutils.Scope) QuboleHiveExecutorMetrics { - return QuboleHiveExecutorMetrics{ - Scope: scope, - ResourceReleased: labeled.NewCounter("resource_release_success", - "Resource allocation token released", scope, labeled.EmitUnlabeledMetric), - ResourceReleaseFailed: labeled.NewCounter("resource_release_failed", - "Error releasing allocation token", scope, labeled.EmitUnlabeledMetric), - AllocationGranted: labeled.NewCounter("allocation_grant_success", - "Allocation request granted", scope, labeled.EmitUnlabeledMetric), - AllocationNotGranted: labeled.NewCounter("allocation_grant_failed", - "Allocation request did not fail but not granted", scope, labeled.EmitUnlabeledMetric), - ResourceWaitTime: scope.MustNewSummaryWithOptions("resource_wait_time", "Duration the execution has been waiting for a resource allocation token", - promutils.SummaryOptions{Objectives: tokenAgeObjectives}), - } -} diff --git a/go/tasks/plugins/hive/plugin.go b/go/tasks/plugins/hive/plugin.go new file mode 100644 index 000000000..e42f02720 --- /dev/null +++ b/go/tasks/plugins/hive/plugin.go @@ -0,0 +1,153 @@ +package hive + +import ( + "context" + "strconv" + + "github.com/lyft/flytestdlib/logger" + + "github.com/lyft/flyteplugins/go/tasks/errors" + + "github.com/lyft/flyteplugins/go/tasks/plugins/hive/client" + "github.com/lyft/flyteplugins/go/tasks/plugins/hive/config" + + pluginMachinery "github.com/lyft/flyteplugins/go/tasks/pluginmachinery" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/remote" +) + +const ( + quboleHiveExecutorID = "qubole-hive-executor" + hiveTaskType = "hive" +) + +type QuboleHivePlugin struct { + client client.QuboleClient + apiKey string + resourceQuotas map[core.ResourceNamespace]int + properties remote.PluginProperties +} + +func (q QuboleHivePlugin) GetPluginProperties() remote.PluginProperties { + return q.properties +} + +func (q QuboleHivePlugin) ResourceRequirements(ctx context.Context, tCtx remote.TaskExecutionContext) ( + namespace core.ResourceNamespace, constraints core.ResourceConstraintsSpec, err error) { + uniqueID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + + clusterPrimaryLabel, err := composeResourceNamespaceWithClusterPrimaryLabel(ctx, tCtx) + if err != nil { + return "", core.ResourceConstraintsSpec{}, errors.Wrapf(errors.ResourceManagerFailure, err, "Error getting query info when requesting allocation token %s", uniqueID) + } + + resourceConstraintsSpec := createResourceConstraintsSpec(ctx, tCtx, clusterPrimaryLabel) + return clusterPrimaryLabel, resourceConstraintsSpec, nil +} + +func (q QuboleHivePlugin) Create(ctx context.Context, tCtx remote.TaskExecutionContext) ( + createdResources remote.ResourceMeta, err error) { + taskName, query, clusterLabelOverride, tags, timeoutSec, err := GetQueryInfo(ctx, tCtx) + if err != nil { + return nil, err + } + + clusterPrimaryLabel := getClusterPrimaryLabel(ctx, tCtx, clusterLabelOverride) + + taskExecutionIdentifier := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID() + commandMetadata := client.CommandMetadata{ + TaskName: taskName, + Domain: taskExecutionIdentifier.GetTaskId().GetDomain(), + Project: taskExecutionIdentifier.GetNodeExecutionId().GetExecutionId().GetProject(), + Labels: tCtx.TaskExecutionMetadata().GetLabels(), + AttemptNumber: taskExecutionIdentifier.GetRetryAttempt(), + MaxAttempts: tCtx.TaskExecutionMetadata().GetMaxAttempts(), + } + + cmdDetails, err := q.client.ExecuteHiveCommand(ctx, query, timeoutSec, + clusterPrimaryLabel, q.apiKey, tags, commandMetadata) + if err != nil { + return nil, err + } + + // If we succeed, then store the command id returned from Qubole, and update our state. Also, add to the + // AutoRefreshCache so we start getting updates. + commandID := strconv.FormatInt(cmdDetails.ID, 10) + logger.Infof(ctx, "Created Qubole ID [%s]", commandID) + + return Resource{ + CommandID: commandID, + URI: cmdDetails.URI.String(), + }, nil +} + +func (q QuboleHivePlugin) Get(ctx context.Context, meta remote.ResourceMeta) ( + newMeta remote.ResourceMeta, err error) { + r := meta.(Resource) + logger.Debugf(ctx, "Retrieving Hive job [%s]", r.CommandID) + + // Get an updated status from Qubole + commandStatus, err := q.client.GetCommandStatus(ctx, r.CommandID, q.apiKey) + if err != nil { + logger.Errorf(ctx, "Error from Qubole command %s. Error: %v", r.CommandID, err) + return nil, err + } + + return Resource{ + CommandStatus: commandStatus, + CommandID: r.CommandID, + URI: r.URI, + }, nil +} + +func (q QuboleHivePlugin) Delete(ctx context.Context, meta remote.ResourceMeta) error { + r := meta.(Resource) + logger.Debugf(ctx, "Killing Hive job [%s]", r.CommandID) + + err := q.client.KillCommand(ctx, r.CommandID, q.apiKey) + if err != nil { + logger.Errorf(ctx, "Error terminating Qubole command [%s]. Error: %v", + r.CommandID, err) + return err + } + + return nil +} + +func (q QuboleHivePlugin) Status(_ context.Context, resource remote.ResourceMeta) ( + phase core.PhaseInfo, err error) { + r := resource.(Resource) + return r.GetPhaseInfo(), nil +} + +func QuboleHivePluginLoader(ctx context.Context, iCtx remote.PluginSetupContext) ( + remote.Plugin, error) { + + cfg := config.GetQuboleConfig() + apiKey, err := iCtx.SecretManager().Get(ctx, cfg.TokenKey) + if err != nil { + return nil, errors.Wrapf(errors.RuntimeFailure, err, "Failed to read token from secrets manager") + } + + return QuboleHivePlugin{ + client: client.NewQuboleClient(cfg), + apiKey: apiKey, + properties: remote.PluginProperties{ + ResourceQuotas: BuildResourceConfig(cfg.ClusterConfigs), + ReadRateLimiter: cfg.ReadRateLimiter, + WriteRateLimiter: cfg.WriteRateLimiter, + Caching: cfg.Caching, + ResourceMeta: Resource{}, + }, + }, nil +} + +func init() { + pluginMachinery.PluginRegistry().RegisterRemotePlugin( + remote.PluginEntry{ + ID: quboleHiveExecutorID, + SupportedTaskTypes: []core.TaskType{hiveTaskType}, + PluginLoader: QuboleHivePluginLoader, + DefaultForTaskTypes: []core.TaskType{hiveTaskType}, + }) +} diff --git a/go/tasks/plugins/hive/resource.go b/go/tasks/plugins/hive/resource.go new file mode 100644 index 000000000..e079dbbdb --- /dev/null +++ b/go/tasks/plugins/hive/resource.go @@ -0,0 +1,57 @@ +package hive + +import ( + "fmt" + "time" + + "github.com/lyft/flyteplugins/go/tasks/plugins/hive/client" + + idlCore "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + + "github.com/lyft/flyteplugins/go/tasks/errors" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" +) + +type Resource struct { + CommandID string + CommandStatus client.QuboleStatus + URI string +} + +func (r Resource) ConstructTaskInfo() *core.TaskInfo { + logs := make([]*idlCore.TaskLog, 0, 1) + t := time.Now() + logs = append(logs) + return &core.TaskInfo{ + Logs: []*idlCore.TaskLog{ + { + Name: fmt.Sprintf("Status: %s [%s]", r.CommandStatus, r.CommandID), + MessageFormat: idlCore.TaskLog_UNKNOWN, + Uri: r.URI, + }, + }, + OccurredAt: &t, + } +} + +func (r Resource) GetPhaseInfo() core.PhaseInfo { + var phaseInfo core.PhaseInfo + t := time.Now() + + switch r.CommandStatus { + case client.QuboleStatusUnknown: + phaseInfo = core.PhaseInfoNotReady(t, core.DefaultPhaseVersion, "Haven't received allocation token") + case client.QuboleStatusWaiting: + phaseInfo = core.PhaseInfoQueued(t, core.DefaultPhaseVersion, "Waiting for Qubole launch") + case client.QuboleStatusRunning: + phaseInfo = core.PhaseInfoRunning(core.DefaultPhaseVersion, r.ConstructTaskInfo()) + case client.QuboleStatusDone: + phaseInfo = core.PhaseInfoSuccess(r.ConstructTaskInfo()) + case client.QuboleStatusCancelled: + fallthrough + case client.QuboleStatusError: + phaseInfo = core.PhaseInfoFailure(errors.DownstreamSystemError, "Query failed", r.ConstructTaskInfo()) + } + + return phaseInfo +} diff --git a/go/tasks/plugins/hive/transformer.go b/go/tasks/plugins/hive/transformer.go new file mode 100644 index 000000000..ea4eaf583 --- /dev/null +++ b/go/tasks/plugins/hive/transformer.go @@ -0,0 +1,163 @@ +package hive + +import ( + "context" + "fmt" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" + "github.com/lyft/flyteplugins/go/tasks/errors" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/remote" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils" + "github.com/lyft/flyteplugins/go/tasks/plugins/hive/config" + "github.com/lyft/flytestdlib/logger" +) + +func BuildResourceConfig(cfg []config.ClusterConfig) map[core.ResourceNamespace]int { + resourceConfig := make(map[core.ResourceNamespace]int, len(cfg)) + + for _, clusterCfg := range cfg { + resourceConfig[core.ResourceNamespace(clusterCfg.PrimaryLabel)] = clusterCfg.Limit + } + + return resourceConfig +} + +func composeResourceNamespaceWithClusterPrimaryLabel(ctx context.Context, tCtx remote.TaskExecutionContext) (core.ResourceNamespace, error) { + _, _, clusterLabelOverride, _, _, err := GetQueryInfo(ctx, tCtx) + if err != nil { + return "", err + } + + clusterPrimaryLabel := getClusterPrimaryLabel(ctx, tCtx, clusterLabelOverride) + return core.ResourceNamespace(clusterPrimaryLabel), nil +} + +func createResourceConstraintsSpec(ctx context.Context, _ remote.TaskExecutionContext, targetClusterPrimaryLabel core.ResourceNamespace) core.ResourceConstraintsSpec { + cfg := config.GetQuboleConfig() + constraintsSpec := core.ResourceConstraintsSpec{ + ProjectScopeResourceConstraint: nil, + NamespaceScopeResourceConstraint: nil, + } + if cfg.ClusterConfigs == nil { + logger.Infof(ctx, "No cluster config is found. Returning an empty resource constraints spec") + return constraintsSpec + } + for _, cluster := range cfg.ClusterConfigs { + if cluster.PrimaryLabel == string(targetClusterPrimaryLabel) { + constraintsSpec.ProjectScopeResourceConstraint = &core.ResourceConstraint{Value: int64(float64(cluster.Limit) * cluster.ProjectScopeQuotaProportionCap)} + constraintsSpec.NamespaceScopeResourceConstraint = &core.ResourceConstraint{Value: int64(float64(cluster.Limit) * cluster.NamespaceScopeQuotaProportionCap)} + break + } + } + logger.Infof(ctx, "Created a resource constraints spec: [%v]", constraintsSpec) + return constraintsSpec +} + +func validateQuboleHiveJob(hiveJob plugins.QuboleHiveJob) error { + if hiveJob.Query == nil { + return errors.Errorf(errors.BadTaskSpecification, + "Query could not be found. Please ensure that you are at least on Flytekit version 0.3.0 or later.") + } + return nil +} + +// This function is the link between the output written by the SDK, and the execution side. It extracts the query +// out of the task template. +func GetQueryInfo(ctx context.Context, tCtx remote.TaskExecutionContext) ( + taskName, query, cluster string, tags []string, timeoutSec uint32, err error) { + + taskTemplate, err := tCtx.TaskReader().Read(ctx) + if err != nil { + return "", "", "", []string{}, 0, err + } + + hiveJob := plugins.QuboleHiveJob{} + err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &hiveJob) + if err != nil { + return "", "", "", []string{}, 0, err + } + + if err := validateQuboleHiveJob(hiveJob); err != nil { + return "", "", "", []string{}, 0, err + } + + query = hiveJob.Query.GetQuery() + cluster = hiveJob.ClusterLabel + timeoutSec = hiveJob.Query.TimeoutSec + taskName = taskTemplate.Id.Name + tags = hiveJob.Tags + tags = append(tags, fmt.Sprintf("ns:%s", tCtx.TaskExecutionMetadata().GetNamespace())) + for k, v := range tCtx.TaskExecutionMetadata().GetLabels() { + tags = append(tags, fmt.Sprintf("%s:%s", k, v)) + } + + logger.Debugf(ctx, "QueryInfo: query: [%v], cluster: [%v], timeoutSec: [%v], tags: [%v]", query, cluster, timeoutSec, tags) + return +} + +func mapLabelToPrimaryLabel(ctx context.Context, quboleCfg *config.Config, label string) (primaryLabel string, found bool) { + primaryLabel = quboleCfg.DefaultClusterLabel + found = false + + if label == "" { + logger.Debugf(ctx, "Input cluster label is an empty string; falling back to using the default primary label [%v]", label, primaryLabel) + return + } + + // Using a linear search because N is small and because of ClusterConfig's struct definition + // which is determined specifically for the readability of the corresponding configmap yaml file + for _, clusterCfg := range quboleCfg.ClusterConfigs { + for _, l := range clusterCfg.Labels { + if label != "" && l == label { + logger.Debugf(ctx, "Found the primary label [%v] for label [%v]", clusterCfg.PrimaryLabel, label) + primaryLabel, found = clusterCfg.PrimaryLabel, true + break + } + } + } + + if !found { + logger.Debugf(ctx, "Cannot find the primary cluster label for label [%v] in configmap; "+ + "falling back to using the default primary label [%v]", label, primaryLabel) + } + + return primaryLabel, found +} + +func mapProjectDomainToDestinationClusterLabel(ctx context.Context, tCtx remote.TaskExecutionContext, quboleCfg *config.Config) (string, bool) { + tExecID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID() + project := tExecID.NodeExecutionId.GetExecutionId().GetProject() + domain := tExecID.NodeExecutionId.GetExecutionId().GetDomain() + logger.Debugf(ctx, "No clusterLabelOverride. Finding the pre-defined cluster label for (project: %v, domain: %v)", project, domain) + // Using a linear search because N is small + for _, m := range quboleCfg.DestinationClusterConfigs { + if project == m.Project && domain == m.Domain { + logger.Debugf(ctx, "Found the pre-defined cluster label [%v] for (project: %v, domain: %v)", m.ClusterLabel, project, domain) + return m.ClusterLabel, true + } + } + + // This function finds the label, not primary label, so in the case where no mapping is found, this function should return an empty string + return "", false +} + +func getClusterPrimaryLabel(ctx context.Context, tCtx remote.TaskExecutionContext, clusterLabelOverride string) string { + cfg := config.GetQuboleConfig() + + // If override is not empty and if it has a mapping, we return the mapped primary label + if clusterLabelOverride != "" { + if primaryLabel, found := mapLabelToPrimaryLabel(ctx, cfg, clusterLabelOverride); found { + return primaryLabel + } + } + + // If override is empty or if the override does not have a mapping, we return the primary label mapped using (project, domain) + if clusterLabel, found := mapProjectDomainToDestinationClusterLabel(ctx, tCtx, cfg); found { + primaryLabel, _ := mapLabelToPrimaryLabel(ctx, cfg, clusterLabel) + return primaryLabel + } + + // Else we return the default primary label + return cfg.DefaultClusterLabel +} diff --git a/go/tasks/plugins/hive/transformer_test.go b/go/tasks/plugins/hive/transformer_test.go new file mode 100644 index 000000000..d6f9498aa --- /dev/null +++ b/go/tasks/plugins/hive/transformer_test.go @@ -0,0 +1,161 @@ +package hive + +import ( + "context" + "testing" + + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils/labeled" + + idlCore "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core/mocks" + pluginsCoreMocks "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "github.com/lyft/flyteplugins/go/tasks/plugins/hive/config" +) + +const ( + DefaultClusterPrimaryLabel = "default" +) + +func init() { + labeled.SetMetricKeys(contextutils.NamespaceKey) +} + +func TestGetQueryInfo(t *testing.T) { + ctx := context.Background() + + taskTemplate := GetSingleHiveQueryTaskTemplate() + mockTaskReader := &mocks.TaskReader{} + mockTaskReader.On("Read", mock.Anything).Return(&taskTemplate, nil) + + mockTaskExecutionContext := mocks.TaskExecutionContext{} + mockTaskExecutionContext.On("TaskReader").Return(mockTaskReader) + + taskMetadata := &pluginsCoreMocks.TaskExecutionMetadata{} + taskMetadata.On("GetNamespace").Return("myproject-staging") + taskMetadata.On("GetLabels").Return(map[string]string{"sample": "label"}) + mockTaskExecutionContext.On("TaskExecutionMetadata").Return(taskMetadata) + + taskName, query, cluster, tags, timeout, err := GetQueryInfo(ctx, &mockTaskExecutionContext) + assert.NoError(t, err) + assert.Equal(t, "sample_hive_task_test_name", taskName) + assert.Equal(t, "select 'one'", query) + assert.Equal(t, "default", cluster) + assert.Equal(t, []string{"flyte_plugin_test", "ns:myproject-staging", "sample:label"}, tags) + assert.Equal(t, 500, int(timeout)) +} + +func TestValidateQuboleHiveJob(t *testing.T) { + hiveJob := plugins.QuboleHiveJob{ + ClusterLabel: "default", + Tags: []string{"flyte_plugin_test", "sample:label"}, + Query: nil, + } + err := validateQuboleHiveJob(hiveJob) + assert.Error(t, err) +} + +func createMockQuboleCfg() *config.Config { + return &config.Config{ + DefaultClusterLabel: "default", + ClusterConfigs: []config.ClusterConfig{ + {PrimaryLabel: "primary A", Labels: []string{"primary A", "A", "label A", "A-prod"}, Limit: 10}, + {PrimaryLabel: "primary B", Labels: []string{"B"}, Limit: 10}, + {PrimaryLabel: "primary C", Labels: []string{"C-prod"}, Limit: 1}, + }, + DestinationClusterConfigs: []config.DestinationClusterConfig{ + {Project: "project A", Domain: "domain X", ClusterLabel: "A-prod"}, + {Project: "project A", Domain: "domain Y", ClusterLabel: "A"}, + {Project: "project A", Domain: "domain Z", ClusterLabel: "B"}, + {Project: "project C", Domain: "domain X", ClusterLabel: "C-prod"}, + }, + } +} + +func Test_mapLabelToPrimaryLabel(t *testing.T) { + ctx := context.TODO() + mockQuboleCfg := createMockQuboleCfg() + + type args struct { + ctx context.Context + quboleCfg *config.Config + label string + } + tests := []struct { + name string + args args + want string + wantFound bool + }{ + {name: "Label has a mapping", args: args{ctx: ctx, quboleCfg: mockQuboleCfg, label: "A-prod"}, want: "primary A", wantFound: true}, + {name: "Label has a typo", args: args{ctx: ctx, quboleCfg: mockQuboleCfg, label: "a"}, want: DefaultClusterPrimaryLabel, wantFound: false}, + {name: "Label has a mapping 2", args: args{ctx: ctx, quboleCfg: mockQuboleCfg, label: "C-prod"}, want: "primary C", wantFound: true}, + {name: "Label has a typo 2", args: args{ctx: ctx, quboleCfg: mockQuboleCfg, label: "C_prod"}, want: DefaultClusterPrimaryLabel, wantFound: false}, + {name: "Label has a mapping 3", args: args{ctx: ctx, quboleCfg: mockQuboleCfg, label: "primary A"}, want: "primary A", wantFound: true}, + {name: "Label has no mapping", args: args{ctx: ctx, quboleCfg: mockQuboleCfg, label: "D"}, want: DefaultClusterPrimaryLabel, wantFound: false}, + {name: "Label is an empty string", args: args{ctx: ctx, quboleCfg: mockQuboleCfg, label: ""}, want: DefaultClusterPrimaryLabel, wantFound: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got, found := mapLabelToPrimaryLabel(tt.args.ctx, tt.args.quboleCfg, tt.args.label); got != tt.want || found != tt.wantFound { + t.Errorf("mapLabelToPrimaryLabel() = (%v, %v), want (%v, %v)", got, found, tt.want, tt.wantFound) + } + }) + } +} + +func createMockTaskExecutionContextWithProjectDomain(project string, domain string) *mocks.TaskExecutionContext { + mockTaskExecutionContext := mocks.TaskExecutionContext{} + taskExecID := &pluginsCoreMocks.TaskExecutionID{} + taskExecID.OnGetID().Return(idlCore.TaskExecutionIdentifier{ + NodeExecutionId: &idlCore.NodeExecutionIdentifier{ExecutionId: &idlCore.WorkflowExecutionIdentifier{ + Project: project, + Domain: domain, + Name: "random name", + }}, + }) + + taskMetadata := &pluginsCoreMocks.TaskExecutionMetadata{} + taskMetadata.OnGetTaskExecutionID().Return(taskExecID) + mockTaskExecutionContext.On("TaskExecutionMetadata").Return(taskMetadata) + return &mockTaskExecutionContext +} + +func Test_getClusterPrimaryLabel(t *testing.T) { + ctx := context.TODO() + err := config.SetQuboleConfig(createMockQuboleCfg()) + assert.Nil(t, err) + + type args struct { + ctx context.Context + tCtx core.TaskExecutionContext + clusterLabelOverride string + } + tests := []struct { + name string + args args + want string + }{ + {name: "Override is not empty + override has NO existing mapping + project-domain has an existing mapping", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project A", "domain Z"), clusterLabelOverride: "AAAA"}, want: "primary B"}, + {name: "Override is not empty + override has NO existing mapping + project-domain has NO existing mapping", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project A", "domain blah"), clusterLabelOverride: "blh"}, want: DefaultClusterPrimaryLabel}, + {name: "Override is not empty + override has an existing mapping + project-domain has NO existing mapping", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project blah", "domain blah"), clusterLabelOverride: "C-prod"}, want: "primary C"}, + {name: "Override is not empty + override has an existing mapping + project-domain has an existing mapping", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project A", "domain A"), clusterLabelOverride: "C-prod"}, want: "primary C"}, + {name: "Override is empty + project-domain has an existing mapping", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project A", "domain X"), clusterLabelOverride: ""}, want: "primary A"}, + {name: "Override is empty + project-domain has an existing mapping2", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project A", "domain Z"), clusterLabelOverride: ""}, want: "primary B"}, + {name: "Override is empty + project-domain has NO existing mapping", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project A", "domain blah"), clusterLabelOverride: ""}, want: DefaultClusterPrimaryLabel}, + {name: "Override is empty + project-domain has NO existing mapping2", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project blah", "domain X"), clusterLabelOverride: ""}, want: DefaultClusterPrimaryLabel}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getClusterPrimaryLabel(tt.args.ctx, tt.args.tCtx, tt.args.clusterLabelOverride); got != tt.want { + t.Errorf("getClusterPrimaryLabel() = %v, want %v", got, tt.want) + } + }) + } +}