diff --git a/pkg/manager/impl/named_entity_manager.go b/pkg/manager/impl/named_entity_manager.go index b44476779..e3bafc3f4 100644 --- a/pkg/manager/impl/named_entity_manager.go +++ b/pkg/manager/impl/named_entity_manager.go @@ -68,16 +68,15 @@ func (m *NamedEntityManager) GetNamedEntity(ctx context.Context, request admin.N return util.GetNamedEntity(ctx, m.db, request.ResourceType, *request.Id) } -func (m *NamedEntityManager) updateQueryFilters(identityFilters []common.InlineFilter, requestFilters string) ( - []common.InlineFilter, error) { +func (m *NamedEntityManager) getQueryFilters(requestFilters string) ([]common.InlineFilter, error) { + filters := make([]common.InlineFilter, 0) if len(requestFilters) == 0 { - return identityFilters, nil + return filters, nil } additionalFilters, err := util.ParseFilters(requestFilters, common.NamedEntityMetadata) if err != nil { return nil, err } - var finalizedFilters = identityFilters for _, filter := range additionalFilters { if strings.Contains(filter.GetField(), state) { filterWithDefaultValue, err := common.NewWithDefaultValueFilter( @@ -85,12 +84,12 @@ func (m *NamedEntityManager) updateQueryFilters(identityFilters []common.InlineF if err != nil { return nil, err } - finalizedFilters = append(finalizedFilters, filterWithDefaultValue) + filters = append(filters, filterWithDefaultValue) } else { - finalizedFilters = append(finalizedFilters, filter) + filters = append(filters, filter) } } - return finalizedFilters, nil + return filters, nil } func (m *NamedEntityManager) ListNamedEntities(ctx context.Context, request admin.NamedEntityListRequest) ( @@ -101,17 +100,10 @@ func (m *NamedEntityManager) ListNamedEntities(ctx context.Context, request admi } ctx = contextutils.WithProjectDomain(ctx, request.Project, request.Domain) - identifierFilters, err := util.GetDbFilters(util.FilterSpec{ - Project: request.Project, - Domain: request.Domain, - }, common.ResourceTypeToEntity[request.ResourceType]) - if err != nil { - return nil, err - } // HACK: In order to filter by state (if requested) - we need to amend the filter to use COALESCE // e.g. eq(state, 1) becomes 'WHERE (COALESCE(state, 0) = '1')' since not every NamedEntity necessarily // has an entry, and therefore the default state value '0' (active), should be assumed. - filters, err := m.updateQueryFilters(identifierFilters, request.Filters) + filters, err := m.getQueryFilters(request.Filters) if err != nil { return nil, err } @@ -127,14 +119,19 @@ func (m *NamedEntityManager) ListNamedEntities(ctx context.Context, request admi return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid pagination token %s for ListNamedEntities", request.Token) } - listInput := repoInterfaces.ListResourceInput{ - Limit: int(request.Limit), - Offset: offset, - InlineFilters: filters, - SortParameter: sortParameter, + listInput := repoInterfaces.ListNamedEntityInput{ + ListResourceInput: repoInterfaces.ListResourceInput{ + Limit: int(request.Limit), + Offset: offset, + InlineFilters: filters, + SortParameter: sortParameter, + }, + Project: request.Project, + Domain: request.Domain, + ResourceType: request.ResourceType, } - output, err := m.db.NamedEntityRepo().List(ctx, request.ResourceType, listInput) + output, err := m.db.NamedEntityRepo().List(ctx, listInput) if err != nil { logger.Debugf(ctx, "Failed to list named entities of type: %s with project: %s, domain: %s. Returned error was: %v", request.ResourceType, request.Project, request.Domain, err) diff --git a/pkg/manager/impl/named_entity_manager_test.go b/pkg/manager/impl/named_entity_manager_test.go index 8c46d356f..91e3685e2 100644 --- a/pkg/manager/impl/named_entity_manager_test.go +++ b/pkg/manager/impl/named_entity_manager_test.go @@ -4,8 +4,6 @@ import ( "context" "testing" - "github.com/lyft/flyteadmin/pkg/common" - "github.com/lyft/flyteadmin/pkg/manager/impl/testutils" "github.com/lyft/flyteadmin/pkg/repositories" "github.com/lyft/flyteadmin/pkg/repositories/interfaces" @@ -86,21 +84,15 @@ func TestNamedEntityManager_Get_BadRequest(t *testing.T) { assert.Nil(t, response) } -func TestNamedEntityManager_UpdateQueryFilters(t *testing.T) { - identityFilter, err := common.NewSingleValueFilter(common.NamedEntityMetadata, common.Equal, "project", "proj") - assert.NoError(t, err) - +func TestNamedEntityManager_getQueryFilters(t *testing.T) { repository := getMockRepositoryForNETest() manager := NewNamedEntityManager(repository, getMockConfigForNETest(), mockScope.NewTestScope()) - updatedFilters, err := manager.(*NamedEntityManager).updateQueryFilters([]common.InlineFilter{ - identityFilter, - }, "eq(state, 0)") + updatedFilters, err := manager.(*NamedEntityManager).getQueryFilters("eq(state, 0)") assert.NoError(t, err) - assert.Len(t, updatedFilters, 2) + assert.Len(t, updatedFilters, 1) - assert.Equal(t, "project", updatedFilters[0].GetField()) - assert.Equal(t, "state", updatedFilters[1].GetField()) - queryExp, err := updatedFilters[1].GetGormQueryExpr() + assert.Equal(t, "state", updatedFilters[0].GetField()) + queryExp, err := updatedFilters[0].GetGormQueryExpr() assert.NoError(t, err) assert.Equal(t, "COALESCE(state, 0) = ?", queryExp.Query) assert.Equal(t, "0", queryExp.Args) diff --git a/pkg/repositories/gormimpl/named_entity_repo.go b/pkg/repositories/gormimpl/named_entity_repo.go index e73e0cfdd..9a193feb2 100644 --- a/pkg/repositories/gormimpl/named_entity_repo.go +++ b/pkg/repositories/gormimpl/named_entity_repo.go @@ -16,6 +16,34 @@ import ( "github.com/lyft/flytestdlib/promutils" ) +const innerJoinTableAlias = "entities" + +var resourceTypeToTableName = map[core.ResourceType]string{ + core.ResourceType_LAUNCH_PLAN: launchPlanTableName, + core.ResourceType_WORKFLOW: workflowTableName, + core.ResourceType_TASK: taskTableName, +} + +var joinString = "RIGHT JOIN ? AS entities ON named_entity_metadata.resource_type = %d AND " + + "named_entity_metadata.project = entities.project AND named_entity_metadata.domain = entities.domain AND " + + "named_entity_metadata.name = entities.name" + +func getSubQueryJoin(db *gorm.DB, tableName string, input interfaces.ListNamedEntityInput) *gorm.DB { + tx := db.Select([]string{Project, Domain, Name}). + Table(tableName). + Where(map[string]interface{}{Project: input.Project, Domain: input.Domain}). + Limit(input.Limit). + Offset(input.Offset). + Group(identifierGroupBy) + + // Apply consistent sort ordering. + if input.SortParameter != nil { + tx = tx.Order(input.SortParameter.GetGormOrderExpr()) + } + + return db.Joins(fmt.Sprintf(joinString, input.ResourceType), tx.SubQuery()) +} + var leftJoinWorkflowNameToMetadata = fmt.Sprintf( "LEFT JOIN %s ON %s.resource_type = %d AND %s.project = %s.project AND %s.domain = %s.domain AND %s.name = %s.name", namedEntityMetadataTableName, namedEntityMetadataTableName, core.ResourceType_WORKFLOW, namedEntityMetadataTableName, workflowTableName, namedEntityMetadataTableName, workflowTableName, @@ -31,23 +59,16 @@ var leftJoinTaskNameToMetadata = fmt.Sprintf( namedEntityMetadataTableName, taskTableName, namedEntityMetadataTableName, taskTableName) -var resourceTypeToTableName = map[core.ResourceType]string{ - core.ResourceType_LAUNCH_PLAN: launchPlanTableName, - core.ResourceType_WORKFLOW: workflowTableName, - core.ResourceType_TASK: taskTableName, -} - var resourceTypeToMetadataJoin = map[core.ResourceType]string{ core.ResourceType_LAUNCH_PLAN: leftJoinLaunchPlanNameToMetadata, core.ResourceType_WORKFLOW: leftJoinWorkflowNameToMetadata, core.ResourceType_TASK: leftJoinTaskNameToMetadata, } -func getGroupByForNamedEntity(tableName string) string { - return fmt.Sprintf("%s.%s, %s.%s, %s.%s, %s.%s, %s.%s", - tableName, Project, tableName, Domain, tableName, Name, namedEntityMetadataTableName, Description, - namedEntityMetadataTableName, State) -} +var getGroupByForNamedEntity = fmt.Sprintf("%s.%s, %s.%s, %s.%s, %s.%s, %s.%s", + innerJoinTableAlias, Project, innerJoinTableAlias, Domain, innerJoinTableAlias, Name, namedEntityMetadataTableName, + Description, + namedEntityMetadataTableName, State) func getSelectForNamedEntity(tableName string, resourceType core.ResourceType) []string { return []string{ @@ -141,29 +162,33 @@ func (r *NamedEntityRepo) Get(ctx context.Context, input interfaces.GetNamedEnti return namedEntity, nil } -func (r *NamedEntityRepo) List(ctx context.Context, resourceType core.ResourceType, input interfaces.ListResourceInput) ( +func (r *NamedEntityRepo) List(ctx context.Context, input interfaces.ListNamedEntityInput) ( interfaces.NamedEntityCollectionOutput, error) { - // Validate input. - if err := ValidateListInput(input); err != nil { - return interfaces.NamedEntityCollectionOutput{}, err + // Validate input. Filters aren't required because they're implicit in the Project & Domain specified by the input. + if len(input.Project) == 0 { + return interfaces.NamedEntityCollectionOutput{}, errors.GetInvalidInputError(Project) + } + if len(input.Domain) == 0 { + return interfaces.NamedEntityCollectionOutput{}, errors.GetInvalidInputError(Domain) + } + if input.Limit == 0 { + return interfaces.NamedEntityCollectionOutput{}, errors.GetInvalidInputError(limit) } - tableName, tableFound := resourceTypeToTableName[resourceType] - joinString, joinFound := resourceTypeToMetadataJoin[resourceType] - if !tableFound || !joinFound { - return interfaces.NamedEntityCollectionOutput{}, adminErrors.NewFlyteAdminErrorf(codes.InvalidArgument, "Cannot list entity names for resource type: %v", resourceType) + tableName, tableFound := resourceTypeToTableName[input.ResourceType] + if !tableFound { + return interfaces.NamedEntityCollectionOutput{}, adminErrors.NewFlyteAdminErrorf(codes.InvalidArgument, + "Cannot list entity names for resource type: %v", input.ResourceType) } - tx := r.db.Table(tableName).Limit(input.Limit).Offset(input.Offset) - tx = tx.Joins(joinString) + tx := getSubQueryJoin(r.db, tableName, input) // Apply filters tx, err := applyScopedFilters(tx, input.InlineFilters, input.MapFilters) if err != nil { return interfaces.NamedEntityCollectionOutput{}, err } - // Apply sort ordering. if input.SortParameter != nil { tx = tx.Order(input.SortParameter.GetGormOrderExpr()) @@ -172,8 +197,11 @@ func (r *NamedEntityRepo) List(ctx context.Context, resourceType core.ResourceTy // Scan the results into a list of named entities var entities []models.NamedEntity timer := r.metrics.ListDuration.Start() - tx.Select(getSelectForNamedEntity(tableName, resourceType)).Group(getGroupByForNamedEntity(tableName)).Scan(&entities) + + tx.Select(getSelectForNamedEntity(innerJoinTableAlias, input.ResourceType)).Table(namedEntityMetadataTableName).Group(getGroupByForNamedEntity).Scan(&entities) + timer.Stop() + if tx.Error != nil { return interfaces.NamedEntityCollectionOutput{}, r.errorTransformer.ToFlyteAdminError(tx.Error) } diff --git a/pkg/repositories/gormimpl/named_entity_repo_test.go b/pkg/repositories/gormimpl/named_entity_repo_test.go index 4513049fd..87532c194 100644 --- a/pkg/repositories/gormimpl/named_entity_repo_test.go +++ b/pkg/repositories/gormimpl/named_entity_repo_test.go @@ -4,6 +4,8 @@ import ( "context" "testing" + "github.com/lyft/flyteadmin/pkg/common" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" mocket "github.com/Selvatico/go-mocket" @@ -133,3 +135,44 @@ func TestUpdateNamedEntity_CreateNew(t *testing.T) { assert.NoError(t, err) assert.True(t, mockQuery.Triggered) } + +func TestListNamedEntity(t *testing.T) { + metadataRepo := NewNamedEntityRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + + results := make([]map[string]interface{}, 0) + metadata := getMockNamedEntityResponseFromDb(models.NamedEntity{ + NamedEntityKey: models.NamedEntityKey{ + ResourceType: resourceType, + Project: project, + Domain: domain, + Name: name, + }, + NamedEntityMetadataFields: models.NamedEntityMetadataFields{ + Description: description, + }, + }) + results = append(results, metadata) + + GlobalMock := mocket.Catcher.Reset() + GlobalMock.Logging = true + mockQuery := GlobalMock.NewMock() + + mockQuery.WithQuery( + `GROUP BY project, domain, name ORDER BY name desc LIMIT 20 OFFSET 0) AS entities`).WithReply(results) + + sortParameter, _ := common.NewSortParameter(admin.Sort{ + Direction: admin.Sort_DESCENDING, + Key: "name", + }) + output, err := metadataRepo.List(context.Background(), interfaces.ListNamedEntityInput{ + ResourceType: resourceType, + Project: "admintests", + Domain: "development", + ListResourceInput: interfaces.ListResourceInput{ + Limit: 20, + SortParameter: sortParameter, + }, + }) + assert.NoError(t, err) + assert.Len(t, output.Entities, 1) +} diff --git a/pkg/repositories/interfaces/named_entity_repo.go b/pkg/repositories/interfaces/named_entity_repo.go index 084cce399..a9ffb0cdf 100644 --- a/pkg/repositories/interfaces/named_entity_repo.go +++ b/pkg/repositories/interfaces/named_entity_repo.go @@ -15,6 +15,14 @@ type GetNamedEntityInput struct { Name string } +// Parameters for querying multiple resources. +type ListNamedEntityInput struct { + ListResourceInput + Project string + Domain string + ResourceType core.ResourceType +} + type NamedEntityCollectionOutput struct { Entities []models.NamedEntity } @@ -23,7 +31,7 @@ type NamedEntityCollectionOutput struct { type NamedEntityRepoInterface interface { // Returns NamedEntity objects matching the provided query. A limit is // required - List(ctx context.Context, resourceType core.ResourceType, input ListResourceInput) (NamedEntityCollectionOutput, error) + List(ctx context.Context, input ListNamedEntityInput) (NamedEntityCollectionOutput, error) // Updates NamedEntity record, will create metadata if it does not exist Update(ctx context.Context, input models.NamedEntity) error // Gets metadata (if available) associated with a NamedEntity diff --git a/pkg/repositories/mocks/named_entity_metadata_repo.go b/pkg/repositories/mocks/named_entity_metadata_repo.go index fb05046cf..619cdbcf7 100644 --- a/pkg/repositories/mocks/named_entity_metadata_repo.go +++ b/pkg/repositories/mocks/named_entity_metadata_repo.go @@ -4,14 +4,12 @@ package mocks import ( "context" - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" - "github.com/lyft/flyteadmin/pkg/repositories/interfaces" "github.com/lyft/flyteadmin/pkg/repositories/models" ) type GetNamedEntityFunc func(input interfaces.GetNamedEntityInput) (models.NamedEntity, error) -type ListNamedEntityFunc func(resourceType core.ResourceType, input interfaces.ListResourceInput) (interfaces.NamedEntityCollectionOutput, error) +type ListNamedEntityFunc func(input interfaces.ListNamedEntityInput) (interfaces.NamedEntityCollectionOutput, error) type UpdateNamedEntityFunc func(input models.NamedEntity) error type MockNamedEntityRepo struct { @@ -45,9 +43,9 @@ func (r *MockNamedEntityRepo) Get( }, nil } -func (r *MockNamedEntityRepo) List(ctx context.Context, resourceType core.ResourceType, input interfaces.ListResourceInput) (interfaces.NamedEntityCollectionOutput, error) { +func (r *MockNamedEntityRepo) List(ctx context.Context, input interfaces.ListNamedEntityInput) (interfaces.NamedEntityCollectionOutput, error) { if r.listFunction != nil { - return r.listFunction(resourceType, input) + return r.listFunction(input) } return interfaces.NamedEntityCollectionOutput{}, nil }