From e0ac32e0a3452103dae4d1168552ecf737f44f2a Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Fri, 3 Jan 2020 10:34:05 -0800 Subject: [PATCH] Consolidate and reorganize project, domain & workflow resource matching (#49) --- Gopkg.lock | 14 +- boilerplate/lyft/golang_test_targets/Makefile | 4 +- pkg/clusterresource/controller.go | 23 +- pkg/clusterresource/controller_test.go | 30 ++- pkg/manager/impl/execution_manager.go | 4 +- pkg/manager/impl/execution_manager_test.go | 14 +- pkg/manager/impl/executions/queues.go | 233 ++++-------------- pkg/manager/impl/executions/queues_test.go | 172 ++++++++----- .../impl/project_attributes_manager.go | 43 ++++ .../impl/project_attributes_manager_test.go | 37 +++ ...o => project_domain_attributes_manager.go} | 25 +- .../project_domain_attributes_manager_test.go | 39 +++ pkg/manager/impl/shared/constants.go | 2 + pkg/manager/impl/testutils/attributes.go | 13 + .../impl/validation/attributes_validator.go | 74 ++++++ .../validation/attributes_validator_test.go | 166 +++++++++++++ .../validation/project_domain_validator.go | 20 -- .../project_domain_validator_test.go | 33 --- pkg/manager/impl/validation/task_validator.go | 51 +++- .../impl/validation/task_validator_test.go | 48 +++- .../impl/workflow_attributes_manager.go | 43 ++++ .../impl/workflow_attributes_manager_test.go | 41 +++ pkg/manager/interfaces/project_attributes.go | 13 + ...domain.go => project_domain_attributes.go} | 4 +- pkg/manager/interfaces/workflow_attributes.go | 13 + pkg/manager/mocks/project_domain.go | 2 +- pkg/repositories/config/migrations.go | 37 ++- pkg/repositories/factory.go | 4 +- ...ain_repo.go => project_attributes_repo.go} | 34 +-- .../gormimpl/project_attributes_repo_test.go | 57 +++++ .../project_domain_attributes_repo.go | 73 ++++++ ...=> project_domain_attributes_repo_test.go} | 26 +- .../gormimpl/workflow_attributes_repo.go | 75 ++++++ .../gormimpl/workflow_attributes_repo_test.go | 62 +++++ .../interfaces/project_attributes_repo.go | 14 ++ .../project_domain_attributes_repo.go | 14 ++ .../interfaces/project_domain_repo.go | 14 -- .../interfaces/workflow_attributes_repo.go | 14 ++ .../mocks/project_attributes_repo.go | 35 +++ .../mocks/project_domain_attributes_repo.go | 35 +++ pkg/repositories/mocks/project_domain_repo.go | 35 --- pkg/repositories/mocks/repository.go | 52 ++-- pkg/repositories/mocks/workflow_attributes.go | 37 +++ pkg/repositories/models/project_attributes.go | 10 + pkg/repositories/models/project_domain.go | 10 - .../models/project_domain_attributes.go | 11 + .../models/workflow_attributes.go | 12 + pkg/repositories/postgres_repo.go | 52 ++-- .../transformers/project_attributes.go | 35 +++ .../transformers/project_attributes_test.go | 63 +++++ .../transformers/project_domain.go | 32 --- .../transformers/project_domain_attributes.go | 37 +++ .../project_domain_attributes_test.go | 68 +++++ .../transformers/project_domain_test.go | 57 ----- .../transformers/workflow_attributes.go | 39 +++ .../transformers/workflow_attributes_test.go | 71 ++++++ pkg/resourcematching/overrides.go | 77 ++++++ pkg/resourcematching/overrides_test.go | 164 ++++++++++++ pkg/rpc/adminservice/attributes.go | 64 +++++ pkg/rpc/adminservice/base.go | 30 ++- pkg/rpc/adminservice/metrics.go | 34 ++- pkg/rpc/adminservice/project_domain.go | 28 --- pkg/rpc/adminservice/tests/util.go | 18 +- pkg/runtime/interfaces/queue_configuration.go | 10 +- tests/attributes_test.go | 121 +++++++++ tests/project_domain_test.go | 52 ---- 66 files changed, 2163 insertions(+), 711 deletions(-) create mode 100644 pkg/manager/impl/project_attributes_manager.go create mode 100644 pkg/manager/impl/project_attributes_manager_test.go rename pkg/manager/impl/{project_domain_manager.go => project_domain_attributes_manager.go} (53%) create mode 100644 pkg/manager/impl/project_domain_attributes_manager_test.go create mode 100644 pkg/manager/impl/testutils/attributes.go create mode 100644 pkg/manager/impl/validation/attributes_validator.go create mode 100644 pkg/manager/impl/validation/attributes_validator_test.go delete mode 100644 pkg/manager/impl/validation/project_domain_validator.go delete mode 100644 pkg/manager/impl/validation/project_domain_validator_test.go create mode 100644 pkg/manager/impl/workflow_attributes_manager.go create mode 100644 pkg/manager/impl/workflow_attributes_manager_test.go create mode 100644 pkg/manager/interfaces/project_attributes.go rename pkg/manager/interfaces/{project_domain.go => project_domain_attributes.go} (58%) create mode 100644 pkg/manager/interfaces/workflow_attributes.go rename pkg/repositories/gormimpl/{project_domain_repo.go => project_attributes_repo.go} (53%) create mode 100644 pkg/repositories/gormimpl/project_attributes_repo_test.go create mode 100644 pkg/repositories/gormimpl/project_domain_attributes_repo.go rename pkg/repositories/gormimpl/{project_domain_repo_test.go => project_domain_attributes_repo_test.go} (53%) create mode 100644 pkg/repositories/gormimpl/workflow_attributes_repo.go create mode 100644 pkg/repositories/gormimpl/workflow_attributes_repo_test.go create mode 100644 pkg/repositories/interfaces/project_attributes_repo.go create mode 100644 pkg/repositories/interfaces/project_domain_attributes_repo.go delete mode 100644 pkg/repositories/interfaces/project_domain_repo.go create mode 100644 pkg/repositories/interfaces/workflow_attributes_repo.go create mode 100644 pkg/repositories/mocks/project_attributes_repo.go create mode 100644 pkg/repositories/mocks/project_domain_attributes_repo.go delete mode 100644 pkg/repositories/mocks/project_domain_repo.go create mode 100644 pkg/repositories/mocks/workflow_attributes.go create mode 100644 pkg/repositories/models/project_attributes.go delete mode 100644 pkg/repositories/models/project_domain.go create mode 100644 pkg/repositories/models/project_domain_attributes.go create mode 100644 pkg/repositories/models/workflow_attributes.go create mode 100644 pkg/repositories/transformers/project_attributes.go create mode 100644 pkg/repositories/transformers/project_attributes_test.go delete mode 100644 pkg/repositories/transformers/project_domain.go create mode 100644 pkg/repositories/transformers/project_domain_attributes.go create mode 100644 pkg/repositories/transformers/project_domain_attributes_test.go delete mode 100644 pkg/repositories/transformers/project_domain_test.go create mode 100644 pkg/repositories/transformers/workflow_attributes.go create mode 100644 pkg/repositories/transformers/workflow_attributes_test.go create mode 100644 pkg/resourcematching/overrides.go create mode 100644 pkg/resourcematching/overrides_test.go create mode 100644 pkg/rpc/adminservice/attributes.go delete mode 100644 pkg/rpc/adminservice/project_domain.go create mode 100644 tests/attributes_test.go delete mode 100644 tests/project_domain_test.go diff --git a/Gopkg.lock b/Gopkg.lock index cd81c90931..af1aa6e938 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -453,7 +453,7 @@ version = "v1.2.0" [[projects]] - digest = "1:adf75c64d9fd1b82df8924fec495eb015ab53e5b2db6db1c2b1b8b1c27145848" + digest = "1:b1ab65fa43650fb503aa710091eb9d13c6c5d7f0cf14684039827b457a1f2ac9" name = "github.com/lyft/flyteidl" packages = [ "clients/go/admin", @@ -466,9 +466,9 @@ "gen/pb-go/flyteidl/service", ] pruneopts = "UT" - revision = "4e175db2a7338f787fc7e939417718e59af639a4" + revision = "6deb3c002d84b2012f02cdecc32902f475deda83" source = "https://github.com/lyft/flyteidl" - version = "v0.16.2" + version = "v0.16.4" [[projects]] digest = "1:938998e14bd5e42c54f3b640a41d869eb79029ad7c623fa47c604b8480c781fc" @@ -814,10 +814,11 @@ [[projects]] branch = "master" - digest = "1:31e33f76456ccf54819ab4a646cf01271d1a99d7712ab84bf1a9e7b61cd2031b" + digest = "1:dbd0f5d2c3d68686f4b80fcc71c632d4f9b5cf622f17c64affecfc90f7650639" name = "golang.org/x/oauth2" packages = [ ".", + "clientcredentials", "google", "internal", "jws", @@ -1183,6 +1184,7 @@ "github.com/aws/aws-sdk-go/service/ses/sesiface", "github.com/benbjohnson/clock", "github.com/coreos/go-oidc", + "github.com/evanphx/json-patch", "github.com/gogo/protobuf/proto", "github.com/golang/glog", "github.com/golang/protobuf/jsonpb", @@ -1197,6 +1199,7 @@ "github.com/graymeta/stow/s3", "github.com/grpc-ecosystem/go-grpc-middleware", "github.com/grpc-ecosystem/go-grpc-middleware/auth", + "github.com/grpc-ecosystem/go-grpc-middleware/util/metautils", "github.com/grpc-ecosystem/go-grpc-prometheus", "github.com/grpc-ecosystem/grpc-gateway/runtime", "github.com/grpc/grpc-go/credentials/oauth", @@ -1246,7 +1249,10 @@ "k8s.io/apimachinery/pkg/api/errors", "k8s.io/apimachinery/pkg/api/resource", "k8s.io/apimachinery/pkg/apis/meta/v1", + "k8s.io/apimachinery/pkg/runtime", "k8s.io/apimachinery/pkg/runtime/schema", + "k8s.io/apimachinery/pkg/types", + "k8s.io/apimachinery/pkg/util/json", "k8s.io/apimachinery/pkg/util/validation", "k8s.io/apimachinery/pkg/util/wait", "k8s.io/client-go/kubernetes/scheme", diff --git a/boilerplate/lyft/golang_test_targets/Makefile b/boilerplate/lyft/golang_test_targets/Makefile index ff844da4ae..948af0c633 100644 --- a/boilerplate/lyft/golang_test_targets/Makefile +++ b/boilerplate/lyft/golang_test_targets/Makefile @@ -10,8 +10,8 @@ lint: #lints the package for common code smells # However, that call seem to have some effects (e.g. https://github.com/golang/go/issues/29452) which, for some # reason, allows the subsequent calls to succeed. # TODO: Evaluate whether this is still a problem after moving admin dependency system to go modules. - GO111MODULE=off GL_DEBUG=linters_output,loader,env golangci-lint run --exclude deprecated -v || true - GO111MODULE=off GL_DEBUG=linters_output,loader,env golangci-lint run --deadline=5m --exclude deprecated -v + GO111MODULE=off golangci-lint run --exclude deprecated -v || true + GO111MODULE=off golangci-lint run --deadline=5m --exclude deprecated -v # If code is failing goimports linter, this will fix. # skips 'vendor' diff --git a/pkg/clusterresource/controller.go b/pkg/clusterresource/controller.go index 228869d2c9..35a8e94a7d 100644 --- a/pkg/clusterresource/controller.go +++ b/pkg/clusterresource/controller.go @@ -11,7 +11,7 @@ import ( "strings" "time" - "github.com/lyft/flyteadmin/pkg/repositories/transformers" + "github.com/lyft/flyteadmin/pkg/resourcematching" "github.com/lyft/flyteadmin/pkg/executioncluster/interfaces" @@ -174,21 +174,18 @@ func (c *controller) getCustomTemplateValues( customTemplateValues[key] = value } collectedErrs := make([]error, 0) - // All project-domain defaults saved in the database take precedence over the domain-specific defaults. - projectDomainModel, err := c.db.ProjectDomainRepo().Get(ctx, project, domain) - if err != nil { - if err.(errors.FlyteAdminError).Code() != codes.NotFound { - // Not found is fine because not every project-domain combination will have specific custom resource - // attributes. - collectedErrs = append(collectedErrs, err) - } - } - projectDomain, err := transformers.FromProjectDomainModel(projectDomainModel) + // All override values saved in the database take precedence over the domain-specific defaults. + attributes, err := resourcematching.GetOverrideValuesToApply(ctx, resourcematching.GetOverrideValuesInput{ + Db: c.db, + Project: project, + Domain: domain, + Resource: admin.MatchableResource_CLUSTER_RESOURCE, + }) if err != nil { collectedErrs = append(collectedErrs, err) } - if len(projectDomain.Attributes) > 0 { - for templateKey, templateValue := range projectDomain.Attributes { + if attributes != nil && attributes.GetClusterResourceAttributes() != nil { + for templateKey, templateValue := range attributes.GetClusterResourceAttributes().Attributes { customTemplateValues[fmt.Sprintf(templateVariableFormat, templateKey)] = templateValue } } diff --git a/pkg/clusterresource/controller_test.go b/pkg/clusterresource/controller_test.go index cfd61176da..c91fc5d4ca 100644 --- a/pkg/clusterresource/controller_test.go +++ b/pkg/clusterresource/controller_test.go @@ -7,9 +7,6 @@ import ( "testing" "time" - "github.com/lyft/flyteadmin/pkg/errors" - "google.golang.org/grpc/codes" - "github.com/lyft/flyteadmin/pkg/repositories/transformers" "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" @@ -154,15 +151,20 @@ func TestGetCustomTemplateValues(t *testing.T) { projectDomainAttributes := admin.ProjectDomainAttributes{ Project: "project-foo", Domain: "domain-bar", - Attributes: map[string]string{ - "var1": "val1", - "var2": "val2", + MatchingAttributes: &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_ClusterResourceAttributes{ClusterResourceAttributes: &admin.ClusterResourceAttributes{ + Attributes: map[string]string{ + "var1": "val1", + "var2": "val2", + }, + }, + }, }, } - projectDomainModel, err := transformers.ToProjectDomainModel(projectDomainAttributes) + projectDomainModel, err := transformers.ToProjectDomainAttributesModel(projectDomainAttributes, admin.MatchableResource_CLUSTER_RESOURCE) assert.Nil(t, err) - mockRepository.ProjectDomainRepo().(*repositoryMocks.MockProjectDomainRepo).GetFunction = func( - ctx context.Context, project, domain string) (models.ProjectDomain, error) { + mockRepository.ProjectDomainAttributesRepo().(*repositoryMocks.MockProjectDomainAttributesRepo).GetFunction = func( + ctx context.Context, project, domain, resource string) (models.ProjectDomainAttributes, error) { assert.Equal(t, "project-foo", project) assert.Equal(t, "domain-bar", domain) return projectDomainModel, nil @@ -188,10 +190,6 @@ func TestGetCustomTemplateValues(t *testing.T) { func TestGetCustomTemplateValues_NothingToOverride(t *testing.T) { mockRepository := repositoryMocks.NewMockRepository() - mockRepository.ProjectDomainRepo().(*repositoryMocks.MockProjectDomainRepo).GetFunction = func( - ctx context.Context, project, domain string) (models.ProjectDomain, error) { - return models.ProjectDomain{}, errors.NewFlyteAdminError(codes.NotFound, "not found") - } testController := controller{ db: mockRepository, } @@ -209,9 +207,9 @@ func TestGetCustomTemplateValues_NothingToOverride(t *testing.T) { func TestGetCustomTemplateValues_InvalidDBModel(t *testing.T) { mockRepository := repositoryMocks.NewMockRepository() - mockRepository.ProjectDomainRepo().(*repositoryMocks.MockProjectDomainRepo).GetFunction = func( - ctx context.Context, project, domain string) (models.ProjectDomain, error) { - return models.ProjectDomain{ + mockRepository.ProjectDomainAttributesRepo().(*repositoryMocks.MockProjectDomainAttributesRepo).GetFunction = func( + ctx context.Context, project, domain, resource string) (models.ProjectDomainAttributes, error) { + return models.ProjectDomainAttributes{ Attributes: []byte("i'm invalid"), }, nil } diff --git a/pkg/manager/impl/execution_manager.go b/pkg/manager/impl/execution_manager.go index 528296cb84..9a3e73d318 100644 --- a/pkg/manager/impl/execution_manager.go +++ b/pkg/manager/impl/execution_manager.go @@ -246,7 +246,7 @@ func (m *ExecutionManager) launchExecutionAndPrepareModel( // Dynamically assign task resource defaults. for _, task := range workflow.Closure.CompiledWorkflow.Tasks { - validation.SetDefaults(ctx, m.config.TaskResourceConfiguration(), task) + validation.SetDefaults(ctx, m.config.TaskResourceConfiguration(), task, m.db, name) } // Dynamically assign execution queues. @@ -926,7 +926,7 @@ func NewExecutionManager( userScope promutils.Scope, publisher notificationInterfaces.Publisher, urlData dataInterfaces.RemoteURLInterface) interfaces.ExecutionInterface { - queueAllocator := executions.NewQueueAllocator(config) + queueAllocator := executions.NewQueueAllocator(config, db) systemMetrics := newExecutionSystemMetrics(systemScope) userMetrics := executionUserMetrics{ diff --git a/pkg/manager/impl/execution_manager_test.go b/pkg/manager/impl/execution_manager_test.go index d21c12278a..314c49cca0 100644 --- a/pkg/manager/impl/execution_manager_test.go +++ b/pkg/manager/impl/execution_manager_test.go @@ -337,10 +337,8 @@ func TestCreateExecution_TaggedQueue(t *testing.T) { }, }, []runtimeInterfaces.WorkflowConfig{ { - Project: "project", - Domain: "domain", - WorkflowName: "name", - Tags: []string{"tag"}, + Domain: "domain", + Tags: []string{"tag"}, }, }), nil, nil, nil, nil) @@ -1551,7 +1549,7 @@ func TestListExecutions_TransformerError(t *testing.T) { func TestExecutionManager_PublishNotifications(t *testing.T) { repository := repositoryMocks.NewMockRepository() - queue := executions.NewQueueAllocator(getMockExecutionsConfigProvider()) + queue := executions.NewQueueAllocator(getMockExecutionsConfigProvider(), repository) mockApplicationConfig := runtimeMocks.MockApplicationProvider{} mockApplicationConfig.SetNotificationsConfig(runtimeInterfaces.NotificationsConfig{ @@ -1647,7 +1645,7 @@ func TestExecutionManager_PublishNotifications(t *testing.T) { func TestExecutionManager_PublishNotificationsTransformError(t *testing.T) { repository := repositoryMocks.NewMockRepository() - queue := executions.NewQueueAllocator(getMockExecutionsConfigProvider()) + queue := executions.NewQueueAllocator(getMockExecutionsConfigProvider(), repository) var execManager = &ExecutionManager{ db: repository, config: getMockExecutionsConfigProvider(), @@ -1688,7 +1686,7 @@ func TestExecutionManager_PublishNotificationsTransformError(t *testing.T) { func TestExecutionManager_TestExecutionManager_PublishNotificationsTransformError(t *testing.T) { repository := repositoryMocks.NewMockRepository() - queue := executions.NewQueueAllocator(getMockExecutionsConfigProvider()) + queue := executions.NewQueueAllocator(getMockExecutionsConfigProvider(), repository) publishFunc := func(ctx context.Context, key string, msg proto.Message) error { return errors.New("error publishing message") } @@ -1759,7 +1757,7 @@ func TestExecutionManager_TestExecutionManager_PublishNotificationsTransformErro func TestExecutionManager_PublishNotificationsNoPhaseMatch(t *testing.T) { repository := repositoryMocks.NewMockRepository() - queue := executions.NewQueueAllocator(getMockExecutionsConfigProvider()) + queue := executions.NewQueueAllocator(getMockExecutionsConfigProvider(), repository) var myExecManager = &ExecutionManager{ db: repository, diff --git a/pkg/manager/impl/executions/queues.go b/pkg/manager/impl/executions/queues.go index ea1e553023..aaa6c119be 100644 --- a/pkg/manager/impl/executions/queues.go +++ b/pkg/manager/impl/executions/queues.go @@ -2,6 +2,11 @@ package executions import ( "context" + "math/rand" + + "github.com/lyft/flyteadmin/pkg/repositories" + "github.com/lyft/flyteadmin/pkg/resourcematching" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" "github.com/lyft/flytestdlib/logger" @@ -10,10 +15,6 @@ import ( runtimeInterfaces "github.com/lyft/flyteadmin/pkg/runtime/interfaces" ) -type project = string -type domain = string -type workflowName = string - type tag = string type singleQueueConfiguration struct { @@ -21,55 +22,18 @@ type singleQueueConfiguration struct { DynamicQueue string } -type queueConfigSet = map[singleQueueConfiguration]bool - type queues = []singleQueueConfiguration type queueConfig = map[tag]queues -// Catch-all queues for a project -type defaultProjectQueueAssignment = map[project]singleQueueConfiguration - -/** -Catch-all queues for a project + domain combo -project: { - domain: queue -} -*/ -type defaultProjectDomainQueueAssignment = map[project]map[domain]singleQueueConfiguration - -/** -Expanded workflowConfig structure: - project: { - domain: { - workflow: queue - } - } -*/ -// Stores an execution queue (when it exists) that matches all tags specified by a workflow config -type workflowQueueAssignment = map[project]map[domain]map[workflowName]singleQueueConfiguration - type QueueAllocator interface { GetQueue(ctx context.Context, identifier core.Identifier) singleQueueConfiguration } type queueAllocatorImpl struct { - queueConfigMap queueConfig - defaultQueue singleQueueConfiguration - defaultProjectQueueAssignmentMap defaultProjectQueueAssignment - defaultProjectDomainQueueAssignmentMap defaultProjectDomainQueueAssignment - workflowQueueAssignmentMap workflowQueueAssignment - config runtimeInterfaces.Configuration -} - -// Returns an arbitrary map entry's key from the input map. Used when a workflow can be run on multiple queues. -func getAnyMapKey(input queueConfigSet) singleQueueConfiguration { - for key := range input { - return key - } - // Nothing can be returned. - logger.Error(context.Background(), "can't find any map key for empty map") - return singleQueueConfiguration{} + queueConfigMap queueConfig + config runtimeInterfaces.Configuration + db repositories.RepositoryInterface } func (q *queueAllocatorImpl) refreshExecutionQueues(executionQueues []runtimeInterfaces.ExecutionQueue) { @@ -90,159 +54,62 @@ func (q *queueAllocatorImpl) refreshExecutionQueues(executionQueues []runtimeInt q.queueConfigMap = queueConfigMap } -func (q *queueAllocatorImpl) findQueueCandidates( - config runtimeInterfaces.WorkflowConfig) queueConfigSet { - // go through and find queues that match *all* specified tags - queueCandidates := make(queueConfigSet) - for _, queue := range q.queueConfigMap[config.Tags[0]] { - queueCandidates[queue] = true - } - for i := 1; i < len(config.Tags); i++ { - filteredQueueCandidates := make(queueConfigSet) - for _, queue := range q.queueConfigMap[config.Tags[i]] { - if ok := queueCandidates[queue]; ok { - filteredQueueCandidates[queue] = true - } - } - if len(filteredQueueCandidates) == 0 { - break - } - queueCandidates = filteredQueueCandidates - } - return queueCandidates -} - -func (q *queueAllocatorImpl) refreshWorkflowQueueMap(workflowConfigs []runtimeInterfaces.WorkflowConfig) { - logger.Debug(context.Background(), "refreshing workflow configs") - var workflowQueueMap = make(workflowQueueAssignment) - var projectQueueMap = make(defaultProjectQueueAssignment) - var projectDomainQueueMap = make(defaultProjectDomainQueueAssignment) - for _, config := range workflowConfigs { - var queue singleQueueConfiguration - // go through and find queues that match *all* specified tags - queueCandidates := q.findQueueCandidates(config) - - if len(queueCandidates) > 0 { - queue = getAnyMapKey(queueCandidates) - } - - if config.Project == "" { - // This is a default queue assignment - q.defaultQueue = queue - continue - } - - // Now assign the queue to the most-specific configuration that is possible. - projectSubMap, ok := workflowQueueMap[config.Project] - if !ok { - projectSubMap = make(map[domain]map[workflowName]singleQueueConfiguration) - } - // This queue applies to *all* workflows in this project - if config.Domain == "" { - projectQueueMap[config.Project] = queue - continue - } - - defaultProjectDomainMap, ok := projectDomainQueueMap[config.Project] - if !ok { - defaultProjectDomainMap = make(map[domain]singleQueueConfiguration) - projectDomainQueueMap[config.Project] = defaultProjectDomainMap - } - - // This queue applies to *all* workflows in this project + domain combo - if config.WorkflowName == "" { - defaultProjectDomainMap[config.Domain] = queue - continue - } - - // This queue applies to individual workflows with this project + domain + workflowName combo - domainSubMap, ok := projectSubMap[config.Domain] - if !ok { - domainSubMap = make(map[workflowName]singleQueueConfiguration) - } - - domainSubMap[config.WorkflowName] = queue - projectSubMap[config.Domain] = domainSubMap - workflowQueueMap[config.Project] = projectSubMap - } - q.defaultProjectQueueAssignmentMap = projectQueueMap - q.defaultProjectDomainQueueAssignmentMap = projectDomainQueueMap - q.workflowQueueAssignmentMap = workflowQueueMap -} - -// Returns a queue specifically matching identifier project, domain, and name -// Barring a match for that, a queue matching a combination of project + domain will be returned. -// And if there is no existing match for that, a queue matching the project will be returned if it exists. -func (q *queueAllocatorImpl) getQueueForIdentifier(identifier core.Identifier) *singleQueueConfiguration { - projectSubMap, ok := q.workflowQueueAssignmentMap[identifier.Project] - if !ok { - return nil - } - domainSubMap, ok := projectSubMap[identifier.Domain] - if !ok { - return nil - } - queue, ok := domainSubMap[identifier.Name] - if !ok { - return nil - } - return &queue -} - -func (q *queueAllocatorImpl) getQueueForProjectAndDomain(identifier core.Identifier) *singleQueueConfiguration { - domainSubMap, ok := q.defaultProjectDomainQueueAssignmentMap[identifier.Project] - if !ok { - return nil - } - defaultDomainQueue, ok := domainSubMap[identifier.Domain] - if !ok { - return nil - } - return &defaultDomainQueue -} - -func (q *queueAllocatorImpl) getQueueForProject(identifier core.Identifier) *singleQueueConfiguration { - queue, ok := q.defaultProjectQueueAssignmentMap[identifier.Project] - if !ok { - return nil - } - return &queue -} - func (q *queueAllocatorImpl) GetQueue(ctx context.Context, identifier core.Identifier) singleQueueConfiguration { // NOTE: If refreshing the execution queues & workflow configs on every call to GetQueue becomes too slow we should // investigate caching the computed queue assignments. executionQueues := q.config.QueueConfiguration().GetExecutionQueues() q.refreshExecutionQueues(executionQueues) - workflowConfigs := q.config.QueueConfiguration().GetWorkflowConfigs() - q.refreshWorkflowQueueMap(workflowConfigs) - - logger.Debugf(ctx, - "Evaluating execution queue for [%+v] with available queues [%+v] and available workflow configs [%+v]", - identifier, executionQueues, workflowConfigs) + attributes, err := resourcematching.GetOverrideValuesToApply(ctx, resourcematching.GetOverrideValuesInput{ + Db: q.db, + Project: identifier.Project, + Domain: identifier.Domain, + Workflow: identifier.Name, + Resource: admin.MatchableResource_EXECUTION_QUEUE, + }) + if err != nil { + logger.Warningf(ctx, "Failed to fetch override values when assigning execution queue for [%+v] with err: %v", + identifier, err) + } - queue := q.getQueueForIdentifier(identifier) - if queue != nil { - logger.Debugf(ctx, "Found queue for identifier [%+v]: %v", identifier, queue) - return *queue + if attributes != nil && attributes.GetExecutionQueueAttributes() != nil { + for _, tag := range attributes.GetExecutionQueueAttributes().Tags { + matches, ok := q.queueConfigMap[tag] + if !ok { + continue + } + return matches[rand.Intn(len(matches))] + } } - queue = q.getQueueForProjectAndDomain(identifier) - if queue != nil { - logger.Debugf(ctx, "Found queue for project+domain [%s/%s]: %v", identifier.Project, identifier.Domain, queue) - return *queue + var tags []string + var defaultTags []string + // If we've made it this far, check to see if a domain-specific default workflow config exists for this particular domain. + for _, workflowConfig := range q.config.QueueConfiguration().GetWorkflowConfigs() { + if workflowConfig.Domain == identifier.Domain { + tags = workflowConfig.Tags + } else if len(workflowConfig.Domain) == 0 { + defaultTags = workflowConfig.Tags + } } - queue = q.getQueueForProject(identifier) - if queue != nil { - logger.Debugf(ctx, "Found queue for project [%s]: %v", identifier.Project, queue) - return *queue + if len(tags) == 0 { + // Use the uber-default queue + tags = defaultTags } - return q.defaultQueue + for _, tag := range tags { + matches, ok := q.queueConfigMap[tag] + if !ok { + continue + } + return matches[rand.Intn(len(matches))] + } + logger.Infof(ctx, "found no matching queue for [%+v]", identifier) + return singleQueueConfiguration{} } -func NewQueueAllocator(config runtimeInterfaces.Configuration) QueueAllocator { +func NewQueueAllocator(config runtimeInterfaces.Configuration, db repositories.RepositoryInterface) QueueAllocator { queueAllocator := queueAllocatorImpl{ config: config, + db: db, } return &queueAllocator } diff --git a/pkg/manager/impl/executions/queues_test.go b/pkg/manager/impl/executions/queues_test.go index bd9e72c375..f107b5d322 100644 --- a/pkg/manager/impl/executions/queues_test.go +++ b/pkg/manager/impl/executions/queues_test.go @@ -4,6 +4,13 @@ import ( "context" "testing" + "github.com/golang/protobuf/proto" + "github.com/lyft/flyteadmin/pkg/errors" + "github.com/lyft/flyteadmin/pkg/repositories/mocks" + "github.com/lyft/flyteadmin/pkg/repositories/models" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "google.golang.org/grpc/codes" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" runtimeInterfaces "github.com/lyft/flyteadmin/pkg/runtime/interfaces" @@ -11,25 +18,9 @@ import ( "github.com/stretchr/testify/assert" ) -func TestGetAnyMapKey(t *testing.T) { - fooConfig := singleQueueConfiguration{ - PrimaryQueue: "foo primary", - DynamicQueue: "foo dynamic", - } - testMap := map[singleQueueConfiguration]bool{ - fooConfig: true, - } - key := getAnyMapKey(testMap) - assert.Equal(t, fooConfig, key) - - barConfig := singleQueueConfiguration{ - PrimaryQueue: "bar primary", - DynamicQueue: "bar dynamic", - } - testMap[barConfig] = true - key = getAnyMapKey(testMap) - assert.Contains(t, []string{"foo primary", "bar primary"}, key.PrimaryQueue) -} +const testProject = "project" +const testDomain = "domain" +const testWorkflow = "name" func TestGetQueue(t *testing.T) { executionQueues := []runtimeInterfaces.ExecutionQueue{ @@ -39,35 +30,43 @@ func TestGetQueue(t *testing.T) { Attributes: []string{"attribute"}, }, } - workflowConfigs := []runtimeInterfaces.WorkflowConfig{ - { - Project: "project", - Domain: "domain", - WorkflowName: "name", - Tags: []string{"attribute"}, - }, - { - Project: "project", - Domain: "domain", - WorkflowName: "name2", - Tags: []string{"another attribute"}, - }, - { - Project: "project", - Domain: "domain2", - WorkflowName: "name", - Tags: []string{"another attribute"}, - }, - { - Project: "project2", - Domain: "domain", - WorkflowName: "name", - Tags: []string{"another attribute"}, - }, + db := mocks.NewMockRepository() + db.WorkflowAttributesRepo().(*mocks.MockWorkflowAttributesRepo).GetFunction = func( + ctx context.Context, project, domain, workflow, resource string) ( + models.WorkflowAttributes, error) { + response := models.WorkflowAttributes{ + Project: project, + Domain: domain, + Workflow: workflow, + Resource: resource, + } + if project == testProject && domain == testDomain && workflow == testWorkflow { + matchingAttributes := &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_ExecutionQueueAttributes{ + ExecutionQueueAttributes: &admin.ExecutionQueueAttributes{ + Tags: []string{"attribute"}, + }, + }, + } + marshalledMatchingAttributes, _ := proto.Marshal(matchingAttributes) + response.Attributes = marshalledMatchingAttributes + } else { + matchingAttributes := &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_ExecutionQueueAttributes{ + ExecutionQueueAttributes: &admin.ExecutionQueueAttributes{ + Tags: []string{"another attribute"}, + }, + }, + } + marshalledMatchingAttributes, _ := proto.Marshal(matchingAttributes) + response.Attributes = marshalledMatchingAttributes + } + return response, nil } + queueAllocator := NewQueueAllocator(runtimeMocks.NewMockConfigurationProvider( - nil, runtimeMocks.NewMockQueueConfigurationProvider(executionQueues, workflowConfigs), - nil, nil, nil, nil)) + nil, runtimeMocks.NewMockQueueConfigurationProvider(executionQueues, nil), + nil, nil, nil, nil), db) queueConfig := singleQueueConfiguration{ PrimaryQueue: "queue primary", DynamicQueue: "queue dynamic", @@ -121,25 +120,74 @@ func TestGetQueueDefaults(t *testing.T) { { Tags: []string{"default"}, }, - { - Project: "project", - Tags: []string{"attr1"}, - }, - { - Project: "project", - Domain: "domain", - Tags: []string{"attr2"}, - }, - { - Project: "project", - Domain: "domain", - WorkflowName: "workflow", - Tags: []string{"attr3"}, - }, } + db := mocks.NewMockRepository() + db.WorkflowAttributesRepo().(*mocks.MockWorkflowAttributesRepo).GetFunction = func( + ctx context.Context, project, domain, workflow, resource string) ( + models.WorkflowAttributes, error) { + if project == testProject && domain == testDomain && workflow == "workflow" && + resource == admin.MatchableResource_EXECUTION_QUEUE.String() { + matchingAttributes := &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_ExecutionQueueAttributes{ + ExecutionQueueAttributes: &admin.ExecutionQueueAttributes{ + Tags: []string{"attr3"}, + }, + }, + } + marshalledMatchingAttributes, _ := proto.Marshal(matchingAttributes) + return models.WorkflowAttributes{ + Project: project, + Domain: domain, + Workflow: workflow, + Resource: resource, + Attributes: marshalledMatchingAttributes, + }, nil + } + return models.WorkflowAttributes{}, errors.NewFlyteAdminError(codes.NotFound, "foo") + } + db.ProjectDomainAttributesRepo().(*mocks.MockProjectDomainAttributesRepo).GetFunction = func( + ctx context.Context, project, domain, resource string) (models.ProjectDomainAttributes, error) { + if project == testProject && domain == testDomain && resource == admin.MatchableResource_EXECUTION_QUEUE.String() { + matchingAttributes := &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_ExecutionQueueAttributes{ + ExecutionQueueAttributes: &admin.ExecutionQueueAttributes{ + Tags: []string{"attr2"}, + }, + }, + } + marshalledMatchingAttributes, _ := proto.Marshal(matchingAttributes) + return models.ProjectDomainAttributes{ + Project: project, + Domain: domain, + Resource: resource, + Attributes: marshalledMatchingAttributes, + }, nil + } + return models.ProjectDomainAttributes{}, errors.NewFlyteAdminError(codes.NotFound, "foo") + } + db.ProjectAttributesRepo().(*mocks.MockProjectAttributesRepo).GetFunction = func( + ctx context.Context, project, resource string) (models.ProjectAttributes, error) { + if project == testProject && resource == admin.MatchableResource_EXECUTION_QUEUE.String() { + matchingAttributes := &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_ExecutionQueueAttributes{ + ExecutionQueueAttributes: &admin.ExecutionQueueAttributes{ + Tags: []string{"attr1"}, + }, + }, + } + marshalledMatchingAttributes, _ := proto.Marshal(matchingAttributes) + return models.ProjectAttributes{ + Project: project, + Resource: resource, + Attributes: marshalledMatchingAttributes, + }, nil + } + return models.ProjectAttributes{}, errors.NewFlyteAdminError(codes.NotFound, "foo") + } + queueAllocator := NewQueueAllocator(runtimeMocks.NewMockConfigurationProvider( nil, runtimeMocks.NewMockQueueConfigurationProvider(executionQueues, workflowConfigs), nil, - nil, nil, nil)) + nil, nil, nil), db) assert.Equal(t, singleQueueConfiguration{ PrimaryQueue: "default primary", DynamicQueue: "default dynamic", diff --git a/pkg/manager/impl/project_attributes_manager.go b/pkg/manager/impl/project_attributes_manager.go new file mode 100644 index 0000000000..810485d8e7 --- /dev/null +++ b/pkg/manager/impl/project_attributes_manager.go @@ -0,0 +1,43 @@ +package impl + +import ( + "context" + + "github.com/lyft/flyteadmin/pkg/manager/impl/validation" + "github.com/lyft/flyteadmin/pkg/repositories/transformers" + + "github.com/lyft/flyteadmin/pkg/manager/interfaces" + "github.com/lyft/flyteadmin/pkg/repositories" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" +) + +type ProjectAttributesManager struct { + db repositories.RepositoryInterface +} + +func (m *ProjectAttributesManager) UpdateProjectAttributes( + ctx context.Context, request admin.ProjectAttributesUpdateRequest) ( + *admin.ProjectAttributesUpdateResponse, error) { + var resource admin.MatchableResource + var err error + if resource, err = validation.ValidateProjectAttributesUpdateRequest(request); err != nil { + return nil, err + } + + model, err := transformers.ToProjectAttributesModel(*request.Attributes, resource) + if err != nil { + return nil, err + } + err = m.db.ProjectAttributesRepo().CreateOrUpdate(ctx, model) + if err != nil { + return nil, err + } + + return &admin.ProjectAttributesUpdateResponse{}, nil +} + +func NewProjectAttributesManager(db repositories.RepositoryInterface) interfaces.ProjectAttributesInterface { + return &ProjectAttributesManager{ + db: db, + } +} diff --git a/pkg/manager/impl/project_attributes_manager_test.go b/pkg/manager/impl/project_attributes_manager_test.go new file mode 100644 index 0000000000..6add5fa418 --- /dev/null +++ b/pkg/manager/impl/project_attributes_manager_test.go @@ -0,0 +1,37 @@ +package impl + +import ( + "context" + "testing" + + "github.com/golang/protobuf/proto" + "github.com/lyft/flyteadmin/pkg/manager/impl/testutils" + "github.com/lyft/flyteadmin/pkg/repositories/mocks" + "github.com/lyft/flyteadmin/pkg/repositories/models" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/stretchr/testify/assert" +) + +func TestUpdateProjectAttributes(t *testing.T) { + request := admin.ProjectAttributesUpdateRequest{ + Attributes: &admin.ProjectAttributes{ + Project: "project", + MatchingAttributes: testutils.ExecutionQueueAttributes, + }, + } + db := mocks.NewMockRepository() + expectedSerializedAttrs, _ := proto.Marshal(testutils.ExecutionQueueAttributes) + var createOrUpdateCalled bool + db.ProjectAttributesRepo().(*mocks.MockProjectAttributesRepo).CreateOrUpdateFunction = func( + ctx context.Context, input models.ProjectAttributes) error { + assert.Equal(t, "project", input.Project) + assert.Equal(t, admin.MatchableResource_EXECUTION_QUEUE.String(), input.Resource) + assert.EqualValues(t, expectedSerializedAttrs, input.Attributes) + createOrUpdateCalled = true + return nil + } + manager := NewProjectAttributesManager(db) + _, err := manager.UpdateProjectAttributes(context.Background(), request) + assert.Nil(t, err) + assert.True(t, createOrUpdateCalled) +} diff --git a/pkg/manager/impl/project_domain_manager.go b/pkg/manager/impl/project_domain_attributes_manager.go similarity index 53% rename from pkg/manager/impl/project_domain_manager.go rename to pkg/manager/impl/project_domain_attributes_manager.go index f790edb92c..a2831e58ee 100644 --- a/pkg/manager/impl/project_domain_manager.go +++ b/pkg/manager/impl/project_domain_attributes_manager.go @@ -10,28 +10,28 @@ import ( "github.com/lyft/flyteadmin/pkg/manager/interfaces" "github.com/lyft/flyteadmin/pkg/repositories" - runtimeInterfaces "github.com/lyft/flyteadmin/pkg/runtime/interfaces" "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" ) -type ProjectDomainManager struct { - db repositories.RepositoryInterface - config runtimeInterfaces.Configuration +type ProjectDomainAttributesManager struct { + db repositories.RepositoryInterface } -func (m *ProjectDomainManager) UpdateProjectDomain( +func (m *ProjectDomainAttributesManager) UpdateProjectDomainAttributes( ctx context.Context, request admin.ProjectDomainAttributesUpdateRequest) ( *admin.ProjectDomainAttributesUpdateResponse, error) { - if err := validation.ValidateProjectDomainAttributesUpdateRequest(request); err != nil { + var resource admin.MatchableResource + var err error + if resource, err = validation.ValidateProjectDomainAttributesUpdateRequest(request); err != nil { return nil, err } ctx = contextutils.WithProjectDomain(ctx, request.Attributes.Project, request.Attributes.Domain) - model, err := transformers.ToProjectDomainModel(*request.Attributes) + model, err := transformers.ToProjectDomainAttributesModel(*request.Attributes, resource) if err != nil { return nil, err } - err = m.db.ProjectDomainRepo().CreateOrUpdate(ctx, model) + err = m.db.ProjectDomainAttributesRepo().CreateOrUpdate(ctx, model) if err != nil { return nil, err } @@ -39,10 +39,9 @@ func (m *ProjectDomainManager) UpdateProjectDomain( return &admin.ProjectDomainAttributesUpdateResponse{}, nil } -func NewProjectDomainManager( - db repositories.RepositoryInterface, config runtimeInterfaces.Configuration) interfaces.ProjectDomainInterface { - return &ProjectDomainManager{ - db: db, - config: config, +func NewProjectDomainAttributesManager( + db repositories.RepositoryInterface) interfaces.ProjectDomainAttributesInterface { + return &ProjectDomainAttributesManager{ + db: db, } } diff --git a/pkg/manager/impl/project_domain_attributes_manager_test.go b/pkg/manager/impl/project_domain_attributes_manager_test.go new file mode 100644 index 0000000000..fc182fcb06 --- /dev/null +++ b/pkg/manager/impl/project_domain_attributes_manager_test.go @@ -0,0 +1,39 @@ +package impl + +import ( + "context" + "testing" + + "github.com/golang/protobuf/proto" + "github.com/lyft/flyteadmin/pkg/manager/impl/testutils" + "github.com/lyft/flyteadmin/pkg/repositories/mocks" + "github.com/lyft/flyteadmin/pkg/repositories/models" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/stretchr/testify/assert" +) + +func TestUpdateProjectDomainAttributes(t *testing.T) { + request := admin.ProjectDomainAttributesUpdateRequest{ + Attributes: &admin.ProjectDomainAttributes{ + Project: "project", + Domain: "domain", + MatchingAttributes: testutils.ExecutionQueueAttributes, + }, + } + db := mocks.NewMockRepository() + expectedSerializedAttrs, _ := proto.Marshal(testutils.ExecutionQueueAttributes) + var createOrUpdateCalled bool + db.ProjectDomainAttributesRepo().(*mocks.MockProjectDomainAttributesRepo).CreateOrUpdateFunction = func( + ctx context.Context, input models.ProjectDomainAttributes) error { + assert.Equal(t, "project", input.Project) + assert.Equal(t, "domain", input.Domain) + assert.Equal(t, admin.MatchableResource_EXECUTION_QUEUE.String(), input.Resource) + assert.EqualValues(t, expectedSerializedAttrs, input.Attributes) + createOrUpdateCalled = true + return nil + } + manager := NewProjectDomainAttributesManager(db) + _, err := manager.UpdateProjectDomainAttributes(context.Background(), request) + assert.Nil(t, err) + assert.True(t, createOrUpdateCalled) +} diff --git a/pkg/manager/impl/shared/constants.go b/pkg/manager/impl/shared/constants.go index dd498654f9..3210043cba 100644 --- a/pkg/manager/impl/shared/constants.go +++ b/pkg/manager/impl/shared/constants.go @@ -33,4 +33,6 @@ const ( ParentTaskExecutionID = "parent_task_execution_id" UserInputs = "user_inputs" ProjectDomain = "project_domain" + Attributes = "attributes" + MatchingAttributes = "matching_attributes" ) diff --git a/pkg/manager/impl/testutils/attributes.go b/pkg/manager/impl/testutils/attributes.go new file mode 100644 index 0000000000..a8a430b5b3 --- /dev/null +++ b/pkg/manager/impl/testutils/attributes.go @@ -0,0 +1,13 @@ +package testutils + +import "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + +var ExecutionQueueAttributes = &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_ExecutionQueueAttributes{ + ExecutionQueueAttributes: &admin.ExecutionQueueAttributes{ + Tags: []string{ + "foo", "bar", "baz", + }, + }, + }, +} diff --git a/pkg/manager/impl/validation/attributes_validator.go b/pkg/manager/impl/validation/attributes_validator.go new file mode 100644 index 0000000000..827879c895 --- /dev/null +++ b/pkg/manager/impl/validation/attributes_validator.go @@ -0,0 +1,74 @@ +package validation + +import ( + "fmt" + + "github.com/lyft/flyteadmin/pkg/errors" + "github.com/lyft/flyteadmin/pkg/manager/impl/shared" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "google.golang.org/grpc/codes" +) + +var defaultMatchableResource = admin.MatchableResource(-1) + +func validateMatchingAttributes(attributes *admin.MatchingAttributes, identifier string) (admin.MatchableResource, error) { + if attributes == nil { + return defaultMatchableResource, shared.GetMissingArgumentError(shared.MatchingAttributes) + } + if attributes.GetTaskResourceAttributes() != nil { + return admin.MatchableResource_TASK_RESOURCE, nil + } else if attributes.GetClusterResourceAttributes() != nil { + return admin.MatchableResource_CLUSTER_RESOURCE, nil + } else if attributes.GetExecutionQueueAttributes() != nil { + return admin.MatchableResource_EXECUTION_QUEUE, nil + } + return defaultMatchableResource, errors.NewFlyteAdminErrorf(codes.InvalidArgument, + "Unrecognized matching attributes type for request %s", identifier) +} + +func ValidateProjectAttributesUpdateRequest(request admin.ProjectAttributesUpdateRequest) ( + admin.MatchableResource, error) { + if request.Attributes == nil { + return defaultMatchableResource, shared.GetMissingArgumentError(shared.Attributes) + } + if err := ValidateEmptyStringField(request.Attributes.Project, shared.Project); err != nil { + return defaultMatchableResource, err + } + + return validateMatchingAttributes(request.Attributes.MatchingAttributes, request.Attributes.Project) +} + +func ValidateProjectDomainAttributesUpdateRequest(request admin.ProjectDomainAttributesUpdateRequest) ( + admin.MatchableResource, error) { + if request.Attributes == nil { + return defaultMatchableResource, shared.GetMissingArgumentError(shared.Attributes) + } + if err := ValidateEmptyStringField(request.Attributes.Project, shared.Project); err != nil { + return defaultMatchableResource, err + } + if err := ValidateEmptyStringField(request.Attributes.Domain, shared.Domain); err != nil { + return defaultMatchableResource, err + } + + return validateMatchingAttributes(request.Attributes.MatchingAttributes, + fmt.Sprintf("%s-%s", request.Attributes.Project, request.Attributes.Domain)) +} + +func ValidateWorkflowAttributesUpdateRequest(request admin.WorkflowAttributesUpdateRequest) ( + admin.MatchableResource, error) { + if request.Attributes == nil { + return defaultMatchableResource, shared.GetMissingArgumentError(shared.Attributes) + } + if err := ValidateEmptyStringField(request.Attributes.Project, shared.Project); err != nil { + return defaultMatchableResource, err + } + if err := ValidateEmptyStringField(request.Attributes.Domain, shared.Domain); err != nil { + return defaultMatchableResource, err + } + if err := ValidateEmptyStringField(request.Attributes.Workflow, shared.Name); err != nil { + return defaultMatchableResource, err + } + + return validateMatchingAttributes(request.Attributes.MatchingAttributes, + fmt.Sprintf("%s-%s-%s", request.Attributes.Project, request.Attributes.Domain, request.Attributes.Workflow)) +} diff --git a/pkg/manager/impl/validation/attributes_validator_test.go b/pkg/manager/impl/validation/attributes_validator_test.go new file mode 100644 index 0000000000..face354f78 --- /dev/null +++ b/pkg/manager/impl/validation/attributes_validator_test.go @@ -0,0 +1,166 @@ +package validation + +import ( + "testing" + + "github.com/lyft/flyteadmin/pkg/errors" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/codes" +) + +func TestValidateMatchingAttributes(t *testing.T) { + testCases := []struct { + attributes *admin.MatchingAttributes + identifier string + expectedMatchableResource admin.MatchableResource + expectedErr error + }{ + { + nil, + "foo", + defaultMatchableResource, + errors.NewFlyteAdminErrorf(codes.InvalidArgument, "missing matching_attributes"), + }, + { + &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_TaskResourceAttributes{ + TaskResourceAttributes: &admin.TaskResourceAttributes{ + Defaults: &admin.TaskResourceSpec{ + Cpu: "1", + }, + }, + }, + }, + "foo", + admin.MatchableResource_TASK_RESOURCE, + nil, + }, + { + &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_ClusterResourceAttributes{ + ClusterResourceAttributes: &admin.ClusterResourceAttributes{ + Attributes: map[string]string{ + "bar": "baz", + }, + }, + }, + }, + "foo", + admin.MatchableResource_CLUSTER_RESOURCE, + nil, + }, + { + &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_ExecutionQueueAttributes{ + ExecutionQueueAttributes: &admin.ExecutionQueueAttributes{ + Tags: []string{"bar", "baz"}, + }, + }, + }, + "foo", + admin.MatchableResource_EXECUTION_QUEUE, + nil, + }, + } + for _, tc := range testCases { + matchableResource, err := validateMatchingAttributes(tc.attributes, tc.identifier) + assert.Equal(t, tc.expectedMatchableResource, matchableResource) + assert.EqualValues(t, tc.expectedErr, err) + } +} + +func TestValidateProjectAttributesUpdateRequest(t *testing.T) { + _, err := ValidateProjectAttributesUpdateRequest(admin.ProjectAttributesUpdateRequest{}) + assert.Equal(t, "missing attributes", err.Error()) + + _, err = ValidateProjectAttributesUpdateRequest(admin.ProjectAttributesUpdateRequest{ + Attributes: &admin.ProjectAttributes{}}) + assert.Equal(t, "missing project", err.Error()) + + matchableResource, err := ValidateProjectAttributesUpdateRequest(admin.ProjectAttributesUpdateRequest{ + Attributes: &admin.ProjectAttributes{ + Project: "project", + MatchingAttributes: &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_TaskResourceAttributes{ + TaskResourceAttributes: &admin.TaskResourceAttributes{ + Defaults: &admin.TaskResourceSpec{ + Cpu: "1", + }, + }, + }, + }, + }}) + assert.Equal(t, admin.MatchableResource_TASK_RESOURCE, matchableResource) + assert.Nil(t, err) +} + +func TestValidateProjectDomainAttributesUpdateRequest(t *testing.T) { + _, err := ValidateProjectDomainAttributesUpdateRequest(admin.ProjectDomainAttributesUpdateRequest{}) + assert.Equal(t, "missing attributes", err.Error()) + + _, err = ValidateProjectDomainAttributesUpdateRequest(admin.ProjectDomainAttributesUpdateRequest{ + Attributes: &admin.ProjectDomainAttributes{}}) + assert.Equal(t, "missing project", err.Error()) + + _, err = ValidateProjectDomainAttributesUpdateRequest(admin.ProjectDomainAttributesUpdateRequest{ + Attributes: &admin.ProjectDomainAttributes{ + Project: "project", + }}) + assert.Equal(t, "missing domain", err.Error()) + + matchableResource, err := ValidateProjectDomainAttributesUpdateRequest(admin.ProjectDomainAttributesUpdateRequest{ + Attributes: &admin.ProjectDomainAttributes{ + Project: "project", + Domain: "domain", + MatchingAttributes: &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_ClusterResourceAttributes{ + ClusterResourceAttributes: &admin.ClusterResourceAttributes{ + Attributes: map[string]string{ + "bar": "baz", + }, + }, + }, + }, + }}) + assert.Equal(t, admin.MatchableResource_CLUSTER_RESOURCE, matchableResource) + assert.Nil(t, err) +} + +func TestValidateWorkflowAttributesUpdateRequest(t *testing.T) { + _, err := ValidateWorkflowAttributesUpdateRequest(admin.WorkflowAttributesUpdateRequest{}) + assert.Equal(t, "missing attributes", err.Error()) + + _, err = ValidateWorkflowAttributesUpdateRequest(admin.WorkflowAttributesUpdateRequest{ + Attributes: &admin.WorkflowAttributes{}}) + assert.Equal(t, "missing project", err.Error()) + + _, err = ValidateWorkflowAttributesUpdateRequest(admin.WorkflowAttributesUpdateRequest{ + Attributes: &admin.WorkflowAttributes{ + Project: "project", + }}) + assert.Equal(t, "missing domain", err.Error()) + + _, err = ValidateWorkflowAttributesUpdateRequest(admin.WorkflowAttributesUpdateRequest{ + Attributes: &admin.WorkflowAttributes{ + Project: "project", + Domain: "domain", + }}) + assert.Equal(t, "missing name", err.Error()) + + matchableResource, err := ValidateWorkflowAttributesUpdateRequest(admin.WorkflowAttributesUpdateRequest{ + Attributes: &admin.WorkflowAttributes{ + Project: "project", + Domain: "domain", + Workflow: "workflow", + MatchingAttributes: &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_ExecutionQueueAttributes{ + ExecutionQueueAttributes: &admin.ExecutionQueueAttributes{ + Tags: []string{"bar", "baz"}, + }, + }, + }, + }}) + assert.Equal(t, admin.MatchableResource_EXECUTION_QUEUE, matchableResource) + assert.Nil(t, err) +} diff --git a/pkg/manager/impl/validation/project_domain_validator.go b/pkg/manager/impl/validation/project_domain_validator.go deleted file mode 100644 index 718e2b7758..0000000000 --- a/pkg/manager/impl/validation/project_domain_validator.go +++ /dev/null @@ -1,20 +0,0 @@ -package validation - -import ( - "github.com/lyft/flyteadmin/pkg/manager/impl/shared" - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" -) - -func ValidateProjectDomainAttributesUpdateRequest(request admin.ProjectDomainAttributesUpdateRequest) error { - if request.Attributes == nil { - return shared.GetMissingArgumentError(shared.ProjectDomain) - } - if err := ValidateEmptyStringField(request.Attributes.Project, shared.Project); err != nil { - return err - } - if err := ValidateEmptyStringField(request.Attributes.Domain, shared.Domain); err != nil { - return err - } - // Resource attributes are not a required field and therefore are not checked in validation. - return nil -} diff --git a/pkg/manager/impl/validation/project_domain_validator_test.go b/pkg/manager/impl/validation/project_domain_validator_test.go deleted file mode 100644 index 29341e76c5..0000000000 --- a/pkg/manager/impl/validation/project_domain_validator_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package validation - -import ( - "testing" - - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" - "github.com/stretchr/testify/assert" -) - -func TestValidateProjectDomainAttributesUpdateRequest(t *testing.T) { - err := ValidateProjectDomainAttributesUpdateRequest(admin.ProjectDomainAttributesUpdateRequest{}) - assert.EqualError(t, err, "missing project_domain") - - err = ValidateProjectDomainAttributesUpdateRequest(admin.ProjectDomainAttributesUpdateRequest{ - Attributes: &admin.ProjectDomainAttributes{}, - }) - assert.EqualError(t, err, "missing project") - - err = ValidateProjectDomainAttributesUpdateRequest(admin.ProjectDomainAttributesUpdateRequest{ - Attributes: &admin.ProjectDomainAttributes{ - Project: "project", - }, - }) - assert.EqualError(t, err, "missing domain") - - err = ValidateProjectDomainAttributesUpdateRequest(admin.ProjectDomainAttributesUpdateRequest{ - Attributes: &admin.ProjectDomainAttributes{ - Project: "project", - Domain: "domain", - }, - }) - assert.Nil(t, err) -} diff --git a/pkg/manager/impl/validation/task_validator.go b/pkg/manager/impl/validation/task_validator.go index 1849de20d3..5627685297 100644 --- a/pkg/manager/impl/validation/task_validator.go +++ b/pkg/manager/impl/validation/task_validator.go @@ -4,6 +4,8 @@ package validation import ( "context" + "github.com/lyft/flyteadmin/pkg/resourcematching" + "github.com/lyft/flyteadmin/pkg/repositories" "github.com/lyft/flyteadmin/pkg/common" @@ -222,7 +224,7 @@ func validateTaskResources( func assignResourcesIfUnset(ctx context.Context, identifier *core.Identifier, platformValues runtimeInterfaces.TaskResourceSet, - resourceEntries []*core.Resources_ResourceEntry) []*core.Resources_ResourceEntry { + resourceEntries []*core.Resources_ResourceEntry, taskResourceSpec *admin.TaskResourceSpec) []*core.Resources_ResourceEntry { var cpuIndex, memoryIndex = -1, -1 for idx, entry := range resourceEntries { switch entry.Name { @@ -232,22 +234,33 @@ func assignResourcesIfUnset(ctx context.Context, identifier *core.Identifier, memoryIndex = idx } } - if cpuIndex > 0 && memoryIndex > 00 { + if cpuIndex > 0 && memoryIndex > 0 { // nothing to do return resourceEntries } + if cpuIndex < 0 && platformValues.CPU != "" { logger.Debugf(ctx, "Setting 'cpu' for [%+v] to %s", identifier, platformValues.CPU) + cpuValue := platformValues.CPU + if taskResourceSpec != nil && len(taskResourceSpec.Cpu) > 0 { + // Use the custom attributes from the database rather than the platform defaults from the application config + cpuValue = taskResourceSpec.Cpu + } cpuResource := &core.Resources_ResourceEntry{ Name: core.Resources_CPU, - Value: platformValues.CPU, + Value: cpuValue, } resourceEntries = append(resourceEntries, cpuResource) } if memoryIndex < 0 && platformValues.Memory != "" { + memoryValue := platformValues.Memory + if taskResourceSpec != nil && len(taskResourceSpec.Memory) > 0 { + // Use the custom attributes from the database rather than the platform defaults from the application config + memoryValue = taskResourceSpec.Memory + } memoryResource := &core.Resources_ResourceEntry{ Name: core.Resources_MEMORY, - Value: platformValues.Memory, + Value: memoryValue, } logger.Debugf(ctx, "Setting 'memory' for [%+v] to %s", identifier, platformValues.Memory) resourceEntries = append(resourceEntries, memoryResource) @@ -260,7 +273,8 @@ func assignResourcesIfUnset(ctx context.Context, identifier *core.Identifier, // Note: The system will assign a system-default value for request but for limit it will deduce it from the request // itself => Limit := Min([Some-Multiplier X Request], System-Max). For now we are using a multiplier of 1. In // general we recommend the users to set limits close to requests for more predictability in the system. -func SetDefaults(ctx context.Context, taskConfig runtime.TaskResourceConfiguration, task *core.CompiledTask) { +func SetDefaults(ctx context.Context, taskConfig runtime.TaskResourceConfiguration, task *core.CompiledTask, + db repositories.RepositoryInterface, workflowName string) { if task == nil { logger.Warningf(ctx, "Can't set default resources for nil task.") return @@ -270,12 +284,35 @@ func SetDefaults(ctx context.Context, taskConfig runtime.TaskResourceConfigurati logger.Debugf(ctx, "Not setting default resources for task [%+v], no container resources found to check", task) return } + + attributes, err := resourcematching.GetOverrideValuesToApply(ctx, resourcematching.GetOverrideValuesInput{ + Db: db, + Project: task.Template.Id.Project, + Domain: task.Template.Id.Domain, + Workflow: workflowName, + Resource: admin.MatchableResource_TASK_RESOURCE, + }) + if err != nil { + logger.Warningf(ctx, "Failed to fetch override values when assigning task resource default values for [%+v]: %v", + task.Template, err) + } + logger.Debugf(ctx, "Assigning task requested resources for [%+v]", task.Template.Id) + var taskResourceSpec *admin.TaskResourceSpec + if attributes != nil && attributes.GetTaskResourceAttributes() != nil { + taskResourceSpec = attributes.GetTaskResourceAttributes().Defaults + } task.Template.GetContainer().Resources.Requests = assignResourcesIfUnset( - ctx, task.Template.Id, taskConfig.GetDefaults(), task.Template.GetContainer().Resources.Requests) + ctx, task.Template.Id, taskConfig.GetDefaults(), task.Template.GetContainer().Resources.Requests, + taskResourceSpec) + logger.Debugf(ctx, "Assigning task resource limits for [%+v]", task.Template.Id) + if attributes != nil && attributes.GetTaskResourceAttributes() != nil { + taskResourceSpec = attributes.GetTaskResourceAttributes().Limits + } task.Template.GetContainer().Resources.Limits = assignResourcesIfUnset( - ctx, task.Template.Id, createTaskDefaultLimits(ctx, task), task.Template.GetContainer().Resources.Limits) + ctx, task.Template.Id, createTaskDefaultLimits(ctx, task), task.Template.GetContainer().Resources.Limits, + taskResourceSpec) } func createTaskDefaultLimits(ctx context.Context, task *core.CompiledTask) runtimeInterfaces.TaskResourceSet { diff --git a/pkg/manager/impl/validation/task_validator_test.go b/pkg/manager/impl/validation/task_validator_test.go index 080918b968..15151dc376 100644 --- a/pkg/manager/impl/validation/task_validator_test.go +++ b/pkg/manager/impl/validation/task_validator_test.go @@ -6,6 +6,9 @@ import ( "fmt" "testing" + "github.com/lyft/flyteadmin/pkg/repositories/mocks" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/golang/protobuf/proto" "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" @@ -414,6 +417,35 @@ func TestIsWholeNumber(t *testing.T) { } } +func TestAssignResourcesIfUnset(t *testing.T) { + platformValues := runtimeInterfaces.TaskResourceSet{ + CPU: "200m", + GPU: "8", + Memory: "200Gi", + } + taskResourceSpec := &admin.TaskResourceSpec{ + Cpu: "400m", + Memory: "400Gi", + } + assignedResources := assignResourcesIfUnset(context.Background(), &core.Identifier{ + Project: "project", + Domain: "domain", + Name: "name", + Version: "version", + }, platformValues, []*core.Resources_ResourceEntry{}, taskResourceSpec) + + assert.EqualValues(t, []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: taskResourceSpec.Cpu, + }, + { + Name: core.Resources_MEMORY, + Value: taskResourceSpec.Memory, + }, + }, assignedResources) +} + func TestSetDefaults(t *testing.T) { task := &core.CompiledTask{ Template: &core.TaskTemplate{ @@ -429,6 +461,12 @@ func TestSetDefaults(t *testing.T) { }, }, }, + Id: &core.Identifier{ + Project: "project", + Domain: "domain", + Name: "task_name", + Version: "version", + }, }, } @@ -443,7 +481,7 @@ func TestSetDefaults(t *testing.T) { GPU: "8", Memory: "500Gi", } - SetDefaults(context.Background(), &taskConfig, task) + SetDefaults(context.Background(), &taskConfig, task, mocks.NewMockRepository(), "workflow") assert.True(t, proto.Equal( &core.Container{ Resources: &core.Resources{ @@ -487,6 +525,12 @@ func TestSetDefaults_MissingDefaults(t *testing.T) { }, }, }, + Id: &core.Identifier{ + Project: "project", + Domain: "domain", + Name: "task_name", + Version: "version", + }, }, } @@ -500,7 +544,7 @@ func TestSetDefaults_MissingDefaults(t *testing.T) { CPU: "300m", GPU: "8", } - SetDefaults(context.Background(), &taskConfig, task) + SetDefaults(context.Background(), &taskConfig, task, mocks.NewMockRepository(), "workflow") assert.True(t, proto.Equal( &core.Container{ Resources: &core.Resources{ diff --git a/pkg/manager/impl/workflow_attributes_manager.go b/pkg/manager/impl/workflow_attributes_manager.go new file mode 100644 index 0000000000..e20e2f2326 --- /dev/null +++ b/pkg/manager/impl/workflow_attributes_manager.go @@ -0,0 +1,43 @@ +package impl + +import ( + "context" + + "github.com/lyft/flyteadmin/pkg/manager/impl/validation" + "github.com/lyft/flyteadmin/pkg/repositories/transformers" + + "github.com/lyft/flyteadmin/pkg/manager/interfaces" + "github.com/lyft/flyteadmin/pkg/repositories" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" +) + +type WorkflowAttributesManager struct { + db repositories.RepositoryInterface +} + +func (m *WorkflowAttributesManager) UpdateWorkflowAttributes( + ctx context.Context, request admin.WorkflowAttributesUpdateRequest) ( + *admin.WorkflowAttributesUpdateResponse, error) { + var resource admin.MatchableResource + var err error + if resource, err = validation.ValidateWorkflowAttributesUpdateRequest(request); err != nil { + return nil, err + } + + model, err := transformers.ToWorkflowAttributesModel(*request.Attributes, resource) + if err != nil { + return nil, err + } + err = m.db.WorkflowAttributesRepo().CreateOrUpdate(ctx, model) + if err != nil { + return nil, err + } + + return &admin.WorkflowAttributesUpdateResponse{}, nil +} + +func NewWorkflowAttributesManager(db repositories.RepositoryInterface) interfaces.WorkflowAttributesInterface { + return &WorkflowAttributesManager{ + db: db, + } +} diff --git a/pkg/manager/impl/workflow_attributes_manager_test.go b/pkg/manager/impl/workflow_attributes_manager_test.go new file mode 100644 index 0000000000..052411b63f --- /dev/null +++ b/pkg/manager/impl/workflow_attributes_manager_test.go @@ -0,0 +1,41 @@ +package impl + +import ( + "context" + "testing" + + "github.com/golang/protobuf/proto" + "github.com/lyft/flyteadmin/pkg/manager/impl/testutils" + "github.com/lyft/flyteadmin/pkg/repositories/mocks" + "github.com/lyft/flyteadmin/pkg/repositories/models" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/stretchr/testify/assert" +) + +func TestUpdateWorkflowAttributes(t *testing.T) { + request := admin.WorkflowAttributesUpdateRequest{ + Attributes: &admin.WorkflowAttributes{ + Project: "project", + Domain: "domain", + Workflow: "workflow", + MatchingAttributes: testutils.ExecutionQueueAttributes, + }, + } + db := mocks.NewMockRepository() + expectedSerializedAttrs, _ := proto.Marshal(testutils.ExecutionQueueAttributes) + var createOrUpdateCalled bool + db.WorkflowAttributesRepo().(*mocks.MockWorkflowAttributesRepo).CreateOrUpdateFunction = func( + ctx context.Context, input models.WorkflowAttributes) error { + assert.Equal(t, "project", input.Project) + assert.Equal(t, "domain", input.Domain) + assert.Equal(t, "workflow", input.Workflow) + assert.Equal(t, admin.MatchableResource_EXECUTION_QUEUE.String(), input.Resource) + assert.EqualValues(t, expectedSerializedAttrs, input.Attributes) + createOrUpdateCalled = true + return nil + } + manager := NewWorkflowAttributesManager(db) + _, err := manager.UpdateWorkflowAttributes(context.Background(), request) + assert.Nil(t, err) + assert.True(t, createOrUpdateCalled) +} diff --git a/pkg/manager/interfaces/project_attributes.go b/pkg/manager/interfaces/project_attributes.go new file mode 100644 index 0000000000..9c0dcd43e3 --- /dev/null +++ b/pkg/manager/interfaces/project_attributes.go @@ -0,0 +1,13 @@ +package interfaces + +import ( + "context" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" +) + +// Interface for managing project-specific attributes. +type ProjectAttributesInterface interface { + UpdateProjectAttributes(ctx context.Context, request admin.ProjectAttributesUpdateRequest) ( + *admin.ProjectAttributesUpdateResponse, error) +} diff --git a/pkg/manager/interfaces/project_domain.go b/pkg/manager/interfaces/project_domain_attributes.go similarity index 58% rename from pkg/manager/interfaces/project_domain.go rename to pkg/manager/interfaces/project_domain_attributes.go index fbd22f31c1..c3f0b30133 100644 --- a/pkg/manager/interfaces/project_domain.go +++ b/pkg/manager/interfaces/project_domain_attributes.go @@ -7,7 +7,7 @@ import ( ) // Interface for managing projects and domain -specific attributes. -type ProjectDomainInterface interface { - UpdateProjectDomain(ctx context.Context, request admin.ProjectDomainAttributesUpdateRequest) ( +type ProjectDomainAttributesInterface interface { + UpdateProjectDomainAttributes(ctx context.Context, request admin.ProjectDomainAttributesUpdateRequest) ( *admin.ProjectDomainAttributesUpdateResponse, error) } diff --git a/pkg/manager/interfaces/workflow_attributes.go b/pkg/manager/interfaces/workflow_attributes.go new file mode 100644 index 0000000000..3897b36ef4 --- /dev/null +++ b/pkg/manager/interfaces/workflow_attributes.go @@ -0,0 +1,13 @@ +package interfaces + +import ( + "context" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" +) + +// Interface for managing project, domain and workflow -specific attributes. +type WorkflowAttributesInterface interface { + UpdateWorkflowAttributes(ctx context.Context, request admin.WorkflowAttributesUpdateRequest) ( + *admin.WorkflowAttributesUpdateResponse, error) +} diff --git a/pkg/manager/mocks/project_domain.go b/pkg/manager/mocks/project_domain.go index 365509cc81..d263b4a110 100644 --- a/pkg/manager/mocks/project_domain.go +++ b/pkg/manager/mocks/project_domain.go @@ -17,7 +17,7 @@ func (m *MockProjectDomainManager) SetUpdateProjectDomainAttributes(updateProjec m.updateProjectDomainFunc = updateProjectDomainFunc } -func (m *MockProjectDomainManager) UpdateProjectDomain( +func (m *MockProjectDomainManager) UpdateProjectDomainAttributes( ctx context.Context, request admin.ProjectDomainAttributesUpdateRequest) ( *admin.ProjectDomainAttributesUpdateResponse, error) { if m.updateProjectDomainFunc != nil { diff --git a/pkg/repositories/config/migrations.go b/pkg/repositories/config/migrations.go index fb07dd0bfd..b761680a21 100644 --- a/pkg/repositories/config/migrations.go +++ b/pkg/repositories/config/migrations.go @@ -139,24 +139,45 @@ var Migrations = []*gormigrate.Migration{ return tx.Exec("ALTER TABLE executions DROP COLUMN IF EXISTS InputsURI, DROP COLUMN IF EXISTS UserInputsURI").Error }, }, - // Add ProjectDomains with custom resource attributes. + // Create named_entity_metadata table. { - ID: "2019-10-28-project-domains", + ID: "2019-11-05-named-entity-metadata", Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&models.ProjectDomain{}).Error + return tx.AutoMigrate(&models.NamedEntityMetadata{}).Error }, Rollback: func(tx *gorm.DB) error { - return tx.DropTable("project_domains").Error + return tx.DropTable("named_entity_metadata").Error }, }, - // Create named_entity_metadata table. + // Add ProjectAttributes with custom resource attributes. { - ID: "2019-11-05-named-entity-metadata", + ID: "2019-12-30-project-attributes", Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&models.NamedEntityMetadata{}).Error + return tx.AutoMigrate(&models.ProjectAttributes{}).Error }, Rollback: func(tx *gorm.DB) error { - return tx.DropTable("named_entity_metadata").Error + return tx.DropTable("project_attributes").Error + }, + }, + + // Add ProjectDomainAttributes with custom resource attributes. + { + ID: "2019-12-30-project-domain-attributes", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&models.ProjectDomainAttributes{}).Error + }, + Rollback: func(tx *gorm.DB) error { + return tx.DropTable("project_domain_attributes").Error + }, + }, + // Add WorkflowAttributes with custom resource attributes. + { + ID: "2019-12-30-workflow-attributes", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&models.WorkflowAttributes{}).Error + }, + Rollback: func(tx *gorm.DB) error { + return tx.DropTable("workflow_attributes").Error }, }, } diff --git a/pkg/repositories/factory.go b/pkg/repositories/factory.go index b7d62964ea..720b77b432 100644 --- a/pkg/repositories/factory.go +++ b/pkg/repositories/factory.go @@ -28,7 +28,9 @@ type RepositoryInterface interface { LaunchPlanRepo() interfaces.LaunchPlanRepoInterface ExecutionRepo() interfaces.ExecutionRepoInterface ProjectRepo() interfaces.ProjectRepoInterface - ProjectDomainRepo() interfaces.ProjectDomainRepoInterface + ProjectAttributesRepo() interfaces.ProjectAttributesRepoInterface + ProjectDomainAttributesRepo() interfaces.ProjectDomainAttributesRepoInterface + WorkflowAttributesRepo() interfaces.WorkflowAttributesRepoInterface NodeExecutionRepo() interfaces.NodeExecutionRepoInterface TaskExecutionRepo() interfaces.TaskExecutionRepoInterface NamedEntityRepo() interfaces.NamedEntityRepoInterface diff --git a/pkg/repositories/gormimpl/project_domain_repo.go b/pkg/repositories/gormimpl/project_attributes_repo.go similarity index 53% rename from pkg/repositories/gormimpl/project_domain_repo.go rename to pkg/repositories/gormimpl/project_attributes_repo.go index 3d33954a75..63fa124df4 100644 --- a/pkg/repositories/gormimpl/project_domain_repo.go +++ b/pkg/repositories/gormimpl/project_attributes_repo.go @@ -13,18 +13,18 @@ import ( flyteAdminErrors "github.com/lyft/flyteadmin/pkg/errors" ) -type ProjectDomainRepo struct { +type ProjectAttributesRepo struct { db *gorm.DB errorTransformer errors.ErrorTransformer metrics gormMetrics } -func (r *ProjectDomainRepo) CreateOrUpdate(ctx context.Context, input models.ProjectDomain) error { +func (r *ProjectAttributesRepo) CreateOrUpdate(ctx context.Context, input models.ProjectAttributes) error { timer := r.metrics.GetDuration.Start() - var record models.ProjectDomain - tx := r.db.FirstOrCreate(&record, models.ProjectDomain{ - Project: input.Project, - Domain: input.Domain, + var record models.ProjectAttributes + tx := r.db.FirstOrCreate(&record, models.ProjectAttributes{ + Project: input.Project, + Resource: input.Resource, }) timer.Stop() if tx.Error != nil { @@ -41,28 +41,28 @@ func (r *ProjectDomainRepo) CreateOrUpdate(ctx context.Context, input models.Pro return nil } -func (r *ProjectDomainRepo) Get(ctx context.Context, project, domain string) (models.ProjectDomain, error) { - var model models.ProjectDomain +func (r *ProjectAttributesRepo) Get(ctx context.Context, project, resource string) (models.ProjectAttributes, error) { + var model models.ProjectAttributes timer := r.metrics.GetDuration.Start() - tx := r.db.Where(&models.ProjectDomain{ - Project: project, - Domain: domain, + tx := r.db.Where(&models.ProjectAttributes{ + Project: project, + Resource: resource, }).First(&model) timer.Stop() if tx.Error != nil { - return models.ProjectDomain{}, r.errorTransformer.ToFlyteAdminError(tx.Error) + return models.ProjectAttributes{}, r.errorTransformer.ToFlyteAdminError(tx.Error) } if tx.RecordNotFound() { - return models.ProjectDomain{}, flyteAdminErrors.NewFlyteAdminErrorf(codes.NotFound, - "project-domain [%s-%s] not found", project, domain) + return models.ProjectAttributes{}, flyteAdminErrors.NewFlyteAdminErrorf(codes.NotFound, + "project [%s] not found", project) } return model, nil } -func NewProjectDomainRepo(db *gorm.DB, errorTransformer errors.ErrorTransformer, - scope promutils.Scope) interfaces.ProjectDomainRepoInterface { +func NewProjectAttributesRepo(db *gorm.DB, errorTransformer errors.ErrorTransformer, + scope promutils.Scope) interfaces.ProjectAttributesRepoInterface { metrics := newMetrics(scope) - return &ProjectDomainRepo{ + return &ProjectAttributesRepo{ db: db, errorTransformer: errorTransformer, metrics: metrics, diff --git a/pkg/repositories/gormimpl/project_attributes_repo_test.go b/pkg/repositories/gormimpl/project_attributes_repo_test.go new file mode 100644 index 0000000000..f82bd84781 --- /dev/null +++ b/pkg/repositories/gormimpl/project_attributes_repo_test.go @@ -0,0 +1,57 @@ +package gormimpl + +import ( + "context" + "testing" + + mocket "github.com/Selvatico/go-mocket" + "github.com/lyft/flyteadmin/pkg/repositories/errors" + "github.com/lyft/flyteadmin/pkg/repositories/models" + mockScope "github.com/lyft/flytestdlib/promutils" + "github.com/stretchr/testify/assert" +) + +const testProjectAttr = "project" +const testResourceAttr = "resource" + +func TestCreateProjectAttributes(t *testing.T) { + projectRepo := NewProjectAttributesRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + GlobalMock := mocket.Catcher.Reset() + + query := GlobalMock.NewMock() + query.WithQuery( + `INSERT INTO "project_attributes" ` + + `("created_at","updated_at","deleted_at","project","resource","attributes") VALUES (?,?,?,?,?,?)`) + + err := projectRepo.CreateOrUpdate(context.Background(), models.ProjectAttributes{ + Project: testProjectAttr, + Resource: testResourceAttr, + Attributes: []byte("attrs"), + }) + assert.NoError(t, err) + assert.True(t, query.Triggered) +} + +func TestGetProjectAttributes(t *testing.T) { + projectRepo := NewProjectAttributesRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + GlobalMock := mocket.Catcher.Reset() + + response := make(map[string]interface{}) + response["project"] = testProjectAttr + response["resource"] = testResourceAttr + response["attributes"] = []byte("attrs") + + query := GlobalMock.NewMock() + query.WithQuery(`SELECT * FROM "project_attributes" WHERE "project_attributes"."deleted_at" ` + + `IS NULL AND (("project_attributes"."project" = project) AND ("project_attributes"."resource" = resource)) ` + + `ORDER BY "project_attributes"."id" ASC LIMIT 1`).WithReply( + []map[string]interface{}{ + response, + }) + + output, err := projectRepo.Get(context.Background(), "project", "resource") + assert.Nil(t, err) + assert.Equal(t, testProjectAttr, output.Project) + assert.Equal(t, testResourceAttr, output.Resource) + assert.Equal(t, []byte("attrs"), output.Attributes) +} diff --git a/pkg/repositories/gormimpl/project_domain_attributes_repo.go b/pkg/repositories/gormimpl/project_domain_attributes_repo.go new file mode 100644 index 0000000000..256f122222 --- /dev/null +++ b/pkg/repositories/gormimpl/project_domain_attributes_repo.go @@ -0,0 +1,73 @@ +package gormimpl + +import ( + "context" + + "github.com/jinzhu/gorm" + "github.com/lyft/flyteadmin/pkg/repositories/errors" + "github.com/lyft/flyteadmin/pkg/repositories/interfaces" + "github.com/lyft/flyteadmin/pkg/repositories/models" + "github.com/lyft/flytestdlib/promutils" + "google.golang.org/grpc/codes" + + flyteAdminErrors "github.com/lyft/flyteadmin/pkg/errors" +) + +type ProjectDomainAttributesRepo struct { + db *gorm.DB + errorTransformer errors.ErrorTransformer + metrics gormMetrics +} + +func (r *ProjectDomainAttributesRepo) CreateOrUpdate(ctx context.Context, input models.ProjectDomainAttributes) error { + timer := r.metrics.GetDuration.Start() + var record models.ProjectDomainAttributes + tx := r.db.FirstOrCreate(&record, models.ProjectDomainAttributes{ + Project: input.Project, + Domain: input.Domain, + Resource: input.Resource, + }) + timer.Stop() + if tx.Error != nil { + return r.errorTransformer.ToFlyteAdminError(tx.Error) + } + + timer = r.metrics.UpdateDuration.Start() + record.Attributes = input.Attributes + tx = r.db.Save(&record) + timer.Stop() + if tx.Error != nil { + return r.errorTransformer.ToFlyteAdminError(tx.Error) + } + return nil +} + +func (r *ProjectDomainAttributesRepo) Get(ctx context.Context, project, domain, resource string) ( + models.ProjectDomainAttributes, error) { + var model models.ProjectDomainAttributes + timer := r.metrics.GetDuration.Start() + tx := r.db.Where(&models.ProjectDomainAttributes{ + Project: project, + Domain: domain, + Resource: resource, + }).First(&model) + timer.Stop() + if tx.Error != nil { + return models.ProjectDomainAttributes{}, r.errorTransformer.ToFlyteAdminError(tx.Error) + } + if tx.RecordNotFound() { + return models.ProjectDomainAttributes{}, flyteAdminErrors.NewFlyteAdminErrorf(codes.NotFound, + "project-domain [%s-%s] not found", project, domain) + } + return model, nil +} + +func NewProjectDomainAttributesRepo(db *gorm.DB, errorTransformer errors.ErrorTransformer, + scope promutils.Scope) interfaces.ProjectDomainAttributesRepoInterface { + metrics := newMetrics(scope) + return &ProjectDomainAttributesRepo{ + db: db, + errorTransformer: errorTransformer, + metrics: metrics, + } +} diff --git a/pkg/repositories/gormimpl/project_domain_repo_test.go b/pkg/repositories/gormimpl/project_domain_attributes_repo_test.go similarity index 53% rename from pkg/repositories/gormimpl/project_domain_repo_test.go rename to pkg/repositories/gormimpl/project_domain_attributes_repo_test.go index c5f9cbb186..7648c2a505 100644 --- a/pkg/repositories/gormimpl/project_domain_repo_test.go +++ b/pkg/repositories/gormimpl/project_domain_attributes_repo_test.go @@ -11,44 +11,48 @@ import ( "github.com/stretchr/testify/assert" ) -func TestCreateProjectDomain(t *testing.T) { - projectRepo := NewProjectDomainRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) +func TestCreateProjectDomainAttributes(t *testing.T) { + projectRepo := NewProjectDomainAttributesRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) GlobalMock := mocket.Catcher.Reset() query := GlobalMock.NewMock() query.WithQuery( - `INSERT INTO "project_domains" ` + - `("created_at","updated_at","deleted_at","project","domain","attributes") VALUES (?,?,?,?,?,?)`) + `INSERT INTO "project_domain_attributes" ` + + `("created_at","updated_at","deleted_at","project","domain","resource","attributes") VALUES (?,?,?,?,?,?,?)`) - err := projectRepo.CreateOrUpdate(context.Background(), models.ProjectDomain{ + err := projectRepo.CreateOrUpdate(context.Background(), models.ProjectDomainAttributes{ Project: "project", Domain: "domain", + Resource: "resource", Attributes: []byte("attrs"), }) assert.NoError(t, err) assert.True(t, query.Triggered) } -func TestGetProjectDomain(t *testing.T) { - projectRepo := NewProjectDomainRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) +func TestGetProjectDomainAttributes(t *testing.T) { + projectRepo := NewProjectDomainAttributesRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) GlobalMock := mocket.Catcher.Reset() response := make(map[string]interface{}) response["project"] = "project" response["domain"] = "domain" + response["resource"] = "resource" response["attributes"] = []byte("attrs") query := GlobalMock.NewMock() - query.WithQuery(`SELECT * FROM "project_domains" WHERE "project_domains"."deleted_at" IS NULL AND ` + - `(("project_domains"."project" = project) AND ("project_domains"."domain" = domain)) ORDER BY ` + - `"project_domains"."id" ASC LIMIT 1`).WithReply( + query.WithQuery(`SELECT * FROM "project_domain_attributes" WHERE "project_domain_attributes"."deleted_at" ` + + `IS NULL AND (("project_domain_attributes"."project" = project) AND ("project_domain_attributes"."domain" = ` + + `domain) AND ("project_domain_attributes"."resource" = resource)) ORDER BY "project_domain_attributes"."id" ` + + `ASC LIMIT 1`).WithReply( []map[string]interface{}{ response, }) - output, err := projectRepo.Get(context.Background(), "project", "domain") + output, err := projectRepo.Get(context.Background(), "project", "domain", "resource") assert.Nil(t, err) assert.Equal(t, "project", output.Project) assert.Equal(t, "domain", output.Domain) + assert.Equal(t, "resource", output.Resource) assert.Equal(t, []byte("attrs"), output.Attributes) } diff --git a/pkg/repositories/gormimpl/workflow_attributes_repo.go b/pkg/repositories/gormimpl/workflow_attributes_repo.go new file mode 100644 index 0000000000..96ac9822b2 --- /dev/null +++ b/pkg/repositories/gormimpl/workflow_attributes_repo.go @@ -0,0 +1,75 @@ +package gormimpl + +import ( + "context" + + "github.com/jinzhu/gorm" + "github.com/lyft/flyteadmin/pkg/repositories/errors" + "github.com/lyft/flyteadmin/pkg/repositories/interfaces" + "github.com/lyft/flyteadmin/pkg/repositories/models" + "github.com/lyft/flytestdlib/promutils" + "google.golang.org/grpc/codes" + + flyteAdminErrors "github.com/lyft/flyteadmin/pkg/errors" +) + +type WorkflowAttributesRepo struct { + db *gorm.DB + errorTransformer errors.ErrorTransformer + metrics gormMetrics +} + +func (r *WorkflowAttributesRepo) CreateOrUpdate(ctx context.Context, input models.WorkflowAttributes) error { + timer := r.metrics.GetDuration.Start() + var record models.WorkflowAttributes + tx := r.db.FirstOrCreate(&record, models.WorkflowAttributes{ + Project: input.Project, + Domain: input.Domain, + Workflow: input.Workflow, + Resource: input.Resource, + }) + timer.Stop() + if tx.Error != nil { + return r.errorTransformer.ToFlyteAdminError(tx.Error) + } + + timer = r.metrics.UpdateDuration.Start() + record.Attributes = input.Attributes + tx = r.db.Save(&record) + timer.Stop() + if tx.Error != nil { + return r.errorTransformer.ToFlyteAdminError(tx.Error) + } + return nil +} + +func (r *WorkflowAttributesRepo) Get(ctx context.Context, project, domain, workflow, resource string) ( + models.WorkflowAttributes, error) { + var model models.WorkflowAttributes + timer := r.metrics.GetDuration.Start() + tx := r.db.Where(&models.WorkflowAttributes{ + Project: project, + Domain: domain, + Workflow: workflow, + Resource: resource, + }).First(&model) + timer.Stop() + if tx.Error != nil { + return models.WorkflowAttributes{}, r.errorTransformer.ToFlyteAdminError(tx.Error) + } + if tx.RecordNotFound() { + return models.WorkflowAttributes{}, flyteAdminErrors.NewFlyteAdminErrorf(codes.NotFound, + "project-domain [%s-%s] not found", project, domain) + } + return model, nil +} + +func NewWorkflowAttributesRepo(db *gorm.DB, errorTransformer errors.ErrorTransformer, + scope promutils.Scope) interfaces.WorkflowAttributesRepoInterface { + metrics := newMetrics(scope) + return &WorkflowAttributesRepo{ + db: db, + errorTransformer: errorTransformer, + metrics: metrics, + } +} diff --git a/pkg/repositories/gormimpl/workflow_attributes_repo_test.go b/pkg/repositories/gormimpl/workflow_attributes_repo_test.go new file mode 100644 index 0000000000..9d37cd7482 --- /dev/null +++ b/pkg/repositories/gormimpl/workflow_attributes_repo_test.go @@ -0,0 +1,62 @@ +package gormimpl + +import ( + "context" + "testing" + + mocket "github.com/Selvatico/go-mocket" + "github.com/lyft/flyteadmin/pkg/repositories/errors" + "github.com/lyft/flyteadmin/pkg/repositories/models" + mockScope "github.com/lyft/flytestdlib/promutils" + "github.com/stretchr/testify/assert" +) + +func TestCreateWorkflowAttributes(t *testing.T) { + workflowRepo := NewWorkflowAttributesRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + GlobalMock := mocket.Catcher.Reset() + GlobalMock.Logging = true + + query := GlobalMock.NewMock() + query.WithQuery( + `INSERT INTO "workflow_attributes" ("id","created_at","updated_at","deleted_at","project","domain",` + + `"workflow","resource","attributes") VALUES (?,?,?,?,?,?,?,?,?)`) + + err := workflowRepo.CreateOrUpdate(context.Background(), models.WorkflowAttributes{ + Project: "project", + Domain: "domain", + Workflow: "workflow", + Resource: "resource", + Attributes: []byte("attrs"), + }) + assert.NoError(t, err) + assert.True(t, query.Triggered) +} + +func TestGetWorkflowAttributes(t *testing.T) { + workflowRepo := NewWorkflowAttributesRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + GlobalMock := mocket.Catcher.Reset() + + response := make(map[string]interface{}) + response["project"] = "project" + response["domain"] = "domain" + response["workflow"] = "workflow" + response["resource"] = "resource" + response["attributes"] = []byte("attrs") + + query := GlobalMock.NewMock() + query.WithQuery(`SELECT * FROM "workflow_attributes" WHERE "workflow_attributes"."deleted_at" IS NULL AND` + + ` (("workflow_attributes"."project" = project) AND ("workflow_attributes"."domain" = domain) AND ` + + `("workflow_attributes"."workflow" = workflow) AND ("workflow_attributes"."resource" = resource)) ORDER BY ` + + `"workflow_attributes"."id" ASC LIMIT 1`).WithReply( + []map[string]interface{}{ + response, + }) + + output, err := workflowRepo.Get(context.Background(), "project", "domain", "workflow", "resource") + assert.Nil(t, err) + assert.Equal(t, "project", output.Project) + assert.Equal(t, "domain", output.Domain) + assert.Equal(t, "workflow", output.Workflow) + assert.Equal(t, "resource", output.Resource) + assert.Equal(t, []byte("attrs"), output.Attributes) +} diff --git a/pkg/repositories/interfaces/project_attributes_repo.go b/pkg/repositories/interfaces/project_attributes_repo.go new file mode 100644 index 0000000000..79c68213e0 --- /dev/null +++ b/pkg/repositories/interfaces/project_attributes_repo.go @@ -0,0 +1,14 @@ +package interfaces + +import ( + "context" + + "github.com/lyft/flyteadmin/pkg/repositories/models" +) + +type ProjectAttributesRepoInterface interface { + // Inserts or updates an existing ProjectAttributes model into the database store. + CreateOrUpdate(ctx context.Context, input models.ProjectAttributes) error + // Returns a matching ProjectAttributes model when it exists. + Get(ctx context.Context, project, resource string) (models.ProjectAttributes, error) +} diff --git a/pkg/repositories/interfaces/project_domain_attributes_repo.go b/pkg/repositories/interfaces/project_domain_attributes_repo.go new file mode 100644 index 0000000000..f63ebb8277 --- /dev/null +++ b/pkg/repositories/interfaces/project_domain_attributes_repo.go @@ -0,0 +1,14 @@ +package interfaces + +import ( + "context" + + "github.com/lyft/flyteadmin/pkg/repositories/models" +) + +type ProjectDomainAttributesRepoInterface interface { + // Inserts or updates an existing ProjectDomainAttributes model into the database store. + CreateOrUpdate(ctx context.Context, input models.ProjectDomainAttributes) error + // Returns a matching ProjectDomainAttributes model when it exists. + Get(ctx context.Context, project, domain, resource string) (models.ProjectDomainAttributes, error) +} diff --git a/pkg/repositories/interfaces/project_domain_repo.go b/pkg/repositories/interfaces/project_domain_repo.go deleted file mode 100644 index aab7813fef..0000000000 --- a/pkg/repositories/interfaces/project_domain_repo.go +++ /dev/null @@ -1,14 +0,0 @@ -package interfaces - -import ( - "context" - - "github.com/lyft/flyteadmin/pkg/repositories/models" -) - -type ProjectDomainRepoInterface interface { - // Inserts or updates an existing ProjectDomain model into the database store. - CreateOrUpdate(ctx context.Context, input models.ProjectDomain) error - // Returns a matching project when it exists. - Get(ctx context.Context, project, domain string) (models.ProjectDomain, error) -} diff --git a/pkg/repositories/interfaces/workflow_attributes_repo.go b/pkg/repositories/interfaces/workflow_attributes_repo.go new file mode 100644 index 0000000000..fdb61095ce --- /dev/null +++ b/pkg/repositories/interfaces/workflow_attributes_repo.go @@ -0,0 +1,14 @@ +package interfaces + +import ( + "context" + + "github.com/lyft/flyteadmin/pkg/repositories/models" +) + +type WorkflowAttributesRepoInterface interface { + // Inserts or updates an existing WorkflowAttributes model into the database store. + CreateOrUpdate(ctx context.Context, input models.WorkflowAttributes) error + // Returns a matching WorkflowAttributes model when it exists. + Get(ctx context.Context, project, domain, workflow, resource string) (models.WorkflowAttributes, error) +} diff --git a/pkg/repositories/mocks/project_attributes_repo.go b/pkg/repositories/mocks/project_attributes_repo.go new file mode 100644 index 0000000000..d7f4ebdaf5 --- /dev/null +++ b/pkg/repositories/mocks/project_attributes_repo.go @@ -0,0 +1,35 @@ +package mocks + +import ( + "context" + + "github.com/lyft/flyteadmin/pkg/repositories/interfaces" + "github.com/lyft/flyteadmin/pkg/repositories/models" +) + +type CreateOrUpdateProjectAttributesFunction func(ctx context.Context, input models.ProjectAttributes) error +type GetProjectAttributesFunction func(ctx context.Context, project, resource string) (models.ProjectAttributes, error) + +type MockProjectAttributesRepo struct { + CreateOrUpdateFunction CreateOrUpdateProjectAttributesFunction + GetFunction GetProjectAttributesFunction +} + +func (r *MockProjectAttributesRepo) CreateOrUpdate(ctx context.Context, input models.ProjectAttributes) error { + if r.CreateOrUpdateFunction != nil { + return r.CreateOrUpdateFunction(ctx, input) + } + return nil +} + +func (r *MockProjectAttributesRepo) Get(ctx context.Context, project, resource string) ( + models.ProjectAttributes, error) { + if r.GetFunction != nil { + return r.GetFunction(ctx, project, resource) + } + return models.ProjectAttributes{}, nil +} + +func NewMockProjectAttributesRepo() interfaces.ProjectAttributesRepoInterface { + return &MockProjectAttributesRepo{} +} diff --git a/pkg/repositories/mocks/project_domain_attributes_repo.go b/pkg/repositories/mocks/project_domain_attributes_repo.go new file mode 100644 index 0000000000..ae243d0094 --- /dev/null +++ b/pkg/repositories/mocks/project_domain_attributes_repo.go @@ -0,0 +1,35 @@ +package mocks + +import ( + "context" + + "github.com/lyft/flyteadmin/pkg/repositories/interfaces" + "github.com/lyft/flyteadmin/pkg/repositories/models" +) + +type CreateOrUpdateProjectDomainAttributesFunction func(ctx context.Context, input models.ProjectDomainAttributes) error +type GetProjectDomainAttributesFunction func(ctx context.Context, project, domain, resource string) (models.ProjectDomainAttributes, error) + +type MockProjectDomainAttributesRepo struct { + CreateOrUpdateFunction CreateOrUpdateProjectDomainAttributesFunction + GetFunction GetProjectDomainAttributesFunction +} + +func (r *MockProjectDomainAttributesRepo) CreateOrUpdate(ctx context.Context, input models.ProjectDomainAttributes) error { + if r.CreateOrUpdateFunction != nil { + return r.CreateOrUpdateFunction(ctx, input) + } + return nil +} + +func (r *MockProjectDomainAttributesRepo) Get(ctx context.Context, project, domain, resource string) ( + models.ProjectDomainAttributes, error) { + if r.GetFunction != nil { + return r.GetFunction(ctx, project, domain, resource) + } + return models.ProjectDomainAttributes{}, nil +} + +func NewMockProjectDomainAttributesRepo() interfaces.ProjectDomainAttributesRepoInterface { + return &MockProjectDomainAttributesRepo{} +} diff --git a/pkg/repositories/mocks/project_domain_repo.go b/pkg/repositories/mocks/project_domain_repo.go deleted file mode 100644 index 5784e479bc..0000000000 --- a/pkg/repositories/mocks/project_domain_repo.go +++ /dev/null @@ -1,35 +0,0 @@ -package mocks - -import ( - "context" - - "github.com/lyft/flyteadmin/pkg/repositories/interfaces" - "github.com/lyft/flyteadmin/pkg/repositories/models" -) - -type CreateOrUpdateProjectDomainFunction func(ctx context.Context, input models.ProjectDomain) error -type GetProjectDomainFunction func(ctx context.Context, project, domain string) (models.ProjectDomain, error) -type UpdateProjectDomainFunction func(ctx context.Context, input models.ProjectDomain) error - -type MockProjectDomainRepo struct { - CreateOrUpdateFunction CreateOrUpdateProjectDomainFunction - GetFunction GetProjectDomainFunction -} - -func (r *MockProjectDomainRepo) CreateOrUpdate(ctx context.Context, input models.ProjectDomain) error { - if r.CreateOrUpdateFunction != nil { - return r.CreateOrUpdateFunction(ctx, input) - } - return nil -} - -func (r *MockProjectDomainRepo) Get(ctx context.Context, project, domain string) (models.ProjectDomain, error) { - if r.GetFunction != nil { - return r.GetFunction(ctx, project, domain) - } - return models.ProjectDomain{}, nil -} - -func NewMockProjectDomainRepo() interfaces.ProjectDomainRepoInterface { - return &MockProjectDomainRepo{} -} diff --git a/pkg/repositories/mocks/repository.go b/pkg/repositories/mocks/repository.go index c10aa00d18..403dcbdf48 100644 --- a/pkg/repositories/mocks/repository.go +++ b/pkg/repositories/mocks/repository.go @@ -6,15 +6,17 @@ import ( ) type MockRepository struct { - taskRepo interfaces.TaskRepoInterface - workflowRepo interfaces.WorkflowRepoInterface - launchPlanRepo interfaces.LaunchPlanRepoInterface - executionRepo interfaces.ExecutionRepoInterface - nodeExecutionRepo interfaces.NodeExecutionRepoInterface - projectRepo interfaces.ProjectRepoInterface - projectDomainRepo interfaces.ProjectDomainRepoInterface - taskExecutionRepo interfaces.TaskExecutionRepoInterface - namedEntityRepo interfaces.NamedEntityRepoInterface + taskRepo interfaces.TaskRepoInterface + workflowRepo interfaces.WorkflowRepoInterface + launchPlanRepo interfaces.LaunchPlanRepoInterface + executionRepo interfaces.ExecutionRepoInterface + nodeExecutionRepo interfaces.NodeExecutionRepoInterface + projectRepo interfaces.ProjectRepoInterface + projectAttributesRepo interfaces.ProjectAttributesRepoInterface + projectDomainAttributesRepo interfaces.ProjectDomainAttributesRepoInterface + workflowAttributesRepo interfaces.WorkflowAttributesRepoInterface + taskExecutionRepo interfaces.TaskExecutionRepoInterface + namedEntityRepo interfaces.NamedEntityRepoInterface } func (r *MockRepository) TaskRepo() interfaces.TaskRepoInterface { @@ -41,8 +43,16 @@ func (r *MockRepository) ProjectRepo() interfaces.ProjectRepoInterface { return r.projectRepo } -func (r *MockRepository) ProjectDomainRepo() interfaces.ProjectDomainRepoInterface { - return r.projectDomainRepo +func (r *MockRepository) ProjectDomainAttributesRepo() interfaces.ProjectDomainAttributesRepoInterface { + return r.projectDomainAttributesRepo +} + +func (r *MockRepository) WorkflowAttributesRepo() interfaces.WorkflowAttributesRepoInterface { + return r.workflowAttributesRepo +} + +func (r *MockRepository) ProjectAttributesRepo() interfaces.ProjectAttributesRepoInterface { + return r.projectAttributesRepo } func (r *MockRepository) TaskExecutionRepo() interfaces.TaskExecutionRepoInterface { @@ -55,14 +65,16 @@ func (r *MockRepository) NamedEntityRepo() interfaces.NamedEntityRepoInterface { func NewMockRepository() repositories.RepositoryInterface { return &MockRepository{ - taskRepo: NewMockTaskRepo(), - workflowRepo: NewMockWorkflowRepo(), - launchPlanRepo: NewMockLaunchPlanRepo(), - executionRepo: NewMockExecutionRepo(), - nodeExecutionRepo: NewMockNodeExecutionRepo(), - projectRepo: NewMockProjectRepo(), - projectDomainRepo: NewMockProjectDomainRepo(), - taskExecutionRepo: NewMockTaskExecutionRepo(), - namedEntityRepo: NewMockNamedEntityRepo(), + taskRepo: NewMockTaskRepo(), + workflowRepo: NewMockWorkflowRepo(), + launchPlanRepo: NewMockLaunchPlanRepo(), + executionRepo: NewMockExecutionRepo(), + nodeExecutionRepo: NewMockNodeExecutionRepo(), + projectRepo: NewMockProjectRepo(), + projectAttributesRepo: NewMockProjectAttributesRepo(), + projectDomainAttributesRepo: NewMockProjectDomainAttributesRepo(), + workflowAttributesRepo: NewMockWorkflowAttributesRepo(), + taskExecutionRepo: NewMockTaskExecutionRepo(), + namedEntityRepo: NewMockNamedEntityRepo(), } } diff --git a/pkg/repositories/mocks/workflow_attributes.go b/pkg/repositories/mocks/workflow_attributes.go new file mode 100644 index 0000000000..ede2e77728 --- /dev/null +++ b/pkg/repositories/mocks/workflow_attributes.go @@ -0,0 +1,37 @@ +package mocks + +import ( + "context" + + "github.com/lyft/flyteadmin/pkg/repositories/interfaces" + "github.com/lyft/flyteadmin/pkg/repositories/models" +) + +type CreateOrUpdateWorkflowAttributesFunction func(ctx context.Context, input models.WorkflowAttributes) error +type GetWorkflowAttributesFunction func(ctx context.Context, project, domain, workflow, resource string) ( + models.WorkflowAttributes, error) +type UpdateWorkflowAttributesFunction func(ctx context.Context, input models.WorkflowAttributes) error + +type MockWorkflowAttributesRepo struct { + CreateOrUpdateFunction CreateOrUpdateWorkflowAttributesFunction + GetFunction GetWorkflowAttributesFunction +} + +func (r *MockWorkflowAttributesRepo) CreateOrUpdate(ctx context.Context, input models.WorkflowAttributes) error { + if r.CreateOrUpdateFunction != nil { + return r.CreateOrUpdateFunction(ctx, input) + } + return nil +} + +func (r *MockWorkflowAttributesRepo) Get(ctx context.Context, project, domain, workflow, resource string) ( + models.WorkflowAttributes, error) { + if r.GetFunction != nil { + return r.GetFunction(ctx, project, domain, workflow, resource) + } + return models.WorkflowAttributes{}, nil +} + +func NewMockWorkflowAttributesRepo() interfaces.WorkflowAttributesRepoInterface { + return &MockWorkflowAttributesRepo{} +} diff --git a/pkg/repositories/models/project_attributes.go b/pkg/repositories/models/project_attributes.go new file mode 100644 index 0000000000..e9a5cc87d8 --- /dev/null +++ b/pkg/repositories/models/project_attributes.go @@ -0,0 +1,10 @@ +package models + +// Represents project-domain customizable configuration. +type ProjectAttributes struct { + BaseModel + Project string `gorm:"primary_key"` + Resource string `gorm:"primary_key"` + // Serialized flyteidl.admin.MatchingAttributes. + Attributes []byte +} diff --git a/pkg/repositories/models/project_domain.go b/pkg/repositories/models/project_domain.go deleted file mode 100644 index 00266f4b15..0000000000 --- a/pkg/repositories/models/project_domain.go +++ /dev/null @@ -1,10 +0,0 @@ -package models - -// Represents project-domain customizable configuration. -type ProjectDomain struct { - BaseModel - Project string `gorm:"primary_key"` - Domain string `gorm:"primary_key"` - // Key-value pairs of substitutable resource attributes. - Attributes []byte -} diff --git a/pkg/repositories/models/project_domain_attributes.go b/pkg/repositories/models/project_domain_attributes.go new file mode 100644 index 0000000000..8674121c77 --- /dev/null +++ b/pkg/repositories/models/project_domain_attributes.go @@ -0,0 +1,11 @@ +package models + +// Represents project-domain customizable configuration. +type ProjectDomainAttributes struct { + BaseModel + Project string `gorm:"primary_key"` + Domain string `gorm:"primary_key"` + Resource string `gorm:"primary_key"` + // Serialized flyteidl.admin.MatchingAttributes. + Attributes []byte +} diff --git a/pkg/repositories/models/workflow_attributes.go b/pkg/repositories/models/workflow_attributes.go new file mode 100644 index 0000000000..4d9afd504a --- /dev/null +++ b/pkg/repositories/models/workflow_attributes.go @@ -0,0 +1,12 @@ +package models + +// Represents project-domain customizable configuration. +type WorkflowAttributes struct { + BaseModel + Project string `gorm:"primary_key"` + Domain string `gorm:"primary_key"` + Workflow string `gorm:"primary_key"` + Resource string `gorm:"primary_key"` + // Serialized flyteidl.admin.MatchingAttributes. + Attributes []byte +} diff --git a/pkg/repositories/postgres_repo.go b/pkg/repositories/postgres_repo.go index 189cde7980..f7efc94d50 100644 --- a/pkg/repositories/postgres_repo.go +++ b/pkg/repositories/postgres_repo.go @@ -9,15 +9,17 @@ import ( ) type PostgresRepo struct { - executionRepo interfaces.ExecutionRepoInterface - namedEntityRepo interfaces.NamedEntityRepoInterface - launchPlanRepo interfaces.LaunchPlanRepoInterface - projectRepo interfaces.ProjectRepoInterface - projectDomainRepo interfaces.ProjectDomainRepoInterface - nodeExecutionRepo interfaces.NodeExecutionRepoInterface - taskRepo interfaces.TaskRepoInterface - taskExecutionRepo interfaces.TaskExecutionRepoInterface - workflowRepo interfaces.WorkflowRepoInterface + executionRepo interfaces.ExecutionRepoInterface + namedEntityRepo interfaces.NamedEntityRepoInterface + launchPlanRepo interfaces.LaunchPlanRepoInterface + projectRepo interfaces.ProjectRepoInterface + projectAttributesRepo interfaces.ProjectAttributesRepoInterface + projectDomainAttributesRepo interfaces.ProjectDomainAttributesRepoInterface + nodeExecutionRepo interfaces.NodeExecutionRepoInterface + taskRepo interfaces.TaskRepoInterface + taskExecutionRepo interfaces.TaskExecutionRepoInterface + workflowRepo interfaces.WorkflowRepoInterface + workflowAttributesRepo interfaces.WorkflowAttributesRepoInterface } func (p *PostgresRepo) ExecutionRepo() interfaces.ExecutionRepoInterface { @@ -36,8 +38,12 @@ func (p *PostgresRepo) ProjectRepo() interfaces.ProjectRepoInterface { return p.projectRepo } -func (p *PostgresRepo) ProjectDomainRepo() interfaces.ProjectDomainRepoInterface { - return p.projectDomainRepo +func (p *PostgresRepo) ProjectAttributesRepo() interfaces.ProjectAttributesRepoInterface { + return p.projectAttributesRepo +} + +func (p *PostgresRepo) ProjectDomainAttributesRepo() interfaces.ProjectDomainAttributesRepoInterface { + return p.projectDomainAttributesRepo } func (p *PostgresRepo) NodeExecutionRepo() interfaces.NodeExecutionRepoInterface { @@ -56,16 +62,22 @@ func (p *PostgresRepo) WorkflowRepo() interfaces.WorkflowRepoInterface { return p.workflowRepo } +func (p *PostgresRepo) WorkflowAttributesRepo() interfaces.WorkflowAttributesRepoInterface { + return p.workflowAttributesRepo +} + func NewPostgresRepo(db *gorm.DB, errorTransformer errors.ErrorTransformer, scope promutils.Scope) RepositoryInterface { return &PostgresRepo{ - executionRepo: gormimpl.NewExecutionRepo(db, errorTransformer, scope.NewSubScope("executions")), - launchPlanRepo: gormimpl.NewLaunchPlanRepo(db, errorTransformer, scope.NewSubScope("launch_plans")), - projectRepo: gormimpl.NewProjectRepo(db, errorTransformer, scope.NewSubScope("project")), - projectDomainRepo: gormimpl.NewProjectDomainRepo(db, errorTransformer, scope.NewSubScope("project_domain")), - namedEntityRepo: gormimpl.NewNamedEntityRepo(db, errorTransformer, scope.NewSubScope("named_entity")), - nodeExecutionRepo: gormimpl.NewNodeExecutionRepo(db, errorTransformer, scope.NewSubScope("node_executions")), - taskRepo: gormimpl.NewTaskRepo(db, errorTransformer, scope.NewSubScope("tasks")), - taskExecutionRepo: gormimpl.NewTaskExecutionRepo(db, errorTransformer, scope.NewSubScope("task_executions")), - workflowRepo: gormimpl.NewWorkflowRepo(db, errorTransformer, scope.NewSubScope("workflows")), + executionRepo: gormimpl.NewExecutionRepo(db, errorTransformer, scope.NewSubScope("executions")), + launchPlanRepo: gormimpl.NewLaunchPlanRepo(db, errorTransformer, scope.NewSubScope("launch_plans")), + projectRepo: gormimpl.NewProjectRepo(db, errorTransformer, scope.NewSubScope("project")), + projectAttributesRepo: gormimpl.NewProjectAttributesRepo(db, errorTransformer, scope.NewSubScope("project_attrs")), + projectDomainAttributesRepo: gormimpl.NewProjectDomainAttributesRepo(db, errorTransformer, scope.NewSubScope("project_domain_attrs")), + namedEntityRepo: gormimpl.NewNamedEntityRepo(db, errorTransformer, scope.NewSubScope("named_entity")), + nodeExecutionRepo: gormimpl.NewNodeExecutionRepo(db, errorTransformer, scope.NewSubScope("node_executions")), + taskRepo: gormimpl.NewTaskRepo(db, errorTransformer, scope.NewSubScope("tasks")), + taskExecutionRepo: gormimpl.NewTaskExecutionRepo(db, errorTransformer, scope.NewSubScope("task_executions")), + workflowRepo: gormimpl.NewWorkflowRepo(db, errorTransformer, scope.NewSubScope("workflows")), + workflowAttributesRepo: gormimpl.NewWorkflowAttributesRepo(db, errorTransformer, scope.NewSubScope("workflow_attrs")), } } diff --git a/pkg/repositories/transformers/project_attributes.go b/pkg/repositories/transformers/project_attributes.go new file mode 100644 index 0000000000..b628ba3189 --- /dev/null +++ b/pkg/repositories/transformers/project_attributes.go @@ -0,0 +1,35 @@ +package transformers + +import ( + "github.com/golang/protobuf/proto" + + "github.com/lyft/flyteadmin/pkg/errors" + "github.com/lyft/flyteadmin/pkg/repositories/models" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "google.golang.org/grpc/codes" +) + +func ToProjectAttributesModel(attributes admin.ProjectAttributes, resource admin.MatchableResource) (models.ProjectAttributes, error) { + attributeBytes, err := proto.Marshal(attributes.MatchingAttributes) + if err != nil { + return models.ProjectAttributes{}, err + } + return models.ProjectAttributes{ + Project: attributes.Project, + Resource: resource.String(), + Attributes: attributeBytes, + }, nil +} + +func FromProjectAttributesModel(model models.ProjectAttributes) (admin.ProjectAttributes, error) { + var attributes admin.MatchingAttributes + err := proto.Unmarshal(model.Attributes, &attributes) + if err != nil { + return admin.ProjectAttributes{}, errors.NewFlyteAdminErrorf( + codes.Internal, "Failed to decode project domain resource projectDomainAttributes with err: %v", err) + } + return admin.ProjectAttributes{ + Project: model.Project, + MatchingAttributes: &attributes, + }, nil +} diff --git a/pkg/repositories/transformers/project_attributes_test.go b/pkg/repositories/transformers/project_attributes_test.go new file mode 100644 index 0000000000..65a78395ab --- /dev/null +++ b/pkg/repositories/transformers/project_attributes_test.go @@ -0,0 +1,63 @@ +package transformers + +import ( + "testing" + + "github.com/golang/protobuf/proto" + + "github.com/lyft/flyteadmin/pkg/errors" + "github.com/lyft/flyteadmin/pkg/repositories/models" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "google.golang.org/grpc/codes" + + "github.com/stretchr/testify/assert" +) + +var matchingTaskResourceAttributes = &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_TaskResourceAttributes{ + TaskResourceAttributes: &admin.TaskResourceAttributes{ + Defaults: &admin.TaskResourceSpec{ + Cpu: "1", + }, + }, + }, +} + +var projectAttributes = admin.ProjectAttributes{ + Project: "project", + MatchingAttributes: matchingTaskResourceAttributes, +} + +var marshalledAttributes, _ = proto.Marshal(matchingTaskResourceAttributes) + +func TestToProjectAttributesModel(t *testing.T) { + model, err := ToProjectAttributesModel(projectAttributes, admin.MatchableResource_TASK_RESOURCE) + assert.Nil(t, err) + assert.EqualValues(t, models.ProjectAttributes{ + Project: "project", + Resource: admin.MatchableResource_TASK_RESOURCE.String(), + Attributes: marshalledAttributes, + }, model) +} + +func TestFromProjectAttributesModel(t *testing.T) { + model := models.ProjectAttributes{ + Project: "project", + Resource: admin.MatchableResource_TASK_RESOURCE.String(), + Attributes: marshalledAttributes, + } + unmarshalledAttributes, err := FromProjectAttributesModel(model) + assert.Nil(t, err) + assert.True(t, proto.Equal(&projectAttributes, &unmarshalledAttributes)) +} + +func TestFromProjectAttributesModel_InvalidResourceAttributes(t *testing.T) { + model := models.ProjectAttributes{ + Project: "project", + Resource: admin.MatchableResource_TASK_RESOURCE.String(), + Attributes: []byte("i'm invalid!"), + } + _, err := FromProjectAttributesModel(model) + assert.NotNil(t, err) + assert.Equal(t, codes.Internal, err.(errors.FlyteAdminError).Code()) +} diff --git a/pkg/repositories/transformers/project_domain.go b/pkg/repositories/transformers/project_domain.go deleted file mode 100644 index 8cc63aee02..0000000000 --- a/pkg/repositories/transformers/project_domain.go +++ /dev/null @@ -1,32 +0,0 @@ -package transformers - -import ( - "github.com/golang/protobuf/proto" - - "github.com/lyft/flyteadmin/pkg/errors" - "github.com/lyft/flyteadmin/pkg/repositories/models" - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" - "google.golang.org/grpc/codes" -) - -func ToProjectDomainModel(attributes admin.ProjectDomainAttributes) (models.ProjectDomain, error) { - attributeBytes, err := proto.Marshal(&attributes) - if err != nil { - return models.ProjectDomain{}, err - } - return models.ProjectDomain{ - Project: attributes.Project, - Domain: attributes.Domain, - Attributes: attributeBytes, - }, nil -} - -func FromProjectDomainModel(model models.ProjectDomain) (admin.ProjectDomainAttributes, error) { - var attributes admin.ProjectDomainAttributes - err := proto.Unmarshal(model.Attributes, &attributes) - if err != nil { - return admin.ProjectDomainAttributes{}, errors.NewFlyteAdminErrorf( - codes.Internal, "Failed to decode project domain resource attributes with err: %v", err) - } - return attributes, nil -} diff --git a/pkg/repositories/transformers/project_domain_attributes.go b/pkg/repositories/transformers/project_domain_attributes.go new file mode 100644 index 0000000000..ccedf335c9 --- /dev/null +++ b/pkg/repositories/transformers/project_domain_attributes.go @@ -0,0 +1,37 @@ +package transformers + +import ( + "github.com/golang/protobuf/proto" + + "github.com/lyft/flyteadmin/pkg/errors" + "github.com/lyft/flyteadmin/pkg/repositories/models" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "google.golang.org/grpc/codes" +) + +func ToProjectDomainAttributesModel(attributes admin.ProjectDomainAttributes, resource admin.MatchableResource) (models.ProjectDomainAttributes, error) { + attributeBytes, err := proto.Marshal(attributes.MatchingAttributes) + if err != nil { + return models.ProjectDomainAttributes{}, err + } + return models.ProjectDomainAttributes{ + Project: attributes.Project, + Domain: attributes.Domain, + Resource: resource.String(), + Attributes: attributeBytes, + }, nil +} + +func FromProjectDomainAttributesModel(model models.ProjectDomainAttributes) (admin.ProjectDomainAttributes, error) { + var attributes admin.MatchingAttributes + err := proto.Unmarshal(model.Attributes, &attributes) + if err != nil { + return admin.ProjectDomainAttributes{}, errors.NewFlyteAdminErrorf( + codes.Internal, "Failed to decode project domain resource projectDomainAttributes with err: %v", err) + } + return admin.ProjectDomainAttributes{ + Project: model.Project, + Domain: model.Domain, + MatchingAttributes: &attributes, + }, nil +} diff --git a/pkg/repositories/transformers/project_domain_attributes_test.go b/pkg/repositories/transformers/project_domain_attributes_test.go new file mode 100644 index 0000000000..f4c42cd929 --- /dev/null +++ b/pkg/repositories/transformers/project_domain_attributes_test.go @@ -0,0 +1,68 @@ +package transformers + +import ( + "testing" + + "github.com/golang/protobuf/proto" + + "github.com/lyft/flyteadmin/pkg/errors" + "github.com/lyft/flyteadmin/pkg/repositories/models" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "google.golang.org/grpc/codes" + + "github.com/stretchr/testify/assert" +) + +var matchingExecutionQueueAttributes = &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_ExecutionQueueAttributes{ + ExecutionQueueAttributes: &admin.ExecutionQueueAttributes{ + Tags: []string{ + "foo", + }, + }, + }, +} + +var projectDomainAttributes = admin.ProjectDomainAttributes{ + Project: "project", + Domain: "domain", + MatchingAttributes: matchingExecutionQueueAttributes, +} + +var marshalledExecutionQueueAttributes, _ = proto.Marshal(matchingExecutionQueueAttributes) + +func TestToProjectDomainAttributesModel(t *testing.T) { + + model, err := ToProjectDomainAttributesModel(projectDomainAttributes, admin.MatchableResource_EXECUTION_QUEUE) + assert.Nil(t, err) + assert.EqualValues(t, models.ProjectDomainAttributes{ + Project: "project", + Domain: "domain", + Resource: admin.MatchableResource_EXECUTION_QUEUE.String(), + Attributes: marshalledExecutionQueueAttributes, + }, model) +} + +func TestFromProjectDomainAttributesModel(t *testing.T) { + model := models.ProjectDomainAttributes{ + Project: "project", + Domain: "domain", + Resource: admin.MatchableResource_EXECUTION_QUEUE.String(), + Attributes: marshalledExecutionQueueAttributes, + } + unmarshalledAttributes, err := FromProjectDomainAttributesModel(model) + assert.Nil(t, err) + assert.True(t, proto.Equal(&projectDomainAttributes, &unmarshalledAttributes)) +} + +func TestFromProjectDomainAttributesModel_InvalidResourceAttributes(t *testing.T) { + model := models.ProjectDomainAttributes{ + Project: "project", + Domain: "domain", + Resource: admin.MatchableResource_EXECUTION_QUEUE.String(), + Attributes: []byte("i'm invalid!"), + } + _, err := FromProjectDomainAttributesModel(model) + assert.NotNil(t, err) + assert.Equal(t, codes.Internal, err.(errors.FlyteAdminError).Code()) +} diff --git a/pkg/repositories/transformers/project_domain_test.go b/pkg/repositories/transformers/project_domain_test.go deleted file mode 100644 index 6975c5f631..0000000000 --- a/pkg/repositories/transformers/project_domain_test.go +++ /dev/null @@ -1,57 +0,0 @@ -package transformers - -import ( - "testing" - - "github.com/golang/protobuf/proto" - - "github.com/lyft/flyteadmin/pkg/errors" - "github.com/lyft/flyteadmin/pkg/repositories/models" - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" - "google.golang.org/grpc/codes" - - "github.com/stretchr/testify/assert" -) - -var attributes = admin.ProjectDomainAttributes{ - Project: "project", - Domain: "domain", - Attributes: map[string]string{ - "cpu": "100", - }, -} - -var marshalledAttributes, _ = proto.Marshal(&attributes) - -func TestToProjectDomainModel(t *testing.T) { - - model, err := ToProjectDomainModel(attributes) - assert.Nil(t, err) - assert.EqualValues(t, models.ProjectDomain{ - Project: "project", - Domain: "domain", - Attributes: marshalledAttributes, - }, model) -} - -func TestFromProjectDomainModel(t *testing.T) { - model := models.ProjectDomain{ - Project: "project", - Domain: "domain", - Attributes: marshalledAttributes, - } - unmarshalledAttributes, err := FromProjectDomainModel(model) - assert.Nil(t, err) - assert.True(t, proto.Equal(&attributes, &unmarshalledAttributes)) -} - -func TestFromProjectDomainModel_InvalidResourceAttributes(t *testing.T) { - model := models.ProjectDomain{ - Project: "project", - Domain: "domain", - Attributes: []byte("i'm invalid!"), - } - _, err := FromProjectDomainModel(model) - assert.NotNil(t, err) - assert.Equal(t, codes.Internal, err.(errors.FlyteAdminError).Code()) -} diff --git a/pkg/repositories/transformers/workflow_attributes.go b/pkg/repositories/transformers/workflow_attributes.go new file mode 100644 index 0000000000..34758c22e3 --- /dev/null +++ b/pkg/repositories/transformers/workflow_attributes.go @@ -0,0 +1,39 @@ +package transformers + +import ( + "github.com/golang/protobuf/proto" + + "github.com/lyft/flyteadmin/pkg/errors" + "github.com/lyft/flyteadmin/pkg/repositories/models" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "google.golang.org/grpc/codes" +) + +func ToWorkflowAttributesModel(attributes admin.WorkflowAttributes, resource admin.MatchableResource) (models.WorkflowAttributes, error) { + attributeBytes, err := proto.Marshal(attributes.MatchingAttributes) + if err != nil { + return models.WorkflowAttributes{}, err + } + return models.WorkflowAttributes{ + Project: attributes.Project, + Domain: attributes.Domain, + Workflow: attributes.Workflow, + Resource: resource.String(), + Attributes: attributeBytes, + }, nil +} + +func FromWorkflowAttributesModel(model models.WorkflowAttributes) (admin.WorkflowAttributes, error) { + var attributes admin.MatchingAttributes + err := proto.Unmarshal(model.Attributes, &attributes) + if err != nil { + return admin.WorkflowAttributes{}, errors.NewFlyteAdminErrorf( + codes.Internal, "Failed to decode project domain resource projectDomainAttributes with err: %v", err) + } + return admin.WorkflowAttributes{ + Project: model.Project, + Domain: model.Domain, + Workflow: model.Workflow, + MatchingAttributes: &attributes, + }, nil +} diff --git a/pkg/repositories/transformers/workflow_attributes_test.go b/pkg/repositories/transformers/workflow_attributes_test.go new file mode 100644 index 0000000000..b23b6f1885 --- /dev/null +++ b/pkg/repositories/transformers/workflow_attributes_test.go @@ -0,0 +1,71 @@ +package transformers + +import ( + "testing" + + "github.com/golang/protobuf/proto" + + "github.com/lyft/flyteadmin/pkg/errors" + "github.com/lyft/flyteadmin/pkg/repositories/models" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "google.golang.org/grpc/codes" + + "github.com/stretchr/testify/assert" +) + +var matchingClusterResourceAttributes = &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_ClusterResourceAttributes{ + ClusterResourceAttributes: &admin.ClusterResourceAttributes{ + Attributes: map[string]string{ + "foo": "bar", + }, + }, + }, +} + +var workflowAttributes = admin.WorkflowAttributes{ + Project: "project", + Domain: "domain", + Workflow: "workflow", + MatchingAttributes: matchingClusterResourceAttributes, +} + +var marshalledClusterResourceAttributes, _ = proto.Marshal(matchingClusterResourceAttributes) + +func TestToWorkflowAttributesModel(t *testing.T) { + model, err := ToWorkflowAttributesModel(workflowAttributes, admin.MatchableResource_EXECUTION_QUEUE) + assert.Nil(t, err) + assert.EqualValues(t, models.WorkflowAttributes{ + Project: "project", + Domain: "domain", + Workflow: "workflow", + Resource: admin.MatchableResource_EXECUTION_QUEUE.String(), + Attributes: marshalledClusterResourceAttributes, + }, model) +} + +func TestFromWorkflowAttributesModel(t *testing.T) { + model := models.WorkflowAttributes{ + Project: "project", + Domain: "domain", + Workflow: "workflow", + Resource: admin.MatchableResource_EXECUTION_QUEUE.String(), + Attributes: marshalledClusterResourceAttributes, + } + unmarshalledAttributes, err := FromWorkflowAttributesModel(model) + assert.Nil(t, err) + assert.True(t, proto.Equal(&workflowAttributes, &unmarshalledAttributes)) +} + +func TestFromWorkflowAttributesModel_InvalidResourceAttributes(t *testing.T) { + model := models.WorkflowAttributes{ + Project: "project", + Domain: "domain", + Workflow: "workflow", + Resource: admin.MatchableResource_EXECUTION_QUEUE.String(), + Attributes: []byte("i'm invalid!"), + } + _, err := FromWorkflowAttributesModel(model) + assert.NotNil(t, err) + assert.Equal(t, codes.Internal, err.(errors.FlyteAdminError).Code()) +} diff --git a/pkg/resourcematching/overrides.go b/pkg/resourcematching/overrides.go new file mode 100644 index 0000000000..06aba4f0fc --- /dev/null +++ b/pkg/resourcematching/overrides.go @@ -0,0 +1,77 @@ +package resourcematching + +import ( + "context" + + "github.com/lyft/flyteadmin/pkg/errors" + "github.com/lyft/flyteadmin/pkg/repositories" + "github.com/lyft/flyteadmin/pkg/repositories/transformers" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "google.golang.org/grpc/codes" +) + +type GetOverrideValuesInput struct { + Db repositories.RepositoryInterface + Project string + Domain string + Workflow string + Resource admin.MatchableResource +} + +func isNotFoundErr(err error) bool { + return err.(errors.FlyteAdminError) != nil && err.(errors.FlyteAdminError).Code() == codes.NotFound +} + +func GetOverrideValuesToApply(ctx context.Context, input GetOverrideValuesInput) ( + *admin.MatchingAttributes, error) { + if len(input.Project) == 0 || len(input.Domain) == 0 { + return nil, errors.NewFlyteAdminErrorf( + codes.InvalidArgument, "Invalid overrides values request configuration: [%+v]", input) + } + if len(input.Workflow) > 0 { + // Only the workflow input argument is optional + workflowAttributesModel, err := input.Db.WorkflowAttributesRepo().Get( + ctx, input.Project, input.Domain, input.Workflow, input.Resource.String()) + if err != nil && !isNotFoundErr(err) { + // Not found is fine, since not every workflow will necessarily have resource overrides. + // Any other error should be bubbled back up. + return nil, err + } else if err == nil { + workflowAttributes, err := transformers.FromWorkflowAttributesModel(workflowAttributesModel) + if err != nil { + return nil, err + } + return workflowAttributes.MatchingAttributes, nil + } + } + + projectDomainAttributesModel, err := input.Db.ProjectDomainAttributesRepo().Get( + ctx, input.Project, input.Domain, input.Resource.String()) + if err != nil && !isNotFoundErr(err) { + // Not found is fine, since not every project+domain will necessarily have resource overrides. + // Any other error should be bubbled back up. + return nil, err + } else if err == nil { + projectDomainAttributes, err := transformers.FromProjectDomainAttributesModel(projectDomainAttributesModel) + if err != nil { + return nil, err + } + return projectDomainAttributes.MatchingAttributes, nil + } + + projectAttributesModel, err := input.Db.ProjectAttributesRepo().Get(ctx, input.Project, input.Resource.String()) + if err != nil && !isNotFoundErr(err) { + // Not found is fine, since not every project will necessarily have resource overrides. + // Any other error should be bubbled back up. + return nil, err + } else if err == nil { + projectAttributes, err := transformers.FromProjectAttributesModel(projectAttributesModel) + if err != nil { + return nil, err + } + return projectAttributes.MatchingAttributes, nil + } + + // If we've made it this far then there are no matching overrides. + return nil, nil +} diff --git a/pkg/resourcematching/overrides_test.go b/pkg/resourcematching/overrides_test.go new file mode 100644 index 0000000000..0fa8c038b5 --- /dev/null +++ b/pkg/resourcematching/overrides_test.go @@ -0,0 +1,164 @@ +package resourcematching + +import ( + "context" + "fmt" + "testing" + + "github.com/golang/protobuf/proto" + "github.com/lyft/flyteadmin/pkg/errors" + "github.com/lyft/flyteadmin/pkg/repositories/mocks" + "github.com/lyft/flyteadmin/pkg/repositories/models" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/codes" +) + +const testProject = "project" +const testDomain = "domain" +const testWorkflow = "workflow" + +func TestIsNotFoundErr(t *testing.T) { + isNotFound := errors.NewFlyteAdminError(codes.NotFound, "foo") + assert.True(t, isNotFoundErr(isNotFound)) + + invalidArgs := errors.NewFlyteAdminErrorf(codes.InvalidArgument, "bar") + assert.False(t, isNotFoundErr(invalidArgs)) +} + +func TestGetOverrideValuesToApply(t *testing.T) { + db := mocks.NewMockRepository() + matchingWorkflowAttributes := &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_ExecutionQueueAttributes{ + ExecutionQueueAttributes: &admin.ExecutionQueueAttributes{ + Tags: []string{"attr3"}, + }, + }, + } + db.WorkflowAttributesRepo().(*mocks.MockWorkflowAttributesRepo).GetFunction = func( + ctx context.Context, project, domain, workflow, resource string) ( + models.WorkflowAttributes, error) { + if project == testProject && domain == testDomain && workflow == testWorkflow && + resource == admin.MatchableResource_EXECUTION_QUEUE.String() { + + marshalledMatchingAttributes, _ := proto.Marshal(matchingWorkflowAttributes) + return models.WorkflowAttributes{ + Project: project, + Domain: domain, + Workflow: workflow, + Resource: resource, + Attributes: marshalledMatchingAttributes, + }, nil + } + if workflow == "error" { + return models.WorkflowAttributes{}, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "bar") + } + return models.WorkflowAttributes{}, errors.NewFlyteAdminError(codes.NotFound, "foo") + } + matchingProjectDomainAttributes := &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_ExecutionQueueAttributes{ + ExecutionQueueAttributes: &admin.ExecutionQueueAttributes{ + Tags: []string{"attr2"}, + }, + }, + } + db.ProjectDomainAttributesRepo().(*mocks.MockProjectDomainAttributesRepo).GetFunction = func( + ctx context.Context, project, domain, resource string) (models.ProjectDomainAttributes, error) { + if project == testProject && domain == testDomain && resource == admin.MatchableResource_EXECUTION_QUEUE.String() { + marshalledMatchingAttributes, _ := proto.Marshal(matchingProjectDomainAttributes) + return models.ProjectDomainAttributes{ + Project: project, + Domain: domain, + Resource: resource, + Attributes: marshalledMatchingAttributes, + }, nil + } + return models.ProjectDomainAttributes{}, errors.NewFlyteAdminError(codes.NotFound, "foo") + } + matchingProjectAttributes := &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_ExecutionQueueAttributes{ + ExecutionQueueAttributes: &admin.ExecutionQueueAttributes{ + Tags: []string{"attr1"}, + }, + }, + } + db.ProjectAttributesRepo().(*mocks.MockProjectAttributesRepo).GetFunction = func( + ctx context.Context, project, resource string) (models.ProjectAttributes, error) { + if project == testProject && resource == admin.MatchableResource_EXECUTION_QUEUE.String() { + marshalledMatchingAttributes, _ := proto.Marshal(matchingProjectAttributes) + return models.ProjectAttributes{ + Project: project, + Resource: resource, + Attributes: marshalledMatchingAttributes, + }, nil + } + return models.ProjectAttributes{}, errors.NewFlyteAdminError(codes.NotFound, "foo") + } + + testCases := []struct { + input GetOverrideValuesInput + expectedMatchingAttributes *admin.MatchingAttributes + expectedErr error + }{ + { + GetOverrideValuesInput{ + Db: db, + Project: "project", + Domain: "domain", + Workflow: "workflow", + Resource: admin.MatchableResource_EXECUTION_QUEUE, + }, + matchingWorkflowAttributes, + nil, + }, + { + GetOverrideValuesInput{ + Db: db, + Project: "project", + Domain: "domain", + Workflow: "workflow2", + Resource: admin.MatchableResource_EXECUTION_QUEUE, + }, + matchingProjectDomainAttributes, + nil, + }, + { + GetOverrideValuesInput{ + Db: db, + Project: "project", + Domain: "domain2", + Workflow: "workflow", + Resource: admin.MatchableResource_EXECUTION_QUEUE, + }, + matchingProjectAttributes, + nil, + }, + { + GetOverrideValuesInput{ + Db: db, + Project: "project2", + Domain: "domain", + Workflow: "workflow", + Resource: admin.MatchableResource_EXECUTION_QUEUE, + }, + nil, + nil, + }, + {GetOverrideValuesInput{ + Db: db, + Project: "project", + Domain: "domain", + Workflow: "error", + Resource: admin.MatchableResource_EXECUTION_QUEUE, + }, + nil, + errors.NewFlyteAdminErrorf(codes.InvalidArgument, "bar"), + }, + } + for _, tc := range testCases { + matchingAttributes, err := GetOverrideValuesToApply(context.Background(), tc.input) + assert.True(t, proto.Equal(tc.expectedMatchingAttributes, matchingAttributes), + fmt.Sprintf("invalid value for [%+v]", tc.input)) + assert.EqualValues(t, tc.expectedErr, err) + } +} diff --git a/pkg/rpc/adminservice/attributes.go b/pkg/rpc/adminservice/attributes.go new file mode 100644 index 0000000000..c4fd131401 --- /dev/null +++ b/pkg/rpc/adminservice/attributes.go @@ -0,0 +1,64 @@ +package adminservice + +import ( + "context" + + "github.com/lyft/flyteadmin/pkg/rpc/adminservice/util" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func (m *AdminService) UpdateWorkflowAttributes(ctx context.Context, request *admin.WorkflowAttributesUpdateRequest) ( + *admin.WorkflowAttributesUpdateResponse, error) { + defer m.interceptPanic(ctx, request) + if request == nil { + return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") + } + var response *admin.WorkflowAttributesUpdateResponse + var err error + m.Metrics.workflowAttributesEndpointMetrics.update.Time(func() { + response, err = m.WorkflowAttributesManager.UpdateWorkflowAttributes(ctx, *request) + }) + if err != nil { + return nil, util.TransformAndRecordError(err, &m.Metrics.workflowAttributesEndpointMetrics.update) + } + + return response, nil +} + +func (m *AdminService) UpdateProjectDomainAttributes(ctx context.Context, request *admin.ProjectDomainAttributesUpdateRequest) ( + *admin.ProjectDomainAttributesUpdateResponse, error) { + defer m.interceptPanic(ctx, request) + if request == nil { + return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") + } + var response *admin.ProjectDomainAttributesUpdateResponse + var err error + m.Metrics.projectDomainAttributesEndpointMetrics.update.Time(func() { + response, err = m.ProjectDomainAttributesManager.UpdateProjectDomainAttributes(ctx, *request) + }) + if err != nil { + return nil, util.TransformAndRecordError(err, &m.Metrics.projectDomainAttributesEndpointMetrics.update) + } + + return response, nil +} + +func (m *AdminService) UpdateProjectAttributes(ctx context.Context, request *admin.ProjectAttributesUpdateRequest) ( + *admin.ProjectAttributesUpdateResponse, error) { + defer m.interceptPanic(ctx, request) + if request == nil { + return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") + } + var response *admin.ProjectAttributesUpdateResponse + var err error + m.Metrics.projectAttributesEndpointMetrics.update.Time(func() { + response, err = m.ProjectAttributesManager.UpdateProjectAttributes(ctx, *request) + }) + if err != nil { + return nil, util.TransformAndRecordError(err, &m.Metrics.projectAttributesEndpointMetrics.update) + } + + return response, nil +} diff --git a/pkg/rpc/adminservice/base.go b/pkg/rpc/adminservice/base.go index dced405c5c..a20d835865 100644 --- a/pkg/rpc/adminservice/base.go +++ b/pkg/rpc/adminservice/base.go @@ -23,16 +23,18 @@ import ( ) type AdminService struct { - TaskManager interfaces.TaskInterface - WorkflowManager interfaces.WorkflowInterface - LaunchPlanManager interfaces.LaunchPlanInterface - ExecutionManager interfaces.ExecutionInterface - NodeExecutionManager interfaces.NodeExecutionInterface - TaskExecutionManager interfaces.TaskExecutionInterface - ProjectManager interfaces.ProjectInterface - ProjectDomainManager interfaces.ProjectDomainInterface - NamedEntityManager interfaces.NamedEntityInterface - Metrics AdminMetrics + TaskManager interfaces.TaskInterface + WorkflowManager interfaces.WorkflowInterface + LaunchPlanManager interfaces.LaunchPlanInterface + ExecutionManager interfaces.ExecutionInterface + NodeExecutionManager interfaces.NodeExecutionInterface + TaskExecutionManager interfaces.TaskExecutionInterface + ProjectManager interfaces.ProjectInterface + ProjectAttributesManager interfaces.ProjectAttributesInterface + ProjectDomainAttributesManager interfaces.ProjectDomainAttributesInterface + WorkflowAttributesManager interfaces.WorkflowAttributesInterface + NamedEntityManager interfaces.NamedEntityInterface + Metrics AdminMetrics } // Intercepts all admin requests to handle panics during execution. @@ -159,8 +161,10 @@ func NewAdminServer(kubeConfig, master string) *AdminService { db, adminScope.NewSubScope("node_execution_manager"), urlData), TaskExecutionManager: manager.NewTaskExecutionManager( db, adminScope.NewSubScope("task_execution_manager"), urlData), - ProjectManager: manager.NewProjectManager(db, configuration), - ProjectDomainManager: manager.NewProjectDomainManager(db, configuration), - Metrics: InitMetrics(adminScope), + ProjectManager: manager.NewProjectManager(db, configuration), + ProjectAttributesManager: manager.NewProjectAttributesManager(db), + ProjectDomainAttributesManager: manager.NewProjectDomainAttributesManager(db), + WorkflowAttributesManager: manager.NewWorkflowAttributesManager(db), + Metrics: InitMetrics(adminScope), } } diff --git a/pkg/rpc/adminservice/metrics.go b/pkg/rpc/adminservice/metrics.go index b1975bc233..351b9d22ad 100644 --- a/pkg/rpc/adminservice/metrics.go +++ b/pkg/rpc/adminservice/metrics.go @@ -56,7 +56,7 @@ type projectEndpointMetrics struct { list util.RequestMetrics } -type projectDomainEndpointMetrics struct { +type attributeEndpointMetrics struct { scope promutils.Scope update util.RequestMetrics @@ -93,15 +93,17 @@ type AdminMetrics struct { Scope promutils.Scope PanicCounter prometheus.Counter - executionEndpointMetrics executionEndpointMetrics - launchPlanEndpointMetrics launchPlanEndpointMetrics - namedEntityEndpointMetrics namedEntityEndpointMetrics - nodeExecutionEndpointMetrics nodeExecutionEndpointMetrics - projectEndpointMetrics projectEndpointMetrics - projectDomainEndpointMetrics projectDomainEndpointMetrics - taskEndpointMetrics taskEndpointMetrics - taskExecutionEndpointMetrics taskExecutionEndpointMetrics - workflowEndpointMetrics workflowEndpointMetrics + executionEndpointMetrics executionEndpointMetrics + launchPlanEndpointMetrics launchPlanEndpointMetrics + namedEntityEndpointMetrics namedEntityEndpointMetrics + nodeExecutionEndpointMetrics nodeExecutionEndpointMetrics + projectEndpointMetrics projectEndpointMetrics + projectAttributesEndpointMetrics attributeEndpointMetrics + projectDomainAttributesEndpointMetrics attributeEndpointMetrics + workflowAttributesEndpointMetrics attributeEndpointMetrics + taskEndpointMetrics taskEndpointMetrics + taskExecutionEndpointMetrics taskExecutionEndpointMetrics + workflowEndpointMetrics workflowEndpointMetrics } func InitMetrics(adminScope promutils.Scope) AdminMetrics { @@ -149,9 +151,17 @@ func InitMetrics(adminScope promutils.Scope) AdminMetrics { register: util.NewRequestMetrics(adminScope, "register_project"), list: util.NewRequestMetrics(adminScope, "list_projects"), }, - projectDomainEndpointMetrics: projectDomainEndpointMetrics{ + projectAttributesEndpointMetrics: attributeEndpointMetrics{ scope: adminScope, - update: util.NewRequestMetrics(adminScope, "update_project_domain"), + update: util.NewRequestMetrics(adminScope, "update_project_attrs"), + }, + projectDomainAttributesEndpointMetrics: attributeEndpointMetrics{ + scope: adminScope, + update: util.NewRequestMetrics(adminScope, "update_project_domain_attrs"), + }, + workflowAttributesEndpointMetrics: attributeEndpointMetrics{ + scope: adminScope, + update: util.NewRequestMetrics(adminScope, "update_workflow_attrs"), }, taskEndpointMetrics: taskEndpointMetrics{ scope: adminScope, diff --git a/pkg/rpc/adminservice/project_domain.go b/pkg/rpc/adminservice/project_domain.go deleted file mode 100644 index 1b93a23086..0000000000 --- a/pkg/rpc/adminservice/project_domain.go +++ /dev/null @@ -1,28 +0,0 @@ -package adminservice - -import ( - "context" - - "github.com/lyft/flyteadmin/pkg/rpc/adminservice/util" - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -func (m *AdminService) UpdateProjectDomainAttributes(ctx context.Context, request *admin.ProjectDomainAttributesUpdateRequest) ( - *admin.ProjectDomainAttributesUpdateResponse, error) { - defer m.interceptPanic(ctx, request) - if request == nil { - return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") - } - var response *admin.ProjectDomainAttributesUpdateResponse - var err error - m.Metrics.projectEndpointMetrics.register.Time(func() { - response, err = m.ProjectDomainManager.UpdateProjectDomain(ctx, *request) - }) - if err != nil { - return nil, util.TransformAndRecordError(err, &m.Metrics.projectDomainEndpointMetrics.update) - } - - return response, nil -} diff --git a/pkg/rpc/adminservice/tests/util.go b/pkg/rpc/adminservice/tests/util.go index 57cda02c35..a60aad68b6 100644 --- a/pkg/rpc/adminservice/tests/util.go +++ b/pkg/rpc/adminservice/tests/util.go @@ -20,14 +20,14 @@ type NewMockAdminServerInput struct { func NewMockAdminServer(input NewMockAdminServerInput) *adminservice.AdminService { var testScope = mockScope.NewTestScope() return &adminservice.AdminService{ - ExecutionManager: input.executionManager, - LaunchPlanManager: input.launchPlanManager, - NodeExecutionManager: input.nodeExecutionManager, - TaskManager: input.taskManager, - ProjectManager: input.projectManager, - ProjectDomainManager: input.projectDomainManager, - WorkflowManager: input.workflowManager, - TaskExecutionManager: input.taskExecutionManager, - Metrics: adminservice.InitMetrics(testScope), + ExecutionManager: input.executionManager, + LaunchPlanManager: input.launchPlanManager, + NodeExecutionManager: input.nodeExecutionManager, + TaskManager: input.taskManager, + ProjectManager: input.projectManager, + ProjectDomainAttributesManager: input.projectDomainManager, + WorkflowManager: input.workflowManager, + TaskExecutionManager: input.taskExecutionManager, + Metrics: adminservice.InitMetrics(testScope), } } diff --git a/pkg/runtime/interfaces/queue_configuration.go b/pkg/runtime/interfaces/queue_configuration.go index 0abea6d04a..cbeaeeb95d 100644 --- a/pkg/runtime/interfaces/queue_configuration.go +++ b/pkg/runtime/interfaces/queue_configuration.go @@ -8,14 +8,16 @@ type ExecutionQueue struct { Attributes []string } +func (q ExecutionQueue) GetAttributes() []string { + return q.Attributes +} + type ExecutionQueues []ExecutionQueue // Defines the specific resource attributes (tags) a workflow requires to run. type WorkflowConfig struct { - Project string `json:"project"` - Domain string `json:"domain"` - WorkflowName string `json:"workflowName"` - Tags []string `json:"tags"` + Domain string `json:"domain"` + Tags []string `json:"tags"` } type WorkflowConfigs []WorkflowConfig diff --git a/tests/attributes_test.go b/tests/attributes_test.go new file mode 100644 index 0000000000..7c425d6017 --- /dev/null +++ b/tests/attributes_test.go @@ -0,0 +1,121 @@ +// +build integration + +package tests + +import ( + "context" + "testing" + + "github.com/golang/protobuf/proto" + + "github.com/lyft/flyteadmin/pkg/repositories/errors" + "github.com/lyft/flyteadmin/pkg/repositories/gormimpl" + "github.com/lyft/flyteadmin/pkg/repositories/transformers" + + "github.com/stretchr/testify/assert" + + databaseConfig "github.com/lyft/flyteadmin/pkg/repositories/config" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" +) + +var matchingAttributes = &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_TaskResourceAttributes{ + TaskResourceAttributes: &admin.TaskResourceAttributes{ + Defaults: &admin.TaskResourceSpec{ + Cpu: "1", + }, + }, + }, +} + +func TestUpdateProjectAttributes(t *testing.T) { + ctx := context.Background() + client, conn := GetTestAdminServiceClient() + defer conn.Close() + + req := admin.ProjectAttributesUpdateRequest{ + Attributes: &admin.ProjectAttributes{ + Project: "admintests", + MatchingAttributes: matchingAttributes, + }, + } + + _, err := client.UpdateProjectAttributes(ctx, &req) + assert.Nil(t, err) + + // If we ever expose get/list ProjectAttributes APIs update the test below to call those instead. + db := databaseConfig.OpenDbConnection(databaseConfig.NewPostgresConfigProvider(geDbConfig(), adminScope)) + defer db.Close() + + errorsTransformer := errors.NewPostgresErrorTransformer(adminScope.NewSubScope("project_attrs_errors")) + projectRepo := gormimpl.NewProjectAttributesRepo(db, errorsTransformer, adminScope.NewSubScope("project_attrs")) + + attributes, err := projectRepo.Get(ctx, "admintests", admin.MatchableResource_TASK_RESOURCE.String()) + assert.Nil(t, err) + + projectAttributes, err := transformers.FromProjectAttributesModel(attributes) + assert.True(t, proto.Equal(matchingAttributes, projectAttributes.MatchingAttributes)) +} + +func TestUpdateProjectDomainAttributes(t *testing.T) { + ctx := context.Background() + client, conn := GetTestAdminServiceClient() + defer conn.Close() + + req := admin.ProjectDomainAttributesUpdateRequest{ + Attributes: &admin.ProjectDomainAttributes{ + Project: "admintests", + Domain: "development", + MatchingAttributes: matchingAttributes, + }, + } + + _, err := client.UpdateProjectDomainAttributes(ctx, &req) + assert.Nil(t, err) + + // If we ever expose get/list ProjectDomainAttributes APIs update the test below to call those instead. + db := databaseConfig.OpenDbConnection(databaseConfig.NewPostgresConfigProvider(getLocalDbConfig(), adminScope)) + defer db.Close() + + errorsTransformer := errors.NewPostgresErrorTransformer(adminScope.NewSubScope("project_domain_attrs_errors")) + projectDomainRepo := gormimpl.NewProjectDomainAttributesRepo(db, errorsTransformer, adminScope.NewSubScope("project_domain_attrs")) + + attributes, err := projectDomainRepo.Get(ctx, "admintests", "development", + admin.MatchableResource_TASK_RESOURCE.String()) + assert.Nil(t, err) + + projectDomainAttributes, err := transformers.FromProjectDomainAttributesModel(attributes) + assert.True(t, proto.Equal(matchingAttributes, projectDomainAttributes.MatchingAttributes)) +} + +func TestUpdateWorkflowAttributes(t *testing.T) { + ctx := context.Background() + client, conn := GetTestAdminServiceClient() + defer conn.Close() + + req := admin.WorkflowAttributesUpdateRequest{ + Attributes: &admin.WorkflowAttributes{ + Project: "admintests", + Domain: "development", + Workflow: "workflow", + MatchingAttributes: matchingAttributes, + }, + } + + _, err := client.UpdateWorkflowAttributes(ctx, &req) + assert.Nil(t, err) + + // If we ever expose get/list WorkflowAttributes APIs update the test below to call those instead. + db := databaseConfig.OpenDbConnection(databaseConfig.NewPostgresConfigProvider(getLocalDbConfig(), adminScope)) + defer db.Close() + + errorsTransformer := errors.NewPostgresErrorTransformer(adminScope.NewSubScope("workflow_attrs_errors")) + workflowRepo := gormimpl.NewWorkflowAttributesRepo(db, errorsTransformer, adminScope.NewSubScope("workflow_attrs")) + + attributes, err := workflowRepo.Get(ctx, "admintests", "development", "workflow", + admin.MatchableResource_TASK_RESOURCE.String()) + assert.Nil(t, err) + + workflowAttributes, err := transformers.FromWorkflowAttributesModel(attributes) + assert.True(t, proto.Equal(matchingAttributes, workflowAttributes.MatchingAttributes)) +} diff --git a/tests/project_domain_test.go b/tests/project_domain_test.go deleted file mode 100644 index cad3563701..0000000000 --- a/tests/project_domain_test.go +++ /dev/null @@ -1,52 +0,0 @@ -// +build integration - -package tests - -import ( - "context" - "testing" - - "github.com/lyft/flyteadmin/pkg/repositories/errors" - "github.com/lyft/flyteadmin/pkg/repositories/gormimpl" - "github.com/lyft/flyteadmin/pkg/repositories/transformers" - - "github.com/stretchr/testify/assert" - - databaseConfig "github.com/lyft/flyteadmin/pkg/repositories/config" - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" -) - -func TestUpdateProjectDomain(t *testing.T) { - ctx := context.Background() - client, conn := GetTestAdminServiceClient() - defer conn.Close() - - req := admin.ProjectDomainAttributesUpdateRequest{ - Attributes: &admin.ProjectDomainAttributes{ - Project: "admintests", - Domain: "development", - Attributes: map[string]string{ - "foo": "bar", - }, - }, - } - - _, err := client.UpdateProjectDomainAttributes(ctx, &req) - assert.Nil(t, err) - - // If we ever expose get/list ProjectDomainAttributes APIs update the test below to call those instead. - db := databaseConfig.OpenDbConnection(databaseConfig.NewPostgresConfigProvider(getDbConfig(), adminScope)) - defer db.Close() - - errorsTransformer := errors.NewPostgresErrorTransformer(adminScope.NewSubScope("errors")) - projectDomainRepo := gormimpl.NewProjectDomainRepo(db, errorsTransformer, adminScope.NewSubScope("project_domain")) - - attributes, err := projectDomainRepo.Get(ctx, "admintests", "development") - assert.Nil(t, err) - - projectDomain, err := transformers.FromProjectDomainModel(attributes) - - assert.EqualValues(t, map[string]string{ - "foo": "bar", - }, projectDomain.Attributes) -}