Skip to content

Commit

Permalink
Added UpdateExecution API implementation (flyteorg#317)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmahindrakar-oss authored Jan 13, 2022
1 parent be9e098 commit 577a5ef
Show file tree
Hide file tree
Showing 14 changed files with 266 additions and 9 deletions.
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,5 @@ require (
)

replace github.com/robfig/cron/v3 => github.com/unionai/cron/v3 v3.0.2-0.20210825070134-bfc34418fe84

replace github.com/flyteorg/flyteidl => github.com/flyteorg/flyteidl v0.21.20-0.20220111070000-bdd241a81330
7 changes: 2 additions & 5 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -310,11 +310,8 @@ github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga
github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94=
github.com/felixge/httpsnoop v1.0.1 h1:lvB5Jl89CsZtGIWuTcDM1E/vkVs49/Ml7JJe07l8SPQ=
github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/flyteorg/flyteidl v0.21.4/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U=
github.com/flyteorg/flyteidl v0.21.16 h1:MepQ5iNnYbYDlR5kfqS4PX6KEk6raNA9VzI08zkvAK0=
github.com/flyteorg/flyteidl v0.21.16/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U=
github.com/flyteorg/flyteidl v0.21.17 h1:UJy5mQqgn99PBxVmiwPQeNFb0YdeqsCxFvOgZDIOLiE=
github.com/flyteorg/flyteidl v0.21.17/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U=
github.com/flyteorg/flyteidl v0.21.20-0.20220111070000-bdd241a81330 h1:pHiSSq3bVs9bKTc2SVnd7oLLuEn4uyXmx54bZqGZJ5M=
github.com/flyteorg/flyteidl v0.21.20-0.20220111070000-bdd241a81330/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U=
github.com/flyteorg/flyteplugins v0.7.1 h1:YdCEQtdPeol7u6LkopGTIfPLAhy3KcclQa+DZFauK8w=
github.com/flyteorg/flyteplugins v0.7.1/go.mod h1:kOiuXk1ddIEVSPoHcc4kBfVQcLuyf8jw3vWJT2Was90=
github.com/flyteorg/flytepropeller v0.15.13 h1:SObqD0/oPzSt1fXRJO8g0zm9IxjzwPcdSOXmOc70v4E=
Expand Down
51 changes: 51 additions & 0 deletions pkg/manager/impl/execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -1267,6 +1267,32 @@ func (m *ExecutionManager) GetExecution(
return execution, nil
}

func (m *ExecutionManager) UpdateExecution(
ctx context.Context, request admin.ExecutionUpdateRequest) (*admin.ExecutionUpdateResponse, error) {
if err := validation.ValidateWorkflowExecutionIdentifier(request.Id); err != nil {
logger.Debugf(ctx, "UpdateExecution request [%+v] failed validation with err: %v", request, err)
return nil, err
}
ctx = getExecutionContext(ctx, request.Id)
executionModel, err := util.GetExecutionModel(ctx, m.db, *request.Id)
if err != nil {
logger.Debugf(ctx, "Failed to get execution model for request [%+v] with err: %v", request, err)
return nil, err
}

stateInt := int32(admin.ExecutionStatus_EXECUTION_ACTIVE)
if request.Status != nil {
stateInt = int32(request.Status.State)
}
executionModel.State = &stateInt

if err := m.db.ExecutionRepo().Update(ctx, *executionModel); err != nil {
return nil, err
}

return &admin.ExecutionUpdateResponse{}, nil
}

func (m *ExecutionManager) GetExecutionData(
ctx context.Context, request admin.WorkflowExecutionGetDataRequest) (*admin.WorkflowExecutionGetDataResponse, error) {
ctx = getExecutionContext(ctx, request.Id)
Expand Down Expand Up @@ -1357,6 +1383,12 @@ func (m *ExecutionManager) ListExecutions(
for _, filter := range filters {
joinTableEntities[filter.GetEntity()] = true
}

// Check if state filter exists and if not then add filter to fetch only ACTIVE executions
if filters, err = addStateFilter(filters); err != nil {
return nil, err
}

listExecutionsInput := repositoryInterfaces.ListResourceInput{
Limit: int(request.Limit),
Offset: offset,
Expand Down Expand Up @@ -1586,3 +1618,22 @@ func (m *ExecutionManager) addProjectLabels(ctx context.Context, projectName str
}
return initialLabels, nil
}

func addStateFilter(filters []common.InlineFilter) ([]common.InlineFilter, error) {
var stateFilterExists bool
for _, inlineFilter := range filters {
if inlineFilter.GetField() == shared.State {
stateFilterExists = true
}
}

if !stateFilterExists {
stateFilter, err := common.NewSingleValueFilter(common.Execution, common.Equal, shared.State,
admin.ExecutionStatus_EXECUTION_ACTIVE)
if err != nil {
return filters, err
}
filters = append(filters, stateFilter)
}
return filters, nil
}
111 changes: 111 additions & 0 deletions pkg/manager/impl/execution_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1820,6 +1820,81 @@ func TestGetExecution_TransformerError(t *testing.T) {
assert.Equal(t, codes.Internal, err.(flyteAdminErrors.FlyteAdminError).Code())
}

func TestUpdateExecution(t *testing.T) {
t.Run("invalid execution identifier", func(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
execManager := NewExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{})
_, err := execManager.UpdateExecution(context.Background(), admin.ExecutionUpdateRequest{
Id: &core.WorkflowExecutionIdentifier{
Project: "project",
Domain: "domain",
},
})
assert.Error(t, err)
})

t.Run("empty status passed", func(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
updateExecFunc := func(ctx context.Context, execModel models.Execution) error {
stateInt := int32(admin.ExecutionStatus_EXECUTION_ACTIVE)
assert.Equal(t, stateInt, *execModel.State)
return nil
}
repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetUpdateExecutionCallback(updateExecFunc)
execManager := NewExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{})
updateResponse, err := execManager.UpdateExecution(context.Background(), admin.ExecutionUpdateRequest{
Id: &executionIdentifier,
})
assert.NoError(t, err)
assert.NotNil(t, updateResponse)
})

t.Run("archive status passed", func(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
updateExecFunc := func(ctx context.Context, execModel models.Execution) error {
stateInt := int32(admin.ExecutionStatus_EXECUTION_ARCHIVED)
assert.Equal(t, stateInt, *execModel.State)
return nil
}
repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetUpdateExecutionCallback(updateExecFunc)
execManager := NewExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{})
updateResponse, err := execManager.UpdateExecution(context.Background(), admin.ExecutionUpdateRequest{
Id: &executionIdentifier,
Status: &admin.ExecutionStatus{State: admin.ExecutionStatus_EXECUTION_ARCHIVED},
})
assert.NoError(t, err)
assert.NotNil(t, updateResponse)
})

t.Run("update error", func(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
updateExecFunc := func(ctx context.Context, execModel models.Execution) error {
return fmt.Errorf("some db error")
}
repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetUpdateExecutionCallback(updateExecFunc)
execManager := NewExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{})
_, err := execManager.UpdateExecution(context.Background(), admin.ExecutionUpdateRequest{
Id: &executionIdentifier,
Status: &admin.ExecutionStatus{State: admin.ExecutionStatus_EXECUTION_ARCHIVED},
})
assert.Error(t, err)
})

t.Run("get execution error", func(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
getExecFunc := func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) {
return models.Execution{}, fmt.Errorf("some db error")
}
repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback(getExecFunc)
execManager := NewExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{})
_, err := execManager.UpdateExecution(context.Background(), admin.ExecutionUpdateRequest{
Id: &executionIdentifier,
Status: &admin.ExecutionStatus{State: admin.ExecutionStatus_EXECUTION_ARCHIVED},
})
assert.Error(t, err)
})
}

func TestListExecutions(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
executionListFunc := func(
Expand Down Expand Up @@ -3749,3 +3824,39 @@ func TestFromAdminProtoTaskResourceSpec(t *testing.T) {
GPU: resource.MustParse("2"),
}, taskResourceSet)
}

func TestAddStateFilter(t *testing.T) {
t.Run("empty filters", func(t *testing.T) {
var filters []common.InlineFilter
updatedFilters, err := addStateFilter(filters)
assert.Nil(t, err)
assert.NotNil(t, updatedFilters)
assert.Equal(t, 1, len(updatedFilters))

assert.Equal(t, shared.State, updatedFilters[0].GetField())
assert.Equal(t, common.Execution, updatedFilters[0].GetEntity())

expression, err := updatedFilters[0].GetGormQueryExpr()
assert.NoError(t, err)
assert.Equal(t, "state = ?", expression.Query)
})

t.Run("passed state filter", func(t *testing.T) {
filter, err := common.NewSingleValueFilter(common.Execution, common.NotEqual, "state", "0")
assert.NoError(t, err)
filters := []common.InlineFilter{filter}

updatedFilters, err := addStateFilter(filters)
assert.Nil(t, err)
assert.NotNil(t, updatedFilters)
assert.Equal(t, 1, len(updatedFilters))

assert.Equal(t, shared.State, updatedFilters[0].GetField())
assert.Equal(t, common.Execution, updatedFilters[0].GetEntity())

expression, err := updatedFilters[0].GetGormQueryExpr()
assert.NoError(t, err)
assert.Equal(t, "state <> ?", expression.Query)
})

}
1 change: 1 addition & 0 deletions pkg/manager/interfaces/execution.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type ExecutionInterface interface {
CreateWorkflowEvent(ctx context.Context, request admin.WorkflowExecutionEventRequest) (
*admin.WorkflowExecutionEventResponse, error)
GetExecution(ctx context.Context, request admin.WorkflowExecutionGetRequest) (*admin.Execution, error)
UpdateExecution(ctx context.Context, request admin.ExecutionUpdateRequest) (*admin.ExecutionUpdateResponse, error)
GetExecutionData(ctx context.Context, request admin.WorkflowExecutionGetDataRequest) (
*admin.WorkflowExecutionGetDataResponse, error)
ListExecutions(ctx context.Context, request admin.ResourceListRequest) (*admin.ExecutionList, error)
Expand Down
14 changes: 14 additions & 0 deletions pkg/manager/mocks/execution.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type RecoverExecutionFunc func(ctx context.Context, request admin.ExecutionRecov
type CreateExecutionEventFunc func(ctx context.Context, request admin.WorkflowExecutionEventRequest) (
*admin.WorkflowExecutionEventResponse, error)
type GetExecutionFunc func(ctx context.Context, request admin.WorkflowExecutionGetRequest) (*admin.Execution, error)
type UpdateExecutionFunc func(ctx context.Context, request admin.ExecutionUpdateRequest) (*admin.ExecutionUpdateResponse, error)
type GetExecutionDataFunc func(ctx context.Context, request admin.WorkflowExecutionGetDataRequest) (
*admin.WorkflowExecutionGetDataResponse, error)
type ListExecutionFunc func(ctx context.Context, request admin.ResourceListRequest) (*admin.ExecutionList, error)
Expand All @@ -30,6 +31,7 @@ type MockExecutionManager struct {
RecoverExecutionFunc RecoverExecutionFunc
createExecutionEventFunc CreateExecutionEventFunc
getExecutionFunc GetExecutionFunc
updateExecutionFunc UpdateExecutionFunc
getExecutionDataFunc GetExecutionDataFunc
listExecutionFunc ListExecutionFunc
terminateExecutionFunc TerminateExecutionFunc
Expand Down Expand Up @@ -94,6 +96,18 @@ func (m *MockExecutionManager) GetExecution(
return nil, nil
}

func (m *MockExecutionManager) SetUpdateExecutionCallback(updateExecutionFunc UpdateExecutionFunc) {
m.updateExecutionFunc = updateExecutionFunc
}

func (m *MockExecutionManager) UpdateExecution(
ctx context.Context, request admin.ExecutionUpdateRequest) (*admin.ExecutionUpdateResponse, error) {
if m.updateExecutionFunc != nil {
return m.updateExecutionFunc(ctx, request)
}
return nil, nil
}

func (m *MockExecutionManager) SetGetDataCallback(getExecutionDataFunc GetExecutionDataFunc) {
m.getExecutionDataFunc = getExecutionDataFunc
}
Expand Down
11 changes: 11 additions & 0 deletions pkg/repositories/config/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,4 +329,15 @@ var Migrations = []*gormigrate.Migration{
return tx.Migrator().DropTable(&schedulerModels.ScheduleEntitiesSnapshot{}, "schedulable_entities_snapshot")
},
},

// Add state to execution model
{
ID: "2022-01-11-execution-state",
Migrate: func(tx *gorm.DB) error {
return tx.AutoMigrate(&models.Execution{})
},
Rollback: func(tx *gorm.DB) error {
return tx.Table("execution").Migrator().DropColumn(&models.Execution{}, "state")
},
},
}
6 changes: 3 additions & 3 deletions pkg/repositories/gormimpl/execution_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ func (r *ExecutionRepo) Update(ctx context.Context, execution models.Execution)

func (r *ExecutionRepo) List(ctx context.Context, input interfaces.ListResourceInput) (
interfaces.ExecutionCollectionOutput, error) {
var err error
// First validate input.
if err := ValidateListInput(input); err != nil {
if err = ValidateListInput(input); err != nil {
return interfaces.ExecutionCollectionOutput{}, err
}
var executions []models.Execution
Expand All @@ -89,11 +90,10 @@ func (r *ExecutionRepo) List(ctx context.Context, input interfaces.ListResourceI
}

// Apply filters
tx, err := applyScopedFilters(tx, input.InlineFilters, input.MapFilters)
tx, err = applyScopedFilters(tx, input.InlineFilters, input.MapFilters)
if err != nil {
return interfaces.ExecutionCollectionOutput{}, err
}

// Apply sort ordering.
if input.SortParameter != nil {
tx = tx.Order(input.SortParameter.GetGormOrderExpr())
Expand Down
2 changes: 1 addition & 1 deletion pkg/repositories/gormimpl/execution_repo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ func TestListExecutionsForWorkflow(t *testing.T) {
GlobalMock.Logging = true

// Only match on queries that append expected filters
GlobalMock.NewMock().WithQuery(`SELECT "executions"."id","executions"."created_at","executions"."updated_at","executions"."deleted_at","executions"."execution_project","executions"."execution_domain","executions"."execution_name","executions"."launch_plan_id","executions"."workflow_id","executions"."task_id","executions"."phase","executions"."closure","executions"."spec","executions"."started_at","executions"."execution_created_at","executions"."execution_updated_at","executions"."duration","executions"."abort_cause","executions"."mode","executions"."source_execution_id","executions"."parent_node_execution_id","executions"."cluster","executions"."inputs_uri","executions"."user_inputs_uri","executions"."error_kind","executions"."error_code","executions"."user" FROM "executions" INNER JOIN workflows ON executions.workflow_id = workflows.id INNER JOIN tasks ON executions.task_id = tasks.id WHERE executions.execution_project = $1 AND executions.execution_domain = $2 AND executions.execution_name = $3 AND (workflows.name = $4) AND tasks.name = $5 LIMIT`).WithReply(executions)
GlobalMock.NewMock().WithQuery(`SELECT "executions"."id","executions"."created_at","executions"."updated_at","executions"."deleted_at","executions"."execution_project","executions"."execution_domain","executions"."execution_name","executions"."launch_plan_id","executions"."workflow_id","executions"."task_id","executions"."phase","executions"."closure","executions"."spec","executions"."started_at","executions"."execution_created_at","executions"."execution_updated_at","executions"."duration","executions"."abort_cause","executions"."mode","executions"."source_execution_id","executions"."parent_node_execution_id","executions"."cluster","executions"."inputs_uri","executions"."user_inputs_uri","executions"."error_kind","executions"."error_code","executions"."user","executions"."state" FROM "executions" INNER JOIN workflows ON executions.workflow_id = workflows.id INNER JOIN tasks ON executions.task_id = tasks.id WHERE executions.execution_project = $1 AND executions.execution_domain = $2 AND executions.execution_name = $3 AND (workflows.name = $4) AND tasks.name = $5 LIMIT 20`).WithReply(executions)

collection, err := executionRepo.List(context.Background(), interfaces.ListResourceInput{
InlineFilters: []common.InlineFilter{
Expand Down
2 changes: 2 additions & 0 deletions pkg/repositories/models/execution.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,6 @@ type Execution struct {
// The user responsible for launching this execution.
// This is also stored in the spec but promoted as a column for filtering.
User string `gorm:"index" valid:"length(0|255)"`
// GORM doesn't save the zero value for ints, so we use a pointer for the State field
State *int32 `gorm:"default:0"`
}
25 changes: 25 additions & 0 deletions pkg/rpc/adminservice/execution.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,31 @@ func (m *AdminService) GetExecution(
return response, nil
}

func (m *AdminService) UpdateExecution(
ctx context.Context, request *admin.ExecutionUpdateRequest) (*admin.ExecutionUpdateResponse, error) {
defer m.interceptPanic(ctx, request)
requestedAt := time.Now()
if request == nil {
return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed")
}
var response *admin.ExecutionUpdateResponse
var err error
m.Metrics.executionEndpointMetrics.update.Time(func() {
response, err = m.ExecutionManager.UpdateExecution(ctx, *request)
})
audit.NewLogBuilder().WithAuthenticatedCtx(ctx).WithRequest(
"UpdateExecution",
audit.ParametersFromExecutionIdentifier(request.Id),
audit.ReadWrite,
requestedAt,
).WithResponse(time.Now(), err).Log(ctx)
if err != nil {
return nil, util.TransformAndRecordError(err, &m.Metrics.executionEndpointMetrics.update)
}
m.Metrics.executionEndpointMetrics.update.Success()
return response, nil
}

func (m *AdminService) GetExecutionData(
ctx context.Context, request *admin.WorkflowExecutionGetDataRequest) (*admin.WorkflowExecutionGetDataResponse, error) {
defer m.interceptPanic(ctx, request)
Expand Down
2 changes: 2 additions & 0 deletions pkg/rpc/adminservice/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type executionEndpointMetrics struct {
recover util.RequestMetrics
createEvent util.RequestMetrics
get util.RequestMetrics
update util.RequestMetrics
getData util.RequestMetrics
list util.RequestMetrics
terminate util.RequestMetrics
Expand Down Expand Up @@ -125,6 +126,7 @@ func InitMetrics(adminScope promutils.Scope) AdminMetrics {
recover: util.NewRequestMetrics(adminScope, "recover_execution"),
createEvent: util.NewRequestMetrics(adminScope, "create_execution_event"),
get: util.NewRequestMetrics(adminScope, "get_execution"),
update: util.NewRequestMetrics(adminScope, "update_execution"),
getData: util.NewRequestMetrics(adminScope, "get_execution_data"),
list: util.NewRequestMetrics(adminScope, "list_execution"),
terminate: util.NewRequestMetrics(adminScope, "terminate_execution"),
Expand Down
Loading

0 comments on commit 577a5ef

Please sign in to comment.