diff --git a/pkg/controller/nodes/handler/state.go b/pkg/controller/nodes/handler/state.go index 89adfc8f8..6f30df67d 100644 --- a/pkg/controller/nodes/handler/state.go +++ b/pkg/controller/nodes/handler/state.go @@ -21,6 +21,7 @@ type TaskNodeState struct { PluginPhaseVersion uint32 PluginState []byte PluginStateVersion uint32 + BarrierClockTick uint32 LastPhaseUpdatedAt time.Time PreviousNodeExecutionCheckpointURI storage.DataReference CleanupOnFailure bool diff --git a/pkg/controller/nodes/handler/transition.go b/pkg/controller/nodes/handler/transition.go index 8d145102d..335076b47 100644 --- a/pkg/controller/nodes/handler/transition.go +++ b/pkg/controller/nodes/handler/transition.go @@ -4,7 +4,6 @@ type TransitionType int const ( TransitionTypeEphemeral TransitionType = iota - // @deprecated support for Barrier type transitions has been deprecated TransitionTypeBarrier ) diff --git a/pkg/controller/nodes/node_state_manager.go b/pkg/controller/nodes/node_state_manager.go index 7a961fed5..ec31901f2 100644 --- a/pkg/controller/nodes/node_state_manager.go +++ b/pkg/controller/nodes/node_state_manager.go @@ -85,6 +85,7 @@ func (n nodeStateManager) GetTaskNodeState() handler.TaskNodeState { PluginPhaseVersion: tn.GetPhaseVersion(), PluginStateVersion: tn.GetPluginStateVersion(), PluginState: tn.GetPluginState(), + BarrierClockTick: tn.GetBarrierClockTick(), LastPhaseUpdatedAt: tn.GetLastPhaseUpdatedAt(), PreviousNodeExecutionCheckpointURI: tn.GetPreviousNodeExecutionCheckpointPath(), CleanupOnFailure: tn.GetCleanupOnFailure(), diff --git a/pkg/controller/nodes/task/barrier.go b/pkg/controller/nodes/task/barrier.go new file mode 100644 index 000000000..0b0f84b6e --- /dev/null +++ b/pkg/controller/nodes/task/barrier.go @@ -0,0 +1,61 @@ +package task + +import ( + "context" + "time" + + "github.com/flyteorg/flytestdlib/logger" + "k8s.io/apimachinery/pkg/util/cache" + + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/config" +) + +type BarrierKey = string + +type PluginCallLog struct { + PluginTransition *pluginRequestedTransition +} + +type BarrierTransition struct { + BarrierClockTick uint32 + CallLog PluginCallLog +} + +var NoBarrierTransition = BarrierTransition{BarrierClockTick: 0} + +type barrier struct { + barrierCacheExpiration time.Duration + barrierTransitions *cache.LRUExpireCache + barrierEnabled bool +} + +func (b *barrier) RecordBarrierTransition(ctx context.Context, k BarrierKey, bt BarrierTransition) { + if b.barrierEnabled { + b.barrierTransitions.Add(k, bt, b.barrierCacheExpiration) + } +} + +func (b *barrier) GetPreviousBarrierTransition(ctx context.Context, k BarrierKey) BarrierTransition { + if b.barrierEnabled { + if v, ok := b.barrierTransitions.Get(k); ok { + f, casted := v.(BarrierTransition) + if !casted { + logger.Errorf(ctx, "Failed to cast recorded value to BarrierTransition") + return NoBarrierTransition + } + return f + } + } + return NoBarrierTransition +} + +func newLRUBarrier(_ context.Context, cfg config.BarrierConfig) *barrier { + b := &barrier{ + barrierEnabled: cfg.Enabled, + } + if cfg.Enabled { + b.barrierCacheExpiration = cfg.CacheTTL.Duration + b.barrierTransitions = cache.NewLRUExpireCache(cfg.CacheSize) + } + return b +} diff --git a/pkg/controller/nodes/task/config/config.go b/pkg/controller/nodes/task/config/config.go index 020795675..4bc2937c5 100644 --- a/pkg/controller/nodes/task/config/config.go +++ b/pkg/controller/nodes/task/config/config.go @@ -20,6 +20,11 @@ var ( defaultConfig = &Config{ TaskPlugins: TaskPluginConfig{EnabledPlugins: []string{}, DefaultForTaskTypes: map[string]string{}}, MaxPluginPhaseVersions: 100000, + BarrierConfig: BarrierConfig{ + Enabled: true, + CacheSize: 10000, + CacheTTL: config.Duration{Duration: time.Minute * 30}, + }, BackOffConfig: BackOffConfig{ BaseSecond: 2, MaxDuration: config.Duration{Duration: time.Second * 20}, @@ -32,10 +37,17 @@ var ( type Config struct { TaskPlugins TaskPluginConfig `json:"task-plugins" pflag:",Task plugin configuration"` MaxPluginPhaseVersions int32 `json:"max-plugin-phase-versions" pflag:",Maximum number of plugin phase versions allowed for one phase."` + BarrierConfig BarrierConfig `json:"barrier" pflag:",Config for Barrier implementation"` BackOffConfig BackOffConfig `json:"backoff" pflag:",Config for Exponential BackOff implementation"` MaxErrorMessageLength int `json:"maxLogMessageLength" pflag:",Deprecated!!! Max length of error message."` } +type BarrierConfig struct { + Enabled bool `json:"enabled" pflag:",Enable Barrier transitions using inmemory context"` + CacheSize int `json:"cache-size" pflag:",Max number of barrier to preserve in memory"` + CacheTTL config.Duration `json:"cache-ttl" pflag:", Max duration that a barrier would be respected if the process is not restarted. This should account for time required to store the record into persistent storage (across multiple rounds."` +} + type TaskPluginConfig struct { EnabledPlugins []string `json:"enabled-plugins" pflag:",Plugins enabled currently"` // Maps task types to their plugin handler (by ID). diff --git a/pkg/controller/nodes/task/config/config_flags.go b/pkg/controller/nodes/task/config/config_flags.go index 540d0214d..a77a6f58e 100755 --- a/pkg/controller/nodes/task/config/config_flags.go +++ b/pkg/controller/nodes/task/config/config_flags.go @@ -52,6 +52,9 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "task-plugins.enabled-plugins"), defaultConfig.TaskPlugins.EnabledPlugins, "Plugins enabled currently") cmdFlags.Int32(fmt.Sprintf("%v%v", prefix, "max-plugin-phase-versions"), defaultConfig.MaxPluginPhaseVersions, "Maximum number of plugin phase versions allowed for one phase.") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "barrier.enabled"), defaultConfig.BarrierConfig.Enabled, "Enable Barrier transitions using inmemory context") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "barrier.cache-size"), defaultConfig.BarrierConfig.CacheSize, "Max number of barrier to preserve in memory") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "barrier.cache-ttl"), defaultConfig.BarrierConfig.CacheTTL.String(), " Max duration that a barrier would be respected if the process is not restarted. This should account for time required to store the record into persistent storage (across multiple rounds.") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "backoff.base-second"), defaultConfig.BackOffConfig.BaseSecond, "The number of seconds representing the base duration of the exponential backoff") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "backoff.max-duration"), defaultConfig.BackOffConfig.MaxDuration.String(), "The cap of the backoff duration") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "maxLogMessageLength"), defaultConfig.MaxErrorMessageLength, "Deprecated!!! Max length of error message.") diff --git a/pkg/controller/nodes/task/config/config_flags_test.go b/pkg/controller/nodes/task/config/config_flags_test.go index cc2f02534..ef4f327d6 100755 --- a/pkg/controller/nodes/task/config/config_flags_test.go +++ b/pkg/controller/nodes/task/config/config_flags_test.go @@ -127,6 +127,48 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_barrier.enabled", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("barrier.enabled", testValue) + if vBool, err := cmdFlags.GetBool("barrier.enabled"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.BarrierConfig.Enabled) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_barrier.cache-size", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("barrier.cache-size", testValue) + if vInt, err := cmdFlags.GetInt("barrier.cache-size"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.BarrierConfig.CacheSize) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_barrier.cache-ttl", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := defaultConfig.BarrierConfig.CacheTTL.String() + + cmdFlags.Set("barrier.cache-ttl", testValue) + if vString, err := cmdFlags.GetString("barrier.cache-ttl"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.BarrierConfig.CacheTTL) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) t.Run("Test_backoff.base-second", func(t *testing.T) { t.Run("Override", func(t *testing.T) { diff --git a/pkg/controller/nodes/task/handler.go b/pkg/controller/nodes/task/handler.go index d2d2107db..b267be777 100644 --- a/pkg/controller/nodes/task/handler.go +++ b/pkg/controller/nodes/task/handler.go @@ -197,6 +197,7 @@ type Handler struct { kubeClient pluginCore.KubeClient secretManager pluginCore.SecretManager resourceManager resourcemanager.BaseResourceManager + barrierCache *barrier cfg *config.Config pluginScope promutils.Scope eventConfig *controllerConfig.EventConfig @@ -567,19 +568,48 @@ func (t Handler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContex } } + barrierTick := uint32(0) occurredAt := time.Now() // STEP 2: If no cache-hit and not transitioning to PhaseWaitingForCache, then lets invoke the plugin and wait for a transition out of undefined if pluginTrns.execInfo.TaskNodeInfo == nil || (pluginTrns.pInfo.Phase() != pluginCore.PhaseWaitingForCache && pluginTrns.execInfo.TaskNodeInfo.TaskNodeMetadata.CacheStatus != core.CatalogCacheStatus_CACHE_HIT) { + prevBarrier := t.barrierCache.GetPreviousBarrierTransition(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()) + // Lets start with the current barrierTick (the value to be stored) same as the barrierTick in the cache + barrierTick = prevBarrier.BarrierClockTick + // Lets check if this value in cache is less than or equal to one in the store + if barrierTick <= ts.BarrierClockTick { + var err error + pluginTrns, err = t.invokePlugin(ctx, p, tCtx, ts) + if err != nil { + return handler.UnknownTransition, errors.Wrapf(errors.RuntimeExecutionError, nCtx.NodeID(), err, "failed during plugin execution") + } + if pluginTrns.IsPreviouslyObserved() { + logger.Debugf(ctx, "No state change for Task, previously observed same transition. Short circuiting.") + return pluginTrns.FinalTransition(ctx) + } + // Now no matter what we should update the barrierTick (stored in state) + // This is because the state is ahead of the inmemory representation + // This can happen in the case where the process restarted or the barrier cache got reset + barrierTick = ts.BarrierClockTick + // Now if the transition is of type barrier, lets tick the clock by one from the prev known value + // store that in the cache + if pluginTrns.ttype == handler.TransitionTypeBarrier { + logger.Infof(ctx, "Barrier transition observed for Plugin [%s], TaskExecID [%s]. recording: [%s]", p.GetID(), tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), pluginTrns.pInfo.String()) + barrierTick = barrierTick + 1 + t.barrierCache.RecordBarrierTransition(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), BarrierTransition{ + BarrierClockTick: barrierTick, + CallLog: PluginCallLog{ + PluginTransition: pluginTrns, + }, + }) - var err error - pluginTrns, err = t.invokePlugin(ctx, p, tCtx, ts) - if err != nil { - return handler.UnknownTransition, errors.Wrapf(errors.RuntimeExecutionError, nCtx.NodeID(), err, "failed during plugin execution") - } - if pluginTrns.IsPreviouslyObserved() { - logger.Debugf(ctx, "No state change for Task, previously observed same transition. Short circuiting.") - return pluginTrns.FinalTransition(ctx) + } + } else { + // Barrier tick will remain to be the one in cache. + // Now it may happen that the cache may get reset before we store the barrier tick + // this will cause us to lose that information and potentially replaying. + logger.Infof(ctx, "Replaying Barrier transition for cache tick [%d] < stored tick [%d], Plugin [%s], TaskExecID [%s]. recording: [%s]", barrierTick, ts.BarrierClockTick, p.GetID(), tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), prevBarrier.CallLog.PluginTransition.pInfo.String()) + pluginTrns = prevBarrier.CallLog.PluginTransition } } @@ -655,6 +685,7 @@ func (t Handler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContex PluginStateVersion: pluginTrns.pluginStateVersion, PluginPhase: pluginTrns.pInfo.Phase(), PluginPhaseVersion: pluginTrns.pInfo.Version(), + BarrierClockTick: barrierTick, LastPhaseUpdatedAt: time.Now(), PreviousNodeExecutionCheckpointURI: ts.PreviousNodeExecutionCheckpointURI, CleanupOnFailure: ts.CleanupOnFailure || pluginTrns.pInfo.CleanupOnFailure(), @@ -870,6 +901,7 @@ func New(ctx context.Context, kubeClient executors.Client, client catalog.Client asyncCatalog: async, resourceManager: nil, secretManager: secretmanager.NewFileEnvSecretManager(secretmanager.GetConfig()), + barrierCache: newLRUBarrier(ctx, cfg.BarrierConfig), cfg: cfg, eventConfig: eventConfig, clusterID: clusterID, diff --git a/pkg/controller/nodes/task/handler_test.go b/pkg/controller/nodes/task/handler_test.go index 38506c4c7..5c77a96e7 100644 --- a/pkg/controller/nodes/task/handler_test.go +++ b/pkg/controller/nodes/task/handler_test.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "testing" + "time" "github.com/flyteorg/flyteidl/clients/go/coreutils" "github.com/golang/protobuf/proto" @@ -705,8 +706,11 @@ func Test_task_Handle_NoCatalog(t *testing.T) { defaultPlugins: map[pluginCore.TaskType]pluginCore.Plugin{ "test": fakeplugins.NewPhaseBasedPlugin(), }, - pluginScope: promutils.NewTestScope(), - catalog: c, + pluginScope: promutils.NewTestScope(), + catalog: c, + barrierCache: newLRUBarrier(context.TODO(), config.BarrierConfig{ + Enabled: false, + }), resourceManager: noopRm, taskMetricsMap: make(map[MetricKey]*taskMetrics), eventConfig: eventConfig, @@ -763,6 +767,310 @@ func Test_task_Handle_NoCatalog(t *testing.T) { } } +func Test_task_Handle_Barrier(t *testing.T) { + // NOTE: Caching is disabled for this test + + createNodeContext := func(recorder interfaces.EventRecorder, ttype string, s *taskNodeStateHolder, prevBarrierClockTick uint32) *nodeMocks.NodeExecutionContext { + wfExecID := &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + } + + nodeID := "n1" + + nm := &nodeMocks.NodeExecutionMetadata{} + nm.OnGetAnnotations().Return(map[string]string{}) + nm.OnGetNodeExecutionID().Return(&core.NodeExecutionIdentifier{ + NodeId: nodeID, + ExecutionId: wfExecID, + }) + nm.OnGetK8sServiceAccount().Return("service-account") + nm.OnGetLabels().Return(map[string]string{}) + nm.OnGetNamespace().Return("namespace") + nm.OnGetOwnerID().Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) + nm.OnGetOwnerReference().Return(v12.OwnerReference{ + Kind: "sample", + Name: "name", + }) + nm.OnIsInterruptible().Return(true) + + taskID := &core.Identifier{} + tk := &core.TaskTemplate{ + Id: taskID, + Type: "test", + Metadata: &core.TaskMetadata{ + Discoverable: false, + }, + Interface: &core.TypedInterface{ + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "x": { + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_BOOLEAN, + }, + }, + }, + }, + }, + }, + } + tr := &nodeMocks.TaskReader{} + tr.OnGetTaskID().Return(taskID) + tr.OnGetTaskType().Return(ttype) + tr.OnReadMatch(mock.Anything).Return(tk, nil) + + ns := &flyteMocks.ExecutableNodeStatus{} + ns.OnGetDataDir().Return(storage.DataReference("data-dir")) + ns.OnGetOutputDir().Return(storage.DataReference("output-dir")) + + res := &v1.ResourceRequirements{} + n := &flyteMocks.ExecutableNode{} + ma := 5 + n.OnGetRetryStrategy().Return(&v1alpha1.RetryStrategy{MinAttempts: &ma}) + n.OnGetResources().Return(res) + + ir := &ioMocks.InputReader{} + ir.OnGetInputPath().Return(storage.DataReference("input")) + ir.OnGetMatch(mock.Anything).Return(&core.LiteralMap{}, nil) + nCtx := &nodeMocks.NodeExecutionContext{} + nCtx.OnNodeExecutionMetadata().Return(nm) + nCtx.OnNode().Return(n) + nCtx.OnInputReader().Return(ir) + ds, err := storage.NewDataStore( + &storage.Config{ + Type: storage.TypeMemory, + }, + promutils.NewTestScope(), + ) + assert.NoError(t, err) + nCtx.OnDataStore().Return(ds) + nCtx.OnCurrentAttempt().Return(uint32(1)) + nCtx.OnTaskReader().Return(tr) + nCtx.OnMaxDatasetSizeBytes().Return(int64(1)) + nCtx.OnNodeStatus().Return(ns) + nCtx.OnNodeID().Return("n1") + nCtx.OnEventsRecorder().Return(recorder) + nCtx.OnEnqueueOwnerFunc().Return(nil) + + executionContext := &mocks.ExecutionContext{} + executionContext.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{}) + executionContext.OnGetEventVersion().Return(v1alpha1.EventVersion0) + executionContext.OnGetParentInfo().Return(nil) + executionContext.OnIncrementParallelism().Return(1) + nCtx.OnExecutionContext().Return(executionContext) + + nCtx.OnRawOutputPrefix().Return("s3://sandbox/") + nCtx.OnOutputShardSelector().Return(ioutils.NewConstantShardSelector([]string{"x"})) + + st := bytes.NewBuffer([]byte{}) + cod := codex.GobStateCodec{} + assert.NoError(t, cod.Encode(&fakeplugins.NextPhaseState{ + Phase: pluginCore.PhaseSuccess, + OutputExists: true, + }, st)) + nr := &nodeMocks.NodeStateReader{} + nr.OnGetTaskNodeState().Return(handler.TaskNodeState{ + PluginState: st.Bytes(), + BarrierClockTick: prevBarrierClockTick, + }) + nCtx.OnNodeStateReader().Return(nr) + nCtx.OnNodeStateWriter().Return(s) + return nCtx + } + + noopRm := CreateNoopResourceManager(context.TODO(), promutils.NewTestScope()) + + trns := pluginCore.DoTransitionType(pluginCore.TransitionTypeBarrier, pluginCore.PhaseInfoQueued(time.Now(), 1, "z")) + type args struct { + prevTick uint32 + btrnsTick uint32 + bTrns *pluginCore.Transition + res []fakeplugins.HandleResponse + } + type wantBarrier struct { + hit bool + tick uint32 + } + type want struct { + wantBarrer wantBarrier + handlerPhase handler.EPhase + wantErr bool + eventPhase core.TaskExecution_Phase + pluginPhase pluginCore.Phase + pluginVer uint32 + } + tests := []struct { + name string + args args + want want + }{ + { + "ephemeral-trns", + args{ + res: []fakeplugins.HandleResponse{ + {T: pluginCore.DoTransitionType(pluginCore.TransitionTypeEphemeral, pluginCore.PhaseInfoRunning(1, &pluginCore.TaskInfo{}))}, + }, + }, + want{ + handlerPhase: handler.EPhaseRunning, + eventPhase: core.TaskExecution_RUNNING, + pluginPhase: pluginCore.PhaseRunning, + pluginVer: 1, + }, + }, + { + "first-barrier-trns", + args{ + res: []fakeplugins.HandleResponse{ + {T: pluginCore.DoTransitionType(pluginCore.TransitionTypeBarrier, pluginCore.PhaseInfoRunning(1, &pluginCore.TaskInfo{}))}, + }, + }, + want{ + wantBarrer: wantBarrier{ + hit: true, + tick: 1, + }, + handlerPhase: handler.EPhaseRunning, + eventPhase: core.TaskExecution_RUNNING, + pluginPhase: pluginCore.PhaseRunning, + pluginVer: 1, + }, + }, + { + "barrier-trns-replay", + args{ + prevTick: 0, + btrnsTick: 1, + bTrns: &trns, + }, + want{ + wantBarrer: wantBarrier{ + hit: true, + tick: 1, + }, + handlerPhase: handler.EPhaseRunning, + eventPhase: core.TaskExecution_QUEUED, + pluginPhase: pluginCore.PhaseQueued, + pluginVer: 1, + }, + }, + { + "barrier-trns-next", + args{ + prevTick: 1, + btrnsTick: 1, + bTrns: &trns, + res: []fakeplugins.HandleResponse{ + {T: pluginCore.DoTransitionType(pluginCore.TransitionTypeBarrier, pluginCore.PhaseInfoRunning(1, &pluginCore.TaskInfo{}))}, + }, + }, + want{ + wantBarrer: wantBarrier{ + hit: true, + tick: 2, + }, + handlerPhase: handler.EPhaseRunning, + eventPhase: core.TaskExecution_RUNNING, + pluginPhase: pluginCore.PhaseRunning, + pluginVer: 1, + }, + }, + { + "barrier-trns-restart-case", + args{ + prevTick: 2, + res: []fakeplugins.HandleResponse{ + {T: pluginCore.DoTransitionType(pluginCore.TransitionTypeBarrier, pluginCore.PhaseInfoRunning(1, &pluginCore.TaskInfo{}))}, + }, + }, + want{ + wantBarrer: wantBarrier{ + hit: true, + tick: 3, + }, + handlerPhase: handler.EPhaseRunning, + eventPhase: core.TaskExecution_RUNNING, + pluginPhase: pluginCore.PhaseRunning, + pluginVer: 1, + }, + }, + { + "barrier-trns-restart-case-ephemeral", + args{ + prevTick: 2, + res: []fakeplugins.HandleResponse{ + {T: pluginCore.DoTransitionType(pluginCore.TransitionTypeEphemeral, pluginCore.PhaseInfoRunning(1, &pluginCore.TaskInfo{}))}, + }, + }, + want{ + wantBarrer: wantBarrier{ + hit: false, + }, + handlerPhase: handler.EPhaseRunning, + eventPhase: core.TaskExecution_RUNNING, + pluginPhase: pluginCore.PhaseRunning, + pluginVer: 1, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + state := &taskNodeStateHolder{} + ev := &fakeBufferedEventRecorder{} + nCtx := createNodeContext(ev, "test", state, tt.args.prevTick) + c := &pluginCatalogMocks.Client{} + + tk, err := New(context.TODO(), mocks.NewFakeKubeClient(), c, eventConfig, testClusterID, promutils.NewTestScope()) + assert.NoError(t, err) + tk.resourceManager = noopRm + + p := &pluginCoreMocks.Plugin{} + p.On("GetID").Return("plugin1") + p.OnGetProperties().Return(pluginCore.PluginProperties{}) + tctx, err := tk.newTaskExecutionContext(context.TODO(), nCtx, p) + assert.NoError(t, err) + id := tctx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + + if tt.args.bTrns != nil { + x := &pluginRequestedTransition{} + x.ObservedTransitionAndState(*tt.args.bTrns, 0, nil) + tk.barrierCache.RecordBarrierTransition(context.TODO(), id, BarrierTransition{tt.args.btrnsTick, PluginCallLog{x}}) + } + + tk.defaultPlugins = map[pluginCore.TaskType]pluginCore.Plugin{ + "test": fakeplugins.NewReplayer("test", pluginCore.PluginProperties{}, + tt.args.res, nil, nil), + } + + got, err := tk.Handle(context.TODO(), nCtx) + if (err != nil) != tt.want.wantErr { + t.Errorf("Handler.Handle() error = %v, wantErr %v", err, tt.want.wantErr) + return + } + if err == nil { + assert.Equal(t, tt.want.handlerPhase.String(), got.Info().GetPhase().String()) + if assert.Equal(t, 1, len(ev.evs)) { + e := ev.evs[0] + assert.Equal(t, tt.want.eventPhase.String(), e.Phase.String()) + } + assert.Equal(t, tt.want.pluginPhase.String(), state.s.PluginPhase.String()) + assert.Equal(t, tt.want.pluginVer, state.s.PluginPhaseVersion) + if tt.want.wantBarrer.hit { + assert.Len(t, tk.barrierCache.barrierTransitions.Keys(), 1) + bt := tk.barrierCache.GetPreviousBarrierTransition(context.TODO(), id) + assert.Equal(t, bt.BarrierClockTick, tt.want.wantBarrer.tick) + assert.Equal(t, tt.want.wantBarrer.tick, state.s.BarrierClockTick) + } else { + assert.Len(t, tk.barrierCache.barrierTransitions.Keys(), 0) + assert.Equal(t, tt.args.prevTick, state.s.BarrierClockTick) + } + } + }) + } +} + func Test_task_Abort(t *testing.T) { createNodeCtx := func(ev interfaces.EventRecorder) *nodeMocks.NodeExecutionContext { wfExecID := &core.WorkflowExecutionIdentifier{ diff --git a/pkg/controller/nodes/transformers.go b/pkg/controller/nodes/transformers.go index 6e09c103b..254713d20 100644 --- a/pkg/controller/nodes/transformers.go +++ b/pkg/controller/nodes/transformers.go @@ -242,6 +242,7 @@ func UpdateNodeStatus(np v1alpha1.NodePhase, p handler.PhaseInfo, n interfaces.N t.SetLastPhaseUpdatedAt(nt.LastPhaseUpdatedAt) t.SetPluginState(nt.PluginState) t.SetPluginStateVersion(nt.PluginStateVersion) + t.SetBarrierClockTick(nt.BarrierClockTick) t.SetPreviousNodeExecutionCheckpointPath(nt.PreviousNodeExecutionCheckpointURI) t.SetCleanupOnFailure(nt.CleanupOnFailure) }