diff --git a/cmd/entrypoints/k8s_secret.go b/cmd/entrypoints/k8s_secret.go index 2ad9356a5e..54dd3c36c1 100644 --- a/cmd/entrypoints/k8s_secret.go +++ b/cmd/entrypoints/k8s_secret.go @@ -17,10 +17,13 @@ import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "github.com/flyteorg/flyteadmin/pkg/config" executioncluster "github.com/flyteorg/flyteadmin/pkg/executioncluster/impl" + "github.com/flyteorg/flyteadmin/pkg/executioncluster/interfaces" "github.com/flyteorg/flyteadmin/pkg/runtime" "github.com/flyteorg/flytestdlib/errors" "github.com/flyteorg/flytestdlib/promutils" + "github.com/spf13/cobra" "github.com/spf13/pflag" "k8s.io/client-go/kubernetes" @@ -99,7 +102,15 @@ func persistSecrets(ctx context.Context, _ *pflag.FlagSet) error { initializationErrorCounter := scope.NewSubScope("secrets").MustNewCounter( "flyteclient_initialization_error", "count of errors encountered initializing a flyte client from kube config") - listTargetsProvider, err := executioncluster.NewListTargets(initializationErrorCounter, executioncluster.NewExecutionTargetProvider(), configuration.ClusterConfiguration()) + + var listTargetsProvider interfaces.ListTargetsInterface + var err error + if len(configuration.ClusterConfiguration().GetClusterConfigs()) == 0 { + serverConfig := config.GetConfig() + listTargetsProvider, err = executioncluster.NewInCluster(initializationErrorCounter, serverConfig.KubeConfig, serverConfig.Master) + } else { + listTargetsProvider, err = executioncluster.NewListTargets(initializationErrorCounter, executioncluster.NewExecutionTargetProvider(), configuration.ClusterConfiguration()) + } if err != nil { return err } diff --git a/go.mod b/go.mod index d51e41c8c2..fc403728f8 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/benbjohnson/clock v1.1.0 github.com/coreos/go-oidc v2.2.1+incompatible github.com/evanphx/json-patch v4.9.0+incompatible - github.com/flyteorg/flyteidl v0.21.18 + github.com/flyteorg/flyteidl v0.21.24 github.com/flyteorg/flyteplugins v0.9.1 github.com/flyteorg/flytepropeller v0.16.14 github.com/flyteorg/flytestdlib v0.4.7 diff --git a/go.sum b/go.sum index fd464a385c..ca8f78d1ae 100644 --- a/go.sum +++ b/go.sum @@ -311,8 +311,9 @@ github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4 github.com/felixge/httpsnoop v1.0.1 h1:lvB5Jl89CsZtGIWuTcDM1E/vkVs49/Ml7JJe07l8SPQ= github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/flyteorg/flyteidl v0.21.11/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U= -github.com/flyteorg/flyteidl v0.21.18 h1:tOrb8U96mJPbiYFDGgoafn/XO2EAWK3U6JWzPIlrKO4= github.com/flyteorg/flyteidl v0.21.18/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U= +github.com/flyteorg/flyteidl v0.21.24 h1:e2wPBK4aiLE+fw2zmhUDNg39QoJk6Lf5lQRvj8XgtFk= +github.com/flyteorg/flyteidl v0.21.24/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U= github.com/flyteorg/flyteplugins v0.9.1 h1:Z0gxSvG7LeI+COfEmuzkhz9RnJ4E5wWUcjj5qh1uKuw= github.com/flyteorg/flyteplugins v0.9.1/go.mod h1:OEGQztPFDJG4DV7tS9lDsRRd521iUINn5dcsBf6bW5k= github.com/flyteorg/flytepropeller v0.16.14 h1:zG+UnfZLPCQdwh7ORm3BNwXlb6Sp2Wwa7I7NnZYcPDw= diff --git a/pkg/manager/impl/execution_manager.go b/pkg/manager/impl/execution_manager.go index 5908049bc5..999a12434d 100644 --- a/pkg/manager/impl/execution_manager.go +++ b/pkg/manager/impl/execution_manager.go @@ -1267,6 +1267,31 @@ func (m *ExecutionManager) GetExecution( return execution, nil } +func (m *ExecutionManager) UpdateExecution(ctx context.Context, request admin.ExecutionUpdateRequest, + requestedAt time.Time) (*admin.ExecutionUpdateResponse, error) { + if err := validation.ValidateWorkflowExecutionIdentifier(request.Id); err != nil { + logger.Debugf(ctx, "UpdateExecution request [%+v] failed validation with err: %v", request, err) + return nil, err + } + ctx = getExecutionContext(ctx, request.Id) + executionModel, err := util.GetExecutionModel(ctx, m.db, *request.Id) + if err != nil { + logger.Debugf(ctx, "Failed to get execution model for request [%+v] with err: %v", request, err) + return nil, err + } + + if err = transformers.UpdateExecutionModelStateChangeDetails(executionModel, request.State, requestedAt, + getUser(ctx)); err != nil { + return nil, err + } + + if err := m.db.ExecutionRepo().Update(ctx, *executionModel); err != nil { + return nil, err + } + + return &admin.ExecutionUpdateResponse{}, nil +} + func (m *ExecutionManager) GetExecutionData( ctx context.Context, request admin.WorkflowExecutionGetDataRequest) (*admin.WorkflowExecutionGetDataResponse, error) { ctx = getExecutionContext(ctx, request.Id) @@ -1358,6 +1383,11 @@ func (m *ExecutionManager) ListExecutions( joinTableEntities[filter.GetEntity()] = true } + // Check if state filter exists and if not then add filter to fetch only ACTIVE executions + if filters, err = addStateFilter(filters); err != nil { + return nil, err + } + listExecutionsInput := repositoryInterfaces.ListResourceInput{ Limit: int(request.Limit), Offset: offset, @@ -1587,3 +1617,22 @@ func (m *ExecutionManager) addProjectLabels(ctx context.Context, projectName str } return initialLabels, nil } + +func addStateFilter(filters []common.InlineFilter) ([]common.InlineFilter, error) { + var stateFilterExists bool + for _, inlineFilter := range filters { + if inlineFilter.GetField() == shared.State { + stateFilterExists = true + } + } + + if !stateFilterExists { + stateFilter, err := common.NewSingleValueFilter(common.Execution, common.Equal, shared.State, + admin.ExecutionState_EXECUTION_ACTIVE) + if err != nil { + return filters, err + } + filters = append(filters, stateFilter) + } + return filters, nil +} diff --git a/pkg/manager/impl/execution_manager_test.go b/pkg/manager/impl/execution_manager_test.go index 23c37026f0..2ef7a72763 100644 --- a/pkg/manager/impl/execution_manager_test.go +++ b/pkg/manager/impl/execution_manager_test.go @@ -67,6 +67,10 @@ var specBytes, _ = proto.Marshal(spec) var phase = core.WorkflowExecution_RUNNING.String() var closure = admin.ExecutionClosure{ Phase: core.WorkflowExecution_RUNNING, + StateChangeDetails: &admin.ExecutionStateChangeDetails{ + State: admin.ExecutionState_EXECUTION_ACTIVE, + OccurredAt: testutils.MockCreatedAtProto, + }, } var closureBytes, _ = proto.Marshal(&closure) @@ -106,6 +110,10 @@ func getLegacyClosure() *admin.ExecutionClosure { return &admin.ExecutionClosure{ Phase: core.WorkflowExecution_RUNNING, ComputedInputs: getLegacySpec().Inputs, + StateChangeDetails: &admin.ExecutionStateChangeDetails{ + State: admin.ExecutionState_EXECUTION_ACTIVE, + OccurredAt: testutils.MockCreatedAtProto, + }, } } @@ -1491,6 +1499,10 @@ func TestCreateWorkflowEvent_StartedRunning(t *testing.T) { Phase: core.WorkflowExecution_RUNNING, StartedAt: occurredAtProto, UpdatedAt: occurredAtProto, + StateChangeDetails: &admin.ExecutionStateChangeDetails{ + State: admin.ExecutionState_EXECUTION_ACTIVE, + OccurredAt: testutils.MockCreatedAtProto, + }, } closureBytes, _ := proto.Marshal(&closure) updateExecutionFunc := func( @@ -1746,6 +1758,9 @@ func TestGetExecution(t *testing.T) { assert.Equal(t, "domain", input.Domain) assert.Equal(t, "name", input.Name) return models.Execution{ + BaseModel: models.BaseModel{ + CreatedAt: testutils.MockCreatedAtValue, + }, ExecutionKey: models.ExecutionKey{ Project: "project", Domain: "domain", @@ -1820,6 +1835,89 @@ func TestGetExecution_TransformerError(t *testing.T) { assert.Equal(t, codes.Internal, err.(flyteAdminErrors.FlyteAdminError).Code()) } +func TestUpdateExecution(t *testing.T) { + t.Run("invalid execution identifier", func(t *testing.T) { + repository := repositoryMocks.NewMockRepository() + execManager := NewExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}) + _, err := execManager.UpdateExecution(context.Background(), admin.ExecutionUpdateRequest{ + Id: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + }, + }, time.Now()) + assert.Error(t, err) + }) + + t.Run("empty status passed", func(t *testing.T) { + repository := repositoryMocks.NewMockRepository() + updateExecFuncCalled := false + updateExecFunc := func(ctx context.Context, execModel models.Execution) error { + stateInt := int32(admin.ExecutionState_EXECUTION_ACTIVE) + assert.Equal(t, stateInt, *execModel.State) + updateExecFuncCalled = true + return nil + } + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetUpdateExecutionCallback(updateExecFunc) + execManager := NewExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}) + updateResponse, err := execManager.UpdateExecution(context.Background(), admin.ExecutionUpdateRequest{ + Id: &executionIdentifier, + }, time.Now()) + assert.NoError(t, err) + assert.NotNil(t, updateResponse) + assert.True(t, updateExecFuncCalled) + }) + + t.Run("archive status passed", func(t *testing.T) { + repository := repositoryMocks.NewMockRepository() + updateExecFuncCalled := false + updateExecFunc := func(ctx context.Context, execModel models.Execution) error { + stateInt := int32(admin.ExecutionState_EXECUTION_ARCHIVED) + assert.Equal(t, stateInt, *execModel.State) + updateExecFuncCalled = true + return nil + } + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetUpdateExecutionCallback(updateExecFunc) + execManager := NewExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}) + updateResponse, err := execManager.UpdateExecution(context.Background(), admin.ExecutionUpdateRequest{ + Id: &executionIdentifier, + State: admin.ExecutionState_EXECUTION_ARCHIVED, + }, time.Now()) + assert.NoError(t, err) + assert.NotNil(t, updateResponse) + assert.True(t, updateExecFuncCalled) + }) + + t.Run("update error", func(t *testing.T) { + repository := repositoryMocks.NewMockRepository() + updateExecFunc := func(ctx context.Context, execModel models.Execution) error { + return fmt.Errorf("some db error") + } + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetUpdateExecutionCallback(updateExecFunc) + execManager := NewExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}) + _, err := execManager.UpdateExecution(context.Background(), admin.ExecutionUpdateRequest{ + Id: &executionIdentifier, + State: admin.ExecutionState_EXECUTION_ARCHIVED, + }, time.Now()) + assert.Error(t, err) + assert.Equal(t, "some db error", err.Error()) + }) + + t.Run("get execution error", func(t *testing.T) { + repository := repositoryMocks.NewMockRepository() + getExecFunc := func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { + return models.Execution{}, fmt.Errorf("some db error") + } + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback(getExecFunc) + execManager := NewExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}) + _, err := execManager.UpdateExecution(context.Background(), admin.ExecutionUpdateRequest{ + Id: &executionIdentifier, + State: admin.ExecutionState_EXECUTION_ARCHIVED, + }, time.Now()) + assert.Error(t, err) + assert.Equal(t, "some db error", err.Error()) + }) +} + func TestListExecutions(t *testing.T) { repository := repositoryMocks.NewMockRepository() executionListFunc := func( @@ -1850,6 +1948,9 @@ func TestListExecutions(t *testing.T) { return interfaces.ExecutionCollectionOutput{ Executions: []models.Execution{ { + BaseModel: models.BaseModel{ + CreatedAt: testutils.MockCreatedAtValue, + }, ExecutionKey: models.ExecutionKey{ Project: projectValue, Domain: domainValue, @@ -1859,6 +1960,9 @@ func TestListExecutions(t *testing.T) { Closure: closureBytes, }, { + BaseModel: models.BaseModel{ + CreatedAt: testutils.MockCreatedAtValue, + }, ExecutionKey: models.ExecutionKey{ Project: projectValue, Domain: domainValue, @@ -2538,6 +2642,9 @@ func TestGetExecution_Legacy(t *testing.T) { assert.Equal(t, "domain", input.Domain) assert.Equal(t, "name", input.Name) return models.Execution{ + BaseModel: models.BaseModel{ + CreatedAt: testutils.MockCreatedAtValue, + }, ExecutionKey: models.ExecutionKey{ Project: "project", Domain: "domain", @@ -2577,6 +2684,9 @@ func TestGetExecutionData_LegacyModel(t *testing.T) { executionGetFunc := func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { return models.Execution{ + BaseModel: models.BaseModel{ + CreatedAt: testutils.MockCreatedAtValue, + }, ExecutionKey: models.ExecutionKey{ Project: "project", Domain: "domain", @@ -2770,6 +2880,9 @@ func TestListExecutions_LegacyModel(t *testing.T) { return interfaces.ExecutionCollectionOutput{ Executions: []models.Execution{ { + BaseModel: models.BaseModel{ + CreatedAt: testutils.MockCreatedAtValue, + }, ExecutionKey: models.ExecutionKey{ Project: projectValue, Domain: domainValue, @@ -2779,6 +2892,9 @@ func TestListExecutions_LegacyModel(t *testing.T) { Closure: getLegacyClosureBytes(), }, { + BaseModel: models.BaseModel{ + CreatedAt: testutils.MockCreatedAtValue, + }, ExecutionKey: models.ExecutionKey{ Project: projectValue, Domain: domainValue, @@ -3749,3 +3865,39 @@ func TestFromAdminProtoTaskResourceSpec(t *testing.T) { GPU: resource.MustParse("2"), }, taskResourceSet) } + +func TestAddStateFilter(t *testing.T) { + t.Run("empty filters", func(t *testing.T) { + var filters []common.InlineFilter + updatedFilters, err := addStateFilter(filters) + assert.Nil(t, err) + assert.NotNil(t, updatedFilters) + assert.Equal(t, 1, len(updatedFilters)) + + assert.Equal(t, shared.State, updatedFilters[0].GetField()) + assert.Equal(t, common.Execution, updatedFilters[0].GetEntity()) + + expression, err := updatedFilters[0].GetGormQueryExpr() + assert.NoError(t, err) + assert.Equal(t, "state = ?", expression.Query) + }) + + t.Run("passed state filter", func(t *testing.T) { + filter, err := common.NewSingleValueFilter(common.Execution, common.NotEqual, "state", "0") + assert.NoError(t, err) + filters := []common.InlineFilter{filter} + + updatedFilters, err := addStateFilter(filters) + assert.Nil(t, err) + assert.NotNil(t, updatedFilters) + assert.Equal(t, 1, len(updatedFilters)) + + assert.Equal(t, shared.State, updatedFilters[0].GetField()) + assert.Equal(t, common.Execution, updatedFilters[0].GetEntity()) + + expression, err := updatedFilters[0].GetGormQueryExpr() + assert.NoError(t, err) + assert.Equal(t, "state <> ?", expression.Query) + }) + +} diff --git a/pkg/manager/interfaces/execution.go b/pkg/manager/interfaces/execution.go index 8850bf1489..3f11a43171 100644 --- a/pkg/manager/interfaces/execution.go +++ b/pkg/manager/interfaces/execution.go @@ -21,6 +21,8 @@ type ExecutionInterface interface { CreateWorkflowEvent(ctx context.Context, request admin.WorkflowExecutionEventRequest) ( *admin.WorkflowExecutionEventResponse, error) GetExecution(ctx context.Context, request admin.WorkflowExecutionGetRequest) (*admin.Execution, error) + UpdateExecution(ctx context.Context, request admin.ExecutionUpdateRequest, requestedAt time.Time) ( + *admin.ExecutionUpdateResponse, error) GetExecutionData(ctx context.Context, request admin.WorkflowExecutionGetDataRequest) ( *admin.WorkflowExecutionGetDataResponse, error) ListExecutions(ctx context.Context, request admin.ResourceListRequest) (*admin.ExecutionList, error) diff --git a/pkg/manager/mocks/execution.go b/pkg/manager/mocks/execution.go index cfcc19d2cf..361234f8df 100644 --- a/pkg/manager/mocks/execution.go +++ b/pkg/manager/mocks/execution.go @@ -18,6 +18,8 @@ type RecoverExecutionFunc func(ctx context.Context, request admin.ExecutionRecov type CreateExecutionEventFunc func(ctx context.Context, request admin.WorkflowExecutionEventRequest) ( *admin.WorkflowExecutionEventResponse, error) type GetExecutionFunc func(ctx context.Context, request admin.WorkflowExecutionGetRequest) (*admin.Execution, error) +type UpdateExecutionFunc func(ctx context.Context, request admin.ExecutionUpdateRequest, requestedAt time.Time) ( + *admin.ExecutionUpdateResponse, error) type GetExecutionDataFunc func(ctx context.Context, request admin.WorkflowExecutionGetDataRequest) ( *admin.WorkflowExecutionGetDataResponse, error) type ListExecutionFunc func(ctx context.Context, request admin.ResourceListRequest) (*admin.ExecutionList, error) @@ -30,6 +32,7 @@ type MockExecutionManager struct { RecoverExecutionFunc RecoverExecutionFunc createExecutionEventFunc CreateExecutionEventFunc getExecutionFunc GetExecutionFunc + updateExecutionFunc UpdateExecutionFunc getExecutionDataFunc GetExecutionDataFunc listExecutionFunc ListExecutionFunc terminateExecutionFunc TerminateExecutionFunc @@ -82,6 +85,18 @@ func (m *MockExecutionManager) CreateWorkflowEvent( return nil, nil } +func (m *MockExecutionManager) SetUpdateExecutionCallback(updateExecutionFunc UpdateExecutionFunc) { + m.updateExecutionFunc = updateExecutionFunc +} + +func (m *MockExecutionManager) UpdateExecution(ctx context.Context, request admin.ExecutionUpdateRequest, + requestedAt time.Time) (*admin.ExecutionUpdateResponse, error) { + if m.updateExecutionFunc != nil { + return m.updateExecutionFunc(ctx, request, requestedAt) + } + return nil, nil +} + func (m *MockExecutionManager) SetGetCallback(getExecutionFunc GetExecutionFunc) { m.getExecutionFunc = getExecutionFunc } diff --git a/pkg/repositories/config/migrations.go b/pkg/repositories/config/migrations.go index d50745a14b..0fba3bb032 100644 --- a/pkg/repositories/config/migrations.go +++ b/pkg/repositories/config/migrations.go @@ -359,6 +359,17 @@ var Migrations = []*gormigrate.Migration{ return alterTableColumnType(db, "id", "int") }, }, + + // Add state to execution model. + { + ID: "2022-01-11-execution-state", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&models.Execution{}) + }, + Rollback: func(tx *gorm.DB) error { + return tx.Table("execution").Migrator().DropColumn(&models.Execution{}, "state") + }, + }, } func alterTableColumnType(db *sql.DB, columnName, columnType string) error { diff --git a/pkg/repositories/gormimpl/execution_repo_test.go b/pkg/repositories/gormimpl/execution_repo_test.go index c7335f1f32..0551128e03 100644 --- a/pkg/repositories/gormimpl/execution_repo_test.go +++ b/pkg/repositories/gormimpl/execution_repo_test.go @@ -310,7 +310,8 @@ func TestListExecutionsForWorkflow(t *testing.T) { GlobalMock.Logging = true // Only match on queries that append expected filters - GlobalMock.NewMock().WithQuery(`SELECT "executions"."id","executions"."created_at","executions"."updated_at","executions"."deleted_at","executions"."execution_project","executions"."execution_domain","executions"."execution_name","executions"."launch_plan_id","executions"."workflow_id","executions"."task_id","executions"."phase","executions"."closure","executions"."spec","executions"."started_at","executions"."execution_created_at","executions"."execution_updated_at","executions"."duration","executions"."abort_cause","executions"."mode","executions"."source_execution_id","executions"."parent_node_execution_id","executions"."cluster","executions"."inputs_uri","executions"."user_inputs_uri","executions"."error_kind","executions"."error_code","executions"."user" FROM "executions" INNER JOIN workflows ON executions.workflow_id = workflows.id INNER JOIN tasks ON executions.task_id = tasks.id WHERE executions.execution_project = $1 AND executions.execution_domain = $2 AND executions.execution_name = $3 AND (workflows.name = $4) AND tasks.name = $5 LIMIT`).WithReply(executions) + GlobalMock.NewMock().WithQuery(`SELECT "executions"."id","executions"."created_at","executions"."updated_at","executions"."deleted_at","executions"."execution_project","executions"."execution_domain","executions"."execution_name","executions"."launch_plan_id","executions"."workflow_id","executions"."task_id","executions"."phase","executions"."closure","executions"."spec","executions"."started_at","executions"."execution_created_at","executions"."execution_updated_at","executions"."duration","executions"."abort_cause","executions"."mode","executions"."source_execution_id","executions"."parent_node_execution_id","executions"."cluster","executions"."inputs_uri","executions"."user_inputs_uri","executions"."error_kind","executions"."error_code","executions"."user","executions"."state" FROM "executions" INNER JOIN workflows ON executions.workflow_id = workflows.id INNER JOIN tasks ON executions.task_id = tasks.id WHERE executions.execution_project = $1 AND executions.execution_domain = $2 AND executions.execution_name = $3 AND (workflows.name = $4) AND tasks.name = $5 LIMIT 20`).WithReply(executions) + collection, err := executionRepo.List(context.Background(), interfaces.ListResourceInput{ InlineFilters: []common.InlineFilter{ getEqualityFilter(common.Execution, "project", project), diff --git a/pkg/repositories/models/execution.go b/pkg/repositories/models/execution.go index 14caeee1f3..fa429226b4 100644 --- a/pkg/repositories/models/execution.go +++ b/pkg/repositories/models/execution.go @@ -56,4 +56,6 @@ type Execution struct { // The user responsible for launching this execution. // This is also stored in the spec but promoted as a column for filtering. User string `gorm:"index" valid:"length(0|255)"` + // GORM doesn't save the zero value for ints, so we use a pointer for the State field + State *int32 `gorm:"index;default:0"` } diff --git a/pkg/repositories/transformers/execution.go b/pkg/repositories/transformers/execution.go index cc3142d8ca..c45bd1211b 100644 --- a/pkg/repositories/transformers/execution.go +++ b/pkg/repositories/transformers/execution.go @@ -5,21 +5,20 @@ import ( "fmt" "time" - "k8s.io/apimachinery/pkg/util/sets" - + "github.com/flyteorg/flyteadmin/pkg/common" + "github.com/flyteorg/flyteadmin/pkg/errors" + "github.com/flyteorg/flyteadmin/pkg/repositories/models" "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/logger" "github.com/flyteorg/flytestdlib/storage" "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" - "google.golang.org/grpc/codes" - "github.com/flyteorg/flyteadmin/pkg/common" - "github.com/flyteorg/flyteadmin/pkg/errors" - "github.com/flyteorg/flyteadmin/pkg/repositories/models" + "google.golang.org/grpc/codes" + "google.golang.org/protobuf/types/known/timestamppb" + "k8s.io/apimachinery/pkg/util/sets" ) var clusterReassignablePhases = sets.NewString(core.WorkflowExecution_UNDEFINED.String(), core.WorkflowExecution_QUEUED.String()) @@ -65,6 +64,11 @@ func CreateExecutionModel(input CreateExecutionModelInput) (*models.Execution, e UpdatedAt: createdAt, Notifications: input.Notifications, WorkflowId: input.WorkflowIdentifier, + StateChangeDetails: &admin.ExecutionStateChangeDetails{ + State: admin.ExecutionState_EXECUTION_ACTIVE, + Principal: requestSpec.Metadata.Principal, + OccurredAt: createdAt, + }, } if input.Phase == core.WorkflowExecution_RUNNING { closure.StartedAt = createdAt @@ -76,6 +80,7 @@ func CreateExecutionModel(input CreateExecutionModelInput) (*models.Execution, e return nil, errors.NewFlyteAdminError(codes.Internal, "Failed to serialize launch plan status") } + activeExecution := int32(admin.ExecutionState_EXECUTION_ACTIVE) executionModel := &models.Execution{ ExecutionKey: models.ExecutionKey{ Project: input.WorkflowExecutionID.Project, @@ -94,6 +99,7 @@ func CreateExecutionModel(input CreateExecutionModelInput) (*models.Execution, e InputsURI: input.InputsURI, UserInputsURI: input.UserInputsURI, User: requestSpec.Metadata.Principal, + State: &activeExecution, } // A reference launch entity can be one of either or a task OR launch plan. Traditionally, workflows are executed // with a reference launch plan which is why this behavior is the default below. @@ -228,6 +234,39 @@ func UpdateExecutionModelState( return nil } +// UpdateExecutionModelStateChangeDetails Updates an existing model with stateUpdateTo, stateUpdateBy and +// statedUpdatedAt details from the request +func UpdateExecutionModelStateChangeDetails(executionModel *models.Execution, stateUpdatedTo admin.ExecutionState, + stateUpdatedAt time.Time, stateUpdatedBy string) error { + + var closure admin.ExecutionClosure + err := proto.Unmarshal(executionModel.Closure, &closure) + if err != nil { + return errors.NewFlyteAdminErrorf(codes.Internal, "Failed to unmarshal execution closure: %v", err) + } + // Update the indexed columns + stateInt := int32(stateUpdatedTo) + executionModel.State = &stateInt + + // Update the closure with the same + var stateUpdatedAtProto *timestamppb.Timestamp + // Default use the createdAt timestamp as the state change occurredAt time + if stateUpdatedAtProto, err = ptypes.TimestampProto(stateUpdatedAt); err != nil { + return err + } + closure.StateChangeDetails = &admin.ExecutionStateChangeDetails{ + State: stateUpdatedTo, + Principal: stateUpdatedBy, + OccurredAt: stateUpdatedAtProto, + } + marshaledClosure, err := proto.Marshal(&closure) + if err != nil { + return errors.NewFlyteAdminErrorf(codes.Internal, "Failed to marshal execution closure: %v", err) + } + executionModel.Closure = marshaledClosure + return nil +} + // The execution abort metadata is recorded but the phase is not actually updated *until* the abort event is propagated // by flytepropeller. The metadata is preemptively saved at the time of the abort. func SetExecutionAborted(execution *models.Execution, cause, principal string) error { @@ -263,15 +302,22 @@ func GetExecutionIdentifier(executionModel *models.Execution) core.WorkflowExecu func FromExecutionModel(executionModel models.Execution) (*admin.Execution, error) { var spec admin.ExecutionSpec - err := proto.Unmarshal(executionModel.Spec, &spec) - if err != nil { + var err error + if err = proto.Unmarshal(executionModel.Spec, &spec); err != nil { return nil, errors.NewFlyteAdminErrorf(codes.Internal, "failed to unmarshal spec") } var closure admin.ExecutionClosure - err = proto.Unmarshal(executionModel.Closure, &closure) - if err != nil { + if err = proto.Unmarshal(executionModel.Closure, &closure); err != nil { return nil, errors.NewFlyteAdminErrorf(codes.Internal, "failed to unmarshal closure") } + + if closure.StateChangeDetails == nil { + // Update execution state details from model for older executions + if closure.StateChangeDetails, err = PopulateDefaultStateChangeDetails(executionModel); err != nil { + return nil, err + } + } + id := GetExecutionIdentifier(&executionModel) if executionModel.Phase == core.WorkflowExecution_ABORTED.String() && closure.GetAbortMetadata() == nil { // In the case of data predating the AbortMetadata field we manually set it in the closure only @@ -293,6 +339,24 @@ func FromExecutionModel(executionModel models.Execution) (*admin.Execution, erro }, nil } +// PopulateDefaultStateChangeDetails used to populate execution state change details for older executions which donot +// have these details captured. Hence we construct a default state change details from existing data model. +func PopulateDefaultStateChangeDetails(executionModel models.Execution) (*admin.ExecutionStateChangeDetails, error) { + var err error + var occurredAt *timestamppb.Timestamp + + // Default use the createdAt timestamp as the state change occurredAt time + if occurredAt, err = ptypes.TimestampProto(executionModel.CreatedAt); err != nil { + return nil, err + } + + return &admin.ExecutionStateChangeDetails{ + State: admin.ExecutionState_EXECUTION_ACTIVE, + OccurredAt: occurredAt, + Principal: executionModel.User, + }, nil +} + func FromExecutionModels(executionModels []models.Execution) ([]*admin.Execution, error) { executions := make([]*admin.Execution, len(executionModels)) for idx, executionModel := range executionModels { diff --git a/pkg/repositories/transformers/execution_test.go b/pkg/repositories/transformers/execution_test.go index 4a934c0eb9..0b8ece90e7 100644 --- a/pkg/repositories/transformers/execution_test.go +++ b/pkg/repositories/transformers/execution_test.go @@ -3,6 +3,8 @@ package transformers import ( "context" "fmt" + "math" + "strings" "testing" "time" @@ -109,6 +111,11 @@ func TestCreateExecutionModel(t *testing.T) { StartedAt: expectedCreatedAt, UpdatedAt: expectedCreatedAt, WorkflowId: workflowIdentifier, + StateChangeDetails: &admin.ExecutionStateChangeDetails{ + State: admin.ExecutionState_EXECUTION_ACTIVE, + OccurredAt: expectedCreatedAt, + Principal: principal, + }, }) assert.Equal(t, expectedClosure, execution.Closure) } @@ -481,26 +488,37 @@ func TestFromExecutionModel(t *testing.T) { specBytes, _ := proto.Marshal(spec) phase := core.WorkflowExecution_RUNNING.String() startedAt := time.Date(2018, 8, 30, 0, 0, 0, 0, time.UTC) + createdAt := time.Date(2022, 01, 18, 0, 0, 0, 0, time.UTC) startedAtProto, _ := ptypes.TimestampProto(startedAt) + createdAtProto, _ := ptypes.TimestampProto(createdAt) closure := admin.ExecutionClosure{ ComputedInputs: spec.Inputs, Phase: core.WorkflowExecution_RUNNING, StartedAt: startedAtProto, + StateChangeDetails: &admin.ExecutionStateChangeDetails{ + State: admin.ExecutionState_EXECUTION_ACTIVE, + OccurredAt: createdAtProto, + }, } closureBytes, _ := proto.Marshal(&closure) - + stateInt := int32(admin.ExecutionState_EXECUTION_ACTIVE) executionModel := models.Execution{ + BaseModel: models.BaseModel{ + CreatedAt: createdAt, + }, ExecutionKey: models.ExecutionKey{ Project: "project", Domain: "domain", Name: "name", }, + User: "", Spec: specBytes, Phase: phase, Closure: closureBytes, LaunchPlanID: uint(1), WorkflowID: uint(2), StartedAt: &startedAt, + State: &stateInt, } execution, err := FromExecutionModel(executionModel) assert.Nil(t, err) @@ -548,7 +566,9 @@ func TestFromExecutionModels(t *testing.T) { specBytes, _ := proto.Marshal(spec) phase := core.WorkflowExecution_SUCCEEDED.String() startedAt := time.Date(2018, 8, 30, 0, 0, 0, 0, time.UTC) + createdAt := time.Date(2022, 01, 18, 0, 0, 0, 0, time.UTC) startedAtProto, _ := ptypes.TimestampProto(startedAt) + createdAtProto, _ := ptypes.TimestampProto(createdAt) duration := 2 * time.Minute durationProto := ptypes.DurationProto(duration) closure := admin.ExecutionClosure{ @@ -556,11 +576,18 @@ func TestFromExecutionModels(t *testing.T) { Phase: core.WorkflowExecution_RUNNING, StartedAt: startedAtProto, Duration: durationProto, + StateChangeDetails: &admin.ExecutionStateChangeDetails{ + State: admin.ExecutionState_EXECUTION_ACTIVE, + OccurredAt: createdAtProto, + }, } closureBytes, _ := proto.Marshal(&closure) - + stateInt := int32(admin.ExecutionState_EXECUTION_ACTIVE) executionModels := []models.Execution{ { + BaseModel: models.BaseModel{ + CreatedAt: createdAt, + }, ExecutionKey: models.ExecutionKey{ Project: "project", Domain: "domain", @@ -573,6 +600,7 @@ func TestFromExecutionModels(t *testing.T) { WorkflowID: uint(2), StartedAt: &startedAt, Duration: duration, + State: &stateInt, }, } executions, err := FromExecutionModels(executionModels) @@ -713,3 +741,73 @@ func TestReassignCluster(t *testing.T) { assert.Equal(t, err.(errors.FlyteAdminError).Code(), codes.Internal) }) } + +func TestGetExecutionStateFromModel(t *testing.T) { + createdAt := time.Date(2022, 01, 90, 16, 0, 0, 0, time.UTC) + createdAtProto, _ := ptypes.TimestampProto(createdAt) + + t.Run("supporting older executions", func(t *testing.T) { + executionModel := models.Execution{ + BaseModel: models.BaseModel{ + CreatedAt: createdAt, + }, + } + executionStatus, err := PopulateDefaultStateChangeDetails(executionModel) + assert.Nil(t, err) + assert.NotNil(t, executionStatus) + assert.Equal(t, admin.ExecutionState_EXECUTION_ACTIVE, executionStatus.State) + assert.NotNil(t, executionStatus.OccurredAt) + assert.Equal(t, createdAtProto, executionStatus.OccurredAt) + }) + t.Run("incorrect created at", func(t *testing.T) { + createdAt := time.Unix(math.MinInt64, math.MinInt32).UTC() + executionModel := models.Execution{ + BaseModel: models.BaseModel{ + CreatedAt: createdAt, + }, + } + executionStatus, err := PopulateDefaultStateChangeDetails(executionModel) + assert.NotNil(t, err) + assert.Nil(t, executionStatus) + }) +} + +func TestUpdateExecutionModelStateChangeDetails(t *testing.T) { + t.Run("empty closure", func(t *testing.T) { + execModel := &models.Execution{} + stateUpdatedAt := time.Now() + statetUpdateAtProto, err := ptypes.TimestampProto(stateUpdatedAt) + assert.Nil(t, err) + err = UpdateExecutionModelStateChangeDetails(execModel, admin.ExecutionState_EXECUTION_ARCHIVED, + stateUpdatedAt, "dummyUser") + assert.Nil(t, err) + stateInt := int32(admin.ExecutionState_EXECUTION_ARCHIVED) + assert.Equal(t, execModel.State, &stateInt) + var closure admin.ExecutionClosure + err = proto.Unmarshal(execModel.Closure, &closure) + assert.Nil(t, err) + assert.NotNil(t, closure) + assert.NotNil(t, closure.StateChangeDetails) + assert.Equal(t, admin.ExecutionState_EXECUTION_ARCHIVED, closure.StateChangeDetails.State) + assert.Equal(t, "dummyUser", closure.StateChangeDetails.Principal) + assert.Equal(t, statetUpdateAtProto, closure.StateChangeDetails.OccurredAt) + + }) + t.Run("bad closure", func(t *testing.T) { + execModel := &models.Execution{ + Closure: []byte{1, 2, 3}, + } + err := UpdateExecutionModelStateChangeDetails(execModel, admin.ExecutionState_EXECUTION_ARCHIVED, + time.Now(), "dummyUser") + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "Failed to unmarshal execution closure") + }) + t.Run("bad stateUpdatedAt time", func(t *testing.T) { + execModel := &models.Execution{} + badTimeData := time.Unix(math.MinInt64, math.MinInt32).UTC() + err := UpdateExecutionModelStateChangeDetails(execModel, admin.ExecutionState_EXECUTION_ARCHIVED, + badTimeData, "dummyUser") + assert.NotNil(t, err) + assert.False(t, strings.Contains(err.Error(), "Failed to unmarshal execution closure")) + }) +} diff --git a/pkg/rpc/adminservice/execution.go b/pkg/rpc/adminservice/execution.go index adf7913cea..5a3d21f04a 100644 --- a/pkg/rpc/adminservice/execution.go +++ b/pkg/rpc/adminservice/execution.go @@ -142,6 +142,31 @@ func (m *AdminService) GetExecution( return response, nil } +func (m *AdminService) UpdateExecution( + ctx context.Context, request *admin.ExecutionUpdateRequest) (*admin.ExecutionUpdateResponse, error) { + defer m.interceptPanic(ctx, request) + requestedAt := time.Now() + if request == nil { + return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") + } + var response *admin.ExecutionUpdateResponse + var err error + m.Metrics.executionEndpointMetrics.update.Time(func() { + response, err = m.ExecutionManager.UpdateExecution(ctx, *request, requestedAt) + }) + audit.NewLogBuilder().WithAuthenticatedCtx(ctx).WithRequest( + "UpdateExecution", + audit.ParametersFromExecutionIdentifier(request.Id), + audit.ReadWrite, + requestedAt, + ).WithResponse(time.Now(), err).Log(ctx) + if err != nil { + return nil, util.TransformAndRecordError(err, &m.Metrics.executionEndpointMetrics.update) + } + m.Metrics.executionEndpointMetrics.update.Success() + return response, nil +} + func (m *AdminService) GetExecutionData( ctx context.Context, request *admin.WorkflowExecutionGetDataRequest) (*admin.WorkflowExecutionGetDataResponse, error) { defer m.interceptPanic(ctx, request) diff --git a/pkg/rpc/adminservice/tests/execution_test.go b/pkg/rpc/adminservice/tests/execution_test.go index 1f6f0d6cf2..a086bff1e7 100644 --- a/pkg/rpc/adminservice/tests/execution_test.go +++ b/pkg/rpc/adminservice/tests/execution_test.go @@ -291,6 +291,46 @@ func TestGetExecutionError(t *testing.T) { assert.Nil(t, actualResponse) } +func TestUpdateExecution(t *testing.T) { + response := &admin.ExecutionUpdateResponse{} + mockExecutionManager := mocks.MockExecutionManager{} + mockExecutionManager.SetUpdateExecutionCallback( + func(ctx context.Context, + request admin.ExecutionUpdateRequest, requestedAt time.Time) (*admin.ExecutionUpdateResponse, error) { + assert.True(t, proto.Equal(&workflowExecutionIdentifier, request.Id)) + return response, nil + }, + ) + mockServer := NewMockAdminServer(NewMockAdminServerInput{ + executionManager: &mockExecutionManager, + }) + + actualResponse, err := mockServer.UpdateExecution(context.Background(), &admin.ExecutionUpdateRequest{ + Id: &workflowExecutionIdentifier, + }) + assert.NoError(t, err) + assert.True(t, proto.Equal(response, actualResponse)) +} + +func TestUpdateExecutionError(t *testing.T) { + mockExecutionManager := mocks.MockExecutionManager{} + mockExecutionManager.SetUpdateExecutionCallback( + func(ctx context.Context, + request admin.ExecutionUpdateRequest, requestedAt time.Time) (*admin.ExecutionUpdateResponse, error) { + return nil, errors.New("expected error") + }, + ) + mockServer := NewMockAdminServer(NewMockAdminServerInput{ + executionManager: &mockExecutionManager, + }) + + actualResponse, err := mockServer.UpdateExecution(context.Background(), &admin.ExecutionUpdateRequest{ + Id: &workflowExecutionIdentifier, + }) + assert.EqualError(t, err, "expected error") + assert.Nil(t, actualResponse) +} + func TestListExecutions(t *testing.T) { mockExecutionManager := mocks.MockExecutionManager{} mockExecutionManager.SetListCallback(func(ctx context.Context, request admin.ResourceListRequest) (