From 3377dc754718620ecffec9653932c2085b0dd5e2 Mon Sep 17 00:00:00 2001 From: Paul Dittamo <37558497+pvditt@users.noreply.github.com> Date: Thu, 4 Apr 2024 16:02:26 -0700 Subject: [PATCH] Feature/array node workflow parallelism (#5062) * update arraynode proto parallelism field to varint compatible int64 Signed-off-by: Paul Dittamo * have array nodes utilize workflow parallelism Signed-off-by: Paul Dittamo * return if available parallelism is 0 Signed-off-by: Paul Dittamo * unit test Signed-off-by: Paul Dittamo --------- Signed-off-by: Paul Dittamo --- flyteidl/clients/go/assets/admin.swagger.json | 2 +- .../gen/pb-es/flyteidl/core/workflow_pb.ts | 8 +- .../gen/pb-go/flyteidl/core/workflow.pb.go | 6 +- .../flyteidl/service/admin.swagger.json | 2 +- flyteidl/gen/pb-js/flyteidl.d.ts | 4 +- flyteidl/gen/pb-js/flyteidl.js | 14 +-- .../pb_python/flyteidl/core/workflow_pb2.py | 2 +- flyteidl/gen/pb_rust/flyteidl.core.rs | 4 +- flyteidl/protos/flyteidl/core/workflow.proto | 2 +- .../pkg/apis/flyteworkflow/v1alpha1/array.go | 4 +- .../apis/flyteworkflow/v1alpha1/array_test.go | 2 +- .../pkg/apis/flyteworkflow/v1alpha1/iface.go | 2 +- .../v1alpha1/mocks/ExecutableArrayNode.go | 10 +- .../pkg/controller/nodes/array/handler.go | 32 +++++-- .../controller/nodes/array/handler_test.go | 92 ++++++++++++++++--- 15 files changed, 136 insertions(+), 50 deletions(-) diff --git a/flyteidl/clients/go/assets/admin.swagger.json b/flyteidl/clients/go/assets/admin.swagger.json index bcaf46928bd..9273a467765 100644 --- a/flyteidl/clients/go/assets/admin.swagger.json +++ b/flyteidl/clients/go/assets/admin.swagger.json @@ -6483,7 +6483,7 @@ "description": "node is the sub-node that will be executed for each element in the array." }, "parallelism": { - "type": "integer", + "type": "string", "format": "int64", "description": "parallelism defines the minimum number of instances to bring up concurrently at any given\npoint. Note that this is an optimistic restriction and that, due to network partitioning or\nother failures, the actual number of currently running instances might be more. This has to\nbe a positive number if assigned. Default value is size." }, diff --git a/flyteidl/gen/pb-es/flyteidl/core/workflow_pb.ts b/flyteidl/gen/pb-es/flyteidl/core/workflow_pb.ts index 414e6eb3192..9efdcf91dde 100644 --- a/flyteidl/gen/pb-es/flyteidl/core/workflow_pb.ts +++ b/flyteidl/gen/pb-es/flyteidl/core/workflow_pb.ts @@ -4,7 +4,7 @@ // @ts-nocheck import type { BinaryReadOptions, FieldList, JsonReadOptions, JsonValue, PartialMessage, PlainMessage } from "@bufbuild/protobuf"; -import { Duration, Message, proto3 } from "@bufbuild/protobuf"; +import { Duration, Message, proto3, protoInt64 } from "@bufbuild/protobuf"; import { BooleanExpression } from "./condition_pb.js"; import { Error, LiteralType } from "./types_pb.js"; import { Identifier } from "./identifier_pb.js"; @@ -512,9 +512,9 @@ export class ArrayNode extends Message { * other failures, the actual number of currently running instances might be more. This has to * be a positive number if assigned. Default value is size. * - * @generated from field: uint32 parallelism = 2; + * @generated from field: int64 parallelism = 2; */ - parallelism = 0; + parallelism = protoInt64.zero; /** * @generated from oneof flyteidl.core.ArrayNode.success_criteria @@ -550,7 +550,7 @@ export class ArrayNode extends Message { static readonly typeName = "flyteidl.core.ArrayNode"; static readonly fields: FieldList = proto3.util.newFieldList(() => [ { no: 1, name: "node", kind: "message", T: Node }, - { no: 2, name: "parallelism", kind: "scalar", T: 13 /* ScalarType.UINT32 */ }, + { no: 2, name: "parallelism", kind: "scalar", T: 3 /* ScalarType.INT64 */ }, { no: 3, name: "min_successes", kind: "scalar", T: 13 /* ScalarType.UINT32 */, oneof: "success_criteria" }, { no: 4, name: "min_success_ratio", kind: "scalar", T: 2 /* ScalarType.FLOAT */, oneof: "success_criteria" }, ]); diff --git a/flyteidl/gen/pb-go/flyteidl/core/workflow.pb.go b/flyteidl/gen/pb-go/flyteidl/core/workflow.pb.go index 8b6cc6ab1dc..983bbdcf436 100644 --- a/flyteidl/gen/pb-go/flyteidl/core/workflow.pb.go +++ b/flyteidl/gen/pb-go/flyteidl/core/workflow.pb.go @@ -727,7 +727,7 @@ type ArrayNode struct { // point. Note that this is an optimistic restriction and that, due to network partitioning or // other failures, the actual number of currently running instances might be more. This has to // be a positive number if assigned. Default value is size. - Parallelism uint32 `protobuf:"varint,2,opt,name=parallelism,proto3" json:"parallelism,omitempty"` + Parallelism int64 `protobuf:"varint,2,opt,name=parallelism,proto3" json:"parallelism,omitempty"` // Types that are assignable to SuccessCriteria: // // *ArrayNode_MinSuccesses @@ -774,7 +774,7 @@ func (x *ArrayNode) GetNode() *Node { return nil } -func (x *ArrayNode) GetParallelism() uint32 { +func (x *ArrayNode) GetParallelism() int64 { if x != nil { return x.Parallelism } @@ -1724,7 +1724,7 @@ var file_flyteidl_core_workflow_proto_rawDesc = []byte{ 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x04, 0x6e, 0x6f, 0x64, 0x65, 0x12, 0x20, 0x0a, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c, 0x69, 0x73, - 0x6d, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, + 0x6d, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c, 0x69, 0x73, 0x6d, 0x12, 0x25, 0x0a, 0x0d, 0x6d, 0x69, 0x6e, 0x5f, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x65, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x0c, 0x6d, 0x69, 0x6e, 0x53, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x65, 0x73, 0x12, 0x2c, 0x0a, 0x11, 0x6d, diff --git a/flyteidl/gen/pb-go/gateway/flyteidl/service/admin.swagger.json b/flyteidl/gen/pb-go/gateway/flyteidl/service/admin.swagger.json index bcaf46928bd..9273a467765 100644 --- a/flyteidl/gen/pb-go/gateway/flyteidl/service/admin.swagger.json +++ b/flyteidl/gen/pb-go/gateway/flyteidl/service/admin.swagger.json @@ -6483,7 +6483,7 @@ "description": "node is the sub-node that will be executed for each element in the array." }, "parallelism": { - "type": "integer", + "type": "string", "format": "int64", "description": "parallelism defines the minimum number of instances to bring up concurrently at any given\npoint. Note that this is an optimistic restriction and that, due to network partitioning or\nother failures, the actual number of currently running instances might be more. This has to\nbe a positive number if assigned. Default value is size." }, diff --git a/flyteidl/gen/pb-js/flyteidl.d.ts b/flyteidl/gen/pb-js/flyteidl.d.ts index e61fbc7acbf..0126d892f29 100644 --- a/flyteidl/gen/pb-js/flyteidl.d.ts +++ b/flyteidl/gen/pb-js/flyteidl.d.ts @@ -4506,7 +4506,7 @@ export namespace flyteidl { node?: (flyteidl.core.INode|null); /** ArrayNode parallelism */ - parallelism?: (number|null); + parallelism?: (Long|null); /** ArrayNode minSuccesses */ minSuccesses?: (number|null); @@ -4528,7 +4528,7 @@ export namespace flyteidl { public node?: (flyteidl.core.INode|null); /** ArrayNode parallelism. */ - public parallelism: number; + public parallelism: Long; /** ArrayNode minSuccesses. */ public minSuccesses: number; diff --git a/flyteidl/gen/pb-js/flyteidl.js b/flyteidl/gen/pb-js/flyteidl.js index 0f1ffe6b404..e288f8e111b 100644 --- a/flyteidl/gen/pb-js/flyteidl.js +++ b/flyteidl/gen/pb-js/flyteidl.js @@ -10808,7 +10808,7 @@ * @memberof flyteidl.core * @interface IArrayNode * @property {flyteidl.core.INode|null} [node] ArrayNode node - * @property {number|null} [parallelism] ArrayNode parallelism + * @property {Long|null} [parallelism] ArrayNode parallelism * @property {number|null} [minSuccesses] ArrayNode minSuccesses * @property {number|null} [minSuccessRatio] ArrayNode minSuccessRatio */ @@ -10838,11 +10838,11 @@ /** * ArrayNode parallelism. - * @member {number} parallelism + * @member {Long} parallelism * @memberof flyteidl.core.ArrayNode * @instance */ - ArrayNode.prototype.parallelism = 0; + ArrayNode.prototype.parallelism = $util.Long ? $util.Long.fromBits(0,0,false) : 0; /** * ArrayNode minSuccesses. @@ -10901,7 +10901,7 @@ if (message.node != null && message.hasOwnProperty("node")) $root.flyteidl.core.Node.encode(message.node, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); if (message.parallelism != null && message.hasOwnProperty("parallelism")) - writer.uint32(/* id 2, wireType 0 =*/16).uint32(message.parallelism); + writer.uint32(/* id 2, wireType 0 =*/16).int64(message.parallelism); if (message.minSuccesses != null && message.hasOwnProperty("minSuccesses")) writer.uint32(/* id 3, wireType 0 =*/24).uint32(message.minSuccesses); if (message.minSuccessRatio != null && message.hasOwnProperty("minSuccessRatio")) @@ -10931,7 +10931,7 @@ message.node = $root.flyteidl.core.Node.decode(reader, reader.uint32()); break; case 2: - message.parallelism = reader.uint32(); + message.parallelism = reader.int64(); break; case 3: message.minSuccesses = reader.uint32(); @@ -10965,8 +10965,8 @@ return "node." + error; } if (message.parallelism != null && message.hasOwnProperty("parallelism")) - if (!$util.isInteger(message.parallelism)) - return "parallelism: integer expected"; + if (!$util.isInteger(message.parallelism) && !(message.parallelism && $util.isInteger(message.parallelism.low) && $util.isInteger(message.parallelism.high))) + return "parallelism: integer|Long expected"; if (message.minSuccesses != null && message.hasOwnProperty("minSuccesses")) { properties.successCriteria = 1; if (!$util.isInteger(message.minSuccesses)) diff --git a/flyteidl/gen/pb_python/flyteidl/core/workflow_pb2.py b/flyteidl/gen/pb_python/flyteidl/core/workflow_pb2.py index 4c070243207..452c38c9c9d 100644 --- a/flyteidl/gen/pb_python/flyteidl/core/workflow_pb2.py +++ b/flyteidl/gen/pb_python/flyteidl/core/workflow_pb2.py @@ -22,7 +22,7 @@ from google.protobuf import duration_pb2 as google_dot_protobuf_dot_duration__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lyteidl/core/workflow.proto\x12\rflyteidl.core\x1a\x1d\x66lyteidl/core/condition.proto\x1a\x1d\x66lyteidl/core/execution.proto\x1a\x1e\x66lyteidl/core/identifier.proto\x1a\x1d\x66lyteidl/core/interface.proto\x1a\x1c\x66lyteidl/core/literals.proto\x1a\x19\x66lyteidl/core/tasks.proto\x1a\x19\x66lyteidl/core/types.proto\x1a\x1c\x66lyteidl/core/security.proto\x1a\x1egoogle/protobuf/duration.proto\"{\n\x07IfBlock\x12>\n\tcondition\x18\x01 \x01(\x0b\x32 .flyteidl.core.BooleanExpressionR\tcondition\x12\x30\n\tthen_node\x18\x02 \x01(\x0b\x32\x13.flyteidl.core.NodeR\x08thenNode\"\xd4\x01\n\x0bIfElseBlock\x12*\n\x04\x63\x61se\x18\x01 \x01(\x0b\x32\x16.flyteidl.core.IfBlockR\x04\x63\x61se\x12,\n\x05other\x18\x02 \x03(\x0b\x32\x16.flyteidl.core.IfBlockR\x05other\x12\x32\n\telse_node\x18\x03 \x01(\x0b\x32\x13.flyteidl.core.NodeH\x00R\x08\x65lseNode\x12,\n\x05\x65rror\x18\x04 \x01(\x0b\x32\x14.flyteidl.core.ErrorH\x00R\x05\x65rrorB\t\n\x07\x64\x65\x66\x61ult\"A\n\nBranchNode\x12\x33\n\x07if_else\x18\x01 \x01(\x0b\x32\x1a.flyteidl.core.IfElseBlockR\x06ifElse\"\x97\x01\n\x08TaskNode\x12>\n\x0creference_id\x18\x01 \x01(\x0b\x32\x19.flyteidl.core.IdentifierH\x00R\x0breferenceId\x12>\n\toverrides\x18\x02 \x01(\x0b\x32 .flyteidl.core.TaskNodeOverridesR\toverridesB\x0b\n\treference\"\xa6\x01\n\x0cWorkflowNode\x12\x42\n\x0elaunchplan_ref\x18\x01 \x01(\x0b\x32\x19.flyteidl.core.IdentifierH\x00R\rlaunchplanRef\x12\x45\n\x10sub_workflow_ref\x18\x02 \x01(\x0b\x32\x19.flyteidl.core.IdentifierH\x00R\x0esubWorkflowRefB\x0b\n\treference\"/\n\x10\x41pproveCondition\x12\x1b\n\tsignal_id\x18\x01 \x01(\tR\x08signalId\"\x90\x01\n\x0fSignalCondition\x12\x1b\n\tsignal_id\x18\x01 \x01(\tR\x08signalId\x12.\n\x04type\x18\x02 \x01(\x0b\x32\x1a.flyteidl.core.LiteralTypeR\x04type\x12\x30\n\x14output_variable_name\x18\x03 \x01(\tR\x12outputVariableName\"G\n\x0eSleepCondition\x12\x35\n\x08\x64uration\x18\x01 \x01(\x0b\x32\x19.google.protobuf.DurationR\x08\x64uration\"\xc5\x01\n\x08GateNode\x12;\n\x07\x61pprove\x18\x01 \x01(\x0b\x32\x1f.flyteidl.core.ApproveConditionH\x00R\x07\x61pprove\x12\x38\n\x06signal\x18\x02 \x01(\x0b\x32\x1e.flyteidl.core.SignalConditionH\x00R\x06signal\x12\x35\n\x05sleep\x18\x03 \x01(\x0b\x32\x1d.flyteidl.core.SleepConditionH\x00R\x05sleepB\x0b\n\tcondition\"\xbf\x01\n\tArrayNode\x12\'\n\x04node\x18\x01 \x01(\x0b\x32\x13.flyteidl.core.NodeR\x04node\x12 \n\x0bparallelism\x18\x02 \x01(\rR\x0bparallelism\x12%\n\rmin_successes\x18\x03 \x01(\rH\x00R\x0cminSuccesses\x12,\n\x11min_success_ratio\x18\x04 \x01(\x02H\x00R\x0fminSuccessRatioB\x12\n\x10success_criteria\"\x8c\x03\n\x0cNodeMetadata\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x33\n\x07timeout\x18\x04 \x01(\x0b\x32\x19.google.protobuf.DurationR\x07timeout\x12\x36\n\x07retries\x18\x05 \x01(\x0b\x32\x1c.flyteidl.core.RetryStrategyR\x07retries\x12&\n\rinterruptible\x18\x06 \x01(\x08H\x00R\rinterruptible\x12\x1e\n\tcacheable\x18\x07 \x01(\x08H\x01R\tcacheable\x12%\n\rcache_version\x18\x08 \x01(\tH\x02R\x0c\x63\x61\x63heVersion\x12/\n\x12\x63\x61\x63he_serializable\x18\t \x01(\x08H\x03R\x11\x63\x61\x63heSerializableB\x15\n\x13interruptible_valueB\x11\n\x0f\x63\x61\x63heable_valueB\x15\n\x13\x63\x61\x63he_version_valueB\x1a\n\x18\x63\x61\x63he_serializable_value\"/\n\x05\x41lias\x12\x10\n\x03var\x18\x01 \x01(\tR\x03var\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias\"\x9f\x04\n\x04Node\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x37\n\x08metadata\x18\x02 \x01(\x0b\x32\x1b.flyteidl.core.NodeMetadataR\x08metadata\x12.\n\x06inputs\x18\x03 \x03(\x0b\x32\x16.flyteidl.core.BindingR\x06inputs\x12*\n\x11upstream_node_ids\x18\x04 \x03(\tR\x0fupstreamNodeIds\x12;\n\x0eoutput_aliases\x18\x05 \x03(\x0b\x32\x14.flyteidl.core.AliasR\routputAliases\x12\x36\n\ttask_node\x18\x06 \x01(\x0b\x32\x17.flyteidl.core.TaskNodeH\x00R\x08taskNode\x12\x42\n\rworkflow_node\x18\x07 \x01(\x0b\x32\x1b.flyteidl.core.WorkflowNodeH\x00R\x0cworkflowNode\x12<\n\x0b\x62ranch_node\x18\x08 \x01(\x0b\x32\x19.flyteidl.core.BranchNodeH\x00R\nbranchNode\x12\x36\n\tgate_node\x18\t \x01(\x0b\x32\x17.flyteidl.core.GateNodeH\x00R\x08gateNode\x12\x39\n\narray_node\x18\n \x01(\x0b\x32\x18.flyteidl.core.ArrayNodeH\x00R\tarrayNodeB\x08\n\x06target\"\xfc\x02\n\x10WorkflowMetadata\x12M\n\x12quality_of_service\x18\x01 \x01(\x0b\x32\x1f.flyteidl.core.QualityOfServiceR\x10qualityOfService\x12N\n\non_failure\x18\x02 \x01(\x0e\x32/.flyteidl.core.WorkflowMetadata.OnFailurePolicyR\tonFailure\x12=\n\x04tags\x18\x03 \x03(\x0b\x32).flyteidl.core.WorkflowMetadata.TagsEntryR\x04tags\x1a\x37\n\tTagsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"Q\n\x0fOnFailurePolicy\x12\x14\n\x10\x46\x41IL_IMMEDIATELY\x10\x00\x12(\n$FAIL_AFTER_EXECUTABLE_NODES_COMPLETE\x10\x01\"@\n\x18WorkflowMetadataDefaults\x12$\n\rinterruptible\x18\x01 \x01(\x08R\rinterruptible\"\xa2\x03\n\x10WorkflowTemplate\x12)\n\x02id\x18\x01 \x01(\x0b\x32\x19.flyteidl.core.IdentifierR\x02id\x12;\n\x08metadata\x18\x02 \x01(\x0b\x32\x1f.flyteidl.core.WorkflowMetadataR\x08metadata\x12;\n\tinterface\x18\x03 \x01(\x0b\x32\x1d.flyteidl.core.TypedInterfaceR\tinterface\x12)\n\x05nodes\x18\x04 \x03(\x0b\x32\x13.flyteidl.core.NodeR\x05nodes\x12\x30\n\x07outputs\x18\x05 \x03(\x0b\x32\x16.flyteidl.core.BindingR\x07outputs\x12\x36\n\x0c\x66\x61ilure_node\x18\x06 \x01(\x0b\x32\x13.flyteidl.core.NodeR\x0b\x66\x61ilureNode\x12T\n\x11metadata_defaults\x18\x07 \x01(\x0b\x32\'.flyteidl.core.WorkflowMetadataDefaultsR\x10metadataDefaults\"\xc5\x01\n\x11TaskNodeOverrides\x12\x36\n\tresources\x18\x01 \x01(\x0b\x32\x18.flyteidl.core.ResourcesR\tresources\x12O\n\x12\x65xtended_resources\x18\x02 \x01(\x0b\x32 .flyteidl.core.ExtendedResourcesR\x11\x65xtendedResources\x12\'\n\x0f\x63ontainer_image\x18\x03 \x01(\tR\x0e\x63ontainerImage\"\xba\x01\n\x12LaunchPlanTemplate\x12)\n\x02id\x18\x01 \x01(\x0b\x32\x19.flyteidl.core.IdentifierR\x02id\x12;\n\tinterface\x18\x02 \x01(\x0b\x32\x1d.flyteidl.core.TypedInterfaceR\tinterface\x12<\n\x0c\x66ixed_inputs\x18\x03 \x01(\x0b\x32\x19.flyteidl.core.LiteralMapR\x0b\x66ixedInputsB\xb3\x01\n\x11\x63om.flyteidl.coreB\rWorkflowProtoP\x01Z:github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core\xa2\x02\x03\x46\x43X\xaa\x02\rFlyteidl.Core\xca\x02\rFlyteidl\\Core\xe2\x02\x19\x46lyteidl\\Core\\GPBMetadata\xea\x02\x0e\x46lyteidl::Coreb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lyteidl/core/workflow.proto\x12\rflyteidl.core\x1a\x1d\x66lyteidl/core/condition.proto\x1a\x1d\x66lyteidl/core/execution.proto\x1a\x1e\x66lyteidl/core/identifier.proto\x1a\x1d\x66lyteidl/core/interface.proto\x1a\x1c\x66lyteidl/core/literals.proto\x1a\x19\x66lyteidl/core/tasks.proto\x1a\x19\x66lyteidl/core/types.proto\x1a\x1c\x66lyteidl/core/security.proto\x1a\x1egoogle/protobuf/duration.proto\"{\n\x07IfBlock\x12>\n\tcondition\x18\x01 \x01(\x0b\x32 .flyteidl.core.BooleanExpressionR\tcondition\x12\x30\n\tthen_node\x18\x02 \x01(\x0b\x32\x13.flyteidl.core.NodeR\x08thenNode\"\xd4\x01\n\x0bIfElseBlock\x12*\n\x04\x63\x61se\x18\x01 \x01(\x0b\x32\x16.flyteidl.core.IfBlockR\x04\x63\x61se\x12,\n\x05other\x18\x02 \x03(\x0b\x32\x16.flyteidl.core.IfBlockR\x05other\x12\x32\n\telse_node\x18\x03 \x01(\x0b\x32\x13.flyteidl.core.NodeH\x00R\x08\x65lseNode\x12,\n\x05\x65rror\x18\x04 \x01(\x0b\x32\x14.flyteidl.core.ErrorH\x00R\x05\x65rrorB\t\n\x07\x64\x65\x66\x61ult\"A\n\nBranchNode\x12\x33\n\x07if_else\x18\x01 \x01(\x0b\x32\x1a.flyteidl.core.IfElseBlockR\x06ifElse\"\x97\x01\n\x08TaskNode\x12>\n\x0creference_id\x18\x01 \x01(\x0b\x32\x19.flyteidl.core.IdentifierH\x00R\x0breferenceId\x12>\n\toverrides\x18\x02 \x01(\x0b\x32 .flyteidl.core.TaskNodeOverridesR\toverridesB\x0b\n\treference\"\xa6\x01\n\x0cWorkflowNode\x12\x42\n\x0elaunchplan_ref\x18\x01 \x01(\x0b\x32\x19.flyteidl.core.IdentifierH\x00R\rlaunchplanRef\x12\x45\n\x10sub_workflow_ref\x18\x02 \x01(\x0b\x32\x19.flyteidl.core.IdentifierH\x00R\x0esubWorkflowRefB\x0b\n\treference\"/\n\x10\x41pproveCondition\x12\x1b\n\tsignal_id\x18\x01 \x01(\tR\x08signalId\"\x90\x01\n\x0fSignalCondition\x12\x1b\n\tsignal_id\x18\x01 \x01(\tR\x08signalId\x12.\n\x04type\x18\x02 \x01(\x0b\x32\x1a.flyteidl.core.LiteralTypeR\x04type\x12\x30\n\x14output_variable_name\x18\x03 \x01(\tR\x12outputVariableName\"G\n\x0eSleepCondition\x12\x35\n\x08\x64uration\x18\x01 \x01(\x0b\x32\x19.google.protobuf.DurationR\x08\x64uration\"\xc5\x01\n\x08GateNode\x12;\n\x07\x61pprove\x18\x01 \x01(\x0b\x32\x1f.flyteidl.core.ApproveConditionH\x00R\x07\x61pprove\x12\x38\n\x06signal\x18\x02 \x01(\x0b\x32\x1e.flyteidl.core.SignalConditionH\x00R\x06signal\x12\x35\n\x05sleep\x18\x03 \x01(\x0b\x32\x1d.flyteidl.core.SleepConditionH\x00R\x05sleepB\x0b\n\tcondition\"\xbf\x01\n\tArrayNode\x12\'\n\x04node\x18\x01 \x01(\x0b\x32\x13.flyteidl.core.NodeR\x04node\x12 \n\x0bparallelism\x18\x02 \x01(\x03R\x0bparallelism\x12%\n\rmin_successes\x18\x03 \x01(\rH\x00R\x0cminSuccesses\x12,\n\x11min_success_ratio\x18\x04 \x01(\x02H\x00R\x0fminSuccessRatioB\x12\n\x10success_criteria\"\x8c\x03\n\x0cNodeMetadata\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x33\n\x07timeout\x18\x04 \x01(\x0b\x32\x19.google.protobuf.DurationR\x07timeout\x12\x36\n\x07retries\x18\x05 \x01(\x0b\x32\x1c.flyteidl.core.RetryStrategyR\x07retries\x12&\n\rinterruptible\x18\x06 \x01(\x08H\x00R\rinterruptible\x12\x1e\n\tcacheable\x18\x07 \x01(\x08H\x01R\tcacheable\x12%\n\rcache_version\x18\x08 \x01(\tH\x02R\x0c\x63\x61\x63heVersion\x12/\n\x12\x63\x61\x63he_serializable\x18\t \x01(\x08H\x03R\x11\x63\x61\x63heSerializableB\x15\n\x13interruptible_valueB\x11\n\x0f\x63\x61\x63heable_valueB\x15\n\x13\x63\x61\x63he_version_valueB\x1a\n\x18\x63\x61\x63he_serializable_value\"/\n\x05\x41lias\x12\x10\n\x03var\x18\x01 \x01(\tR\x03var\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias\"\x9f\x04\n\x04Node\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x37\n\x08metadata\x18\x02 \x01(\x0b\x32\x1b.flyteidl.core.NodeMetadataR\x08metadata\x12.\n\x06inputs\x18\x03 \x03(\x0b\x32\x16.flyteidl.core.BindingR\x06inputs\x12*\n\x11upstream_node_ids\x18\x04 \x03(\tR\x0fupstreamNodeIds\x12;\n\x0eoutput_aliases\x18\x05 \x03(\x0b\x32\x14.flyteidl.core.AliasR\routputAliases\x12\x36\n\ttask_node\x18\x06 \x01(\x0b\x32\x17.flyteidl.core.TaskNodeH\x00R\x08taskNode\x12\x42\n\rworkflow_node\x18\x07 \x01(\x0b\x32\x1b.flyteidl.core.WorkflowNodeH\x00R\x0cworkflowNode\x12<\n\x0b\x62ranch_node\x18\x08 \x01(\x0b\x32\x19.flyteidl.core.BranchNodeH\x00R\nbranchNode\x12\x36\n\tgate_node\x18\t \x01(\x0b\x32\x17.flyteidl.core.GateNodeH\x00R\x08gateNode\x12\x39\n\narray_node\x18\n \x01(\x0b\x32\x18.flyteidl.core.ArrayNodeH\x00R\tarrayNodeB\x08\n\x06target\"\xfc\x02\n\x10WorkflowMetadata\x12M\n\x12quality_of_service\x18\x01 \x01(\x0b\x32\x1f.flyteidl.core.QualityOfServiceR\x10qualityOfService\x12N\n\non_failure\x18\x02 \x01(\x0e\x32/.flyteidl.core.WorkflowMetadata.OnFailurePolicyR\tonFailure\x12=\n\x04tags\x18\x03 \x03(\x0b\x32).flyteidl.core.WorkflowMetadata.TagsEntryR\x04tags\x1a\x37\n\tTagsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"Q\n\x0fOnFailurePolicy\x12\x14\n\x10\x46\x41IL_IMMEDIATELY\x10\x00\x12(\n$FAIL_AFTER_EXECUTABLE_NODES_COMPLETE\x10\x01\"@\n\x18WorkflowMetadataDefaults\x12$\n\rinterruptible\x18\x01 \x01(\x08R\rinterruptible\"\xa2\x03\n\x10WorkflowTemplate\x12)\n\x02id\x18\x01 \x01(\x0b\x32\x19.flyteidl.core.IdentifierR\x02id\x12;\n\x08metadata\x18\x02 \x01(\x0b\x32\x1f.flyteidl.core.WorkflowMetadataR\x08metadata\x12;\n\tinterface\x18\x03 \x01(\x0b\x32\x1d.flyteidl.core.TypedInterfaceR\tinterface\x12)\n\x05nodes\x18\x04 \x03(\x0b\x32\x13.flyteidl.core.NodeR\x05nodes\x12\x30\n\x07outputs\x18\x05 \x03(\x0b\x32\x16.flyteidl.core.BindingR\x07outputs\x12\x36\n\x0c\x66\x61ilure_node\x18\x06 \x01(\x0b\x32\x13.flyteidl.core.NodeR\x0b\x66\x61ilureNode\x12T\n\x11metadata_defaults\x18\x07 \x01(\x0b\x32\'.flyteidl.core.WorkflowMetadataDefaultsR\x10metadataDefaults\"\xc5\x01\n\x11TaskNodeOverrides\x12\x36\n\tresources\x18\x01 \x01(\x0b\x32\x18.flyteidl.core.ResourcesR\tresources\x12O\n\x12\x65xtended_resources\x18\x02 \x01(\x0b\x32 .flyteidl.core.ExtendedResourcesR\x11\x65xtendedResources\x12\'\n\x0f\x63ontainer_image\x18\x03 \x01(\tR\x0e\x63ontainerImage\"\xba\x01\n\x12LaunchPlanTemplate\x12)\n\x02id\x18\x01 \x01(\x0b\x32\x19.flyteidl.core.IdentifierR\x02id\x12;\n\tinterface\x18\x02 \x01(\x0b\x32\x1d.flyteidl.core.TypedInterfaceR\tinterface\x12<\n\x0c\x66ixed_inputs\x18\x03 \x01(\x0b\x32\x19.flyteidl.core.LiteralMapR\x0b\x66ixedInputsB\xb3\x01\n\x11\x63om.flyteidl.coreB\rWorkflowProtoP\x01Z:github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core\xa2\x02\x03\x46\x43X\xaa\x02\rFlyteidl.Core\xca\x02\rFlyteidl\\Core\xe2\x02\x19\x46lyteidl\\Core\\GPBMetadata\xea\x02\x0e\x46lyteidl::Coreb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) diff --git a/flyteidl/gen/pb_rust/flyteidl.core.rs b/flyteidl/gen/pb_rust/flyteidl.core.rs index 280140c75d0..6922855798a 100644 --- a/flyteidl/gen/pb_rust/flyteidl.core.rs +++ b/flyteidl/gen/pb_rust/flyteidl.core.rs @@ -2401,8 +2401,8 @@ pub struct ArrayNode { /// point. Note that this is an optimistic restriction and that, due to network partitioning or /// other failures, the actual number of currently running instances might be more. This has to /// be a positive number if assigned. Default value is size. - #[prost(uint32, tag="2")] - pub parallelism: u32, + #[prost(int64, tag="2")] + pub parallelism: i64, #[prost(oneof="array_node::SuccessCriteria", tags="3, 4")] pub success_criteria: ::core::option::Option, } diff --git a/flyteidl/protos/flyteidl/core/workflow.proto b/flyteidl/protos/flyteidl/core/workflow.proto index dcbe9367f95..a305c8fad77 100644 --- a/flyteidl/protos/flyteidl/core/workflow.proto +++ b/flyteidl/protos/flyteidl/core/workflow.proto @@ -118,7 +118,7 @@ message ArrayNode { // point. Note that this is an optimistic restriction and that, due to network partitioning or // other failures, the actual number of currently running instances might be more. This has to // be a positive number if assigned. Default value is size. - uint32 parallelism = 2; + int64 parallelism = 2; oneof success_criteria { // min_successes is an absolute number of the minimum number of successful completions of diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/array.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/array.go index 6680e741063..9916f6a0753 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/array.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/array.go @@ -2,7 +2,7 @@ package v1alpha1 type ArrayNodeSpec struct { SubNodeSpec *NodeSpec - Parallelism uint32 + Parallelism int64 MinSuccesses *uint32 MinSuccessRatio *float32 } @@ -11,7 +11,7 @@ func (a *ArrayNodeSpec) GetSubNodeSpec() *NodeSpec { return a.SubNodeSpec } -func (a *ArrayNodeSpec) GetParallelism() uint32 { +func (a *ArrayNodeSpec) GetParallelism() int64 { return a.Parallelism } diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/array_test.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/array_test.go index 74ea26a4288..c17051b6bd7 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/array_test.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/array_test.go @@ -16,7 +16,7 @@ func TestArrayNodeSpec_GetSubNodeSpec(t *testing.T) { } func TestArrayNodeSpec_GetParallelism(t *testing.T) { - parallelism := uint32(5) + parallelism := int64(5) arrayNodeSpec := ArrayNodeSpec{ Parallelism: parallelism, } diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go index 664ebb6767f..d6f07b856f0 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go @@ -258,7 +258,7 @@ type ExecutableGateNode interface { type ExecutableArrayNode interface { GetSubNodeSpec() *NodeSpec - GetParallelism() uint32 + GetParallelism() int64 GetMinSuccesses() *uint32 GetMinSuccessRatio() *float32 } diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableArrayNode.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableArrayNode.go index a4ca819a6a8..742ceb2dbbb 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableArrayNode.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableArrayNode.go @@ -84,7 +84,7 @@ type ExecutableArrayNode_GetParallelism struct { *mock.Call } -func (_m ExecutableArrayNode_GetParallelism) Return(_a0 uint32) *ExecutableArrayNode_GetParallelism { +func (_m ExecutableArrayNode_GetParallelism) Return(_a0 int64) *ExecutableArrayNode_GetParallelism { return &ExecutableArrayNode_GetParallelism{Call: _m.Call.Return(_a0)} } @@ -99,14 +99,14 @@ func (_m *ExecutableArrayNode) OnGetParallelismMatch(matchers ...interface{}) *E } // GetParallelism provides a mock function with given fields: -func (_m *ExecutableArrayNode) GetParallelism() uint32 { +func (_m *ExecutableArrayNode) GetParallelism() int64 { ret := _m.Called() - var r0 uint32 - if rf, ok := ret.Get(0).(func() uint32); ok { + var r0 int64 + if rf, ok := ret.Get(0).(func() int64); ok { r0 = rf() } else { - r0 = ret.Get(0).(uint32) + r0 = ret.Get(0).(int64) } return r0 diff --git a/flytepropeller/pkg/controller/nodes/array/handler.go b/flytepropeller/pkg/controller/nodes/array/handler.go index 9a79b1fffde..1084326a335 100644 --- a/flytepropeller/pkg/controller/nodes/array/handler.go +++ b/flytepropeller/pkg/controller/nodes/array/handler.go @@ -252,13 +252,29 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu arrayNodeState.Phase = v1alpha1.ArrayNodePhaseExecuting case v1alpha1.ArrayNodePhaseExecuting: // process array node subNodes - currentParallelism := int(arrayNode.GetParallelism()) - if currentParallelism == 0 { - currentParallelism = len(arrayNodeState.SubNodePhases.GetItems()) + + availableParallelism := 0 + // using the workflow's parallelism if the array node parallelism is not set + useWorkflowParallelism := int(arrayNode.GetParallelism()) == -1 + if useWorkflowParallelism { + // greedily take all available slots + // TODO: This will need to be re-evaluated if we want to support dynamics & sub_workflows + currentParallelism := nCtx.ExecutionContext().CurrentParallelism() + maxParallelism := nCtx.ExecutionContext().GetExecutionConfig().MaxParallelism + availableParallelism = int(maxParallelism - currentParallelism) + } else { + availableParallelism = int(arrayNode.GetParallelism()) + if availableParallelism == 0 { + availableParallelism = len(arrayNodeState.SubNodePhases.GetItems()) + } } - nodeExecutionRequests := make([]*nodeExecutionRequest, 0, currentParallelism) + nodeExecutionRequests := make([]*nodeExecutionRequest, 0, availableParallelism) for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { + if availableParallelism == 0 { + break + } + nodePhase := v1alpha1.NodePhase(nodePhaseUint64) taskPhase := int(arrayNodeState.SubNodeTaskPhases.GetItem(i)) @@ -298,11 +314,11 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // TODO - this is a naive implementation of parallelism, if we want to support more // complex subNodes (ie. dynamics / subworkflows) we need to revisit this so that - // parallelism is handled during subNode evaluations. - currentParallelism-- - if currentParallelism == 0 { - break + // parallelism is handled during subNode evaluations + avoid deadlocks + if useWorkflowParallelism { + nCtx.ExecutionContext().IncrementParallelism() } + availableParallelism-- } workerErrorCollector := errorcollector.NewErrorMessageCollector() diff --git a/flytepropeller/pkg/controller/nodes/array/handler_test.go b/flytepropeller/pkg/controller/nodes/array/handler_test.go index b36f3a5b626..f5147905424 100644 --- a/flytepropeller/pkg/controller/nodes/array/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/array/handler_test.go @@ -43,6 +43,7 @@ var ( }, }, } + workflowMaxParallelism = uint32(10) ) func createArrayNodeHandler(ctx context.Context, t *testing.T, nodeHandler interfaces.NodeHandler, dataStore *storage.DataStore, scope promutils.Scope) (interfaces.NodeHandler, error) { @@ -74,7 +75,8 @@ func createArrayNodeHandler(ctx context.Context, t *testing.T, nodeHandler inter } func createNodeExecutionContext(dataStore *storage.DataStore, eventRecorder interfaces.EventRecorder, outputVariables []string, - inputLiteralMap *idlcore.LiteralMap, arrayNodeSpec *v1alpha1.NodeSpec, arrayNodeState *handler.ArrayNodeState) interfaces.NodeExecutionContext { + inputLiteralMap *idlcore.LiteralMap, arrayNodeSpec *v1alpha1.NodeSpec, arrayNodeState *handler.ArrayNodeState, + currentParallelism uint32, maxParallelism uint32) interfaces.NodeExecutionContext { nCtx := &mocks.NodeExecutionContext{} nCtx.OnMaxDatasetSizeBytes().Return(9999999) @@ -91,7 +93,9 @@ func createNodeExecutionContext(dataStore *storage.DataStore, eventRecorder inte // ExecutionContext executionContext := &execmocks.ExecutionContext{} executionContext.OnGetEventVersion().Return(1) - executionContext.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{}) + executionContext.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{ + MaxParallelism: maxParallelism, + }) executionContext.OnGetExecutionID().Return( v1alpha1.ExecutionID{ WorkflowExecutionIdentifier: &idlcore.WorkflowExecutionIdentifier{ @@ -120,6 +124,8 @@ func createNodeExecutionContext(dataStore *storage.DataStore, eventRecorder inte }, nil, ) + executionContext.OnCurrentParallelism().Return(currentParallelism) + executionContext.On("IncrementParallelism").Run(func(args mock.Arguments) {}).Return(currentParallelism) executionContext.OnIncrementNodeExecutionCount().Return(1) executionContext.OnIncrementTaskExecutionCount().Return(1) executionContext.OnCurrentNodeExecutionCount().Return(1) @@ -258,7 +264,7 @@ func TestAbort(t *testing.T) { // create NodeExecutionContext eventRecorder := newBufferedEventRecorder() - nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState) + nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState, 0, workflowMaxParallelism) // evaluate node err := arrayNodeHandler.Abort(ctx, nCtx, "foo") @@ -354,7 +360,7 @@ func TestFinalize(t *testing.T) { // create NodeExecutionContext eventRecorder := newBufferedEventRecorder() - nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState) + nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState, 0, workflowMaxParallelism) // evaluate node err := arrayNodeHandler.Finalize(ctx, nCtx) @@ -425,7 +431,7 @@ func TestHandleArrayNodePhaseNone(t *testing.T) { arrayNodeState := &handler.ArrayNodeState{ Phase: v1alpha1.ArrayNodePhaseNone, } - nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState) + nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState, 0, workflowMaxParallelism) // evaluate node transition, err := arrayNodeHandler.Handle(ctx, nCtx) @@ -480,6 +486,9 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { expectedArrayNodePhase v1alpha1.ArrayNodePhase expectedTransitionPhase handler.EPhase expectedExternalResourcePhases []idlcore.TaskExecution_Phase + currentWfParallelism uint32 + maxWfParallelism uint32 + incrementParallelismCount uint32 }{ { name: "StartAllSubNodes", @@ -517,6 +526,65 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { expectedTransitionPhase: handler.EPhaseRunning, expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_RUNNING}, }, + { + name: "UtilizeWfParallelismAllSubNodes", + parallelism: -1, + currentWfParallelism: 0, + incrementParallelismCount: 2, + subNodePhases: []v1alpha1.NodePhase{ + v1alpha1.NodePhaseQueued, + v1alpha1.NodePhaseQueued, + }, + subNodeTaskPhases: []core.Phase{ + core.PhaseUndefined, + core.PhaseUndefined, + }, + subNodeTransitions: []handler.Transition{ + handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), + handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), + }, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, + expectedTransitionPhase: handler.EPhaseRunning, + expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_RUNNING, idlcore.TaskExecution_RUNNING}, + }, + { + name: "UtilizeWfParallelismSomeSubNodes", + parallelism: -1, + currentWfParallelism: workflowMaxParallelism - 1, + incrementParallelismCount: 1, + subNodePhases: []v1alpha1.NodePhase{ + v1alpha1.NodePhaseQueued, + v1alpha1.NodePhaseQueued, + }, + subNodeTaskPhases: []core.Phase{ + core.PhaseUndefined, + core.PhaseUndefined, + }, + subNodeTransitions: []handler.Transition{ + handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), + }, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, + expectedTransitionPhase: handler.EPhaseRunning, + expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_RUNNING}, + }, + { + name: "UtilizeWfParallelismNoSubNodes", + parallelism: -1, + currentWfParallelism: workflowMaxParallelism, + incrementParallelismCount: 0, + subNodePhases: []v1alpha1.NodePhase{ + v1alpha1.NodePhaseQueued, + v1alpha1.NodePhaseQueued, + }, + subNodeTaskPhases: []core.Phase{ + core.PhaseUndefined, + core.PhaseUndefined, + }, + subNodeTransitions: []handler.Transition{}, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, + expectedTransitionPhase: handler.EPhaseRunning, + expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{}, + }, { name: "StartSubNodesNewAttempts", subNodePhases: []v1alpha1.NodePhase{ @@ -629,10 +697,10 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { eventRecorder := newBufferedEventRecorder() nodeSpec := arrayNodeSpec - nodeSpec.ArrayNode.Parallelism = uint32(test.parallelism) + nodeSpec.ArrayNode.Parallelism = int64(test.parallelism) nodeSpec.ArrayNode.MinSuccessRatio = test.minSuccessRatio - nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState) + nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState, test.currentWfParallelism, workflowMaxParallelism) // initialize ArrayNodeHandler nodeHandler := &mocks.NodeHandler{} @@ -678,6 +746,8 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { } else { assert.Equal(t, 0, len(eventRecorder.taskExecutionEvents)) } + + nCtx.ExecutionContext().(*execmocks.ExecutionContext).AssertNumberOfCalls(t, "IncrementParallelism", int(test.incrementParallelismCount)) }) } } @@ -753,7 +823,7 @@ func TestHandleArrayNodePhaseExecutingSubNodeFailures(t *testing.T) { arrayNodeState := &handler.ArrayNodeState{ Phase: v1alpha1.ArrayNodePhaseNone, } - nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState) + nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState, 0, workflowMaxParallelism) // initialize ArrayNodeHandler nodeHandler := &mocks.NodeHandler{} @@ -781,7 +851,7 @@ func TestHandleArrayNodePhaseExecutingSubNodeFailures(t *testing.T) { // evaluate node until failure attempts := 1 for { - nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState) + nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState, 0, workflowMaxParallelism) _, err = arrayNodeHandler.Handle(ctx, nCtx) assert.NoError(t, err) @@ -866,7 +936,7 @@ func TestHandleArrayNodePhaseSucceeding(t *testing.T) { // create NodeExecutionContext eventRecorder := newBufferedEventRecorder() literalMap := &idlcore.LiteralMap{} - nCtx := createNodeExecutionContext(dataStore, eventRecorder, []string{test.outputVariable}, literalMap, &arrayNodeSpec, arrayNodeState) + nCtx := createNodeExecutionContext(dataStore, eventRecorder, []string{test.outputVariable}, literalMap, &arrayNodeSpec, arrayNodeState, 0, workflowMaxParallelism) // write mocked output files for i, outputValue := range test.outputValues { @@ -992,7 +1062,7 @@ func TestHandleArrayNodePhaseFailing(t *testing.T) { // create NodeExecutionContext eventRecorder := newBufferedEventRecorder() literalMap := &idlcore.LiteralMap{} - nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState) + nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState, 0, workflowMaxParallelism) // evaluate node transition, err := arrayNodeHandler.Handle(ctx, nCtx)