Skip to content

Commit

Permalink
Avoid making nodeExecContext public
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Newton <[email protected]>
  • Loading branch information
Tom-Newton committed Jan 16, 2024
1 parent 99fc625 commit 216e1a3
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 21 deletions.
40 changes: 20 additions & 20 deletions flytepropeller/pkg/controller/nodes/node_exec_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func (e nodeExecMetadata) GetLabels() map[string]string {
return e.nodeLabels
}

type NodeExecContext struct {
type nodeExecContext struct {
store *storage.DataStore
tr interfaces.TaskReader
md interfaces.NodeExecutionMetadata
Expand All @@ -135,78 +135,78 @@ type NodeExecContext struct {
ic executors.ExecutionContext
}

func (e NodeExecContext) ExecutionContext() executors.ExecutionContext {
func (e nodeExecContext) ExecutionContext() executors.ExecutionContext {
return e.ic
}

func (e NodeExecContext) ContextualNodeLookup() executors.NodeLookup {
func (e nodeExecContext) ContextualNodeLookup() executors.NodeLookup {
return e.nl
}

func (e NodeExecContext) OutputShardSelector() ioutils.ShardSelector {
func (e nodeExecContext) OutputShardSelector() ioutils.ShardSelector {
return e.shardSelector
}

func (e NodeExecContext) RawOutputPrefix() storage.DataReference {
func (e nodeExecContext) RawOutputPrefix() storage.DataReference {
return e.rawOutputPrefix
}

func (e NodeExecContext) EnqueueOwnerFunc() func() error {
func (e nodeExecContext) EnqueueOwnerFunc() func() error {
return e.enqueueOwner
}

func (e NodeExecContext) TaskReader() interfaces.TaskReader {
func (e nodeExecContext) TaskReader() interfaces.TaskReader {
return e.tr
}

func (e NodeExecContext) NodeStateReader() interfaces.NodeStateReader {
func (e nodeExecContext) NodeStateReader() interfaces.NodeStateReader {
return e.nsm
}

func (e NodeExecContext) NodeStateWriter() interfaces.NodeStateWriter {
func (e nodeExecContext) NodeStateWriter() interfaces.NodeStateWriter {
return e.nsm
}

func (e NodeExecContext) DataStore() *storage.DataStore {
func (e nodeExecContext) DataStore() *storage.DataStore {
return e.store
}

func (e NodeExecContext) InputReader() io.InputReader {
func (e nodeExecContext) InputReader() io.InputReader {
return e.inputs
}

func (e NodeExecContext) EventsRecorder() interfaces.EventRecorder {
func (e nodeExecContext) EventsRecorder() interfaces.EventRecorder {
return e.eventRecorder
}

func (e NodeExecContext) NodeID() v1alpha1.NodeID {
func (e nodeExecContext) NodeID() v1alpha1.NodeID {
return e.node.GetID()
}

func (e NodeExecContext) Node() v1alpha1.ExecutableNode {
func (e nodeExecContext) Node() v1alpha1.ExecutableNode {
return e.node
}

func (e NodeExecContext) CurrentAttempt() uint32 {
func (e nodeExecContext) CurrentAttempt() uint32 {
return e.nodeStatus.GetAttempts()
}

func (e NodeExecContext) NodeStatus() v1alpha1.ExecutableNodeStatus {
func (e nodeExecContext) NodeStatus() v1alpha1.ExecutableNodeStatus {
return e.nodeStatus
}

func (e NodeExecContext) NodeExecutionMetadata() interfaces.NodeExecutionMetadata {
func (e nodeExecContext) NodeExecutionMetadata() interfaces.NodeExecutionMetadata {
return e.md
}

func (e NodeExecContext) MaxDatasetSizeBytes() int64 {
func (e nodeExecContext) MaxDatasetSizeBytes() int64 {
return e.maxDatasetSizeBytes
}

func newNodeExecContext(_ context.Context, store *storage.DataStore, execContext executors.ExecutionContext, nl executors.NodeLookup,
node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus, inputs io.InputReader, interruptible bool, interruptibleFailureThreshold int32,
maxDatasetSize int64, taskEventRecorder events.TaskEventRecorder, nodeEventRecorder events.NodeEventRecorder, tr interfaces.TaskReader, nsm *nodeStateManager,
enqueueOwner func() error, rawOutputPrefix storage.DataReference, outputShardSelector ioutils.ShardSelector) *NodeExecContext {
enqueueOwner func() error, rawOutputPrefix storage.DataReference, outputShardSelector ioutils.ShardSelector) *nodeExecContext {

md := nodeExecMetadata{
Meta: execContext,
Expand All @@ -230,7 +230,7 @@ func newNodeExecContext(_ context.Context, store *storage.DataStore, execContext
nodeLabels[NodeInterruptibleLabel] = strconv.FormatBool(interruptible)
md.nodeLabels = nodeLabels

return &NodeExecContext{
return &nodeExecContext{
md: md,
store: store,
node: node,
Expand Down
6 changes: 5 additions & 1 deletion flytepropeller/pkg/controller/workflow/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ type fakeRemoteWritePlugin struct {
t assert.TestingT
}

type fakeNodeExecContext interface {
Node() v1alpha1.ExecutableNode
}

func (f fakeRemoteWritePlugin) Handle(ctx context.Context, tCtx pluginCore.TaskExecutionContext) (pluginCore.Transition, error) {
logger.Infof(ctx, "----------------------------------------------------------------------------------------------")
logger.Infof(ctx, "Handle called for %s", tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName())
Expand Down Expand Up @@ -517,7 +521,7 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Failing(t *testing.T) {
h := &nodemocks.NodeHandler{}
h.OnAbortMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil)

startNodeMatcher := mock.MatchedBy(func(nodeExecContext *nodes.NodeExecContext) bool {
startNodeMatcher := mock.MatchedBy(func(nodeExecContext fakeNodeExecContext) bool {
return nodeExecContext.Node().IsStartNode()
})
h.OnHandleMatch(mock.Anything, startNodeMatcher).Return(handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil)
Expand Down

0 comments on commit 216e1a3

Please sign in to comment.