Skip to content

Commit

Permalink
Merge branch 'master' into update-ray-setup
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmirror-ops authored Nov 14, 2023
2 parents 34cfa38 + 72e7438 commit d627270
Show file tree
Hide file tree
Showing 11 changed files with 694 additions and 28 deletions.
32 changes: 30 additions & 2 deletions flyteplugins/go/tasks/plugins/k8s/ray/ray.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ray
import (
"context"
"fmt"
"regexp"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -37,6 +38,14 @@ const (
DisableUsageStatsStartParameter = "disable-usage-stats"
)

var logTemplateRegexes = struct {
RayClusterName *regexp.Regexp
RayJobID *regexp.Regexp
}{
tasklog.MustCreateRegex("rayClusterName"),
tasklog.MustCreateRegex("rayJobID"),
}

type rayJobResourceHandler struct {
}

Expand Down Expand Up @@ -442,8 +451,27 @@ func getEventInfoForRayJob(logConfig logs.LogConfig, pluginContext k8s.PluginCon

taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID()
input := tasklog.Input{
Namespace: rayJob.Namespace,
TaskExecutionID: taskExecID,
Namespace: rayJob.Namespace,
TaskExecutionID: taskExecID,
ExtraTemplateVarsByScheme: &tasklog.TemplateVarsByScheme{},
}
if rayJob.Status.JobId != "" {
input.ExtraTemplateVarsByScheme.Common = append(
input.ExtraTemplateVarsByScheme.Common,
tasklog.TemplateVar{
Regex: logTemplateRegexes.RayJobID,
Value: rayJob.Status.JobId,
},
)
}
if rayJob.Status.RayClusterName != "" {
input.ExtraTemplateVarsByScheme.Common = append(
input.ExtraTemplateVarsByScheme.Common,
tasklog.TemplateVar{
Regex: logTemplateRegexes.RayClusterName,
Value: rayJob.Status.RayClusterName,
},
)
}

// TODO: Retrieve the name of head pod from rayJob.status, and add it to task logs
Expand Down
114 changes: 111 additions & 3 deletions flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -608,13 +608,21 @@ func newPluginContext() k8s.PluginContext {

taskExecID := &mocks.TaskExecutionID{}
taskExecID.OnGetID().Return(core.TaskExecutionIdentifier{
TaskId: &core.Identifier{
ResourceType: core.ResourceType_TASK,
Name: "my-task-name",
Project: "my-task-project",
Domain: "my-task-domain",
Version: "1",
},
NodeExecutionId: &core.NodeExecutionIdentifier{
ExecutionId: &core.WorkflowExecutionIdentifier{
Name: "my_name",
Project: "my_project",
Domain: "my_domain",
Name: "my-execution-name",
Project: "my-execution-project",
Domain: "my-execution-domain",
},
},
RetryAttempt: 1,
})
taskExecID.OnGetUniqueNodeID().Return("unique-node")
taskExecID.OnGetGeneratedName().Return("generated-name")
Expand Down Expand Up @@ -678,6 +686,106 @@ func TestGetTaskPhase(t *testing.T) {
}
}

func TestGetEventInfo_LogTemplates(t *testing.T) {
pluginCtx := newPluginContext()
testCases := []struct {
name string
rayJob rayv1alpha1.RayJob
logPlugin tasklog.TemplateLogPlugin
expectedTaskLogs []*core.TaskLog
}{
{
name: "namespace",
rayJob: rayv1alpha1.RayJob{
ObjectMeta: metav1.ObjectMeta{
Namespace: "test-namespace",
},
},
logPlugin: tasklog.TemplateLogPlugin{
DisplayName: "namespace",
TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}"},
},
expectedTaskLogs: []*core.TaskLog{
{
Name: "namespace",
Uri: "http://test/test-namespace",
},
},
},
{
name: "task execution ID",
rayJob: rayv1alpha1.RayJob{},
logPlugin: tasklog.TemplateLogPlugin{
DisplayName: "taskExecID",
TemplateURIs: []tasklog.TemplateURI{
"http://test/projects/{{ .executionProject }}/domains/{{ .executionDomain }}/executions/{{ .executionName }}/nodeId/{{ .nodeID }}/taskId/{{ .taskID }}/attempt/{{ .taskRetryAttempt }}",
},
Scheme: tasklog.TemplateSchemeTaskExecution,
},
expectedTaskLogs: []*core.TaskLog{
{
Name: "taskExecID",
Uri: "http://test/projects/my-execution-project/domains/my-execution-domain/executions/my-execution-name/nodeId/unique-node/taskId/my-task-name/attempt/1",
},
},
},
{
name: "ray cluster name",
rayJob: rayv1alpha1.RayJob{
ObjectMeta: metav1.ObjectMeta{
Namespace: "test-namespace",
},
Status: rayv1alpha1.RayJobStatus{
RayClusterName: "ray-cluster",
},
},
logPlugin: tasklog.TemplateLogPlugin{
DisplayName: "ray cluster name",
TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}/{{ .rayClusterName }}"},
},
expectedTaskLogs: []*core.TaskLog{
{
Name: "ray cluster name",
Uri: "http://test/test-namespace/ray-cluster",
},
},
},
{
name: "ray job ID",
rayJob: rayv1alpha1.RayJob{
ObjectMeta: metav1.ObjectMeta{
Namespace: "test-namespace",
},
Status: rayv1alpha1.RayJobStatus{
JobId: "ray-job-1",
},
},
logPlugin: tasklog.TemplateLogPlugin{
DisplayName: "ray job ID",
TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}/{{ .rayJobID }}"},
},
expectedTaskLogs: []*core.TaskLog{
{
Name: "ray job ID",
Uri: "http://test/test-namespace/ray-job-1",
},
},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ti, err := getEventInfoForRayJob(
logs.LogConfig{Templates: []tasklog.TemplateLogPlugin{tc.logPlugin}},
pluginCtx,
&tc.rayJob,
)
assert.NoError(t, err)
assert.Equal(t, tc.expectedTaskLogs, ti.Logs)
})
}
}

func TestGetEventInfo_DashboardURL(t *testing.T) {
pluginCtx := newPluginContext()
testCases := []struct {
Expand Down
10 changes: 9 additions & 1 deletion flyteplugins/go/tasks/plugins/webapi/athena/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package athena
import (
"context"
"fmt"
"strings"
"time"

awsSdk "github.com/aws/aws-sdk-go-v2/aws"
Expand Down Expand Up @@ -177,12 +178,19 @@ func (p Plugin) Status(ctx context.Context, tCtx webapi.StatusContext) (phase co

func createTaskInfo(queryID string, cfg awsSdk.Config) *core.TaskInfo {
timeNow := time.Now()
var consoleURL string
if strings.Contains(cfg.Region, "gov") {
consoleURL = "console.amazonaws-us-gov.com"
} else {
consoleURL = "console.aws.amazon.com"
}
return &core.TaskInfo{
OccurredAt: &timeNow,
Logs: []*idlCore.TaskLog{
{
Uri: fmt.Sprintf("https://%v.console.aws.amazon.com/athena/home?force&region=%v#query/history/%v",
Uri: fmt.Sprintf("https://%v.%v/athena/home?force&region=%v#query/history/%v",
cfg.Region,
consoleURL,
cfg.Region,
queryID),
Name: "Athena Query Console",
Expand Down
15 changes: 15 additions & 0 deletions flyteplugins/go/tasks/plugins/webapi/athena/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,18 @@ func TestCreateTaskInfo(t *testing.T) {
assert.Len(t, taskInfo.ExternalResources, 1)
assert.Equal(t, taskInfo.ExternalResources[0].ExternalID, "query_id")
}


func TestCreateTaskInfoGovAWS(t *testing.T) {
taskInfo := createTaskInfo("query_id", awsSdk.Config{
Region: "us-gov-east-1",
})
assert.EqualValues(t, []*idlCore.TaskLog{
{
Uri: "https://us-gov-east-1.console.amazonaws-us-gov.com/athena/home?force&region=us-gov-east-1#query/history/query_id",
Name: "Athena Query Console",
},
}, taskInfo.Logs)
assert.Len(t, taskInfo.ExternalResources, 1)
assert.Equal(t, taskInfo.ExternalResources[0].ExternalID, "query_id")
}
49 changes: 49 additions & 0 deletions flytepropeller/pkg/apis/flyteworkflow/v1alpha1/array_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package v1alpha1

import (
"testing"
)

func TestArrayNodeSpec_GetSubNodeSpec(t *testing.T) {
nodeSpec := &NodeSpec{}
arrayNodeSpec := ArrayNodeSpec{
SubNodeSpec: nodeSpec,
}

if arrayNodeSpec.GetSubNodeSpec() != nodeSpec {
t.Errorf("Expected nodeSpec, but got a different value")
}
}

func TestArrayNodeSpec_GetParallelism(t *testing.T) {
parallelism := uint32(5)
arrayNodeSpec := ArrayNodeSpec{
Parallelism: parallelism,
}

if arrayNodeSpec.GetParallelism() != parallelism {
t.Errorf("Expected %d, but got %d", parallelism, arrayNodeSpec.GetParallelism())
}
}

func TestArrayNodeSpec_GetMinSuccesses(t *testing.T) {
minSuccesses := uint32(3)
arrayNodeSpec := ArrayNodeSpec{
MinSuccesses: &minSuccesses,
}

if *arrayNodeSpec.GetMinSuccesses() != minSuccesses {
t.Errorf("Expected %d, but got %d", minSuccesses, *arrayNodeSpec.GetMinSuccesses())
}
}

func TestArrayNodeSpec_GetMinSuccessRatio(t *testing.T) {
minSuccessRatio := float32(0.8)
arrayNodeSpec := ArrayNodeSpec{
MinSuccessRatio: &minSuccessRatio,
}

if *arrayNodeSpec.GetMinSuccessRatio() != minSuccessRatio {
t.Errorf("Expected %f, but got %f", minSuccessRatio, *arrayNodeSpec.GetMinSuccessRatio())
}
}
78 changes: 75 additions & 3 deletions flytepropeller/pkg/apis/flyteworkflow/v1alpha1/branch_test.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
package v1alpha1_test
package v1alpha1

import (
"bytes"
"encoding/json"
"io/ioutil"
"testing"

"github.com/golang/protobuf/jsonpb"
"github.com/stretchr/testify/assert"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1"
)

func TestMarshalUnMarshal_BranchTask(t *testing.T) {
r, err := ioutil.ReadFile("testdata/branch.json")
assert.NoError(t, err)
o := v1alpha1.NodeSpec{}
o := NodeSpec{}
err = json.Unmarshal(r, &o)
assert.NoError(t, err)
assert.NotNil(t, o.BranchNode.If)
Expand All @@ -25,3 +26,74 @@ func TestMarshalUnMarshal_BranchTask(t *testing.T) {
assert.NotEmpty(t, raw)
}
}

// TestBranchNodeSpecMethods tests the methods of the BranchNodeSpec struct.
func TestErrorMarshalAndUnmarshalJSON(t *testing.T) {
coreError := &core.Error{
FailedNodeId: "TestNode",
Message: "Test error message",
}

err := Error{Error: coreError}
data, jErr := err.MarshalJSON()
assert.Nil(t, jErr)

// Unmarshalling the JSON back to a new core.Error struct
var newCoreError core.Error
uErr := jsonpb.Unmarshal(bytes.NewReader(data), &newCoreError)
assert.Nil(t, uErr)
assert.Equal(t, coreError.Message, newCoreError.Message)
assert.Equal(t, coreError.FailedNodeId, newCoreError.FailedNodeId)
}

func TestBranchNodeSpecMethods(t *testing.T) {
// Creating a core.BooleanExpression instance for testing
boolExpr := &core.BooleanExpression{}

// Creating an Error instance for testing
errorMessage := &core.Error{
Message: "Test error",
}

ifNode := NodeID("ifNode")
elifNode := NodeID("elifNode")
elseNode := NodeID("elseNode")

// Creating a BranchNodeSpec instance for testing
branchNodeSpec := BranchNodeSpec{
If: IfBlock{
Condition: BooleanExpression{
BooleanExpression: boolExpr,
},
ThenNode: &ifNode,
},
ElseIf: []*IfBlock{
{
Condition: BooleanExpression{
BooleanExpression: boolExpr,
},
ThenNode: &elifNode,
},
},
Else: &elseNode,
ElseFail: &Error{Error: errorMessage},
}

assert.Equal(t, boolExpr, branchNodeSpec.If.GetCondition())

assert.Equal(t, &ifNode, branchNodeSpec.If.GetThenNode())

assert.Equal(t, &branchNodeSpec.If, branchNodeSpec.GetIf())

assert.Equal(t, &elseNode, branchNodeSpec.GetElse())

elifs := branchNodeSpec.GetElseIf()
assert.Equal(t, 1, len(elifs))
assert.Equal(t, boolExpr, elifs[0].GetCondition())
assert.Equal(t, &elifNode, elifs[0].GetThenNode())

assert.Equal(t, errorMessage, branchNodeSpec.GetElseFail())

branchNodeSpec.ElseFail = nil
assert.Nil(t, branchNodeSpec.GetElseFail())
}
Loading

0 comments on commit d627270

Please sign in to comment.