diff --git a/pkg/apis/flyteworkflow/v1alpha1/node_status.go b/pkg/apis/flyteworkflow/v1alpha1/node_status.go index aaed1357a7..be040d03ae 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/node_status.go +++ b/pkg/apis/flyteworkflow/v1alpha1/node_status.go @@ -523,7 +523,7 @@ func (in *NodeStatus) GetNodeExecutionStatus(ctx context.Context, id NodeID) Exe n.SetParentTaskID(in.GetParentTaskID()) n.DataReferenceConstructor = in.DataReferenceConstructor if len(n.GetDataDir()) == 0 { - dataDir, err := in.DataReferenceConstructor.ConstructReference(ctx, in.GetDataDir(), id) + dataDir, err := in.DataReferenceConstructor.ConstructReference(ctx, in.GetOutputDir(), id) if err != nil { logger.Errorf(ctx, "Failed to construct data dir for node [%v]", id) return n @@ -552,9 +552,10 @@ func (in *NodeStatus) GetNodeExecutionStatus(ctx context.Context, id NodeID) Exe newNodeStatus := &NodeStatus{ MutableStruct: MutableStruct{}, } + newNodeStatus.SetParentTaskID(in.GetParentTaskID()) newNodeStatus.SetParentNodeID(in.GetParentNodeID()) - dataDir, err := in.DataReferenceConstructor.ConstructReference(ctx, in.GetDataDir(), id) + dataDir, err := in.DataReferenceConstructor.ConstructReference(ctx, in.GetOutputDir(), id) if err != nil { logger.Errorf(ctx, "Failed to construct data dir for node [%v]", id) return n diff --git a/pkg/apis/flyteworkflow/v1alpha1/node_status_test.go b/pkg/apis/flyteworkflow/v1alpha1/node_status_test.go index 85827a7fc3..4d396cc6b1 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/node_status_test.go +++ b/pkg/apis/flyteworkflow/v1alpha1/node_status_test.go @@ -1,9 +1,12 @@ package v1alpha1 import ( + "context" "encoding/json" "testing" + "github.com/flyteorg/flytestdlib/storage" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/stretchr/testify/assert" ) @@ -190,3 +193,62 @@ func TestDynamicNodeStatus_SetExecutionError(t *testing.T) { }) } } + +func TestNodeStatus_GetNodeExecutionStatus(t *testing.T) { + ctx := context.Background() + t.Run("First Level", func(t *testing.T) { + t.Run("Not cached", func(t *testing.T) { + n := NodeStatus{ + SubNodeStatus: map[NodeID]*NodeStatus{}, + DataReferenceConstructor: storage.URLPathConstructor{}, + } + + newNode := n.GetNodeExecutionStatus(ctx, "abc") + assert.Equal(t, storage.DataReference("/abc/0"), newNode.GetOutputDir()) + assert.Equal(t, storage.DataReference("/abc"), newNode.GetDataDir()) + }) + + t.Run("cached", func(t *testing.T) { + n := NodeStatus{ + SubNodeStatus: map[NodeID]*NodeStatus{}, + DataReferenceConstructor: storage.URLPathConstructor{}, + } + + newNode := n.GetNodeExecutionStatus(ctx, "abc") + assert.Equal(t, storage.DataReference("/abc/0"), newNode.GetOutputDir()) + assert.Equal(t, storage.DataReference("/abc"), newNode.GetDataDir()) + + newNode = n.GetNodeExecutionStatus(ctx, "abc") + assert.Equal(t, storage.DataReference("/abc/0"), newNode.GetOutputDir()) + assert.Equal(t, storage.DataReference("/abc"), newNode.GetDataDir()) + }) + + t.Run("cached but datadir not populated", func(t *testing.T) { + n := NodeStatus{ + SubNodeStatus: map[NodeID]*NodeStatus{ + "abc": {}, + }, + DataReferenceConstructor: storage.URLPathConstructor{}, + } + + newNode := n.GetNodeExecutionStatus(ctx, "abc") + assert.Equal(t, storage.DataReference("/abc/0"), newNode.GetOutputDir()) + assert.Equal(t, storage.DataReference("/abc"), newNode.GetDataDir()) + }) + }) + + t.Run("Nested", func(t *testing.T) { + n := NodeStatus{ + SubNodeStatus: map[NodeID]*NodeStatus{}, + DataReferenceConstructor: storage.URLPathConstructor{}, + } + + newNode := n.GetNodeExecutionStatus(ctx, "abc") + assert.Equal(t, storage.DataReference("/abc/0"), newNode.GetOutputDir()) + assert.Equal(t, storage.DataReference("/abc"), newNode.GetDataDir()) + + subsubNode := newNode.GetNodeExecutionStatus(ctx, "xyz") + assert.Equal(t, storage.DataReference("/abc/0/xyz/0"), subsubNode.GetOutputDir()) + assert.Equal(t, storage.DataReference("/abc/0/xyz"), subsubNode.GetDataDir()) + }) +}