From bdbfb5ee1b1d05a4ef6f9f079b63c59bd8a3db83 Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Fri, 9 Feb 2024 15:24:11 -0800 Subject: [PATCH] Always force org in database where clauses (#65) --- .../pkg/manager/impl/project_manager_test.go | 29 ++++++++++++----- flyteadmin/pkg/manager/impl/util/filters.go | 11 +++---- .../pkg/manager/impl/util/filters_test.go | 2 ++ .../pkg/repositories/gormimpl/common.go | 13 ++++++++ .../gormimpl/description_entity_repo_test.go | 31 +++++++++++++++++- .../repositories/gormimpl/execution_repo.go | 5 ++- .../gormimpl/execution_repo_test.go | 8 ++--- .../repositories/gormimpl/launch_plan_repo.go | 9 +++--- .../gormimpl/launch_plan_repo_test.go | 11 +++---- .../gormimpl/named_entity_repo_test.go | 2 -- .../gormimpl/node_execution_repo.go | 32 ++++--------------- .../gormimpl/node_execution_repo_test.go | 6 ++-- .../pkg/repositories/gormimpl/project_repo.go | 3 +- .../gormimpl/project_repo_test.go | 8 ++--- .../repositories/gormimpl/resource_repo.go | 10 +++--- .../gormimpl/resource_repo_test.go | 7 ++-- .../pkg/repositories/gormimpl/signal_repo.go | 6 ++-- .../repositories/gormimpl/signal_repo_test.go | 6 ++-- .../gormimpl/task_execution_repo.go | 8 ++--- .../gormimpl/task_execution_repo_test.go | 4 +-- .../pkg/repositories/gormimpl/task_repo.go | 3 +- .../repositories/gormimpl/task_repo_test.go | 2 +- .../repositories/gormimpl/workflow_repo.go | 3 +- .../gormimpl/workflow_repo_test.go | 2 +- .../scheduler/repositories/gormimpl/common.go | 9 ++++++ .../gormimpl/schedulable_entity_repo.go | 11 +++---- 26 files changed, 132 insertions(+), 109 deletions(-) create mode 100644 flyteadmin/scheduler/repositories/gormimpl/common.go diff --git a/flyteadmin/pkg/manager/impl/project_manager_test.go b/flyteadmin/pkg/manager/impl/project_manager_test.go index 1abd46d856..31523a05f8 100644 --- a/flyteadmin/pkg/manager/impl/project_manager_test.go +++ b/flyteadmin/pkg/manager/impl/project_manager_test.go @@ -46,13 +46,23 @@ func getMockApplicationConfigForProjectManagerTest() runtimeInterfaces.Applicati return &mockApplicationConfig } -func testListProjects(request admin.ProjectListRequest, token string, orderExpr string, queryExpr *common.GormQueryExpr, t *testing.T) { +func expectedOrgQueryExpr() *common.GormQueryExpr { + return &common.GormQueryExpr{ + Query: "org = ?", + Args: "", + } +} + +func testListProjects(request admin.ProjectListRequest, token string, orderExpr string, queryExprs []*common.GormQueryExpr, t *testing.T) { repository := repositoryMocks.NewMockRepository() repository.ProjectRepo().(*repositoryMocks.MockProjectRepo).ListProjectsFunction = func( ctx context.Context, input interfaces.ListResourceInput) ([]models.Project, error) { if len(input.InlineFilters) != 0 { - q, _ := input.InlineFilters[0].GetGormQueryExpr() - assert.Equal(t, *queryExpr, q) + for idx, inlineFilter := range input.InlineFilters { + q, _ := inlineFilter.GetGormQueryExpr() + assert.Equal(t, *queryExprs[idx], q) + } + } assert.Equal(t, orderExpr, input.SortParameter.GetGormOrderExpr()) activeState := int32(admin.Project_ACTIVE) @@ -82,7 +92,7 @@ func TestListProjects_NoFilters_LimitOne(t *testing.T) { testListProjects(admin.ProjectListRequest{ Token: "1", Limit: 1, - }, "2", "identifier asc", nil, t) + }, "2", "identifier asc", []*common.GormQueryExpr{expectedOrgQueryExpr()}, t) } func TestListProjects_HighLimit_SortBy_Filter(t *testing.T) { @@ -94,14 +104,17 @@ func TestListProjects_HighLimit_SortBy_Filter(t *testing.T) { Key: "name", Direction: admin.Sort_DESCENDING, }, - }, "", "name desc", &common.GormQueryExpr{ - Query: "name = ?", - Args: "foo", + }, "", "name desc", []*common.GormQueryExpr{ + expectedOrgQueryExpr(), + { + Query: "name = ?", + Args: "foo", + }, }, t) } func TestListProjects_NoToken_NoLimit(t *testing.T) { - testListProjects(admin.ProjectListRequest{}, "", "identifier asc", nil, t) + testListProjects(admin.ProjectListRequest{}, "", "identifier asc", []*common.GormQueryExpr{expectedOrgQueryExpr()}, t) } func TestProjectManager_CreateProject(t *testing.T) { diff --git a/flyteadmin/pkg/manager/impl/util/filters.go b/flyteadmin/pkg/manager/impl/util/filters.go index b26275e512..68f4d19ae8 100644 --- a/flyteadmin/pkg/manager/impl/util/filters.go +++ b/flyteadmin/pkg/manager/impl/util/filters.go @@ -210,13 +210,12 @@ type FilterSpec struct { func getIdentifierFilters(entity common.Entity, spec FilterSpec) ([]common.InlineFilter, error) { filters := make([]common.InlineFilter, 0) - if spec.Org != "" { - orgFilter, err := GetSingleValueEqualityFilter(entity, shared.Org, spec.Org) - if err != nil { - return nil, err - } - filters = append(filters, orgFilter) + // Always apply the org filter even when it's omitted + orgFilter, err := GetSingleValueEqualityFilter(entity, shared.Org, spec.Org) + if err != nil { + return nil, err } + filters = append(filters, orgFilter) if spec.Project != "" { projectFilter, err := GetSingleValueEqualityFilter(entity, shared.Project, spec.Project) diff --git a/flyteadmin/pkg/manager/impl/util/filters_test.go b/flyteadmin/pkg/manager/impl/util/filters_test.go index 47fbd8fc46..1067a0570c 100644 --- a/flyteadmin/pkg/manager/impl/util/filters_test.go +++ b/flyteadmin/pkg/manager/impl/util/filters_test.go @@ -155,12 +155,14 @@ func TestGetDbFilters(t *testing.T) { assert.NoError(t, err) // Init expected values for filters. + orgFilter, _ := GetSingleValueEqualityFilter(common.LaunchPlan, shared.Org, "") projectFilter, _ := GetSingleValueEqualityFilter(common.LaunchPlan, shared.Project, "project") domainFilter, _ := GetSingleValueEqualityFilter(common.LaunchPlan, shared.Domain, "domain") nameFilter, _ := GetSingleValueEqualityFilter(common.LaunchPlan, shared.Name, "name") versionFilter, _ := common.NewSingleValueFilter(common.LaunchPlan, common.NotEqual, shared.Version, "TheWorst") workflowNameFilter, _ := common.NewSingleValueFilter(common.Workflow, common.Equal, shared.Name, "workflow") expectedFilters := []common.InlineFilter{ + orgFilter, projectFilter, domainFilter, nameFilter, diff --git a/flyteadmin/pkg/repositories/gormimpl/common.go b/flyteadmin/pkg/repositories/gormimpl/common.go index 735c9e8ba2..30b8d429c3 100644 --- a/flyteadmin/pkg/repositories/gormimpl/common.go +++ b/flyteadmin/pkg/repositories/gormimpl/common.go @@ -116,3 +116,16 @@ func applyScopedFilters(tx *gorm.DB, inlineFilters []common.InlineFilter, mapFil } return tx, nil } + +const ( + orgColumn = "org" + executionOrgColumn = "execution_org" +) + +func getOrgFilter(org string) map[string]interface{} { + return map[string]interface{}{orgColumn: org} +} + +func getExecutionOrgFilter(executionOrg string) map[string]interface{} { + return map[string]interface{}{executionOrgColumn: executionOrg} +} diff --git a/flyteadmin/pkg/repositories/gormimpl/description_entity_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/description_entity_repo_test.go index a50dc2bdc5..0ae63a41e2 100644 --- a/flyteadmin/pkg/repositories/gormimpl/description_entity_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/description_entity_repo_test.go @@ -54,6 +54,33 @@ func TestGetDescriptionEntity(t *testing.T) { assert.Equal(t, shortDescription, output.ShortDescription) } +func TestGetDescriptionEntityNoOrg(t *testing.T) { + descriptionEntityRepo := NewDescriptionEntityRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + + descriptionEntities := make([]map[string]interface{}, 0) + descriptionEntity := getMockDescriptionEntityResponseFromDb(version, []byte{1, 2}) + descriptionEntities = append(descriptionEntities, descriptionEntity) + + GlobalMock := mocket.Catcher.Reset() + GlobalMock.Logging = true + // Only match on queries that append expected filters + GlobalMock.NewMock().WithQuery(`SELECT * FROM "description_entities" WHERE project = $1 AND domain = $2 AND name = $3 AND version = $4 AND org = $5 LIMIT 1`). + WithReply(descriptionEntities) + output, err := descriptionEntityRepo.Get(context.Background(), interfaces.GetDescriptionEntityInput{ + ResourceType: resourceType, + Project: project, + Domain: domain, + Name: name, + Version: version, + }) + assert.Empty(t, err) + assert.Equal(t, project, output.Project) + assert.Equal(t, domain, output.Domain) + assert.Equal(t, name, output.Name) + assert.Equal(t, version, output.Version) + assert.Equal(t, shortDescription, output.ShortDescription) +} + func TestListDescriptionEntities(t *testing.T) { descriptionEntityRepo := NewDescriptionEntityRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) @@ -65,7 +92,8 @@ func TestListDescriptionEntities(t *testing.T) { } GlobalMock := mocket.Catcher.Reset() - GlobalMock.NewMock().WithReply(descriptionEntities) + GlobalMock.Logging = true + GlobalMock.NewMock().WithQuery("SELECT * FROM \"description_entities\" WHERE project = $1 AND domain = $2 AND name = $3 AND org = $4").WithReply(descriptionEntities) collection, err := descriptionEntityRepo.List(context.Background(), interfaces.ListResourceInput{}) assert.Equal(t, 0, len(collection.Entities)) @@ -76,6 +104,7 @@ func TestListDescriptionEntities(t *testing.T) { getEqualityFilter(common.Workflow, "project", project), getEqualityFilter(common.Workflow, "domain", domain), getEqualityFilter(common.Workflow, "name", name), + getEqualityFilter(common.Workflow, "org", ""), }, Limit: 20, }) diff --git a/flyteadmin/pkg/repositories/gormimpl/execution_repo.go b/flyteadmin/pkg/repositories/gormimpl/execution_repo.go index 2bd2168c78..00ce3e8db7 100644 --- a/flyteadmin/pkg/repositories/gormimpl/execution_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/execution_repo.go @@ -40,9 +40,8 @@ func (r *ExecutionRepo) Get(ctx context.Context, input interfaces.Identifier) (m Project: input.Project, Domain: input.Domain, Name: input.Name, - Org: input.Org, }, - }).Take(&execution) + }).Where(getExecutionOrgFilter(input.Org)).Take(&execution) timer.Stop() if tx.Error != nil && errors.Is(tx.Error, gorm.ErrRecordNotFound) { @@ -60,7 +59,7 @@ func (r *ExecutionRepo) Get(ctx context.Context, input interfaces.Identifier) (m func (r *ExecutionRepo) Update(ctx context.Context, execution models.Execution) error { timer := r.metrics.UpdateDuration.Start() - tx := r.db.WithContext(ctx).Model(&execution).Updates(execution) + tx := r.db.WithContext(ctx).Model(&execution).Where(getExecutionOrgFilter(execution.Org)).Updates(execution) timer.Stop() if err := tx.Error; err != nil { return r.errorTransformer.ToFlyteAdminError(err) diff --git a/flyteadmin/pkg/repositories/gormimpl/execution_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/execution_repo_test.go index 73c40272c7..692d0a40c7 100644 --- a/flyteadmin/pkg/repositories/gormimpl/execution_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/execution_repo_test.go @@ -47,10 +47,7 @@ func TestUpdateExecution(t *testing.T) { updated := false // Only match on queries that append expected filters - GlobalMock.NewMock().WithQuery(`UPDATE "executions" SET "updated_at"=$1,"execution_project"=$2,` + - `"execution_domain"=$3,"execution_name"=$4,"launch_plan_id"=$5,"workflow_id"=$6,"phase"=$7,"closure"=$8,` + - `"spec"=$9,"started_at"=$10,"execution_created_at"=$11,"execution_updated_at"=$12,"duration"=$13 WHERE "` + - `execution_project" = $14 AND "execution_domain" = $15 AND "execution_name" = $16`).WithCallback( + GlobalMock.NewMock().WithQuery(`UPDATE "executions" SET "updated_at"=$1,"execution_project"=$2,"execution_domain"=$3,"execution_name"=$4,"launch_plan_id"=$5,"workflow_id"=$6,"phase"=$7,"closure"=$8,"spec"=$9,"started_at"=$10,"execution_created_at"=$11,"execution_updated_at"=$12,"duration"=$13 WHERE "execution_org" = $14 AND "execution_project" = $15 AND "execution_domain" = $16 AND "execution_name" = $17`).WithCallback( func(s string, values []driver.NamedValue) { updated = true }, @@ -129,13 +126,12 @@ func TestGetExecution(t *testing.T) { GlobalMock.Logging = true // Only match on queries that append expected filters - GlobalMock.NewMock().WithQuery(`SELECT * FROM "executions" WHERE "executions"."execution_project" = $1 AND "executions"."execution_domain" = $2 AND "executions"."execution_name" = $3 AND "executions"."execution_org" = $4 LIMIT 1`).WithReply(executions) + GlobalMock.NewMock().WithQuery(`SELECT * FROM "executions" WHERE "executions"."execution_project" = $1 AND "executions"."execution_domain" = $2 AND "executions"."execution_name" = $3 AND "execution_org" = $4 LIMIT 1`).WithReply(executions) output, err := executionRepo.Get(context.Background(), interfaces.Identifier{ Project: "project", Domain: "domain", Name: "1", - Org: testOrg, }) assert.NoError(t, err) assert.EqualValues(t, expectedExecution, output) diff --git a/flyteadmin/pkg/repositories/gormimpl/launch_plan_repo.go b/flyteadmin/pkg/repositories/gormimpl/launch_plan_repo.go index 355460965d..cff302c399 100644 --- a/flyteadmin/pkg/repositories/gormimpl/launch_plan_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/launch_plan_repo.go @@ -41,7 +41,7 @@ func (r *LaunchPlanRepo) Create(ctx context.Context, input models.LaunchPlan) er func (r *LaunchPlanRepo) Update(ctx context.Context, input models.LaunchPlan) error { timer := r.metrics.UpdateDuration.Start() - tx := r.db.WithContext(ctx).Model(&input).Updates(input) + tx := r.db.WithContext(ctx).Model(&input).Where(getOrgFilter(input.Org)).Updates(input) timer.Stop() if err := tx.Error; err != nil { return r.errorTransformer.ToFlyteAdminError(err) @@ -58,9 +58,8 @@ func (r *LaunchPlanRepo) Get(ctx context.Context, input interfaces.Identifier) ( Domain: input.Domain, Name: input.Name, Version: input.Version, - Org: input.Org, }, - }).Take(&launchPlan) + }).Where(getOrgFilter(input.Org)).Take(&launchPlan) timer.Stop() if tx.Error != nil && errors.Is(tx.Error, gorm.ErrRecordNotFound) { @@ -90,7 +89,7 @@ func (r *LaunchPlanRepo) SetActive( // There is a launch plan to disable as part of this transaction if toDisable != nil { - tx.Model(&toDisable).UpdateColumns(toDisable) + tx.Model(&toDisable).Where(getOrgFilter(toDisable.Org)).UpdateColumns(toDisable) if err := tx.Error; err != nil { tx.Rollback() return r.errorTransformer.ToFlyteAdminError(err) @@ -98,7 +97,7 @@ func (r *LaunchPlanRepo) SetActive( } // And update the desired version. - tx.Model(&toEnable).UpdateColumns(toEnable) + tx.Model(&toEnable).Where(getOrgFilter(toEnable.Org)).UpdateColumns(toEnable) if err := tx.Error; err != nil { tx.Rollback() return r.errorTransformer.ToFlyteAdminError(err) diff --git a/flyteadmin/pkg/repositories/gormimpl/launch_plan_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/launch_plan_repo_test.go index 85a5a0b304..75c65986e7 100644 --- a/flyteadmin/pkg/repositories/gormimpl/launch_plan_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/launch_plan_repo_test.go @@ -64,7 +64,6 @@ func TestGetLaunchPlan(t *testing.T) { Domain: domain, Name: name, Version: version, - Org: testOrg, }, Spec: launchPlanSpec, WorkflowID: workflowID, @@ -77,20 +76,18 @@ func TestGetLaunchPlan(t *testing.T) { GlobalMock.Logging = true // Only match on queries that append expected filters GlobalMock.NewMock().WithQuery( - `SELECT * FROM "launch_plans" WHERE "launch_plans"."project" = $1 AND "launch_plans"."domain" = $2 AND "launch_plans"."name" = $3 AND "launch_plans"."version" = $4 AND "launch_plans"."org" = $5 LIMIT 1`).WithReply(launchPlans) + `SELECT * FROM "launch_plans" WHERE "launch_plans"."project" = $1 AND "launch_plans"."domain" = $2 AND "launch_plans"."name" = $3 AND "launch_plans"."version" = $4 AND "org" = $5 LIMIT 1`).WithReply(launchPlans) output, err := launchPlanRepo.Get(context.Background(), interfaces.Identifier{ Project: project, Domain: domain, Name: name, Version: version, - Org: testOrg, }) assert.NoError(t, err) assert.Equal(t, project, output.Project) assert.Equal(t, domain, output.Domain) assert.Equal(t, name, output.Name) assert.Equal(t, version, output.Version) - assert.Equal(t, testOrg, output.Org) assert.Equal(t, launchPlanSpec, output.Spec) } @@ -102,7 +99,7 @@ func TestSetInactiveLaunchPlan(t *testing.T) { mockDb := GlobalMock.NewMock() updated := false mockDb.WithQuery( - `UPDATE "launch_plans" SET "id"=$1,"updated_at"=$2,"project"=$3,"domain"=$4,"name"=$5,"version"=$6,"closure"=$7,"state"=$8 WHERE "project" = $9 AND "domain" = $10 AND "name" = $11 AND "version" = $12`).WithCallback( + `UPDATE "launch_plans" SET "id"=$1,"updated_at"=$2,"project"=$3,"domain"=$4,"name"=$5,"version"=$6,"closure"=$7,"state"=$8 WHERE "org" = $9 AND "project" = $10 AND "domain" = $11 AND "name" = $12 AND "version" = $13`).WithCallback( func(s string, values []driver.NamedValue) { updated = true }, @@ -133,7 +130,7 @@ func TestSetActiveLaunchPlan(t *testing.T) { mockQuery := GlobalMock.NewMock() updated := false mockQuery.WithQuery( - `UPDATE "launch_plans" SET "id"=$1,"project"=$2,"domain"=$3,"name"=$4,"version"=$5,"closure"=$6,"state"=$7 WHERE "project" = $8 AND "domain" = $9 AND "name" = $10 AND "version" = $11`).WithCallback( + `UPDATE "launch_plans" SET "id"=$1,"project"=$2,"domain"=$3,"name"=$4,"version"=$5,"closure"=$6,"state"=$7 WHERE "org" = $8 AND "project" = $9 AND "domain" = $10 AND "name" = $11 AND "version" = $12`).WithCallback( func(s string, values []driver.NamedValue) { updated = true }, @@ -176,7 +173,7 @@ func TestSetActiveLaunchPlan_NoCurrentlyActiveLaunchPlan(t *testing.T) { mockQuery := GlobalMock.NewMock() updated := false mockQuery.WithQuery( - `UPDATE "launch_plans" SET "id"=$1,"project"=$2,"domain"=$3,"name"=$4,"version"=$5,"closure"=$6,"state"=$7 WHERE "project" = $8 AND "domain" = $9 AND "name" = $10 AND "version" = $11`).WithCallback( + `UPDATE "launch_plans" SET "id"=$1,"project"=$2,"domain"=$3,"name"=$4,"version"=$5,"closure"=$6,"state"=$7 WHERE "org" = $8 AND "project" = $9 AND "domain" = $10 AND "name" = $11 AND "version" = $12`).WithCallback( func(s string, values []driver.NamedValue) { updated = true }, diff --git a/flyteadmin/pkg/repositories/gormimpl/named_entity_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/named_entity_repo_test.go index 8587ac0ef9..70b377c544 100644 --- a/flyteadmin/pkg/repositories/gormimpl/named_entity_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/named_entity_repo_test.go @@ -37,7 +37,6 @@ func TestGetNamedEntity(t *testing.T) { Project: project, Domain: domain, Name: name, - Org: testOrg, }, NamedEntityMetadataFields: models.NamedEntityMetadataFields{ Description: description, @@ -62,7 +61,6 @@ func TestGetNamedEntity(t *testing.T) { assert.Equal(t, name, output.Name) assert.Equal(t, resourceType, output.ResourceType) assert.Equal(t, description, output.Description) - assert.Equal(t, testOrg, output.Org) } func TestUpdateNamedEntity_WithExisting(t *testing.T) { diff --git a/flyteadmin/pkg/repositories/gormimpl/node_execution_repo.go b/flyteadmin/pkg/repositories/gormimpl/node_execution_repo.go index 3161f857bd..f855d8b52d 100644 --- a/flyteadmin/pkg/repositories/gormimpl/node_execution_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/node_execution_repo.go @@ -10,7 +10,6 @@ import ( adminErrors "github.com/flyteorg/flyte/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyte/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyte/flyteadmin/pkg/repositories/models" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flytestdlib/promutils" ) @@ -41,23 +40,14 @@ func (r *NodeExecutionRepo) Get(ctx context.Context, input interfaces.NodeExecut Project: input.NodeExecutionIdentifier.ExecutionId.Project, Domain: input.NodeExecutionIdentifier.ExecutionId.Domain, Name: input.NodeExecutionIdentifier.ExecutionId.Name, - Org: input.NodeExecutionIdentifier.ExecutionId.Org, }, }, - }).Take(&nodeExecution) + }).Where(getExecutionOrgFilter(input.NodeExecutionIdentifier.ExecutionId.Org)).Take(&nodeExecution) timer.Stop() if tx.Error != nil && errors.Is(tx.Error, gorm.ErrRecordNotFound) { return models.NodeExecution{}, - adminErrors.GetMissingEntityError("node execution", &core.NodeExecutionIdentifier{ - NodeId: input.NodeExecutionIdentifier.NodeId, - ExecutionId: &core.WorkflowExecutionIdentifier{ - Project: input.NodeExecutionIdentifier.ExecutionId.Project, - Domain: input.NodeExecutionIdentifier.ExecutionId.Domain, - Name: input.NodeExecutionIdentifier.ExecutionId.Name, - Org: input.NodeExecutionIdentifier.ExecutionId.Org, - }, - }) + adminErrors.GetMissingEntityError("node execution", &input.NodeExecutionIdentifier) } else if tx.Error != nil { return models.NodeExecution{}, r.errorTransformer.ToFlyteAdminError(tx.Error) } @@ -75,23 +65,14 @@ func (r *NodeExecutionRepo) GetWithChildren(ctx context.Context, input interface Project: input.NodeExecutionIdentifier.ExecutionId.Project, Domain: input.NodeExecutionIdentifier.ExecutionId.Domain, Name: input.NodeExecutionIdentifier.ExecutionId.Name, - Org: input.NodeExecutionIdentifier.ExecutionId.Org, }, }, - }).Preload("ChildNodeExecutions").Take(&nodeExecution) + }).Where(getExecutionOrgFilter(input.NodeExecutionIdentifier.ExecutionId.Org)).Preload("ChildNodeExecutions").Take(&nodeExecution) timer.Stop() if tx.Error != nil && errors.Is(tx.Error, gorm.ErrRecordNotFound) { return models.NodeExecution{}, - adminErrors.GetMissingEntityError("node execution", &core.NodeExecutionIdentifier{ - NodeId: input.NodeExecutionIdentifier.NodeId, - ExecutionId: &core.WorkflowExecutionIdentifier{ - Project: input.NodeExecutionIdentifier.ExecutionId.Project, - Domain: input.NodeExecutionIdentifier.ExecutionId.Domain, - Name: input.NodeExecutionIdentifier.ExecutionId.Name, - Org: input.NodeExecutionIdentifier.ExecutionId.Org, - }, - }) + adminErrors.GetMissingEntityError("node execution", &input.NodeExecutionIdentifier) } else if tx.Error != nil { return models.NodeExecution{}, r.errorTransformer.ToFlyteAdminError(tx.Error) } @@ -101,7 +82,7 @@ func (r *NodeExecutionRepo) GetWithChildren(ctx context.Context, input interface func (r *NodeExecutionRepo) Update(ctx context.Context, nodeExecution *models.NodeExecution) error { timer := r.metrics.UpdateDuration.Start() - tx := r.db.WithContext(ctx).Model(&nodeExecution).Updates(nodeExecution) + tx := r.db.WithContext(ctx).Model(&nodeExecution).Where(getExecutionOrgFilter(nodeExecution.Org)).Updates(nodeExecution) timer.Stop() if err := tx.Error; err != nil { return r.errorTransformer.ToFlyteAdminError(err) @@ -155,10 +136,9 @@ func (r *NodeExecutionRepo) Exists(ctx context.Context, input interfaces.NodeExe Project: input.NodeExecutionIdentifier.ExecutionId.Project, Domain: input.NodeExecutionIdentifier.ExecutionId.Domain, Name: input.NodeExecutionIdentifier.ExecutionId.Name, - Org: input.NodeExecutionIdentifier.ExecutionId.Org, }, }, - }).Take(&nodeExecution) + }).Where(getExecutionOrgFilter(input.NodeExecutionIdentifier.ExecutionId.Org)).Take(&nodeExecution) timer.Stop() if tx.Error != nil { return false, r.errorTransformer.ToFlyteAdminError(tx.Error) diff --git a/flyteadmin/pkg/repositories/gormimpl/node_execution_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/node_execution_repo_test.go index 7e911d31a6..7147e62aa2 100644 --- a/flyteadmin/pkg/repositories/gormimpl/node_execution_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/node_execution_repo_test.go @@ -65,7 +65,7 @@ func TestUpdateNodeExecution(t *testing.T) { GlobalMock := mocket.Catcher.Reset() // Only match on queries that append the name filter nodeExecutionQuery := GlobalMock.NewMock() - nodeExecutionQuery.WithQuery(`UPDATE "node_executions" SET "id"=$1,"updated_at"=$2,"execution_project"=$3,"execution_domain"=$4,"execution_name"=$5,"node_id"=$6,"phase"=$7,"input_uri"=$8,"closure"=$9,"started_at"=$10,"node_execution_created_at"=$11,"node_execution_updated_at"=$12,"duration"=$13 WHERE "execution_project" = $14 AND "execution_domain" = $15 AND "execution_name" = $16 AND "node_id" = $17`) + nodeExecutionQuery.WithQuery(`UPDATE "node_executions" SET "id"=$1,"updated_at"=$2,"execution_project"=$3,"execution_domain"=$4,"execution_name"=$5,"node_id"=$6,"phase"=$7,"input_uri"=$8,"closure"=$9,"started_at"=$10,"node_execution_created_at"=$11,"node_execution_updated_at"=$12,"duration"=$13 WHERE "execution_org" = $14 AND "execution_project" = $15 AND "execution_domain" = $16 AND "execution_name" = $17 AND "node_id" = $18`) err := nodeExecutionRepo.Update(context.Background(), &models.NodeExecution{ BaseModel: models.BaseModel{ID: 1}, @@ -138,7 +138,7 @@ func TestGetNodeExecution(t *testing.T) { GlobalMock := mocket.Catcher.Reset() GlobalMock.NewMock().WithQuery( - `SELECT * FROM "node_executions" WHERE "node_executions"."execution_project" = $1 AND "node_executions"."execution_domain" = $2 AND "node_executions"."execution_name" = $3 AND "node_executions"."node_id" = $4 LIMIT 1`).WithReply(nodeExecutions) + `SELECT * FROM "node_executions" WHERE "node_executions"."execution_project" = $1 AND "node_executions"."execution_domain" = $2 AND "node_executions"."execution_name" = $3 AND "node_executions"."node_id" = $4 AND "execution_org" = $5 LIMIT 1`).WithReply(nodeExecutions) output, err := nodeExecutionRepo.Get(context.Background(), interfaces.NodeExecutionResource{ NodeExecutionIdentifier: core.NodeExecutionIdentifier{ NodeId: "1", @@ -362,7 +362,7 @@ func TestNodeExecutionExists(t *testing.T) { GlobalMock := mocket.Catcher.Reset() GlobalMock.NewMock().WithQuery( - `SELECT "id" FROM "node_executions" WHERE "node_executions"."execution_project" = $1 AND "node_executions"."execution_domain" = $2 AND "node_executions"."execution_name" = $3 AND "node_executions"."node_id" = $4 LIMIT 1`).WithReply(nodeExecutions) + `SELECT "id" FROM "node_executions" WHERE "node_executions"."execution_project" = $1 AND "node_executions"."execution_domain" = $2 AND "node_executions"."execution_name" = $3 AND "node_executions"."node_id" = $4 AND "execution_org" = $5 LIMIT 1`).WithReply(nodeExecutions) exists, err := nodeExecutionRepo.Exists(context.Background(), interfaces.NodeExecutionResource{ NodeExecutionIdentifier: core.NodeExecutionIdentifier{ NodeId: "1", diff --git a/flyteadmin/pkg/repositories/gormimpl/project_repo.go b/flyteadmin/pkg/repositories/gormimpl/project_repo.go index 30f2935f0b..7da981ea0c 100644 --- a/flyteadmin/pkg/repositories/gormimpl/project_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/project_repo.go @@ -36,8 +36,7 @@ func (r *ProjectRepo) Get(ctx context.Context, projectID, org string) (models.Pr timer := r.metrics.GetDuration.Start() tx := r.db.WithContext(ctx).Where(&models.Project{ Identifier: projectID, - Org: org, - }).Take(&project) + }).Where(getOrgFilter(org)).Take(&project) timer.Stop() if errors.Is(tx.Error, gorm.ErrRecordNotFound) { return models.Project{}, flyteAdminErrors.NewFlyteAdminErrorf(codes.NotFound, "project [%s] not found", projectID) diff --git a/flyteadmin/pkg/repositories/gormimpl/project_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/project_repo_test.go index 252c852dda..9a5e03e0ed 100644 --- a/flyteadmin/pkg/repositories/gormimpl/project_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/project_repo_test.go @@ -50,26 +50,24 @@ func TestGetProject(t *testing.T) { response["name"] = "project_name" response["description"] = "project_description" response["state"] = admin.Project_ACTIVE - response["org"] = testOrg - output, err := projectRepo.Get(context.Background(), "project_id", testOrg) + output, err := projectRepo.Get(context.Background(), "project_id", "") assert.Empty(t, output) assert.EqualError(t, err, "project [project_id] not found") query := GlobalMock.NewMock() GlobalMock.Logging = true - query.WithQuery(`SELECT * FROM "projects" WHERE "projects"."identifier" = $1 AND "projects"."org" = $2 LIMIT 1`).WithReply( + query.WithQuery(`SELECT * FROM "projects" WHERE "projects"."identifier" = $1 AND "org" = $2 LIMIT 1`).WithReply( []map[string]interface{}{ response, }) - output, err = projectRepo.Get(context.Background(), "project_id", testOrg) + output, err = projectRepo.Get(context.Background(), "project_id", "") assert.Nil(t, err) assert.Equal(t, "project_id", output.Identifier) assert.Equal(t, "project_name", output.Name) assert.Equal(t, "project_description", output.Description) assert.Equal(t, int32(admin.Project_ACTIVE), *output.State) - assert.Equal(t, testOrg, output.Org) } func testListProjects(input interfaces.ListResourceInput, sql string, t *testing.T) { diff --git a/flyteadmin/pkg/repositories/gormimpl/resource_repo.go b/flyteadmin/pkg/repositories/gormimpl/resource_repo.go index daaa8a86b8..fe0d102f7b 100644 --- a/flyteadmin/pkg/repositories/gormimpl/resource_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/resource_repo.go @@ -62,7 +62,7 @@ func (r *ResourceRepo) CreateOrUpdate(ctx context.Context, input models.Resource } timer := r.metrics.GetDuration.Start() var record models.Resource - tx := r.db.WithContext(ctx).FirstOrCreate(&record, models.Resource{ + tx := r.db.WithContext(ctx).Where(getOrgFilter(input.Org)).FirstOrCreate(&record, models.Resource{ Project: input.Project, Domain: input.Domain, Workflow: input.Workflow, @@ -171,8 +171,7 @@ func (r *ResourceRepo) GetRaw(ctx context.Context, ID interfaces.ResourceID) (mo Workflow: ID.Workflow, LaunchPlan: ID.LaunchPlan, ResourceType: ID.ResourceType, - Org: ID.Org, - }).First(&model) + }).Where(getOrgFilter(ID.Org)).First(&model) timer.Stop() if tx.Error != nil && errors.Is(tx.Error, gorm.ErrRecordNotFound) { @@ -188,7 +187,7 @@ func (r *ResourceRepo) ListAll(ctx context.Context, resourceType, org string) ([ var resources []models.Resource timer := r.metrics.ListDuration.Start() - tx := r.db.WithContext(ctx).Where(&models.Resource{Org: org, ResourceType: resourceType}).Order(priorityDescending).Find(&resources) + tx := r.db.WithContext(ctx).Where(&models.Resource{ResourceType: resourceType}).Where(getOrgFilter(org)).Order(priorityDescending).Find(&resources) timer.Stop() if tx.Error != nil { @@ -206,8 +205,7 @@ func (r *ResourceRepo) Delete(ctx context.Context, ID interfaces.ResourceID) err Workflow: ID.Workflow, LaunchPlan: ID.LaunchPlan, ResourceType: ID.ResourceType, - Org: ID.Org, - }).Unscoped().Delete(models.Resource{}) + }).Where(getOrgFilter(ID.Org)).Unscoped().Delete(models.Resource{}) }) if tx.Error != nil && errors.Is(tx.Error, gorm.ErrRecordNotFound) { diff --git a/flyteadmin/pkg/repositories/gormimpl/resource_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/resource_repo_test.go index c6fb3d3c25..ae0f4896ce 100644 --- a/flyteadmin/pkg/repositories/gormimpl/resource_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/resource_repo_test.go @@ -63,7 +63,7 @@ func TestUpdateWorkflowAttributes_WithExisting(t *testing.T) { mockSelectQuery := GlobalMock.NewMock() mockSelectQuery.WithQuery( - `SELECT * FROM "resources" WHERE "resources"."project" = $1 AND "resources"."domain" = $2 AND "resources"."org" = $3 AND "resources"."resource_type" = $4 AND "resources"."priority" = $5 ORDER BY "resources"."id" LIMIT 1`).WithReply(results) + `SELECT * FROM "resources" WHERE "org" = $1 AND "resources"."project" = $2 AND "resources"."domain" = $3 AND "resources"."resource_type" = $4 AND "resources"."priority" = $5 ORDER BY "resources"."id" LIMIT 1`).WithReply(results) mockSaveQuery := GlobalMock.NewMock() mockSaveQuery.WithQuery( @@ -73,7 +73,6 @@ func TestUpdateWorkflowAttributes_WithExisting(t *testing.T) { ResourceType: resourceType.String(), Project: project, Domain: domain, - Org: testOrg, Priority: 2, }) assert.NoError(t, err) @@ -174,7 +173,7 @@ func TestGetRawWorkflowAttributes(t *testing.T) { response["attributes"] = []byte("attrs") query := GlobalMock.NewMock() - query.WithQuery(`SELECT * FROM "resources" WHERE "resources"."project" = $1 AND "resources"."domain" = $2 AND "resources"."workflow" = $3 AND "resources"."launch_plan" = $4 AND "resources"."resource_type" = $5 ORDER BY "resources"."id" LIMIT 1`).WithReply( + query.WithQuery(`SELECT * FROM "resources" WHERE "resources"."project" = $1 AND "resources"."domain" = $2 AND "resources"."workflow" = $3 AND "resources"."launch_plan" = $4 AND "resources"."resource_type" = $5 AND "org" = $6 ORDER BY "resources"."id" LIMIT 1`).WithReply( []map[string]interface{}{ response, }) @@ -217,7 +216,7 @@ func TestListAll(t *testing.T) { response["launch_plan"] = "launch_plan" response["attributes"] = []byte("attrs") - fakeResponse := query.WithQuery(`SELECT * FROM "resources" WHERE "resources"."org" = $1 AND "resources"."resource_type" = $2 ORDER BY priority desc`).WithReply( + fakeResponse := query.WithQuery(`SELECT * FROM "resources" WHERE "resources"."resource_type" = $1 AND "org" = $2 ORDER BY priority desc`).WithReply( []map[string]interface{}{response}) output, err := resourceRepo.ListAll(context.Background(), "resource", "org") assert.Nil(t, err) diff --git a/flyteadmin/pkg/repositories/gormimpl/signal_repo.go b/flyteadmin/pkg/repositories/gormimpl/signal_repo.go index dccfbda748..4152a43459 100644 --- a/flyteadmin/pkg/repositories/gormimpl/signal_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/signal_repo.go @@ -27,7 +27,7 @@ func (s *SignalRepo) Get(ctx context.Context, input models.SignalKey) (models.Si timer := s.metrics.GetDuration.Start() tx := s.db.Where(&models.Signal{ SignalKey: input, - }).Take(&signal) + }).Where(getExecutionOrgFilter(input.Org)).Take(&signal) timer.Stop() if errors.Is(tx.Error, gorm.ErrRecordNotFound) { return models.Signal{}, adminerrors.NewFlyteAdminError(codes.NotFound, "signal does not exist") @@ -41,7 +41,7 @@ func (s *SignalRepo) Get(ctx context.Context, input models.SignalKey) (models.Si // GetOrCreate returns a signal if it already exists, if not it creates a new one given the input func (s *SignalRepo) GetOrCreate(ctx context.Context, input *models.Signal) error { timer := s.metrics.CreateDuration.Start() - tx := s.db.FirstOrCreate(&input, input) + tx := s.db.Where(getExecutionOrgFilter(input.Org)).FirstOrCreate(&input, input) timer.Stop() if tx.Error != nil { return s.errorTransformer.ToFlyteAdminError(tx.Error) @@ -85,7 +85,7 @@ func (s *SignalRepo) Update(ctx context.Context, input models.SignalKey, value [ } timer := s.metrics.GetDuration.Start() - tx := s.db.Model(&signal).Select("value").Updates(signal) + tx := s.db.Model(&signal).Where(getExecutionOrgFilter(input.Org)).Select("value").Updates(signal) timer.Stop() if tx.Error != nil { return s.errorTransformer.ToFlyteAdminError(tx.Error) diff --git a/flyteadmin/pkg/repositories/gormimpl/signal_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/signal_repo_test.go index e7a63b5739..c6fc7fcc24 100644 --- a/flyteadmin/pkg/repositories/gormimpl/signal_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/signal_repo_test.go @@ -66,7 +66,7 @@ func TestGetSignal(t *testing.T) { mockSelectQuery := GlobalMock.NewMock() mockSelectQuery.WithQuery( - `SELECT * FROM "signals" WHERE "signals"."execution_project" = $1 AND "signals"."execution_domain" = $2 AND "signals"."execution_name" = $3 AND "signals"."signal_id" = $4 LIMIT 1`) + `SELECT * FROM "signals" WHERE "signals"."execution_project" = $1 AND "signals"."execution_domain" = $2 AND "signals"."execution_name" = $3 AND "signals"."signal_id" = $4 AND "execution_org" = $5 LIMIT 1`) // retrieve non-existent signalModel lookupSignalModel, err := signalRepo.Get(ctx, signalModel.SignalKey) @@ -110,7 +110,7 @@ func TestGetOrCreateSignal(t *testing.T) { signalModels := []map[string]interface{}{toSignalMap(*signalModel)} mockSelectQuery := GlobalMock.NewMock() mockSelectQuery.WithQuery( - `SELECT * FROM "signals" WHERE "signals"."id" = $1 AND "signals"."created_at" = $2 AND "signals"."updated_at" = $3 AND "signals"."execution_project" = $4 AND "signals"."execution_domain" = $5 AND "signals"."execution_name" = $6 AND "signals"."signal_id" = $7 AND "signals"."execution_project" = $8 AND "signals"."execution_domain" = $9 AND "signals"."execution_name" = $10 AND "signals"."signal_id" = $11 ORDER BY "signals"."id" LIMIT 1`).WithReply(signalModels) + `SELECT * FROM "signals" WHERE "execution_org" = $1 AND "signals"."id" = $2 AND "signals"."created_at" = $3 AND "signals"."updated_at" = $4 AND "signals"."execution_project" = $5 AND "signals"."execution_domain" = $6 AND "signals"."execution_name" = $7 AND "signals"."signal_id" = $8 AND "signals"."execution_project" = $9 AND "signals"."execution_domain" = $10 AND "signals"."execution_name" = $11 AND "signals"."signal_id" = $12 ORDER BY "signals"."id" LIMIT 1`).WithReply(signalModels) // retrieve existing signalModel lookupSignalModel := &models.Signal{} @@ -165,7 +165,7 @@ func TestUpdateSignal(t *testing.T) { // update signalModel does not exits mockUpdateQuery := GlobalMock.NewMock() mockUpdateQuery.WithQuery( - `UPDATE "signals" SET "updated_at"=$1,"value"=$2 WHERE "execution_project" = $3 AND "execution_domain" = $4 AND "execution_name" = $5 AND "signal_id" = $6`).WithRowsNum(0) + `UPDATE "signals" SET "updated_at"=$1,"value"=$2 WHERE "execution_org" = $3 AND "execution_project" = $4 AND "execution_domain" = $5 AND "execution_name" = $6 AND "signal_id" = $7`).WithRowsNum(0) err := signalRepo.Update(ctx, signalModel.SignalKey, signalModel.Value) assert.Error(t, err) diff --git a/flyteadmin/pkg/repositories/gormimpl/task_execution_repo.go b/flyteadmin/pkg/repositories/gormimpl/task_execution_repo.go index ac0f75ba1a..0a4b80a298 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_execution_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_execution_repo.go @@ -40,7 +40,6 @@ func (r *TaskExecutionRepo) Get(ctx context.Context, input interfaces.GetTaskExe Domain: input.TaskExecutionID.TaskId.Domain, Name: input.TaskExecutionID.TaskId.Name, Version: input.TaskExecutionID.TaskId.Version, - Org: input.TaskExecutionID.TaskId.Org, }, NodeExecutionKey: models.NodeExecutionKey{ NodeID: input.TaskExecutionID.NodeExecutionId.NodeId, @@ -48,12 +47,12 @@ func (r *TaskExecutionRepo) Get(ctx context.Context, input interfaces.GetTaskExe Project: input.TaskExecutionID.NodeExecutionId.ExecutionId.Project, Domain: input.TaskExecutionID.NodeExecutionId.ExecutionId.Domain, Name: input.TaskExecutionID.NodeExecutionId.ExecutionId.Name, - Org: input.TaskExecutionID.NodeExecutionId.ExecutionId.Org, }, }, RetryAttempt: &input.TaskExecutionID.RetryAttempt, }, - }).Preload("ChildNodeExecution").Take(&taskExecution) + }).Where(getExecutionOrgFilter(input.TaskExecutionID.NodeExecutionId.ExecutionId.Org)). + Where(getOrgFilter(input.TaskExecutionID.TaskId.Org)).Preload("ChildNodeExecution").Take(&taskExecution) timer.Stop() if tx.Error != nil && errors.Is(tx.Error, gorm.ErrRecordNotFound) { @@ -82,7 +81,8 @@ func (r *TaskExecutionRepo) Get(ctx context.Context, input interfaces.GetTaskExe func (r *TaskExecutionRepo) Update(ctx context.Context, execution models.TaskExecution) error { timer := r.metrics.UpdateDuration.Start() - tx := r.db.WithContext(ctx).WithContext(ctx).Updates(&execution) // TODO @hmaersaw - need to add WithContext to all db calls to link otel spans + tx := r.db.WithContext(ctx).WithContext(ctx).Where(getOrgFilter(execution.Org)).Where(getExecutionOrgFilter(execution.ExecutionKey.Org)). + Updates(&execution) // TODO @hmaersaw - need to add WithContext to all db calls to link otel spans timer.Stop() if err := tx.Error; err != nil { diff --git a/flyteadmin/pkg/repositories/gormimpl/task_execution_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/task_execution_repo_test.go index a5a9415277..d5046984fe 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_execution_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_execution_repo_test.go @@ -85,7 +85,7 @@ func TestUpdateTaskExecution(t *testing.T) { GlobalMock.Logging = true taskExecutionQuery := GlobalMock.NewMock() - taskExecutionQuery.WithQuery(`UPDATE "task_executions" SET "updated_at"=$1,"phase"=$2,"input_uri"=$3,"closure"=$4,"started_at"=$5,"task_execution_created_at"=$6,"task_execution_updated_at"=$7,"duration"=$8 WHERE "project" = $9 AND "domain" = $10 AND "name" = $11 AND "version" = $12 AND "execution_project" = $13 AND "execution_domain" = $14 AND "execution_name" = $15 AND "node_id" = $16 AND "retry_attempt" = $17`) + taskExecutionQuery.WithQuery(`UPDATE "task_executions" SET "updated_at"=$1,"phase"=$2,"input_uri"=$3,"closure"=$4,"started_at"=$5,"task_execution_created_at"=$6,"task_execution_updated_at"=$7,"duration"=$8 WHERE "org" = $9 AND "execution_org" = $10 AND "project" = $11 AND "domain" = $12 AND "name" = $13 AND "version" = $14 AND "execution_project" = $15 AND "execution_domain" = $16 AND "execution_name" = $17 AND "node_id" = $18 AND "retry_attempt" = $19`) err := taskExecutionRepo.Update(context.Background(), testTaskExecution) assert.NoError(t, err) assert.True(t, taskExecutionQuery.Triggered) @@ -100,7 +100,7 @@ func TestGetTaskExecution(t *testing.T) { GlobalMock := mocket.Catcher.Reset() GlobalMock.Logging = true GlobalMock.NewMock().WithQuery( - `SELECT * FROM "task_executions" WHERE "task_executions"."project" = $1 AND "task_executions"."domain" = $2 AND "task_executions"."name" = $3 AND "task_executions"."version" = $4 AND "task_executions"."execution_project" = $5 AND "task_executions"."execution_domain" = $6 AND "task_executions"."execution_name" = $7 AND "task_executions"."node_id" = $8 AND "task_executions"."retry_attempt" = $9 LIMIT 1`). + `SELECT * FROM "task_executions" WHERE "task_executions"."project" = $1 AND "task_executions"."domain" = $2 AND "task_executions"."name" = $3 AND "task_executions"."version" = $4 AND "task_executions"."execution_project" = $5 AND "task_executions"."execution_domain" = $6 AND "task_executions"."execution_name" = $7 AND "task_executions"."node_id" = $8 AND "task_executions"."retry_attempt" = $9 AND "execution_org" = $10 AND "org" = $11 LIMIT 1`). WithReply(taskExecutions) output, err := taskExecutionRepo.Get(context.Background(), interfaces.GetTaskExecutionInput{ diff --git a/flyteadmin/pkg/repositories/gormimpl/task_repo.go b/flyteadmin/pkg/repositories/gormimpl/task_repo.go index bdb9958702..8dfaf59135 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_repo.go @@ -55,9 +55,8 @@ func (r *TaskRepo) Get(ctx context.Context, input interfaces.Identifier) (models Domain: input.Domain, Name: input.Name, Version: input.Version, - Org: input.Org, }, - }).Take(&task) + }).Where(getOrgFilter(input.Org)).Take(&task) timer.Stop() if errors.Is(tx.Error, gorm.ErrRecordNotFound) { return models.Task{}, flyteAdminDbErrors.GetMissingEntityError(core.ResourceType_TASK.String(), &core.Identifier{ diff --git a/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go index 1e8ad0fbea..2ce0a17ecf 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go @@ -64,7 +64,7 @@ func TestGetTask(t *testing.T) { GlobalMock.Logging = true // Only match on queries that append expected filters GlobalMock.NewMock().WithQuery( - `SELECT * FROM "tasks" WHERE "tasks"."project" = $1 AND "tasks"."domain" = $2 AND "tasks"."name" = $3 AND "tasks"."version" = $4 AND "tasks"."org" = $5 LIMIT 1`). + `SELECT * FROM "tasks" WHERE "tasks"."project" = $1 AND "tasks"."domain" = $2 AND "tasks"."name" = $3 AND "tasks"."version" = $4 AND "org" = $5`). WithReply(tasks) output, err = taskRepo.Get(context.Background(), interfaces.Identifier{ Project: project, diff --git a/flyteadmin/pkg/repositories/gormimpl/workflow_repo.go b/flyteadmin/pkg/repositories/gormimpl/workflow_repo.go index 9642df6f93..c6956e6853 100644 --- a/flyteadmin/pkg/repositories/gormimpl/workflow_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/workflow_repo.go @@ -49,9 +49,8 @@ func (r *WorkflowRepo) Get(ctx context.Context, input interfaces.Identifier) (mo Domain: input.Domain, Name: input.Name, Version: input.Version, - Org: input.Org, }, - }).Take(&workflow) + }).Where(getOrgFilter(input.Org)).Take(&workflow) timer.Stop() if tx.Error != nil && errors.Is(tx.Error, gorm.ErrRecordNotFound) { diff --git a/flyteadmin/pkg/repositories/gormimpl/workflow_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/workflow_repo_test.go index fa4b6f4c85..04fb3bb9a3 100644 --- a/flyteadmin/pkg/repositories/gormimpl/workflow_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/workflow_repo_test.go @@ -55,7 +55,7 @@ func TestGetWorkflow(t *testing.T) { GlobalMock := mocket.Catcher.Reset() // Only match on queries that append expected filters GlobalMock.NewMock().WithQuery( - `SELECT * FROM "workflows" WHERE "workflows"."project" = $1 AND "workflows"."domain" = $2 AND "workflows"."name" = $3 AND "workflows"."version" = $4 AND "workflows"."org" = $5 LIMIT 1`).WithReply(workflows) + `SELECT * FROM "workflows" WHERE "workflows"."project" = $1 AND "workflows"."domain" = $2 AND "workflows"."name" = $3 AND "workflows"."version" = $4 AND "org" = $5 LIMIT 1`).WithReply(workflows) output, err := workflowRepo.Get(context.Background(), interfaces.Identifier{ Project: project, Domain: domain, diff --git a/flyteadmin/scheduler/repositories/gormimpl/common.go b/flyteadmin/scheduler/repositories/gormimpl/common.go new file mode 100644 index 0000000000..d3c336ec9a --- /dev/null +++ b/flyteadmin/scheduler/repositories/gormimpl/common.go @@ -0,0 +1,9 @@ +package gormimpl + +const ( + orgColumn = "org" +) + +func getOrgFilter(org string) map[string]interface{} { + return map[string]interface{}{orgColumn: org} +} diff --git a/flyteadmin/scheduler/repositories/gormimpl/schedulable_entity_repo.go b/flyteadmin/scheduler/repositories/gormimpl/schedulable_entity_repo.go index 6c27474974..ada25f5789 100644 --- a/flyteadmin/scheduler/repositories/gormimpl/schedulable_entity_repo.go +++ b/flyteadmin/scheduler/repositories/gormimpl/schedulable_entity_repo.go @@ -24,7 +24,7 @@ type SchedulableEntityRepo struct { func (r *SchedulableEntityRepo) Create(ctx context.Context, input models.SchedulableEntity) error { timer := r.metrics.GetDuration.Start() var record models.SchedulableEntity - tx := r.db.Omit("id").FirstOrCreate(&record, input) + tx := r.db.Omit("id").FirstOrCreate(&record, input).Where(getOrgFilter(input.Org)) timer.Stop() if tx.Error != nil { return r.errorTransformer.ToFlyteAdminError(tx.Error) @@ -42,9 +42,8 @@ func (r *SchedulableEntityRepo) Activate(ctx context.Context, input models.Sched Domain: input.Domain, Name: input.Name, Version: input.Version, - Org: input.Org, }, - }).Take(&schedulableEntity) + }).Where(getOrgFilter(input.Org)).Take(&schedulableEntity) timer.Stop() if tx.Error != nil { @@ -92,9 +91,8 @@ func (r *SchedulableEntityRepo) Get(ctx context.Context, ID models.SchedulableEn Domain: ID.Domain, Name: ID.Name, Version: ID.Version, - Org: ID.Org, }, - }).Take(&schedulableEntity) + }).Where(getOrgFilter(ID.Org)).Take(&schedulableEntity) timer.Stop() if tx.Error != nil { @@ -123,9 +121,8 @@ func activateOrDeactivate(r *SchedulableEntityRepo, ID models.SchedulableEntityK Domain: ID.Domain, Name: ID.Name, Version: ID.Version, - Org: ID.Org, }, - }).Update("active", activate) + }).Where(getOrgFilter(ID.Org)).Update("active", activate) timer.Stop() if tx.Error != nil { if errors.Is(tx.Error, gorm.ErrRecordNotFound) {