Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Improve ListNamedEntities perf: Use sub-query with group by (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
katrogan authored Apr 23, 2020
1 parent 6a55199 commit f48108c
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 63 deletions.
39 changes: 18 additions & 21 deletions pkg/manager/impl/named_entity_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,29 +68,28 @@ func (m *NamedEntityManager) GetNamedEntity(ctx context.Context, request admin.N
return util.GetNamedEntity(ctx, m.db, request.ResourceType, *request.Id)
}

func (m *NamedEntityManager) updateQueryFilters(identityFilters []common.InlineFilter, requestFilters string) (
[]common.InlineFilter, error) {
func (m *NamedEntityManager) getQueryFilters(requestFilters string) ([]common.InlineFilter, error) {
filters := make([]common.InlineFilter, 0)
if len(requestFilters) == 0 {
return identityFilters, nil
return filters, nil
}
additionalFilters, err := util.ParseFilters(requestFilters, common.NamedEntityMetadata)
if err != nil {
return nil, err
}
var finalizedFilters = identityFilters
for _, filter := range additionalFilters {
if strings.Contains(filter.GetField(), state) {
filterWithDefaultValue, err := common.NewWithDefaultValueFilter(
strconv.Itoa(int(admin.NamedEntityState_NAMED_ENTITY_ACTIVE)), filter)
if err != nil {
return nil, err
}
finalizedFilters = append(finalizedFilters, filterWithDefaultValue)
filters = append(filters, filterWithDefaultValue)
} else {
finalizedFilters = append(finalizedFilters, filter)
filters = append(filters, filter)
}
}
return finalizedFilters, nil
return filters, nil
}

func (m *NamedEntityManager) ListNamedEntities(ctx context.Context, request admin.NamedEntityListRequest) (
Expand All @@ -101,17 +100,10 @@ func (m *NamedEntityManager) ListNamedEntities(ctx context.Context, request admi
}
ctx = contextutils.WithProjectDomain(ctx, request.Project, request.Domain)

identifierFilters, err := util.GetDbFilters(util.FilterSpec{
Project: request.Project,
Domain: request.Domain,
}, common.ResourceTypeToEntity[request.ResourceType])
if err != nil {
return nil, err
}
// HACK: In order to filter by state (if requested) - we need to amend the filter to use COALESCE
// e.g. eq(state, 1) becomes 'WHERE (COALESCE(state, 0) = '1')' since not every NamedEntity necessarily
// has an entry, and therefore the default state value '0' (active), should be assumed.
filters, err := m.updateQueryFilters(identifierFilters, request.Filters)
filters, err := m.getQueryFilters(request.Filters)
if err != nil {
return nil, err
}
Expand All @@ -127,14 +119,19 @@ func (m *NamedEntityManager) ListNamedEntities(ctx context.Context, request admi
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument,
"invalid pagination token %s for ListNamedEntities", request.Token)
}
listInput := repoInterfaces.ListResourceInput{
Limit: int(request.Limit),
Offset: offset,
InlineFilters: filters,
SortParameter: sortParameter,
listInput := repoInterfaces.ListNamedEntityInput{
ListResourceInput: repoInterfaces.ListResourceInput{
Limit: int(request.Limit),
Offset: offset,
InlineFilters: filters,
SortParameter: sortParameter,
},
Project: request.Project,
Domain: request.Domain,
ResourceType: request.ResourceType,
}

output, err := m.db.NamedEntityRepo().List(ctx, request.ResourceType, listInput)
output, err := m.db.NamedEntityRepo().List(ctx, listInput)
if err != nil {
logger.Debugf(ctx, "Failed to list named entities of type: %s with project: %s, domain: %s. Returned error was: %v",
request.ResourceType, request.Project, request.Domain, err)
Expand Down
18 changes: 5 additions & 13 deletions pkg/manager/impl/named_entity_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import (
"context"
"testing"

"github.com/lyft/flyteadmin/pkg/common"

"github.com/lyft/flyteadmin/pkg/manager/impl/testutils"
"github.com/lyft/flyteadmin/pkg/repositories"
"github.com/lyft/flyteadmin/pkg/repositories/interfaces"
Expand Down Expand Up @@ -86,21 +84,15 @@ func TestNamedEntityManager_Get_BadRequest(t *testing.T) {
assert.Nil(t, response)
}

func TestNamedEntityManager_UpdateQueryFilters(t *testing.T) {
identityFilter, err := common.NewSingleValueFilter(common.NamedEntityMetadata, common.Equal, "project", "proj")
assert.NoError(t, err)

func TestNamedEntityManager_getQueryFilters(t *testing.T) {
repository := getMockRepositoryForNETest()
manager := NewNamedEntityManager(repository, getMockConfigForNETest(), mockScope.NewTestScope())
updatedFilters, err := manager.(*NamedEntityManager).updateQueryFilters([]common.InlineFilter{
identityFilter,
}, "eq(state, 0)")
updatedFilters, err := manager.(*NamedEntityManager).getQueryFilters("eq(state, 0)")
assert.NoError(t, err)
assert.Len(t, updatedFilters, 2)
assert.Len(t, updatedFilters, 1)

assert.Equal(t, "project", updatedFilters[0].GetField())
assert.Equal(t, "state", updatedFilters[1].GetField())
queryExp, err := updatedFilters[1].GetGormQueryExpr()
assert.Equal(t, "state", updatedFilters[0].GetField())
queryExp, err := updatedFilters[0].GetGormQueryExpr()
assert.NoError(t, err)
assert.Equal(t, "COALESCE(state, 0) = ?", queryExp.Query)
assert.Equal(t, "0", queryExp.Args)
Expand Down
74 changes: 51 additions & 23 deletions pkg/repositories/gormimpl/named_entity_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,34 @@ import (
"github.com/lyft/flytestdlib/promutils"
)

const innerJoinTableAlias = "entities"

var resourceTypeToTableName = map[core.ResourceType]string{
core.ResourceType_LAUNCH_PLAN: launchPlanTableName,
core.ResourceType_WORKFLOW: workflowTableName,
core.ResourceType_TASK: taskTableName,
}

var joinString = "RIGHT JOIN ? AS entities ON named_entity_metadata.resource_type = %d AND " +
"named_entity_metadata.project = entities.project AND named_entity_metadata.domain = entities.domain AND " +
"named_entity_metadata.name = entities.name"

func getSubQueryJoin(db *gorm.DB, tableName string, input interfaces.ListNamedEntityInput) *gorm.DB {
tx := db.Select([]string{Project, Domain, Name}).
Table(tableName).
Where(map[string]interface{}{Project: input.Project, Domain: input.Domain}).
Limit(input.Limit).
Offset(input.Offset).
Group(identifierGroupBy)

// Apply consistent sort ordering.
if input.SortParameter != nil {
tx = tx.Order(input.SortParameter.GetGormOrderExpr())
}

return db.Joins(fmt.Sprintf(joinString, input.ResourceType), tx.SubQuery())
}

var leftJoinWorkflowNameToMetadata = fmt.Sprintf(
"LEFT JOIN %s ON %s.resource_type = %d AND %s.project = %s.project AND %s.domain = %s.domain AND %s.name = %s.name", namedEntityMetadataTableName, namedEntityMetadataTableName, core.ResourceType_WORKFLOW, namedEntityMetadataTableName, workflowTableName,
namedEntityMetadataTableName, workflowTableName,
Expand All @@ -31,23 +59,16 @@ var leftJoinTaskNameToMetadata = fmt.Sprintf(
namedEntityMetadataTableName, taskTableName,
namedEntityMetadataTableName, taskTableName)

var resourceTypeToTableName = map[core.ResourceType]string{
core.ResourceType_LAUNCH_PLAN: launchPlanTableName,
core.ResourceType_WORKFLOW: workflowTableName,
core.ResourceType_TASK: taskTableName,
}

var resourceTypeToMetadataJoin = map[core.ResourceType]string{
core.ResourceType_LAUNCH_PLAN: leftJoinLaunchPlanNameToMetadata,
core.ResourceType_WORKFLOW: leftJoinWorkflowNameToMetadata,
core.ResourceType_TASK: leftJoinTaskNameToMetadata,
}

func getGroupByForNamedEntity(tableName string) string {
return fmt.Sprintf("%s.%s, %s.%s, %s.%s, %s.%s, %s.%s",
tableName, Project, tableName, Domain, tableName, Name, namedEntityMetadataTableName, Description,
namedEntityMetadataTableName, State)
}
var getGroupByForNamedEntity = fmt.Sprintf("%s.%s, %s.%s, %s.%s, %s.%s, %s.%s",
innerJoinTableAlias, Project, innerJoinTableAlias, Domain, innerJoinTableAlias, Name, namedEntityMetadataTableName,
Description,
namedEntityMetadataTableName, State)

func getSelectForNamedEntity(tableName string, resourceType core.ResourceType) []string {
return []string{
Expand Down Expand Up @@ -141,29 +162,33 @@ func (r *NamedEntityRepo) Get(ctx context.Context, input interfaces.GetNamedEnti
return namedEntity, nil
}

func (r *NamedEntityRepo) List(ctx context.Context, resourceType core.ResourceType, input interfaces.ListResourceInput) (
func (r *NamedEntityRepo) List(ctx context.Context, input interfaces.ListNamedEntityInput) (
interfaces.NamedEntityCollectionOutput, error) {

// Validate input.
if err := ValidateListInput(input); err != nil {
return interfaces.NamedEntityCollectionOutput{}, err
// Validate input. Filters aren't required because they're implicit in the Project & Domain specified by the input.
if len(input.Project) == 0 {
return interfaces.NamedEntityCollectionOutput{}, errors.GetInvalidInputError(Project)
}
if len(input.Domain) == 0 {
return interfaces.NamedEntityCollectionOutput{}, errors.GetInvalidInputError(Domain)
}
if input.Limit == 0 {
return interfaces.NamedEntityCollectionOutput{}, errors.GetInvalidInputError(limit)
}

tableName, tableFound := resourceTypeToTableName[resourceType]
joinString, joinFound := resourceTypeToMetadataJoin[resourceType]
if !tableFound || !joinFound {
return interfaces.NamedEntityCollectionOutput{}, adminErrors.NewFlyteAdminErrorf(codes.InvalidArgument, "Cannot list entity names for resource type: %v", resourceType)
tableName, tableFound := resourceTypeToTableName[input.ResourceType]
if !tableFound {
return interfaces.NamedEntityCollectionOutput{}, adminErrors.NewFlyteAdminErrorf(codes.InvalidArgument,
"Cannot list entity names for resource type: %v", input.ResourceType)
}

tx := r.db.Table(tableName).Limit(input.Limit).Offset(input.Offset)
tx = tx.Joins(joinString)
tx := getSubQueryJoin(r.db, tableName, input)

// Apply filters
tx, err := applyScopedFilters(tx, input.InlineFilters, input.MapFilters)
if err != nil {
return interfaces.NamedEntityCollectionOutput{}, err
}

// Apply sort ordering.
if input.SortParameter != nil {
tx = tx.Order(input.SortParameter.GetGormOrderExpr())
Expand All @@ -172,8 +197,11 @@ func (r *NamedEntityRepo) List(ctx context.Context, resourceType core.ResourceTy
// Scan the results into a list of named entities
var entities []models.NamedEntity
timer := r.metrics.ListDuration.Start()
tx.Select(getSelectForNamedEntity(tableName, resourceType)).Group(getGroupByForNamedEntity(tableName)).Scan(&entities)

tx.Select(getSelectForNamedEntity(innerJoinTableAlias, input.ResourceType)).Table(namedEntityMetadataTableName).Group(getGroupByForNamedEntity).Scan(&entities)

timer.Stop()

if tx.Error != nil {
return interfaces.NamedEntityCollectionOutput{}, r.errorTransformer.ToFlyteAdminError(tx.Error)
}
Expand Down
43 changes: 43 additions & 0 deletions pkg/repositories/gormimpl/named_entity_repo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"testing"

"github.com/lyft/flyteadmin/pkg/common"

"github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin"

mocket "github.com/Selvatico/go-mocket"
Expand Down Expand Up @@ -133,3 +135,44 @@ func TestUpdateNamedEntity_CreateNew(t *testing.T) {
assert.NoError(t, err)
assert.True(t, mockQuery.Triggered)
}

func TestListNamedEntity(t *testing.T) {
metadataRepo := NewNamedEntityRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope())

results := make([]map[string]interface{}, 0)
metadata := getMockNamedEntityResponseFromDb(models.NamedEntity{
NamedEntityKey: models.NamedEntityKey{
ResourceType: resourceType,
Project: project,
Domain: domain,
Name: name,
},
NamedEntityMetadataFields: models.NamedEntityMetadataFields{
Description: description,
},
})
results = append(results, metadata)

GlobalMock := mocket.Catcher.Reset()
GlobalMock.Logging = true
mockQuery := GlobalMock.NewMock()

mockQuery.WithQuery(
`GROUP BY project, domain, name ORDER BY name desc LIMIT 20 OFFSET 0) AS entities`).WithReply(results)

sortParameter, _ := common.NewSortParameter(admin.Sort{
Direction: admin.Sort_DESCENDING,
Key: "name",
})
output, err := metadataRepo.List(context.Background(), interfaces.ListNamedEntityInput{
ResourceType: resourceType,
Project: "admintests",
Domain: "development",
ListResourceInput: interfaces.ListResourceInput{
Limit: 20,
SortParameter: sortParameter,
},
})
assert.NoError(t, err)
assert.Len(t, output.Entities, 1)
}
10 changes: 9 additions & 1 deletion pkg/repositories/interfaces/named_entity_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ type GetNamedEntityInput struct {
Name string
}

// Parameters for querying multiple resources.
type ListNamedEntityInput struct {
ListResourceInput
Project string
Domain string
ResourceType core.ResourceType
}

type NamedEntityCollectionOutput struct {
Entities []models.NamedEntity
}
Expand All @@ -23,7 +31,7 @@ type NamedEntityCollectionOutput struct {
type NamedEntityRepoInterface interface {
// Returns NamedEntity objects matching the provided query. A limit is
// required
List(ctx context.Context, resourceType core.ResourceType, input ListResourceInput) (NamedEntityCollectionOutput, error)
List(ctx context.Context, input ListNamedEntityInput) (NamedEntityCollectionOutput, error)
// Updates NamedEntity record, will create metadata if it does not exist
Update(ctx context.Context, input models.NamedEntity) error
// Gets metadata (if available) associated with a NamedEntity
Expand Down
8 changes: 3 additions & 5 deletions pkg/repositories/mocks/named_entity_metadata_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@ package mocks
import (
"context"

"github.com/lyft/flyteidl/gen/pb-go/flyteidl/core"

"github.com/lyft/flyteadmin/pkg/repositories/interfaces"
"github.com/lyft/flyteadmin/pkg/repositories/models"
)

type GetNamedEntityFunc func(input interfaces.GetNamedEntityInput) (models.NamedEntity, error)
type ListNamedEntityFunc func(resourceType core.ResourceType, input interfaces.ListResourceInput) (interfaces.NamedEntityCollectionOutput, error)
type ListNamedEntityFunc func(input interfaces.ListNamedEntityInput) (interfaces.NamedEntityCollectionOutput, error)
type UpdateNamedEntityFunc func(input models.NamedEntity) error

type MockNamedEntityRepo struct {
Expand Down Expand Up @@ -45,9 +43,9 @@ func (r *MockNamedEntityRepo) Get(
}, nil
}

func (r *MockNamedEntityRepo) List(ctx context.Context, resourceType core.ResourceType, input interfaces.ListResourceInput) (interfaces.NamedEntityCollectionOutput, error) {
func (r *MockNamedEntityRepo) List(ctx context.Context, input interfaces.ListNamedEntityInput) (interfaces.NamedEntityCollectionOutput, error) {
if r.listFunction != nil {
return r.listFunction(resourceType, input)
return r.listFunction(input)
}
return interfaces.NamedEntityCollectionOutput{}, nil
}
Expand Down

0 comments on commit f48108c

Please sign in to comment.