Skip to content

Commit

Permalink
deep copying arraynode tasktemplate interface (#5479) (#322)
Browse files Browse the repository at this point in the history
## Overview
Cherry picking [this PR from OSS](#5479).

## Test Plan
Tested with the following workflow as repro:
```
@task(cache=True, cache_version="1.0")
def bAr(a: int) -> List[int]:
    return [a + 1, a + 2]

@task(cache=True, cache_version="1.0")
def bAr2(a: int) -> List[int]:
    return [a + 1, a + 2]

@workflow
def my_wf_2(a: int) -> List[List[int]]:
    x = bAr(a=a)
    return map_task(bAr2)(a=x)
```

## Rollout Plan (if applicable)
May be immediatley rolled out

## Upstream Changes
Should this change be upstreamed to OSS (flyteorg/flyte)? If not, please uncheck this box, which is used for auditing. Note, it is the responsibility of each developer to actually upstream their changes. See [this guide](https://unionai.atlassian.net/wiki/spaces/ENG/pages/447610883/Flyte+-+Union+Cloud+Development+Runbook/#When-are-versions-updated%3F).
- [x] To be upstreamed to OSS

## Issue
https://linear.app/unionai/issue/COR-1132/arraynode-cache-data-type-does-not-match-issue

## Checklist
* [ ] Added tests
* [ ] Ran a deploy dry run and shared the terraform plan
* [ ] Added logging and metrics
* [ ] Updated [dashboards](https://unionai.grafana.net/dashboards) and [alerts](https://unionai.grafana.net/alerting/list)
* [ ] Updated documentation
  • Loading branch information
hamersaw authored Jun 14, 2024
1 parent 2616d4c commit 5dd2f6c
Showing 1 changed file with 9 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ type arrayTaskReader struct {
}

func (a *arrayTaskReader) Read(ctx context.Context) (*core.TaskTemplate, error) {
taskTemplate, err := a.TaskReader.Read(ctx)
originalTaskTemplate, err := a.TaskReader.Read(ctx)
if err != nil {
return nil, err
}

// convert output list variable to singular
outputVariables := make(map[string]*core.Variable)
for key, value := range taskTemplate.Interface.Outputs.Variables {
for key, value := range originalTaskTemplate.Interface.Outputs.Variables {
switch v := value.Type.Type.(type) {
case *core.LiteralType_CollectionType:
outputVariables[key] = &core.Variable{
Expand All @@ -69,10 +69,14 @@ func (a *arrayTaskReader) Read(ctx context.Context) (*core.TaskTemplate, error)
}
}

taskTemplate.Interface.Outputs = &core.VariableMap{
Variables: outputVariables,
taskTemplate := *originalTaskTemplate
taskTemplate.Interface = &core.TypedInterface{
Inputs: originalTaskTemplate.Interface.Inputs,
Outputs: &core.VariableMap{
Variables: outputVariables,
},
}
return taskTemplate, nil
return &taskTemplate, nil
}

type arrayNodeExecutionContext struct {
Expand Down

0 comments on commit 5dd2f6c

Please sign in to comment.