diff --git a/flyteplugins/go/tasks/plugins/presto/execution_state.go b/flyteplugins/go/tasks/plugins/presto/execution_state.go index 6217b4b21f..e9bcb4e00d 100644 --- a/flyteplugins/go/tasks/plugins/presto/execution_state.go +++ b/flyteplugins/go/tasks/plugins/presto/execution_state.go @@ -56,7 +56,8 @@ func (p ExecutionPhase) String() string { } type ExecutionState struct { - Phase ExecutionPhase + CurrentPhase ExecutionPhase + PreviousPhase ExecutionPhase // This will store the command ID from Presto CommandID string `json:"commandId,omitempty"` @@ -105,7 +106,7 @@ func HandleExecutionState( var transformError error var newState ExecutionState - switch currentState.Phase { + switch currentState.CurrentPhase { case PhaseNotStarted: newState, transformError = GetAllocationToken(ctx, tCtx, currentState, metrics) @@ -125,8 +126,11 @@ func HandleExecutionState( // If there are still Presto statements to execute, increment the query count, reset the phase to 'queued' // and continue executing the remaining statements. In this case, we won't request another allocation token // as the 5 statements that get executed are all considered to be part of the same "query" - currentState.Phase = PhaseQueued + currentState.PreviousPhase = currentState.CurrentPhase + currentState.CurrentPhase = PhaseQueued } else { + //currentState.Phase = PhaseQuerySucceeded + currentState.PreviousPhase = currentState.CurrentPhase transformError = writeOutput(ctx, tCtx, currentState.CurrentPrestoQuery.ExternalLocation) } currentState.QueryCount++ @@ -172,11 +176,11 @@ func GetAllocationToken( } if allocationStatus == core.AllocationStatusGranted { - newState.Phase = PhaseQueued + newState.CurrentPhase = PhaseQueued } else if allocationStatus == core.AllocationStatusExhausted { - newState.Phase = PhaseNotStarted + newState.CurrentPhase = PhaseNotStarted } else if allocationStatus == core.AllocationStatusNamespaceQuotaExceeded { - newState.Phase = PhaseNotStarted + newState.CurrentPhase = PhaseNotStarted } else { return newState, errors.Errorf(errors.ResourceManagerFailure, "Got bad allocation result [%s] for token [%s]", allocationStatus, uniqueID) @@ -389,7 +393,8 @@ func KickOffQuery( commandID := response.ID logger.Infof(ctx, "Created Presto ID [%s] for token %s", commandID, uniqueID) currentState.CommandID = commandID - currentState.Phase = PhaseSubmitted + currentState.PreviousPhase = currentState.CurrentPhase + currentState.CurrentPhase = PhaseSubmitted currentState.URI = response.NextURI currentState.CurrentPrestoQueryUUID = uniqueID @@ -475,7 +480,8 @@ func MapExecutionStateToPhaseInfo(state ExecutionState) core.PhaseInfo { var phaseInfo core.PhaseInfo t := time.Now() - switch state.Phase { + //switch state.Phase { + switch state.CurrentPhase { case PhaseNotStarted: phaseInfo = core.PhaseInfoNotReady(t, core.DefaultPhaseVersion, "Haven't received allocation token") case PhaseQueued: @@ -515,7 +521,7 @@ func ConstructTaskInfo(e ExecutionState) *core.TaskInfo { func ConstructTaskLog(e ExecutionState) *idlCore.TaskLog { return &idlCore.TaskLog{ - Name: fmt.Sprintf("Status: %s [%s]", e.Phase, e.CommandID), + Name: fmt.Sprintf("Status: %s [%s]", e.PreviousPhase, e.CommandID), MessageFormat: idlCore.TaskLog_UNKNOWN, Uri: e.URI, } @@ -551,11 +557,11 @@ func Finalize(ctx context.Context, tCtx core.TaskExecutionContext, _ ExecutionSt } func InTerminalState(e ExecutionState) bool { - return e.Phase == PhaseQuerySucceeded || e.Phase == PhaseQueryFailed + return e.CurrentPhase == PhaseQuerySucceeded || e.CurrentPhase == PhaseQueryFailed } func IsNotYetSubmitted(e ExecutionState) bool { - if e.Phase == PhaseNotStarted || e.Phase == PhaseQueued { + if e.CurrentPhase == PhaseNotStarted || e.CurrentPhase == PhaseQueued { return true } return false diff --git a/flyteplugins/go/tasks/plugins/presto/execution_state_test.go b/flyteplugins/go/tasks/plugins/presto/execution_state_test.go index 9de1d3b215..32a0997467 100644 --- a/flyteplugins/go/tasks/plugins/presto/execution_state_test.go +++ b/flyteplugins/go/tasks/plugins/presto/execution_state_test.go @@ -42,7 +42,7 @@ func TestInTerminalState(t *testing.T) { for _, tt := range stateTests { t.Run(tt.phase.String(), func(t *testing.T) { - e := ExecutionState{Phase: tt.phase} + e := ExecutionState{CurrentPhase: tt.phase} res := InTerminalState(e) assert.Equal(t, tt.isTerminal, res) }) @@ -63,7 +63,7 @@ func TestIsNotYetSubmitted(t *testing.T) { for _, tt := range stateTests { t.Run(tt.phase.String(), func(t *testing.T) { - e := ExecutionState{Phase: tt.phase} + e := ExecutionState{CurrentPhase: tt.phase} res := IsNotYetSubmitted(e) assert.Equal(t, tt.isNotYetSubmitted, res) }) @@ -98,7 +98,7 @@ func TestConstructTaskInfo(t *testing.T) { assert.NoError(t, err) e := ExecutionState{ - Phase: PhaseQuerySucceeded, + CurrentPhase: PhaseQuerySucceeded, CommandID: "123", SyncFailureCount: 0, URI: u.String(), @@ -111,7 +111,7 @@ func TestConstructTaskInfo(t *testing.T) { func TestMapExecutionStateToPhaseInfo(t *testing.T) { t.Run("NotStarted", func(t *testing.T) { e := ExecutionState{ - Phase: PhaseNotStarted, + CurrentPhase: PhaseNotStarted, } phaseInfo := MapExecutionStateToPhaseInfo(e) assert.Equal(t, core.PhaseNotReady, phaseInfo.Phase()) @@ -119,14 +119,14 @@ func TestMapExecutionStateToPhaseInfo(t *testing.T) { t.Run("Queued", func(t *testing.T) { e := ExecutionState{ - Phase: PhaseQueued, + CurrentPhase: PhaseQueued, CreationFailureCount: 0, } phaseInfo := MapExecutionStateToPhaseInfo(e) assert.Equal(t, core.PhaseRunning, phaseInfo.Phase()) e = ExecutionState{ - Phase: PhaseQueued, + CurrentPhase: PhaseQueued, CreationFailureCount: 100, } phaseInfo = MapExecutionStateToPhaseInfo(e) @@ -136,7 +136,7 @@ func TestMapExecutionStateToPhaseInfo(t *testing.T) { t.Run("Submitted", func(t *testing.T) { e := ExecutionState{ - Phase: PhaseSubmitted, + CurrentPhase: PhaseSubmitted, } phaseInfo := MapExecutionStateToPhaseInfo(e) assert.Equal(t, core.PhaseRunning, phaseInfo.Phase()) @@ -157,7 +157,7 @@ func TestGetAllocationToken(t *testing.T) { mockMetrics := getPrestoExecutorMetrics(promutils.NewTestScope()) state, err := GetAllocationToken(ctx, tCtx, mockCurrentState, mockMetrics) assert.NoError(t, err) - assert.Equal(t, PhaseQueued, state.Phase) + assert.Equal(t, PhaseQueued, state.CurrentPhase) }) t.Run("exhausted", func(t *testing.T) { @@ -171,7 +171,7 @@ func TestGetAllocationToken(t *testing.T) { mockMetrics := getPrestoExecutorMetrics(promutils.NewTestScope()) state, err := GetAllocationToken(ctx, tCtx, mockCurrentState, mockMetrics) assert.NoError(t, err) - assert.Equal(t, PhaseNotStarted, state.Phase) + assert.Equal(t, PhaseNotStarted, state.CurrentPhase) }) t.Run("namespace exhausted", func(t *testing.T) { @@ -185,7 +185,7 @@ func TestGetAllocationToken(t *testing.T) { mockMetrics := getPrestoExecutorMetrics(promutils.NewTestScope()) state, err := GetAllocationToken(ctx, tCtx, mockCurrentState, mockMetrics) assert.NoError(t, err) - assert.Equal(t, PhaseNotStarted, state.Phase) + assert.Equal(t, PhaseNotStarted, state.CurrentPhase) }) t.Run("Request start time, if empty in current state, should be set", func(t *testing.T) { @@ -232,7 +232,7 @@ func TestAbort(t *testing.T) { x = true }).Return(nil) - err := Abort(ctx, ExecutionState{Phase: PhaseSubmitted, CommandID: "123456"}, mockPresto) + err := Abort(ctx, ExecutionState{CurrentPhase: PhaseSubmitted, CommandID: "123456"}, mockPresto) assert.NoError(t, err) assert.True(t, x) }) @@ -245,7 +245,7 @@ func TestAbort(t *testing.T) { x = true }).Return(nil) - err := Abort(ctx, ExecutionState{Phase: PhaseQuerySucceeded, CommandID: "123456"}, mockPresto) + err := Abort(ctx, ExecutionState{CurrentPhase: PhaseQuerySucceeded, CommandID: "123456"}, mockPresto) assert.NoError(t, err) assert.False(t, x) }) @@ -272,12 +272,12 @@ func TestMonitorQuery(t *testing.T) { ctx := context.Background() tCtx := GetMockTaskExecutionContext() state := ExecutionState{ - Phase: PhaseSubmitted, + CurrentPhase: PhaseSubmitted, } var getOrCreateCalled = false mockCache := &mocks2.AutoRefresh{} mockCache.OnGetOrCreateMatch(mock.AnythingOfType("string"), mock.Anything).Return(ExecutionStateCacheItem{ - ExecutionState: ExecutionState{Phase: PhaseQuerySucceeded}, + ExecutionState: ExecutionState{CurrentPhase: PhaseQuerySucceeded}, Identifier: "my_wf_exec_project:my_wf_exec_domain:my_wf_exec_name", }, nil).Run(func(_ mock.Arguments) { getOrCreateCalled = true @@ -286,7 +286,7 @@ func TestMonitorQuery(t *testing.T) { newState, err := MonitorQuery(ctx, tCtx, state, mockCache) assert.NoError(t, err) assert.True(t, getOrCreateCalled) - assert.Equal(t, PhaseQuerySucceeded, newState.Phase) + assert.Equal(t, PhaseQuerySucceeded, newState.CurrentPhase) } func TestKickOffQuery(t *testing.T) { @@ -312,7 +312,7 @@ func TestKickOffQuery(t *testing.T) { state := ExecutionState{} newState, err := KickOffQuery(ctx, tCtx, state, mockPresto, mockCache) assert.NoError(t, err) - assert.Equal(t, PhaseSubmitted, newState.Phase) + assert.Equal(t, PhaseSubmitted, newState.CurrentPhase) assert.Equal(t, "1234567", newState.CommandID) assert.True(t, getOrCreateCalled) assert.True(t, prestoCalled) diff --git a/flyteplugins/go/tasks/plugins/presto/executions_cache.go b/flyteplugins/go/tasks/plugins/presto/executions_cache.go index a26c6b3a2d..28780ca46e 100644 --- a/flyteplugins/go/tasks/plugins/presto/executions_cache.go +++ b/flyteplugins/go/tasks/plugins/presto/executions_cache.go @@ -123,11 +123,12 @@ func (p *ExecutionsCache) SyncPrestoQuery(ctx context.Context, batch cache.Batch return nil, err } - if newExecutionPhase > executionStateCacheItem.Phase { + if newExecutionPhase > executionStateCacheItem.CurrentPhase { logger.Infof(ctx, "Moving ExecutionPhase for %s %s from %s to %s", executionStateCacheItem.CommandID, - executionStateCacheItem.Identifier, executionStateCacheItem.Phase, newExecutionPhase) + executionStateCacheItem.Identifier, executionStateCacheItem.CurrentPhase, newExecutionPhase) - executionStateCacheItem.Phase = newExecutionPhase + executionStateCacheItem.PreviousPhase = executionStateCacheItem.CurrentPhase + executionStateCacheItem.CurrentPhase = newExecutionPhase resp = append(resp, cache.ItemSyncResponse{ ID: query.GetID(), diff --git a/flyteplugins/go/tasks/plugins/presto/executions_cache_test.go b/flyteplugins/go/tasks/plugins/presto/executions_cache_test.go index 3f6114762f..4620b15f5d 100644 --- a/flyteplugins/go/tasks/plugins/presto/executions_cache_test.go +++ b/flyteplugins/go/tasks/plugins/presto/executions_cache_test.go @@ -34,7 +34,7 @@ func TestPrestoExecutionsCache_SyncQuboleQuery(t *testing.T) { } state := ExecutionState{ - Phase: PhaseQuerySucceeded, + CurrentPhase: PhaseQuerySucceeded, } cacheItem := ExecutionStateCacheItem{ ExecutionState: state, @@ -67,8 +67,8 @@ func TestPrestoExecutionsCache_SyncQuboleQuery(t *testing.T) { } state := ExecutionState{ - CommandID: "123456", - Phase: PhaseSubmitted, + CommandID: "123456", + CurrentPhase: PhaseSubmitted, } cacheItem := ExecutionStateCacheItem{ ExecutionState: state, @@ -86,6 +86,6 @@ func TestPrestoExecutionsCache_SyncQuboleQuery(t *testing.T) { newExecutionState := newCacheItem[0].Item.(ExecutionStateCacheItem) assert.NoError(t, err) assert.Equal(t, cache.Update, newCacheItem[0].Action) - assert.Equal(t, PhaseQuerySucceeded, newExecutionState.Phase) + assert.Equal(t, PhaseQuerySucceeded, newExecutionState.CurrentPhase) }) }