diff --git a/flyteadmin/auth/init_secrets.go b/flyteadmin/auth/init_secrets.go index 64e95c0d1..e272a17cf 100644 --- a/flyteadmin/auth/init_secrets.go +++ b/flyteadmin/auth/init_secrets.go @@ -33,7 +33,8 @@ var ( // GetInitSecretsCommand creates a command to issue secrets to be used for Auth settings. It writes the secrets to the // working directory. The expectation is that they are put in a location and made available to the serve command later. // To configure where the serve command looks for secrets, update this config: -// secrets: +// +// secrets: // secrets-prefix: func GetInitSecretsCommand() *cobra.Command { cmd := &cobra.Command{ diff --git a/flyteadmin/pkg/clusterresource/controller.go b/flyteadmin/pkg/clusterresource/controller.go index 1fb55e058..7e85e7188 100644 --- a/flyteadmin/pkg/clusterresource/controller.go +++ b/flyteadmin/pkg/clusterresource/controller.go @@ -269,10 +269,10 @@ func prepareDynamicCreate(target executioncluster.ExecutionTarget, config string // This function loops through the kubernetes resource template files in the configured template directory. // For each unapplied template file (wrt the namespace) this func attempts to -// 1) create k8s object resource from template by performing: -// a) read template file -// b) substitute templatized variables with their resolved values -// 2) create the resource on the kubernetes cluster and cache successful outcomes +// 1. create k8s object resource from template by performing: +// a) read template file +// b) substitute templatized variables with their resolved values +// 2. create the resource on the kubernetes cluster and cache successful outcomes func (c *controller) syncNamespace(ctx context.Context, project *admin.Project, domain *admin.Domain, namespace NamespaceName, templateValues, customTemplateValues templateValuesType) error { templateDir := c.config.ClusterResourceConfiguration().GetTemplatePath() @@ -445,8 +445,9 @@ func addResourceVersion(patch []byte, rv string) ([]byte, error) { } // createResourceFromTemplate this method perform following processes: -// 1) read template file pointed by templateDir and templateFileName -// 2) substitute templatized variables with their resolved values +// 1. read template file pointed by templateDir and templateFileName +// 2. substitute templatized variables with their resolved values +// // the method will return the kubernetes raw manifest func (c *controller) createResourceFromTemplate(ctx context.Context, templateDir string, templateFileName string, project *admin.Project, domain *admin.Domain, namespace NamespaceName, diff --git a/flyteadmin/pkg/errors/errors.go b/flyteadmin/pkg/errors/errors.go index 9cf9d5a2d..30922ae02 100644 --- a/flyteadmin/pkg/errors/errors.go +++ b/flyteadmin/pkg/errors/errors.go @@ -41,7 +41,8 @@ func (e *flyteAdminErrorImpl) String() string { } // enclose the error in the format that grpc server expect from golang: -// https://github.com/grpc/grpc-go/blob/master/status/status.go#L133 +// +// https://github.com/grpc/grpc-go/blob/master/status/status.go#L133 func (e *flyteAdminErrorImpl) WithDetails(details *admin.EventFailureReason) (FlyteAdminError, error) { s, err := e.status.WithDetails(details) if err != nil { diff --git a/flyteadmin/pkg/manager/impl/execution_manager_test.go b/flyteadmin/pkg/manager/impl/execution_manager_test.go index aa73c8654..fa7e8fd26 100644 --- a/flyteadmin/pkg/manager/impl/execution_manager_test.go +++ b/flyteadmin/pkg/manager/impl/execution_manager_test.go @@ -2294,7 +2294,7 @@ func TestUpdateExecution(t *testing.T) { updateExecFuncCalled = true return nil } - repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetUpdateExecutionCallback(updateExecFunc) + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetUpdateCallback(updateExecFunc) r := plugins.NewRegistry() r.RegisterDefault(plugins.PluginIDWorkflowExecutor, &defaultTestExecutor) execManager := NewExecutionManager(repository, r, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}) @@ -2315,7 +2315,7 @@ func TestUpdateExecution(t *testing.T) { updateExecFuncCalled = true return nil } - repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetUpdateExecutionCallback(updateExecFunc) + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetUpdateCallback(updateExecFunc) r := plugins.NewRegistry() r.RegisterDefault(plugins.PluginIDWorkflowExecutor, &defaultTestExecutor) execManager := NewExecutionManager(repository, r, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}) @@ -2333,7 +2333,7 @@ func TestUpdateExecution(t *testing.T) { updateExecFunc := func(ctx context.Context, execModel models.Execution) error { return fmt.Errorf("some db error") } - repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetUpdateExecutionCallback(updateExecFunc) + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetUpdateCallback(updateExecFunc) r := plugins.NewRegistry() r.RegisterDefault(plugins.PluginIDWorkflowExecutor, &defaultTestExecutor) execManager := NewExecutionManager(repository, r, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}) @@ -2818,7 +2818,7 @@ func TestTerminateExecution(t *testing.T) { }, unmarshaledClosure.GetAbortMetadata())) return nil } - repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetUpdateExecutionCallback(updateExecutionFunc) + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetUpdateCallback(updateExecutionFunc) mockExecutor := workflowengineMocks.WorkflowExecutor{} mockExecutor.OnAbortMatch(mock.Anything, mock.MatchedBy(func(data workflowengineInterfaces.AbortData) bool { @@ -2860,7 +2860,7 @@ func TestTerminateExecution_PropellerError(t *testing.T) { updateCalled := false repository := repositoryMocks.NewMockRepository() - repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetUpdateExecutionCallback(func( + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetUpdateCallback(func( context context.Context, execution models.Execution) error { updateCalled = true assert.Equal(t, core.WorkflowExecution_ABORTING.String(), execution.Phase) @@ -2892,7 +2892,7 @@ func TestTerminateExecution_DatabaseError(t *testing.T) { context context.Context, execution models.Execution) error { return expectedError } - repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetUpdateExecutionCallback(updateExecutionFunc) + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetUpdateCallback(updateExecutionFunc) mockExecutor := workflowengineMocks.WorkflowExecutor{} mockExecutor.OnAbortMatch(mock.Anything, mock.Anything).Return(nil) mockExecutor.OnID().Return("testMockExecutor") diff --git a/flyteadmin/pkg/manager/impl/executions/quality_of_service.go b/flyteadmin/pkg/manager/impl/executions/quality_of_service.go index 7fa1fb94f..490646612 100644 --- a/flyteadmin/pkg/manager/impl/executions/quality_of_service.go +++ b/flyteadmin/pkg/manager/impl/executions/quality_of_service.go @@ -64,11 +64,11 @@ func (q qualityOfServiceAllocator) getQualityOfServiceFromDb(ctx context.Context /* Users can specify the quality of service for an execution (in order of decreasing specificity) -- At CreateExecution request time -- In the LaunchPlan spec -- In the Workflow spec -- As an overridable MatchableResource (https://lyft.github.io/flyte/administrator/install/managing_customizable_resources.html) - for the underlying workflow + - At CreateExecution request time + - In the LaunchPlan spec + - In the Workflow spec + - As an overridable MatchableResource (https://lyft.github.io/flyte/administrator/install/managing_customizable_resources.html) + for the underlying workflow System administrators can specify default QualityOfService specs (https://github.com/flyteorg/flyteidl/blob/e9727afcedf8d4c30a1fc2eeac45593e426d9bb0/protos/flyteidl/core/execution.proto#L92)s diff --git a/flyteadmin/pkg/manager/impl/node_execution_manager_test.go b/flyteadmin/pkg/manager/impl/node_execution_manager_test.go index 7e243101e..477bdcb16 100644 --- a/flyteadmin/pkg/manager/impl/node_execution_manager_test.go +++ b/flyteadmin/pkg/manager/impl/node_execution_manager_test.go @@ -410,7 +410,7 @@ func TestTransformNodeExecutionModel(t *testing.T) { ExecutionId: &workflowExecutionIdentifier, } t.Run("event version 0", func(t *testing.T) { - repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).GetWithChildrenFunction = + repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetWithChildrenCallback( func(ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { assert.True(t, proto.Equal(nodeExecID, &input.NodeExecutionIdentifier)) return models.NodeExecution{ @@ -432,7 +432,7 @@ func TestTransformNodeExecutionModel(t *testing.T) { }, }, }, nil - } + }) manager := NodeExecutionManager{ db: repository, @@ -484,11 +484,11 @@ func TestTransformNodeExecutionModel(t *testing.T) { }) t.Run("get with children err", func(t *testing.T) { expectedErr := flyteAdminErrors.NewFlyteAdminError(codes.Internal, "foo") - repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).GetWithChildrenFunction = + repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetWithChildrenCallback( func(ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { assert.True(t, proto.Equal(nodeExecID, &input.NodeExecutionIdentifier)) return models.NodeExecution{}, expectedErr - } + }) manager := NodeExecutionManager{ db: repository, @@ -501,7 +501,7 @@ func TestTransformNodeExecutionModel(t *testing.T) { func TestTransformNodeExecutionModelList(t *testing.T) { ctx := context.TODO() repository := repositoryMocks.NewMockRepository() - repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).GetWithChildrenFunction = + repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetWithChildrenCallback( func(ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { return models.NodeExecution{ NodeExecutionKey: models.NodeExecutionKey{ @@ -522,7 +522,7 @@ func TestTransformNodeExecutionModelList(t *testing.T) { }, }, }, nil - } + }) manager := NodeExecutionManager{ db: repository, @@ -600,45 +600,46 @@ func TestGetNodeExecutionParentNode(t *testing.T) { } metadataBytes, _ := proto.Marshal(&expectedMetadata) closureBytes, _ := proto.Marshal(&expectedClosure) - repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).GetWithChildrenFunction = func( - ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { - workflowExecutionIdentifier := core.WorkflowExecutionIdentifier{ - Project: "project", - Domain: "domain", - Name: "name", - } - assert.True(t, proto.Equal(&core.NodeExecutionIdentifier{ - NodeId: "node id", - ExecutionId: &workflowExecutionIdentifier, - }, &input.NodeExecutionIdentifier)) - return models.NodeExecution{ - NodeExecutionKey: models.NodeExecutionKey{ - NodeID: "node id", - ExecutionKey: models.ExecutionKey{ - Project: "project", - Domain: "domain", - Name: "name", + repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetWithChildrenCallback( + func( + ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { + workflowExecutionIdentifier := core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + } + assert.True(t, proto.Equal(&core.NodeExecutionIdentifier{ + NodeId: "node id", + ExecutionId: &workflowExecutionIdentifier, + }, &input.NodeExecutionIdentifier)) + return models.NodeExecution{ + NodeExecutionKey: models.NodeExecutionKey{ + NodeID: "node id", + ExecutionKey: models.ExecutionKey{ + Project: "project", + Domain: "domain", + Name: "name", + }, }, - }, - Phase: core.NodeExecution_SUCCEEDED.String(), - InputURI: "input uri", - StartedAt: &occurredAt, - Closure: closureBytes, - NodeExecutionMetadata: metadataBytes, - ChildNodeExecutions: []models.NodeExecution{ - { - NodeExecutionKey: models.NodeExecutionKey{ - NodeID: "node-child", - ExecutionKey: models.ExecutionKey{ - Project: "project", - Domain: "domain", - Name: "name", + Phase: core.NodeExecution_SUCCEEDED.String(), + InputURI: "input uri", + StartedAt: &occurredAt, + Closure: closureBytes, + NodeExecutionMetadata: metadataBytes, + ChildNodeExecutions: []models.NodeExecution{ + { + NodeExecutionKey: models.NodeExecutionKey{ + NodeID: "node-child", + ExecutionKey: models.ExecutionKey{ + Project: "project", + Domain: "domain", + Name: "name", + }, }, }, }, - }, - }, nil - } + }, nil + }) nodeExecManager := NewNodeExecutionManager(repository, getMockExecutionsConfigProvider(), make([]string, 0), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockNodeExecutionRemoteURL, nil, nil, &eventWriterMocks.NodeExecutionEventWriter{}) nodeExecution, err := nodeExecManager.GetNodeExecution(context.Background(), admin.NodeExecutionGetRequest{ Id: &nodeExecutionIdentifier, @@ -664,33 +665,34 @@ func TestGetNodeExecutionEventVersion0(t *testing.T) { } metadataBytes, _ := proto.Marshal(&expectedMetadata) closureBytes, _ := proto.Marshal(&expectedClosure) - repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).GetWithChildrenFunction = func( - ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { - workflowExecutionIdentifier := core.WorkflowExecutionIdentifier{ - Project: "project", - Domain: "domain", - Name: "name", - } - assert.True(t, proto.Equal(&core.NodeExecutionIdentifier{ - NodeId: "node id", - ExecutionId: &workflowExecutionIdentifier, - }, &input.NodeExecutionIdentifier)) - return models.NodeExecution{ - NodeExecutionKey: models.NodeExecutionKey{ - NodeID: "node id", - ExecutionKey: models.ExecutionKey{ - Project: "project", - Domain: "domain", - Name: "name", + repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetWithChildrenCallback( + func( + ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { + workflowExecutionIdentifier := core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + } + assert.True(t, proto.Equal(&core.NodeExecutionIdentifier{ + NodeId: "node id", + ExecutionId: &workflowExecutionIdentifier, + }, &input.NodeExecutionIdentifier)) + return models.NodeExecution{ + NodeExecutionKey: models.NodeExecutionKey{ + NodeID: "node id", + ExecutionKey: models.ExecutionKey{ + Project: "project", + Domain: "domain", + Name: "name", + }, }, - }, - Phase: core.NodeExecution_SUCCEEDED.String(), - InputURI: "input uri", - StartedAt: &occurredAt, - Closure: closureBytes, - NodeExecutionMetadata: metadataBytes, - }, nil - } + Phase: core.NodeExecution_SUCCEEDED.String(), + InputURI: "input uri", + StartedAt: &occurredAt, + Closure: closureBytes, + NodeExecutionMetadata: metadataBytes, + }, nil + }) nodeExecManager := NewNodeExecutionManager(repository, getMockExecutionsConfigProvider(), make([]string, 0), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockNodeExecutionRemoteURL, nil, nil, &eventWriterMocks.NodeExecutionEventWriter{}) nodeExecution, err := nodeExecManager.GetNodeExecution(context.Background(), admin.NodeExecutionGetRequest{ @@ -809,24 +811,25 @@ func TestListNodeExecutionsLevelZero(t *testing.T) { }, }, nil }) - repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).GetWithChildrenFunction = func( - ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { - return models.NodeExecution{ - NodeExecutionKey: models.NodeExecutionKey{ - NodeID: "node id", - ExecutionKey: models.ExecutionKey{ - Project: "project", - Domain: "domain", - Name: "name", + repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetWithChildrenCallback( + func( + ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { + return models.NodeExecution{ + NodeExecutionKey: models.NodeExecutionKey{ + NodeID: "node id", + ExecutionKey: models.ExecutionKey{ + Project: "project", + Domain: "domain", + Name: "name", + }, }, - }, - Phase: core.NodeExecution_SUCCEEDED.String(), - InputURI: "input uri", - StartedAt: &occurredAt, - Closure: closureBytes, - NodeExecutionMetadata: metadataBytes, - }, nil - } + Phase: core.NodeExecution_SUCCEEDED.String(), + InputURI: "input uri", + StartedAt: &occurredAt, + Closure: closureBytes, + NodeExecutionMetadata: metadataBytes, + }, nil + }) nodeExecManager := NewNodeExecutionManager(repository, getMockExecutionsConfigProvider(), make([]string, 0), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockNodeExecutionRemoteURL, nil, nil, &eventWriterMocks.NodeExecutionEventWriter{}) nodeExecutions, err := nodeExecManager.ListNodeExecutions(context.Background(), admin.NodeExecutionListRequest{ WorkflowExecutionId: &core.WorkflowExecutionIdentifier{ diff --git a/flyteadmin/pkg/manager/impl/task_execution_manager_test.go b/flyteadmin/pkg/manager/impl/task_execution_manager_test.go index 34f556e64..ec3bb434a 100644 --- a/flyteadmin/pkg/manager/impl/task_execution_manager_test.go +++ b/flyteadmin/pkg/manager/impl/task_execution_manager_test.go @@ -304,10 +304,11 @@ func TestCreateTaskEvent_MissingExecution(t *testing.T) { func(ctx context.Context, input interfaces.GetTaskExecutionInput) (models.TaskExecution, error) { return models.TaskExecution{}, flyteAdminErrors.NewFlyteAdminError(codes.NotFound, "foo") }) - repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).ExistsFunction = func( - ctx context.Context, input interfaces.NodeExecutionResource) (bool, error) { - return false, expectedErr - } + repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetExistsCallback( + func( + ctx context.Context, input interfaces.NodeExecutionResource) (bool, error) { + return false, expectedErr + }) taskExecManager := NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, nil, nil) resp, err := taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest) assert.EqualError(t, err, "Failed to get existing node execution id: [node_id:\"node-id\""+ @@ -315,10 +316,11 @@ func TestCreateTaskEvent_MissingExecution(t *testing.T) { "with err: expected error") assert.Nil(t, resp) - repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).ExistsFunction = func( - ctx context.Context, input interfaces.NodeExecutionResource) (bool, error) { - return false, nil - } + repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetExistsCallback( + func( + ctx context.Context, input interfaces.NodeExecutionResource) (bool, error) { + return false, nil + }) taskExecManager = NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, nil, nil) resp, err = taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest) assert.EqualError(t, err, "failed to get existing node execution id: [node_id:\"node-id\""+ diff --git a/flyteadmin/pkg/repositories/errors/postgres.go b/flyteadmin/pkg/repositories/errors/postgres.go index a61c61345..e85becb66 100644 --- a/flyteadmin/pkg/repositories/errors/postgres.go +++ b/flyteadmin/pkg/repositories/errors/postgres.go @@ -2,7 +2,9 @@ // This errors utility translates postgres application error codes into internal error types. // The go postgres driver defines possible error codes here: https://github.com/lib/pq/blob/master/error.go // And the postgres standard defines error responses here: -// https://www.postgresql.org/docs/current/static/protocol-error-fields.html +// +// https://www.postgresql.org/docs/current/static/protocol-error-fields.html +// // Inspired by https://www.codementor.io/tamizhvendan/managing-data-in-golang-using-gorm-part-1-a9cdjb8nb package errors diff --git a/flyteadmin/pkg/repositories/gormimpl/execution_repo.go b/flyteadmin/pkg/repositories/gormimpl/execution_repo.go index 7e3b58c12..e300dcd33 100644 --- a/flyteadmin/pkg/repositories/gormimpl/execution_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/execution_repo.go @@ -110,6 +110,41 @@ func (r *ExecutionRepo) List(ctx context.Context, input interfaces.ListResourceI }, nil } +func (r *ExecutionRepo) Count(ctx context.Context, input interfaces.CountResourceInput) (int64, error) { + var err error + tx := r.db.Model(&models.Execution{}) + + // Add join condition as required by user-specified filters (which can potentially include join table attrs). + if ok := input.JoinTableEntities[common.LaunchPlan]; ok { + tx = tx.Joins(fmt.Sprintf("INNER JOIN %s ON %s.launch_plan_id = %s.id", + launchPlanTableName, executionTableName, launchPlanTableName)) + } + if ok := input.JoinTableEntities[common.Workflow]; ok { + tx = tx.Joins(fmt.Sprintf("INNER JOIN %s ON %s.workflow_id = %s.id", + workflowTableName, executionTableName, workflowTableName)) + } + if ok := input.JoinTableEntities[common.Task]; ok { + tx = tx.Joins(fmt.Sprintf("INNER JOIN %s ON %s.task_id = %s.id", + taskTableName, executionTableName, taskTableName)) + } + + // Apply filters + tx, err = applyScopedFilters(tx, input.InlineFilters, input.MapFilters) + if err != nil { + return 0, err + } + + // Run the query + timer := r.metrics.CountDuration.Start() + var count int64 + tx = tx.Count(&count) + timer.Stop() + if tx.Error != nil { + return 0, r.errorTransformer.ToFlyteAdminError(tx.Error) + } + return count, nil +} + // Returns an instance of ExecutionRepoInterface func NewExecutionRepo( db *gorm.DB, errorTransformer adminErrors.ErrorTransformer, scope promutils.Scope) interfaces.ExecutionRepoInterface { diff --git a/flyteadmin/pkg/repositories/gormimpl/execution_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/execution_repo_test.go index 0551128e0..9a3dc194c 100644 --- a/flyteadmin/pkg/repositories/gormimpl/execution_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/execution_repo_test.go @@ -343,3 +343,41 @@ func TestListExecutionsForWorkflow(t *testing.T) { assert.Equal(t, time.Hour, execution.Duration) } } + +func TestCountExecutions(t *testing.T) { + executionRepo := NewExecutionRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + + GlobalMock := mocket.Catcher.Reset() + GlobalMock.NewMock().WithQuery( + `SELECT count(*) FROM "executions"`).WithReply([]map[string]interface{}{{"rows": 2}}) + + count, err := executionRepo.Count(context.Background(), interfaces.CountResourceInput{}) + assert.NoError(t, err) + assert.Equal(t, int64(2), count) +} + +func TestCountExecutions_Filters(t *testing.T) { + executionRepo := NewExecutionRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + + GlobalMock := mocket.Catcher.Reset() + GlobalMock.NewMock().WithQuery( + `SELECT count(*) FROM "executions" INNER JOIN workflows ON executions.workflow_id = workflows.id INNER JOIN tasks ON executions.task_id = tasks.id WHERE executions.phase = $1 AND "error_code" IS NULL`, + ).WithReply([]map[string]interface{}{{"rows": 3}}) + + count, err := executionRepo.Count(context.Background(), interfaces.CountResourceInput{ + InlineFilters: []common.InlineFilter{ + getEqualityFilter(common.Execution, "phase", core.WorkflowExecution_FAILED.String()), + }, + MapFilters: []common.MapFilter{ + common.NewMapFilter(map[string]interface{}{ + "error_code": nil, + }), + }, + JoinTableEntities: map[common.Entity]bool{ + common.Workflow: true, + common.Task: true, + }, + }) + assert.NoError(t, err) + assert.Equal(t, int64(3), count) +} diff --git a/flyteadmin/pkg/repositories/gormimpl/metrics.go b/flyteadmin/pkg/repositories/gormimpl/metrics.go index f00225b4b..ae4a6a7d6 100644 --- a/flyteadmin/pkg/repositories/gormimpl/metrics.go +++ b/flyteadmin/pkg/repositories/gormimpl/metrics.go @@ -16,6 +16,7 @@ type gormMetrics struct { ListIdentifiersDuration promutils.StopWatch DeleteDuration promutils.StopWatch ExistsDuration promutils.StopWatch + CountDuration promutils.StopWatch } func newMetrics(scope promutils.Scope) gormMetrics { @@ -31,7 +32,11 @@ func newMetrics(scope promutils.Scope) gormMetrics { "list", "time taken to list entries", time.Millisecond), ListIdentifiersDuration: scope.MustNewStopWatch( "list_identifiers", "time taken to list identifier entries", time.Millisecond), - DeleteDuration: scope.MustNewStopWatch("delete", "time taken to delete an individual entry", time.Millisecond), - ExistsDuration: scope.MustNewStopWatch("exists", "time taken to determine whether an individual entry exists", time.Millisecond), + DeleteDuration: scope.MustNewStopWatch( + "delete", "time taken to delete an individual entry", time.Millisecond), + ExistsDuration: scope.MustNewStopWatch( + "exists", "time taken to determine whether an individual entry exists", time.Millisecond), + CountDuration: scope.MustNewStopWatch( + "count", "time taken to count entries", time.Millisecond), } } diff --git a/flyteadmin/pkg/repositories/gormimpl/node_execution_repo.go b/flyteadmin/pkg/repositories/gormimpl/node_execution_repo.go index af58fc65e..947ce863c 100644 --- a/flyteadmin/pkg/repositories/gormimpl/node_execution_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/node_execution_repo.go @@ -196,6 +196,32 @@ func (r *NodeExecutionRepo) Exists(ctx context.Context, input interfaces.NodeExe return true, nil } +func (r *NodeExecutionRepo) Count(ctx context.Context, input interfaces.CountResourceInput) (int64, error) { + var err error + tx := r.db.Model(&models.NodeExecution{}) + + // Add join condition (joining multiple tables is fine even we only filter on a subset of table attributes). + // (this query isn't called for deletes). + tx = tx.Joins(innerJoinNodeExecToNodeEvents) + tx = tx.Joins(innerJoinExecToNodeExec) + + // Apply filters + tx, err = applyScopedFilters(tx, input.InlineFilters, input.MapFilters) + if err != nil { + return 0, err + } + + // Run the query + timer := r.metrics.CountDuration.Start() + var count int64 + tx = tx.Count(&count) + timer.Stop() + if tx.Error != nil { + return 0, r.errorTransformer.ToFlyteAdminError(tx.Error) + } + return count, nil +} + // Returns an instance of NodeExecutionRepoInterface func NewNodeExecutionRepo( db *gorm.DB, errorTransformer adminErrors.ErrorTransformer, diff --git a/flyteadmin/pkg/repositories/gormimpl/node_execution_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/node_execution_repo_test.go index d80df4e3b..c34579e3f 100644 --- a/flyteadmin/pkg/repositories/gormimpl/node_execution_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/node_execution_repo_test.go @@ -448,3 +448,36 @@ func TestNodeExecutionExists(t *testing.T) { assert.NoError(t, err) assert.True(t, exists) } + +func TestCountNodeExecutions(t *testing.T) { + nodeExecutionRepo := NewNodeExecutionRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + + GlobalMock := mocket.Catcher.Reset() + GlobalMock.NewMock().WithQuery( + `SELECT count(*) FROM "node_executions"`).WithReply([]map[string]interface{}{{"rows": 2}}) + + count, err := nodeExecutionRepo.Count(context.Background(), interfaces.CountResourceInput{}) + assert.NoError(t, err) + assert.Equal(t, int64(2), count) +} + +func TestCountNodeExecutions_Filters(t *testing.T) { + nodeExecutionRepo := NewNodeExecutionRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + + GlobalMock := mocket.Catcher.Reset() + GlobalMock.NewMock().WithQuery( + `SELECT count(*) FROM "node_executions" INNER JOIN node_executions ON node_event_executions.node_execution_id = node_executions.id INNER JOIN executions ON node_executions.execution_project = executions.execution_project AND node_executions.execution_domain = executions.execution_domain AND node_executions.execution_name = executions.execution_name WHERE node_executions.phase = $1 AND "error_code" IS NULL`).WithReply([]map[string]interface{}{{"rows": 3}}) + + count, err := nodeExecutionRepo.Count(context.Background(), interfaces.CountResourceInput{ + InlineFilters: []common.InlineFilter{ + getEqualityFilter(common.NodeExecution, "phase", core.NodeExecution_FAILED.String()), + }, + MapFilters: []common.MapFilter{ + common.NewMapFilter(map[string]interface{}{ + "error_code": nil, + }), + }, + }) + assert.NoError(t, err) + assert.Equal(t, int64(3), count) +} diff --git a/flyteadmin/pkg/repositories/gormimpl/resource_repo.go b/flyteadmin/pkg/repositories/gormimpl/resource_repo.go index b1d3a46a3..e4b3f5dfd 100644 --- a/flyteadmin/pkg/repositories/gormimpl/resource_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/resource_repo.go @@ -24,18 +24,18 @@ type ResourceRepo struct { const priorityDescending = "priority desc" /* - The data in the Resource repo maps to the following rules: - * Domain and ResourceType can never be empty. - * Empty string can be interpreted as all. Example: "" for Project field can be interpreted as all Projects for a domain. - * One cannot provide specific value for Project, unless a specific value for Domain is provided. - ** Project is always scoped within a domain. - ** Example: Domain="" Project="Lyft" is invalid. - * One cannot provide specific value for Workflow, unless a specific value for Domain and Project is provided. - ** Workflow is always scoped within a domain and project. - ** Example: Domain="staging" Project="" Workflow="W1" is invalid. - * One cannot provide specific value for Launch plan, unless a specific value for Domain, Project and Workflow is provided. - ** Launch plan is always scoped within a domain, project and workflow. - ** Example: Domain="staging" Project="Lyft" Workflow="" LaunchPlan= "l1" is invalid. +The data in the Resource repo maps to the following rules: +* Domain and ResourceType can never be empty. +* Empty string can be interpreted as all. Example: "" for Project field can be interpreted as all Projects for a domain. +* One cannot provide specific value for Project, unless a specific value for Domain is provided. +** Project is always scoped within a domain. +** Example: Domain="" Project="Lyft" is invalid. +* One cannot provide specific value for Workflow, unless a specific value for Domain and Project is provided. +** Workflow is always scoped within a domain and project. +** Example: Domain="staging" Project="" Workflow="W1" is invalid. +* One cannot provide specific value for Launch plan, unless a specific value for Domain, Project and Workflow is provided. +** Launch plan is always scoped within a domain, project and workflow. +** Example: Domain="staging" Project="Lyft" Workflow="" LaunchPlan= "l1" is invalid. */ func validateCreateOrUpdateResourceInput(project, domain, workflow, launchPlan, resourceType string) bool { if domain == "" || resourceType == "" { diff --git a/flyteadmin/pkg/repositories/gormimpl/task_execution_repo.go b/flyteadmin/pkg/repositories/gormimpl/task_execution_repo.go index fab70b2a8..b864d802e 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_execution_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_execution_repo.go @@ -99,7 +99,7 @@ func (r *TaskExecutionRepo) List(ctx context.Context, input interfaces.ListResou tx := r.db.Limit(input.Limit).Offset(input.Offset).Preload("ChildNodeExecution") // And add three join conditions (joining multiple tables is fine even we only filter on a subset of table attributes). - // We are joining on task -> taskExec->NodeExec -> Exec. + // We are joining on task -> taskExec -> NodeExec -> Exec. // NOTE: the order in which the joins are called below are important because postgres will only know about certain // tables as they are joined. So we should do it in the order specified above. tx = tx.Joins(leftJoinTaskToTaskExec) @@ -129,6 +129,35 @@ func (r *TaskExecutionRepo) List(ctx context.Context, input interfaces.ListResou }, nil } +func (r *TaskExecutionRepo) Count(ctx context.Context, input interfaces.CountResourceInput) (int64, error) { + var err error + tx := r.db.Model(&models.TaskExecution{}) + + // Add three join conditions (joining multiple tables is fine even we only filter on a subset of table attributes). + // We are joining on task -> taskExec -> NodeExec -> Exec. + // NOTE: the order in which the joins are called below are important because postgres will only know about certain + // tables as they are joined. So we should do it in the order specified above. + tx = tx.Joins(leftJoinTaskToTaskExec) + tx = tx.Joins(innerJoinNodeExecToTaskExec) + tx = tx.Joins(innerJoinExecToNodeExec) + + // Apply filters + tx, err = applyScopedFilters(tx, input.InlineFilters, input.MapFilters) + if err != nil { + return 0, err + } + + // Run the query + timer := r.metrics.CountDuration.Start() + var count int64 + tx = tx.Count(&count) + timer.Stop() + if tx.Error != nil { + return 0, r.errorTransformer.ToFlyteAdminError(tx.Error) + } + return count, nil +} + // Returns an instance of TaskExecutionRepoInterface func NewTaskExecutionRepo( db *gorm.DB, errorTransformer flyteAdminDbErrors.ErrorTransformer, scope promutils.Scope) interfaces.TaskExecutionRepoInterface { diff --git a/flyteadmin/pkg/repositories/gormimpl/task_execution_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/task_execution_repo_test.go index f4a01a768..955645332 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_execution_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_execution_repo_test.go @@ -201,3 +201,36 @@ func TestListTaskExecutionsForTaskExecution(t *testing.T) { assert.Equal(t, time.Hour, taskExecution.Duration) } } + +func TestCountTaskExecutions(t *testing.T) { + taskExecutionRepo := NewTaskExecutionRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + + GlobalMock := mocket.Catcher.Reset() + GlobalMock.NewMock().WithQuery( + `SELECT count(*) FROM "task_executions"`).WithReply([]map[string]interface{}{{"rows": 2}}) + + count, err := taskExecutionRepo.Count(context.Background(), interfaces.CountResourceInput{}) + assert.NoError(t, err) + assert.Equal(t, int64(2), count) +} + +func TestCountTaskExecutions_Filters(t *testing.T) { + taskExecutionRepo := NewTaskExecutionRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + + GlobalMock := mocket.Catcher.Reset() + GlobalMock.NewMock().WithQuery( + `SELECT count(*) FROM "task_executions" LEFT JOIN tasks ON task_executions.project = tasks.project AND task_executions.domain = tasks.domain AND task_executions.name = tasks.name AND task_executions.version = tasks.version INNER JOIN node_executions ON task_executions.node_id = node_executions.node_id AND task_executions.execution_project = node_executions.execution_project AND task_executions.execution_domain = node_executions.execution_domain AND task_executions.execution_name = node_executions.execution_name INNER JOIN executions ON node_executions.execution_project = executions.execution_project AND node_executions.execution_domain = executions.execution_domain AND node_executions.execution_name = executions.execution_name WHERE task_executions.phase = $1 AND "task_execution_updated_at" IS NULL`).WithReply([]map[string]interface{}{{"rows": 3}}) + + count, err := taskExecutionRepo.Count(context.Background(), interfaces.CountResourceInput{ + InlineFilters: []common.InlineFilter{ + getEqualityFilter(common.TaskExecution, "phase", core.TaskExecution_FAILED.String()), + }, + MapFilters: []common.MapFilter{ + common.NewMapFilter(map[string]interface{}{ + "task_execution_updated_at": nil, + }), + }, + }) + assert.NoError(t, err) + assert.Equal(t, int64(3), count) +} diff --git a/flyteadmin/pkg/repositories/interfaces/common.go b/flyteadmin/pkg/repositories/interfaces/common.go index b2bb1222c..60065a20c 100644 --- a/flyteadmin/pkg/repositories/interfaces/common.go +++ b/flyteadmin/pkg/repositories/interfaces/common.go @@ -22,7 +22,7 @@ type ListResourceInput struct { // pq driver value substitution. MapFilters []common.MapFilter SortParameter common.SortParameter - // A set of the entities (besides the primary table being queries) that should be joined with when performing + // A set of the entities (besides the primary table being queried) that should be joined with when performing // the list query. This enables filtering on non-primary entity attributes. JoinTableEntities map[common.Entity]bool } @@ -32,3 +32,15 @@ type UpdateResourceInput struct { Filters []common.InlineFilter Attributes map[string]interface{} } + +// Parameters for counting multiple resources. +type CountResourceInput struct { + InlineFilters []common.InlineFilter + // MapFilters refers to primary entity filters defined as map values rather than inline sql queries. + // These exist to permit filtering on "IS NULL" which isn't permitted with inline filter queries and + // pq driver value substitution. + MapFilters []common.MapFilter + // A set of the entities (besides the primary table being queried) that should be joined with when performing + // the count query. This enables filtering on non-primary entity attributes. + JoinTableEntities map[common.Entity]bool +} diff --git a/flyteadmin/pkg/repositories/interfaces/execution_repo.go b/flyteadmin/pkg/repositories/interfaces/execution_repo.go index 770bcd352..6a65d3d3f 100644 --- a/flyteadmin/pkg/repositories/interfaces/execution_repo.go +++ b/flyteadmin/pkg/repositories/interfaces/execution_repo.go @@ -16,6 +16,8 @@ type ExecutionRepoInterface interface { Get(ctx context.Context, input Identifier) (models.Execution, error) // Returns executions matching query parameters. A limit must be provided for the results page size. List(ctx context.Context, input ListResourceInput) (ExecutionCollectionOutput, error) + // Returns count of executions matching query parameters. + Count(ctx context.Context, input CountResourceInput) (int64, error) } // Response format for a query on workflows. diff --git a/flyteadmin/pkg/repositories/interfaces/node_execution_repo.go b/flyteadmin/pkg/repositories/interfaces/node_execution_repo.go index 2121df093..a1c72a872 100644 --- a/flyteadmin/pkg/repositories/interfaces/node_execution_repo.go +++ b/flyteadmin/pkg/repositories/interfaces/node_execution_repo.go @@ -15,15 +15,17 @@ type NodeExecutionRepoInterface interface { Update(ctx context.Context, execution *models.NodeExecution) error // Get returns a matching execution if it exists. Get(ctx context.Context, input NodeExecutionResource) (models.NodeExecution, error) - // GetWithChildren Returns a matching execution with preloaded child node executions. This should only be called for legacy node executions + // GetWithChildren returns a matching execution with preloaded child node executions. This should only be called for legacy node executions // which were created with eventVersion == 0 GetWithChildren(ctx context.Context, input NodeExecutionResource) (models.NodeExecution, error) - // List eturns node executions matching query parameters. A limit must be provided for the results page size. + // List returns node executions matching query parameters. A limit must be provided for the results page size. List(ctx context.Context, input ListResourceInput) (NodeExecutionCollectionOutput, error) - // ListEvents eturns node execution events matching query parameters. A limit must be provided for the results page size. + // ListEvents returns node execution events matching query parameters. A limit must be provided for the results page size. ListEvents(ctx context.Context, input ListResourceInput) (NodeExecutionEventCollectionOutput, error) // Exists returns whether a matching execution exists. Exists(ctx context.Context, input NodeExecutionResource) (bool, error) + // Returns count of node executions matching query parameters. + Count(ctx context.Context, input CountResourceInput) (int64, error) } type NodeExecutionResource struct { diff --git a/flyteadmin/pkg/repositories/interfaces/task_execution_repo.go b/flyteadmin/pkg/repositories/interfaces/task_execution_repo.go index 32fa6e2af..4ef26c116 100644 --- a/flyteadmin/pkg/repositories/interfaces/task_execution_repo.go +++ b/flyteadmin/pkg/repositories/interfaces/task_execution_repo.go @@ -17,6 +17,8 @@ type TaskExecutionRepoInterface interface { Get(ctx context.Context, input GetTaskExecutionInput) (models.TaskExecution, error) // Returns task executions matching query parameters. A limit must be provided for the results page size. List(ctx context.Context, input ListResourceInput) (TaskExecutionCollectionOutput, error) + // Returns count of task executions matching query parameters. + Count(ctx context.Context, input CountResourceInput) (int64, error) } type GetTaskExecutionInput struct { diff --git a/flyteadmin/pkg/repositories/mocks/execution_repo.go b/flyteadmin/pkg/repositories/mocks/execution_repo.go index 27cb01015..19e4980fe 100644 --- a/flyteadmin/pkg/repositories/mocks/execution_repo.go +++ b/flyteadmin/pkg/repositories/mocks/execution_repo.go @@ -12,12 +12,14 @@ type UpdateExecutionFunc func(ctx context.Context, execution models.Execution) e type GetExecutionFunc func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) type ListExecutionFunc func(ctx context.Context, input interfaces.ListResourceInput) ( interfaces.ExecutionCollectionOutput, error) +type CountExecutionFunc func(ctx context.Context, input interfaces.CountResourceInput) (int64, error) type MockExecutionRepo struct { createFunction CreateExecutionFunc updateFunction UpdateExecutionFunc getFunction GetExecutionFunc listFunction ListExecutionFunc + countFunction CountExecutionFunc } func (r *MockExecutionRepo) Create(ctx context.Context, input models.Execution) error { @@ -31,10 +33,6 @@ func (r *MockExecutionRepo) SetCreateCallback(createFunction CreateExecutionFunc r.createFunction = createFunction } -func (r *MockExecutionRepo) SetUpdateCallback(updateFunction UpdateExecutionFunc) { - r.updateFunction = updateFunction -} - func (r *MockExecutionRepo) Update(ctx context.Context, execution models.Execution) error { if r.updateFunction != nil { return r.updateFunction(ctx, execution) @@ -42,8 +40,8 @@ func (r *MockExecutionRepo) Update(ctx context.Context, execution models.Executi return nil } -func (r *MockExecutionRepo) SetUpdateExecutionCallback(updateExecutionFunc UpdateExecutionFunc) { - r.updateFunction = updateExecutionFunc +func (r *MockExecutionRepo) SetUpdateCallback(updateFunction UpdateExecutionFunc) { + r.updateFunction = updateFunction } func (r *MockExecutionRepo) Get(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { @@ -69,6 +67,17 @@ func (r *MockExecutionRepo) SetListCallback(listFunction ListExecutionFunc) { r.listFunction = listFunction } +func (r *MockExecutionRepo) Count(ctx context.Context, input interfaces.CountResourceInput) (int64, error) { + if r.countFunction != nil { + return r.countFunction(ctx, input) + } + return 0, nil +} + +func (r *MockExecutionRepo) SetCountCallback(countFunction CountExecutionFunc) { + r.countFunction = countFunction +} + func NewMockExecutionRepo() interfaces.ExecutionRepoInterface { return &MockExecutionRepo{} } diff --git a/flyteadmin/pkg/repositories/mocks/node_execution_repo.go b/flyteadmin/pkg/repositories/mocks/node_execution_repo.go index 4e59005ba..b33c45018 100644 --- a/flyteadmin/pkg/repositories/mocks/node_execution_repo.go +++ b/flyteadmin/pkg/repositories/mocks/node_execution_repo.go @@ -14,15 +14,18 @@ type ListNodeExecutionFunc func(ctx context.Context, input interfaces.ListResour interfaces.NodeExecutionCollectionOutput, error) type ListNodeExecutionEventFunc func(ctx context.Context, input interfaces.ListResourceInput) ( interfaces.NodeExecutionEventCollectionOutput, error) +type ExistsNodeExecutionFunc func(ctx context.Context, input interfaces.NodeExecutionResource) (bool, error) +type CountNodeExecutionFunc func(ctx context.Context, input interfaces.CountResourceInput) (int64, error) type MockNodeExecutionRepo struct { createFunction CreateNodeExecutionFunc updateFunction UpdateNodeExecutionFunc getFunction GetNodeExecutionFunc - GetWithChildrenFunction GetNodeExecutionFunc + getWithChildrenFunction GetNodeExecutionFunc listFunction ListNodeExecutionFunc listEventFunction ListNodeExecutionEventFunc - ExistsFunction func(ctx context.Context, input interfaces.NodeExecutionResource) (bool, error) + existsFunction ExistsNodeExecutionFunc + countFunction CountNodeExecutionFunc } func (r *MockNodeExecutionRepo) Create(ctx context.Context, input *models.NodeExecution) error { @@ -59,12 +62,16 @@ func (r *MockNodeExecutionRepo) SetGetCallback(getFunction GetNodeExecutionFunc) } func (r *MockNodeExecutionRepo) GetWithChildren(ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { - if r.GetWithChildrenFunction != nil { - return r.GetWithChildrenFunction(ctx, input) + if r.getWithChildrenFunction != nil { + return r.getWithChildrenFunction(ctx, input) } return models.NodeExecution{}, nil } +func (r *MockNodeExecutionRepo) SetGetWithChildrenCallback(getWithChildrenFunction GetNodeExecutionFunc) { + r.getWithChildrenFunction = getWithChildrenFunction +} + func (r *MockNodeExecutionRepo) List(ctx context.Context, input interfaces.ListResourceInput) ( interfaces.NodeExecutionCollectionOutput, error) { if r.listFunction != nil { @@ -90,12 +97,27 @@ func (r *MockNodeExecutionRepo) SetListEventCallback(listEventFunction ListNodeE } func (r *MockNodeExecutionRepo) Exists(ctx context.Context, input interfaces.NodeExecutionResource) (bool, error) { - if r.ExistsFunction != nil { - return r.ExistsFunction(ctx, input) + if r.existsFunction != nil { + return r.existsFunction(ctx, input) } return true, nil } +func (r *MockNodeExecutionRepo) SetExistsCallback(existsFunction ExistsNodeExecutionFunc) { + r.existsFunction = existsFunction +} + +func (r *MockNodeExecutionRepo) Count(ctx context.Context, input interfaces.CountResourceInput) (int64, error) { + if r.countFunction != nil { + return r.countFunction(ctx, input) + } + return 0, nil +} + +func (r *MockNodeExecutionRepo) SetCountCallback(countFunction CountNodeExecutionFunc) { + r.countFunction = countFunction +} + func NewMockNodeExecutionRepo() interfaces.NodeExecutionRepoInterface { return &MockNodeExecutionRepo{} } diff --git a/flyteadmin/pkg/repositories/mocks/task_execution_repo.go b/flyteadmin/pkg/repositories/mocks/task_execution_repo.go index fa28887b5..d2d1188ed 100644 --- a/flyteadmin/pkg/repositories/mocks/task_execution_repo.go +++ b/flyteadmin/pkg/repositories/mocks/task_execution_repo.go @@ -11,12 +11,14 @@ type CreateTaskExecutionFunc func(ctx context.Context, input models.TaskExecutio type GetTaskExecutionFunc func(ctx context.Context, input interfaces.GetTaskExecutionInput) (models.TaskExecution, error) type UpdateTaskExecutionFunc func(ctx context.Context, execution models.TaskExecution) error type ListTaskExecutionFunc func(ctx context.Context, input interfaces.ListResourceInput) (interfaces.TaskExecutionCollectionOutput, error) +type CountTaskExecutionFunc func(ctx context.Context, input interfaces.CountResourceInput) (int64, error) type MockTaskExecutionRepo struct { createFunction CreateTaskExecutionFunc getFunction GetTaskExecutionFunc updateFunction UpdateTaskExecutionFunc listFunction ListTaskExecutionFunc + countFunction CountTaskExecutionFunc } func (r *MockTaskExecutionRepo) Create(ctx context.Context, input models.TaskExecution) error { @@ -63,6 +65,17 @@ func (r *MockTaskExecutionRepo) SetListCallback(listFunction ListTaskExecutionFu r.listFunction = listFunction } +func (r *MockTaskExecutionRepo) Count(ctx context.Context, input interfaces.CountResourceInput) (int64, error) { + if r.countFunction != nil { + return r.countFunction(ctx, input) + } + return 0, nil +} + +func (r *MockTaskExecutionRepo) SetCountCallback(countFunction CountTaskExecutionFunc) { + r.countFunction = countFunction +} + func NewMockTaskExecutionRepo() interfaces.TaskExecutionRepoInterface { return &MockTaskExecutionRepo{} } diff --git a/flyteadmin/pkg/server/service.go b/flyteadmin/pkg/server/service.go index f5273da35..85af63bd5 100644 --- a/flyteadmin/pkg/server/service.go +++ b/flyteadmin/pkg/server/service.go @@ -305,7 +305,13 @@ func serveGatewayInsecure(ctx context.Context, pluginRegistry *plugins.Registry, handler = httpServer } - err = http.ListenAndServe(cfg.GetHostAddress(), handler) + server := &http.Server{ + Addr: cfg.GetHostAddress(), + Handler: handler, + ReadHeaderTimeout: time.Duration(cfg.ReadHeaderTimeoutSeconds) * time.Second, + } + + err = server.ListenAndServe() if err != nil { return errors.Wrapf(err, "failed to Start HTTP Server") } diff --git a/flyteadmin/scheduler/core/doc.go b/flyteadmin/scheduler/core/doc.go index e909d8f18..47b7dd9de 100644 --- a/flyteadmin/scheduler/core/doc.go +++ b/flyteadmin/scheduler/core/doc.go @@ -1,9 +1,9 @@ // Package core // This is core package for the scheduler which includes -// - scheduler interface -// - scheduler implementation using gocron https://github.com/robfig/cron -// - updater which updates the schedules in the scheduler by reading periodically from the DB -// - snapshot runner which snapshot the schedules with there last exec times so that it can be used as check point -// in case of a crash. After a crash the scheduler replays the schedules from the last recorded snapshot. -// It relies on the admin idempotency aspect to fail executions if the execution with a scheduled time already exists with it. +// - scheduler interface +// - scheduler implementation using gocron https://github.com/robfig/cron +// - updater which updates the schedules in the scheduler by reading periodically from the DB +// - snapshot runner which snapshot the schedules with there last exec times so that it can be used as check point +// in case of a crash. After a crash the scheduler replays the schedules from the last recorded snapshot. +// It relies on the admin idempotency aspect to fail executions if the execution with a scheduled time already exists with it. package core