From 59340aee1ba04ad549f2b4a0e124a5cf5ee75569 Mon Sep 17 00:00:00 2001 From: pmahindrakar-oss Date: Tue, 5 Apr 2022 16:26:07 +0530 Subject: [PATCH] Passing security context only if its non empty (#300) * Rebase and fixes tests Signed-off-by: Prafulla Mahindrakar * updating to setup-go@v3 Signed-off-by: Prafulla Mahindrakar * Fixing the tests for go 1.18 by adding mocking for k8s copy context Signed-off-by: Prafulla Mahindrakar * fixes Signed-off-by: Prafulla Mahindrakar * test fixes Signed-off-by: Prafulla Mahindrakar * Fixes Signed-off-by: Prafulla Mahindrakar * Fixes Signed-off-by: Prafulla Mahindrakar --- flytectl/.github/workflows/checks.yml | 2 +- flytectl/cmd/create/execution.go | 8 +- flytectl/cmd/create/execution_util.go | 56 +++++---- flytectl/cmd/create/execution_util_test.go | 125 +++++++++++++++++++- flytectl/cmd/get/execution_test.go | 46 +------ flytectl/cmd/get/launch_plan_test.go | 16 +-- flytectl/cmd/register/register_util_test.go | 12 +- flytectl/cmd/sandbox/start.go | 8 +- flytectl/cmd/sandbox/start_test.go | 4 + flytectl/cmd/sandbox/teardown.go | 5 +- flytectl/cmd/sandbox/teardown_test.go | 12 +- flytectl/cmd/testutils/test_utils.go | 7 +- flytectl/pkg/k8s/k8s.go | 79 ++++++++----- flytectl/pkg/k8s/mocks/context_ops.go | 78 ++++++++++++ 14 files changed, 317 insertions(+), 141 deletions(-) create mode 100644 flytectl/pkg/k8s/mocks/context_ops.go diff --git a/flytectl/.github/workflows/checks.yml b/flytectl/.github/workflows/checks.yml index d31399b768..eee6209b0c 100644 --- a/flytectl/.github/workflows/checks.yml +++ b/flytectl/.github/workflows/checks.yml @@ -50,7 +50,7 @@ jobs: ~/.cache/go-build ~/go/pkg/mod key: ${{ runner.os }}-go-${{ hashFiles('go.sum') }} - - uses: actions/setup-go@v2 + - uses: actions/setup-go@v3 with: go-version: '1.17' - name: Run GoReleaser dry run diff --git a/flytectl/cmd/create/execution.go b/flytectl/cmd/create/execution.go index 3693c97424..3ca734d3ed 100644 --- a/flytectl/cmd/create/execution.go +++ b/flytectl/cmd/create/execution.go @@ -177,15 +177,15 @@ func createExecutionCommand(ctx context.Context, args []string, cmdCtx cmdCore.C var executionRequest *admin.ExecutionCreateRequest switch execParams.execType { case Relaunch: - return relaunchExecution(ctx, execParams.name, sourceProject, sourceDomain, cmdCtx) + return relaunchExecution(ctx, execParams.name, sourceProject, sourceDomain, cmdCtx, executionConfig) case Recover: - return recoverExecution(ctx, execParams.name, sourceProject, sourceDomain, cmdCtx) + return recoverExecution(ctx, execParams.name, sourceProject, sourceDomain, cmdCtx, executionConfig) case Task: - if executionRequest, err = createExecutionRequestForTask(ctx, execParams.name, sourceProject, sourceDomain, cmdCtx); err != nil { + if executionRequest, err = createExecutionRequestForTask(ctx, execParams.name, sourceProject, sourceDomain, cmdCtx, executionConfig); err != nil { return err } case Workflow: - if executionRequest, err = createExecutionRequestForWorkflow(ctx, execParams.name, sourceProject, sourceDomain, cmdCtx); err != nil { + if executionRequest, err = createExecutionRequestForWorkflow(ctx, execParams.name, sourceProject, sourceDomain, cmdCtx, executionConfig); err != nil { return err } default: diff --git a/flytectl/cmd/create/execution_util.go b/flytectl/cmd/create/execution_util.go index 82adaf6110..ed862783d0 100644 --- a/flytectl/cmd/create/execution_util.go +++ b/flytectl/cmd/create/execution_util.go @@ -16,7 +16,7 @@ import ( ) func createExecutionRequestForWorkflow(ctx context.Context, workflowName, project, domain string, - cmdCtx cmdCore.CommandContext) (*admin.ExecutionCreateRequest, error) { + cmdCtx cmdCore.CommandContext, executionConfig *ExecutionConfig) (*admin.ExecutionCreateRequest, error) { // Fetch the launch plan lp, err := cmdCtx.AdminFetcherExt().FetchLPVersion(ctx, workflowName, executionConfig.Version, project, domain) if err != nil { @@ -35,23 +35,27 @@ func createExecutionRequestForWorkflow(ctx context.Context, workflowName, projec } // Set both deprecated field and new field for security identity passing - authRole := &admin.AuthRole{ - KubernetesServiceAccount: executionConfig.KubeServiceAcct, - AssumableIamRole: executionConfig.IamRoleARN, - } - - securityContext := &core.SecurityContext{ - RunAs: &core.Identity{ - K8SServiceAccount: executionConfig.KubeServiceAcct, - IamRole: executionConfig.IamRoleARN, - }, + var securityContext *core.SecurityContext + var authRole *admin.AuthRole + + if len(executionConfig.KubeServiceAcct) > 0 || len(executionConfig.IamRoleARN) > 0 { + authRole = &admin.AuthRole{ + KubernetesServiceAccount: executionConfig.KubeServiceAcct, + AssumableIamRole: executionConfig.IamRoleARN, + } + securityContext = &core.SecurityContext{ + RunAs: &core.Identity{ + K8SServiceAccount: executionConfig.KubeServiceAcct, + IamRole: executionConfig.IamRoleARN, + }, + } } return createExecutionRequest(lp.Id, inputs, securityContext, authRole), nil } func createExecutionRequestForTask(ctx context.Context, taskName string, project string, domain string, - cmdCtx cmdCore.CommandContext) (*admin.ExecutionCreateRequest, error) { + cmdCtx cmdCore.CommandContext, executionConfig *ExecutionConfig) (*admin.ExecutionCreateRequest, error) { // Fetch the task task, err := cmdCtx.AdminFetcherExt().FetchTaskVersion(ctx, taskName, executionConfig.Version, project, domain) if err != nil { @@ -69,16 +73,20 @@ func createExecutionRequestForTask(ctx context.Context, taskName string, project } // Set both deprecated field and new field for security identity passing - authRole := &admin.AuthRole{ - KubernetesServiceAccount: executionConfig.KubeServiceAcct, - AssumableIamRole: executionConfig.IamRoleARN, - } - - securityContext := &core.SecurityContext{ - RunAs: &core.Identity{ - K8SServiceAccount: executionConfig.KubeServiceAcct, - IamRole: executionConfig.IamRoleARN, - }, + var securityContext *core.SecurityContext + var authRole *admin.AuthRole + + if len(executionConfig.KubeServiceAcct) > 0 || len(executionConfig.IamRoleARN) > 0 { + authRole = &admin.AuthRole{ + KubernetesServiceAccount: executionConfig.KubeServiceAcct, + AssumableIamRole: executionConfig.IamRoleARN, + } + securityContext = &core.SecurityContext{ + RunAs: &core.Identity{ + K8SServiceAccount: executionConfig.KubeServiceAcct, + IamRole: executionConfig.IamRoleARN, + }, + } } id := &core.Identifier{ @@ -93,7 +101,7 @@ func createExecutionRequestForTask(ctx context.Context, taskName string, project } func relaunchExecution(ctx context.Context, executionName string, project string, domain string, - cmdCtx cmdCore.CommandContext) error { + cmdCtx cmdCore.CommandContext, executionConfig *ExecutionConfig) error { if executionConfig.DryRun { logger.Debugf(ctx, "skipping RelaunchExecution request (DryRun)") return nil @@ -113,7 +121,7 @@ func relaunchExecution(ctx context.Context, executionName string, project string } func recoverExecution(ctx context.Context, executionName string, project string, domain string, - cmdCtx cmdCore.CommandContext) error { + cmdCtx cmdCore.CommandContext, executionConfig *ExecutionConfig) error { if executionConfig.DryRun { logger.Debugf(ctx, "skipping RecoverExecution request (DryRun)") return nil diff --git a/flytectl/cmd/create/execution_util_test.go b/flytectl/cmd/create/execution_util_test.go index 643248daaa..0342d4b5c7 100644 --- a/flytectl/cmd/create/execution_util_test.go +++ b/flytectl/cmd/create/execution_util_test.go @@ -2,6 +2,7 @@ package create import ( "errors" + "fmt" "testing" "github.com/flyteorg/flytectl/cmd/config" @@ -9,6 +10,7 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) var ( @@ -40,13 +42,14 @@ func createExecutionUtilSetup() { Domain: config.GetConfig().Domain, }, } + executionConfig = &ExecutionConfig{} } func TestCreateExecutionForRelaunch(t *testing.T) { s := setup() createExecutionUtilSetup() s.MockAdminClient.OnRelaunchExecutionMatch(s.Ctx, relaunchRequest).Return(executionCreateResponse, nil) - err := relaunchExecution(s.Ctx, "execName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx) + err := relaunchExecution(s.Ctx, "execName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx, executionConfig) assert.Nil(t, err) } @@ -54,7 +57,8 @@ func TestCreateExecutionForRelaunchNotFound(t *testing.T) { s := setup() createExecutionUtilSetup() s.MockAdminClient.OnRelaunchExecutionMatch(s.Ctx, relaunchRequest).Return(nil, errors.New("unknown execution")) - err := relaunchExecution(s.Ctx, "execName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx) + err := relaunchExecution(s.Ctx, "execName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx, executionConfig) + assert.NotNil(t, err) assert.Equal(t, err, errors.New("unknown execution")) } @@ -63,7 +67,7 @@ func TestCreateExecutionForRecovery(t *testing.T) { s := setup() createExecutionUtilSetup() s.MockAdminClient.OnRecoverExecutionMatch(s.Ctx, recoverRequest).Return(executionCreateResponse, nil) - err := recoverExecution(s.Ctx, "execName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx) + err := recoverExecution(s.Ctx, "execName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx, executionConfig) assert.Nil(t, err) } @@ -71,7 +75,120 @@ func TestCreateExecutionForRecoveryNotFound(t *testing.T) { s := setup() createExecutionUtilSetup() s.MockAdminClient.OnRecoverExecutionMatch(s.Ctx, recoverRequest).Return(nil, errors.New("unknown execution")) - err := recoverExecution(s.Ctx, "execName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx) + err := recoverExecution(s.Ctx, "execName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx, executionConfig) assert.NotNil(t, err) assert.Equal(t, err, errors.New("unknown execution")) } + +func TestCreateExecutionRequestForWorkflow(t *testing.T) { + t.Run("successful", func(t *testing.T) { + s := setup() + createExecutionUtilSetup() + launchPlan := &admin.LaunchPlan{} + s.FetcherExt.OnFetchLPVersionMatch(s.Ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(launchPlan, nil) + execCreateRequest, err := createExecutionRequestForWorkflow(s.Ctx, "wfName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx, executionConfig) + assert.Nil(t, err) + assert.NotNil(t, execCreateRequest) + }) + t.Run("failed literal conversion", func(t *testing.T) { + s := setup() + createExecutionUtilSetup() + launchPlan := &admin.LaunchPlan{ + Spec: &admin.LaunchPlanSpec{ + DefaultInputs: &core.ParameterMap{ + Parameters: map[string]*core.Parameter{"nilparam": nil}, + }, + }, + } + s.FetcherExt.OnFetchLPVersionMatch(s.Ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(launchPlan, nil) + execCreateRequest, err := createExecutionRequestForWorkflow(s.Ctx, "wfName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx, executionConfig) + assert.NotNil(t, err) + assert.Nil(t, execCreateRequest) + assert.Equal(t, fmt.Errorf("parameter [nilparam] has nil Variable"), err) + }) + t.Run("failed fetch", func(t *testing.T) { + s := setup() + createExecutionUtilSetup() + s.FetcherExt.OnFetchLPVersionMatch(s.Ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, fmt.Errorf("failed")) + execCreateRequest, err := createExecutionRequestForWorkflow(s.Ctx, "wfName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx, executionConfig) + assert.NotNil(t, err) + assert.Nil(t, execCreateRequest) + assert.Equal(t, err, errors.New("failed")) + }) + t.Run("with security context", func(t *testing.T) { + s := setup() + createExecutionUtilSetup() + executionConfig.KubeServiceAcct = "default" + launchPlan := &admin.LaunchPlan{} + s.FetcherExt.OnFetchLPVersionMatch(s.Ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(launchPlan, nil) + s.MockAdminClient.OnGetLaunchPlanMatch(s.Ctx, mock.Anything).Return(launchPlan, nil) + execCreateRequest, err := createExecutionRequestForWorkflow(s.Ctx, "wfName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx, executionConfig) + assert.Nil(t, err) + assert.NotNil(t, execCreateRequest) + executionConfig.KubeServiceAcct = "" + }) +} + +func TestCreateExecutionRequestForTask(t *testing.T) { + t.Run("successful", func(t *testing.T) { + s := setup() + createExecutionUtilSetup() + task := &admin.Task{ + Id: &core.Identifier{ + Name: "taskName", + }, + } + s.FetcherExt.OnFetchTaskVersionMatch(s.Ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(task, nil) + execCreateRequest, err := createExecutionRequestForTask(s.Ctx, "taskName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx, executionConfig) + assert.Nil(t, err) + assert.NotNil(t, execCreateRequest) + }) + t.Run("failed literal conversion", func(t *testing.T) { + s := setup() + createExecutionUtilSetup() + task := &admin.Task{ + Closure: &admin.TaskClosure{ + CompiledTask: &core.CompiledTask{ + Template: &core.TaskTemplate{ + Interface: &core.TypedInterface{ + Inputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "nilvar": nil, + }, + }, + }, + }, + }, + }, + } + s.FetcherExt.OnFetchTaskVersionMatch(s.Ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(task, nil) + execCreateRequest, err := createExecutionRequestForTask(s.Ctx, "taskName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx, executionConfig) + assert.NotNil(t, err) + assert.Nil(t, execCreateRequest) + assert.Equal(t, fmt.Errorf("variable [nilvar] has nil type"), err) + }) + t.Run("failed fetch", func(t *testing.T) { + s := setup() + createExecutionUtilSetup() + s.FetcherExt.OnFetchTaskVersionMatch(s.Ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, fmt.Errorf("failed")) + execCreateRequest, err := createExecutionRequestForTask(s.Ctx, "taskName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx, executionConfig) + assert.NotNil(t, err) + assert.Nil(t, execCreateRequest) + assert.Equal(t, err, errors.New("failed")) + }) + t.Run("with security context", func(t *testing.T) { + s := setup() + createExecutionUtilSetup() + executionConfig.KubeServiceAcct = "default" + task := &admin.Task{ + Id: &core.Identifier{ + Name: "taskName", + }, + } + s.FetcherExt.OnFetchTaskVersionMatch(s.Ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(task, nil) + execCreateRequest, err := createExecutionRequestForTask(s.Ctx, "taskName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx, executionConfig) + assert.Nil(t, err) + assert.NotNil(t, execCreateRequest) + executionConfig.KubeServiceAcct = "" + }) +} diff --git a/flytectl/cmd/get/execution_test.go b/flytectl/cmd/get/execution_test.go index ff11d03db0..64476ad82f 100644 --- a/flytectl/cmd/get/execution_test.go +++ b/flytectl/cmd/get/execution_test.go @@ -32,18 +32,6 @@ func getExecutionSetup() { func TestListExecutionFunc(t *testing.T) { getExecutionSetup() s := setup() - ctx := s.Ctx - execListRequest := &admin.ResourceListRequest{ - Limit: 100, - SortBy: &admin.Sort{ - Key: "created_at", - Direction: admin.Sort_DESCENDING, - }, - Id: &admin.NamedEntityIdentifier{ - Project: projectValue, - Domain: domainValue, - }, - } executionResponse := &admin.Execution{ Id: &core.WorkflowExecutionIdentifier{ Project: projectValue, @@ -72,26 +60,14 @@ func TestListExecutionFunc(t *testing.T) { executionList := &admin.ExecutionList{ Executions: executions, } - s.MockAdminClient.OnListExecutionsMatch(mock.Anything, execListRequest).Return(executionList, nil) + s.FetcherExt.OnListExecutionMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(executionList, nil) err := getExecutionFunc(s.Ctx, []string{}, s.CmdCtx) assert.Nil(t, err) - s.MockAdminClient.AssertCalled(t, "ListExecutions", ctx, execListRequest) + s.FetcherExt.AssertCalled(t, "ListExecution", s.Ctx, projectValue, domainValue, execution.DefaultConfig.Filter) } func TestListExecutionFuncWithError(t *testing.T) { - ctx := context.Background() getExecutionSetup() - execListRequest := &admin.ResourceListRequest{ - Limit: 100, - SortBy: &admin.Sort{ - Key: "created_at", - }, - Id: &admin.NamedEntityIdentifier{ - Project: projectValue, - Domain: domainValue, - }, - } - _ = &admin.Execution{ Id: &core.WorkflowExecutionIdentifier{ Project: projectValue, @@ -118,23 +94,14 @@ func TestListExecutionFuncWithError(t *testing.T) { } s := setup() s.FetcherExt.OnListExecutionMatch(s.Ctx, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("executions NotFound")) - s.MockAdminClient.OnListExecutionsMatch(mock.Anything, execListRequest).Return(nil, errors.New("executions NotFound")) err := getExecutionFunc(s.Ctx, []string{}, s.CmdCtx) assert.NotNil(t, err) assert.Equal(t, err, errors.New("executions NotFound")) - s.MockAdminClient.AssertCalled(t, "ListExecutions", ctx, execListRequest) + s.FetcherExt.AssertCalled(t, "ListExecution", s.Ctx, projectValue, domainValue, execution.DefaultConfig.Filter) } func TestGetExecutionFunc(t *testing.T) { - ctx := context.Background() getExecutionSetup() - execGetRequest := &admin.WorkflowExecutionGetRequest{ - Id: &core.WorkflowExecutionIdentifier{ - Project: projectValue, - Domain: domainValue, - Name: executionNameValue, - }, - } executionResponse := &admin.Execution{ Id: &core.WorkflowExecutionIdentifier{ Project: projectValue, @@ -161,14 +128,11 @@ func TestGetExecutionFunc(t *testing.T) { } args := []string{executionNameValue} s := setup() - //executionList := &admin.ExecutionList{ - // Executions: []*admin.Execution{executionResponse}, - //} - s.MockAdminClient.OnGetExecutionMatch(ctx, execGetRequest).Return(executionResponse, nil) + s.FetcherExt.OnFetchExecutionMatch(s.Ctx, mock.Anything, mock.Anything, mock.Anything).Return(executionResponse, nil) err := getExecutionFunc(s.Ctx, args, s.CmdCtx) assert.Nil(t, err) - s.MockAdminClient.AssertCalled(t, "GetExecution", ctx, execGetRequest) + s.FetcherExt.AssertCalled(t, "FetchExecution", s.Ctx, executionNameValue, projectValue, domainValue) } func TestGetExecutionFuncForDetails(t *testing.T) { diff --git a/flytectl/cmd/get/launch_plan_test.go b/flytectl/cmd/get/launch_plan_test.go index 3831631f87..d04698d8b0 100644 --- a/flytectl/cmd/get/launch_plan_test.go +++ b/flytectl/cmd/get/launch_plan_test.go @@ -253,13 +253,10 @@ func TestGetLaunchPlanFuncWithError(t *testing.T) { func TestGetLaunchPlanFunc(t *testing.T) { s := setup() getLaunchPlanSetup() - s.MockAdminClient.OnListLaunchPlansMatch(s.Ctx, resourceGetRequest).Return(launchPlanListResponse, nil) - s.MockAdminClient.OnGetLaunchPlanMatch(s.Ctx, objectGetRequest).Return(launchPlan2, nil) - s.MockAdminClient.OnListLaunchPlanIdsMatch(s.Ctx, namedIDRequest).Return(namedIdentifierList, nil) - s.FetcherExt.OnFetchAllVerOfLP(s.Ctx, "launchplan1", "dummyProject", "dummyDomain", filters.Filters{}).Return(launchPlanListResponse.LaunchPlans, nil) + s.FetcherExt.OnFetchAllVerOfLPMatch(mock.Anything, mock.Anything, "dummyProject", "dummyDomain", filters.Filters{}).Return(launchPlanListResponse.LaunchPlans, nil) err := getLaunchPlanFunc(s.Ctx, argsLp, s.CmdCtx) assert.Nil(t, err) - s.MockAdminClient.AssertCalled(t, "ListLaunchPlans", s.Ctx, resourceGetRequest) + s.FetcherExt.AssertCalled(t, "FetchAllVerOfLP", s.Ctx, "launchplan1", "dummyProject", "dummyDomain", launchplan.DefaultConfig.Filter) tearDownAndVerify(t, s.Writer, `[{"id": {"name": "launchplan1","version": "v2"},"spec": {"workflowId": {"name": "workflow2"},"defaultInputs": {"parameters": {"numbers": {"var": {"type": {"collectionType": {"simple": "INTEGER"}},"description": "short desc"}},"numbers_count": {"var": {"type": {"simple": "INTEGER"},"description": "long description will be truncated in table"}},"run_local_at_count": {"var": {"type": {"simple": "INTEGER"},"description": "run_local_at_count"},"default": {"scalar": {"primitive": {"integer": "10"}}}}}}},"closure": {"expectedInputs": {"parameters": {"numbers": {"var": {"type": {"collectionType": {"simple": "INTEGER"}},"description": "short desc"}},"numbers_count": {"var": {"type": {"simple": "INTEGER"},"description": "long description will be truncated in table"}},"run_local_at_count": {"var": {"type": {"simple": "INTEGER"},"description": "run_local_at_count"},"default": {"scalar": {"primitive": {"integer": "10"}}}}}},"createdAt": "1970-01-01T00:00:01Z"}},{"id": {"name": "launchplan1","version": "v1"},"spec": {"workflowId": {"name": "workflow1"},"defaultInputs": {"parameters": {"numbers": {"var": {"type": {"collectionType": {"simple": "INTEGER"}},"description": "short desc"}},"numbers_count": {"var": {"type": {"simple": "INTEGER"},"description": "long description will be truncated in table"}},"run_local_at_count": {"var": {"type": {"simple": "INTEGER"},"description": "run_local_at_count"},"default": {"scalar": {"primitive": {"integer": "10"}}}}}}},"closure": {"expectedInputs": {"parameters": {"numbers": {"var": {"type": {"collectionType": {"simple": "INTEGER"}},"description": "short desc"}},"numbers_count": {"var": {"type": {"simple": "INTEGER"},"description": "long description will be truncated in table"}},"run_local_at_count": {"var": {"type": {"simple": "INTEGER"},"description": "run_local_at_count"},"default": {"scalar": {"primitive": {"integer": "10"}}}}}},"createdAt": "1970-01-01T00:00:00Z"}}]`) } @@ -268,11 +265,10 @@ func TestGetLaunchPlanFuncLatest(t *testing.T) { getLaunchPlanSetup() launchplan.DefaultConfig.Latest = true launchplan.DefaultConfig.Filter = filters.Filters{} - s.MockAdminClient.OnListLaunchPlansMatch(s.Ctx, resourceGetRequest).Return(launchPlanListResponse, nil) - s.MockAdminClient.OnGetLaunchPlanMatch(s.Ctx, objectGetRequest).Return(launchPlan2, nil) + s.FetcherExt.OnFetchLPLatestVersionMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(launchPlan2, nil) err := getLaunchPlanFunc(s.Ctx, argsLp, s.CmdCtx) assert.Nil(t, err) - s.MockAdminClient.AssertCalled(t, "ListLaunchPlans", s.Ctx, resourceGetRequest) + s.FetcherExt.AssertCalled(t, "FetchLPLatestVersion", s.Ctx, "launchplan1", projectValue, domainValue, launchplan.DefaultConfig.Filter) tearDownAndVerify(t, s.Writer, `{"id": {"name": "launchplan1","version": "v2"},"spec": {"workflowId": {"name": "workflow2"},"defaultInputs": {"parameters": {"numbers": {"var": {"type": {"collectionType": {"simple": "INTEGER"}},"description": "short desc"}},"numbers_count": {"var": {"type": {"simple": "INTEGER"},"description": "long description will be truncated in table"}},"run_local_at_count": {"var": {"type": {"simple": "INTEGER"},"description": "run_local_at_count"},"default": {"scalar": {"primitive": {"integer": "10"}}}}}}},"closure": {"expectedInputs": {"parameters": {"numbers": {"var": {"type": {"collectionType": {"simple": "INTEGER"}},"description": "short desc"}},"numbers_count": {"var": {"type": {"simple": "INTEGER"},"description": "long description will be truncated in table"}},"run_local_at_count": {"var": {"type": {"simple": "INTEGER"},"description": "run_local_at_count"},"default": {"scalar": {"primitive": {"integer": "10"}}}}}},"createdAt": "1970-01-01T00:00:01Z"}}`) } @@ -280,9 +276,6 @@ func TestGetLaunchPlanWithVersion(t *testing.T) { s := testutils.SetupWithExt() getLaunchPlanSetup() launchplan.DefaultConfig.Version = "v2" - s.MockAdminClient.OnListLaunchPlansMatch(s.Ctx, resourceListRequest).Return(launchPlanListResponse, nil) - s.MockAdminClient.OnGetLaunchPlanMatch(s.Ctx, objectGetRequest).Return(launchPlan2, nil) - s.MockAdminClient.OnListLaunchPlanIdsMatch(s.Ctx, namedIDRequest).Return(namedIdentifierList, nil) s.FetcherExt.OnFetchLPVersion(s.Ctx, "launchplan1", "v2", "dummyProject", "dummyDomain").Return(launchPlan2, nil) err := getLaunchPlanFunc(s.Ctx, argsLp, s.CmdCtx) assert.Nil(t, err) @@ -294,7 +287,6 @@ func TestGetLaunchPlans(t *testing.T) { t.Run("no workflow filter", func(t *testing.T) { s := setup() getLaunchPlanSetup() - s.MockAdminClient.OnListLaunchPlansMatch(s.Ctx, resourceListRequest).Return(launchPlanListResponse, nil) s.FetcherExt.OnFetchAllVerOfLP(s.Ctx, "", "dummyProject", "dummyDomain", filters.Filters{}).Return(launchPlanListResponse.LaunchPlans, nil) argsLp = []string{} err := getLaunchPlanFunc(s.Ctx, argsLp, s.CmdCtx) diff --git a/flytectl/cmd/register/register_util_test.go b/flytectl/cmd/register/register_util_test.go index e9b4caae05..ca55bb348c 100644 --- a/flytectl/cmd/register/register_util_test.go +++ b/flytectl/cmd/register/register_util_test.go @@ -276,9 +276,7 @@ func TestRegisterFile(t *testing.T) { }, }, } - s.MockAdminClient.OnGetWorkflowMatch(mock.Anything, mock.Anything).Return(wf, nil) - s.FetcherExt.OnFetchWorkflowVersion(s.Ctx, "core.scheduled_workflows.lp_schedules.date_formatter_wf", "v0.3.59", "dummyProject", "dummyDomain").Return(wf, nil) - s.FetcherExt.OnFetchWorkflowVersion(s.Ctx, "core.scheduled_workflows.lp_schedules.date_formatter_wf", "", "dummyProject", "dummyDomain").Return(wf, nil) + s.FetcherExt.OnFetchWorkflowVersionMatch(s.Ctx, "core.scheduled_workflows.lp_schedules.date_formatter_wf", mock.Anything, "dummyProject", "dummyDomain").Return(wf, nil) args := []string{"testdata/152_my_cron_scheduled_lp_3.pb"} var registerResults []Result results, err := registerFile(s.Ctx, args[0], registerResults, s.CmdCtx, "", *rconfig.DefaultFilesConfig) @@ -624,7 +622,7 @@ func TestValidateLaunchSpec(t *testing.T) { s := setup() registerFilesSetup() - s.MockAdminClient.OnGetWorkflowMatch(mock.Anything, mock.Anything).Return(nil, fmt.Errorf("failed")) + s.FetcherExt.OnFetchWorkflowVersionMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, fmt.Errorf("failed")) lpSpec := &admin.LaunchPlanSpec{ WorkflowId: &core.Identifier{ Project: "projectValue", @@ -649,7 +647,7 @@ func TestValidateLaunchSpec(t *testing.T) { s := setup() registerFilesSetup() - s.MockAdminClient.OnGetWorkflowMatch(mock.Anything, mock.Anything).Return(nil, fmt.Errorf("failed")) + s.FetcherExt.OnFetchWorkflowVersionMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, fmt.Errorf("failed")) lpSpec := &admin.LaunchPlanSpec{ WorkflowId: &core.Identifier{ Project: "projectValue", @@ -711,7 +709,7 @@ func TestValidateLaunchSpec(t *testing.T) { }, }, } - s.MockAdminClient.OnGetWorkflowMatch(mock.Anything, mock.Anything).Return(wf, nil) + s.FetcherExt.OnFetchWorkflowVersionMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(wf, nil) lpSpec := &admin.LaunchPlanSpec{ WorkflowId: &core.Identifier{ Project: "projectValue", @@ -773,7 +771,7 @@ func TestValidateLaunchSpec(t *testing.T) { }, }, } - s.MockAdminClient.OnGetWorkflowMatch(mock.Anything, mock.Anything).Return(wf, nil) + s.FetcherExt.OnFetchWorkflowVersionMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(wf, nil) lpSpec := &admin.LaunchPlanSpec{ WorkflowId: &core.Identifier{ Project: "projectValue", diff --git a/flytectl/cmd/sandbox/start.go b/flytectl/cmd/sandbox/start.go index 7c373bb2ed..f45c791f08 100644 --- a/flytectl/cmd/sandbox/start.go +++ b/flytectl/cmd/sandbox/start.go @@ -150,14 +150,12 @@ func startSandboxCluster(ctx context.Context, args []string, cmdCtx cmdCore.Comm } func updateLocalKubeContext() error { - localConfigAccess := clientcmd.NewDefaultPathOptions() - - dockerConfigAccess := &clientcmd.PathOptions{ + srcConfigAccess := &clientcmd.PathOptions{ GlobalFile: docker.Kubeconfig, LoadingRules: clientcmd.NewDefaultClientConfigLoadingRules(), } - - return k8s.CopyKubeContext(dockerConfigAccess, localConfigAccess, sandboxDockerContext, sandboxContextName) + k8sCtxMgr := k8s.NewK8sContextManager() + return k8sCtxMgr.CopyContext(srcConfigAccess, sandboxDockerContext, sandboxContextName) } func startSandbox(ctx context.Context, cli docker.Docker, reader io.Reader) (*bufio.Scanner, error) { diff --git a/flytectl/cmd/sandbox/start_test.go b/flytectl/cmd/sandbox/start_test.go index ae938ab1e2..d6c789fd7b 100644 --- a/flytectl/cmd/sandbox/start_test.go +++ b/flytectl/cmd/sandbox/start_test.go @@ -24,6 +24,7 @@ import ( "github.com/flyteorg/flytectl/pkg/docker" "github.com/flyteorg/flytectl/pkg/docker/mocks" f "github.com/flyteorg/flytectl/pkg/filesystemutils" + k8sMocks "github.com/flyteorg/flytectl/pkg/k8s/mocks" "github.com/flyteorg/flytectl/pkg/util" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -554,9 +555,12 @@ func TestStartSandboxFunc(t *testing.T) { Follow: true, }).Return(reader, nil) mockDocker.OnContainerWaitMatch(ctx, mock.Anything, container.WaitConditionNotRunning).Return(bodyStatus, errCh) + mockK8sContextMgr := &k8sMocks.ContextOps{} docker.Client = mockDocker sandboxConfig.DefaultConfig.Source = "" sandboxConfig.DefaultConfig.Version = "" + k8s.ContextMgr = mockK8sContextMgr + mockK8sContextMgr.OnCopyContextMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) err = startSandboxCluster(ctx, []string{}, cmdCtx) assert.Nil(t, err) }) diff --git a/flytectl/cmd/sandbox/teardown.go b/flytectl/cmd/sandbox/teardown.go index bbdfa303c6..16ca11bf20 100644 --- a/flytectl/cmd/sandbox/teardown.go +++ b/flytectl/cmd/sandbox/teardown.go @@ -13,7 +13,6 @@ import ( cmdCore "github.com/flyteorg/flytectl/cmd/core" "github.com/flyteorg/flytectl/pkg/k8s" - "k8s.io/client-go/tools/clientcmd" ) const ( @@ -58,6 +57,6 @@ func tearDownSandbox(ctx context.Context, cli docker.Docker) error { } func removeSandboxKubeContext() error { - localConfigAccess := clientcmd.NewDefaultPathOptions() - return k8s.RemoveKubeContext(localConfigAccess, sandboxContextName) + k8sCtxMgr := k8s.NewK8sContextManager() + return k8sCtxMgr.RemoveContext(sandboxContextName) } diff --git a/flytectl/cmd/sandbox/teardown_test.go b/flytectl/cmd/sandbox/teardown_test.go index d95b12221d..0664342979 100644 --- a/flytectl/cmd/sandbox/teardown_test.go +++ b/flytectl/cmd/sandbox/teardown_test.go @@ -5,14 +5,14 @@ import ( "fmt" "testing" + "github.com/docker/docker/api/types" "github.com/flyteorg/flytectl/cmd/testutils" - "github.com/flyteorg/flytectl/pkg/configutil" - "github.com/flyteorg/flytectl/pkg/util" - - "github.com/docker/docker/api/types" "github.com/flyteorg/flytectl/pkg/docker" "github.com/flyteorg/flytectl/pkg/docker/mocks" + "github.com/flyteorg/flytectl/pkg/k8s" + k8sMocks "github.com/flyteorg/flytectl/pkg/k8s/mocks" + "github.com/flyteorg/flytectl/pkg/util" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) @@ -33,7 +33,9 @@ func TestTearDownFunc(t *testing.T) { mockDocker := &mocks.Docker{} mockDocker.OnContainerList(ctx, types.ContainerListOptions{All: true}).Return(containers, nil) mockDocker.OnContainerRemove(ctx, mock.Anything, types.ContainerRemoveOptions{Force: true}).Return(nil) - + mockK8sContextMgr := &k8sMocks.ContextOps{} + k8s.ContextMgr = mockK8sContextMgr + mockK8sContextMgr.OnRemoveContextMatch(mock.Anything).Return(nil) err := tearDownSandbox(ctx, mockDocker) assert.Nil(t, err) }) diff --git a/flytectl/cmd/testutils/test_utils.go b/flytectl/cmd/testutils/test_utils.go index 86577b1323..b94b7d3671 100644 --- a/flytectl/cmd/testutils/test_utils.go +++ b/flytectl/cmd/testutils/test_utils.go @@ -10,8 +10,6 @@ import ( "strings" "testing" - "github.com/flyteorg/flytectl/pkg/ext" - "github.com/flyteorg/flyteidl/clients/go/admin/mocks" "github.com/flyteorg/flyteidl/clients/go/admin" @@ -61,11 +59,8 @@ func Setup() (s TestStruct) { s.UpdaterExt.OnAdminServiceClient().Return(s.MockClient.AdminClient()) s.DeleterExt.OnAdminServiceClient().Return(s.MockClient.AdminClient()) s.MockAdminClient = s.MockClient.AdminClient().(*mocks.AdminServiceClient) - fetcher := &ext.AdminFetcherExtClient{ - AdminClient: s.MockAdminClient, - } s.MockOutStream = s.Writer - s.CmdCtx = cmdCore.NewCommandContextWithExt(s.MockClient, fetcher, s.UpdaterExt, s.DeleterExt, s.MockOutStream) + s.CmdCtx = cmdCore.NewCommandContextWithExt(s.MockClient, s.FetcherExt, s.UpdaterExt, s.DeleterExt, s.MockOutStream) config.GetConfig().Project = projectValue config.GetConfig().Domain = domainValue config.GetConfig().Output = output diff --git a/flytectl/pkg/k8s/k8s.go b/flytectl/pkg/k8s/k8s.go index 74086d76a2..6f9ca3e07d 100644 --- a/flytectl/pkg/k8s/k8s.go +++ b/flytectl/pkg/k8s/k8s.go @@ -15,7 +15,29 @@ type K8s interface { CoreV1() corev1.CoreV1Interface } +//go:generate mockery -name=ContextOps -case=underscore +type ContextOps interface { + CopyContext(srcConfigAccess clientcmd.ConfigAccess, srcCtxName, targetCtxName string) error + RemoveContext(ctxName string) error +} + +// ContextManager context manager implementing ContextOps +type ContextManager struct { + configAccess clientcmd.ConfigAccess +} + +func NewK8sContextManager() ContextOps { + if ContextMgr != nil { + return ContextMgr + } + ContextMgr = &ContextManager{ + configAccess: clientcmd.NewDefaultPathOptions(), + } + return ContextMgr +} + var Client K8s +var ContextMgr ContextOps // GetK8sClient return the k8s client from sandbox kubeconfig func GetK8sClient(cfg, master string) (K8s, error) { @@ -34,70 +56,69 @@ func GetK8sClient(cfg, master string) (K8s, error) { return Client, nil } -// CopyKubeContext copies context fromContext part of fromConfigAccess to toContext part of toConfigAccess. -func CopyKubeContext(fromConfigAccess, toConfigAccess clientcmd.ConfigAccess, fromContext, toContext string) error { - _, err := toConfigAccess.GetStartingConfig() +// CopyKubeContext copies context srcCtxName part of srcConfigAccess to targetCtxName part of targetConfigAccess. +func (k *ContextManager) CopyContext(srcConfigAccess clientcmd.ConfigAccess, srcCtxName, targetCtxName string) error { + _, err := k.configAccess.GetStartingConfig() if err != nil { return err } - fromStartingConfig, err := fromConfigAccess.GetStartingConfig() + fromStartingConfig, err := srcConfigAccess.GetStartingConfig() if err != nil { return err } - _, exists := fromStartingConfig.Contexts[fromContext] + _, exists := fromStartingConfig.Contexts[srcCtxName] if !exists { - return fmt.Errorf("context %v doesn't exist", fromContext) + return fmt.Errorf("context %v doesn't exist", srcCtxName) } - toStartingConfig, err := toConfigAccess.GetStartingConfig() + toStartingConfig, err := k.configAccess.GetStartingConfig() if err != nil { return err } - _, exists = toStartingConfig.Contexts[toContext] + _, exists = toStartingConfig.Contexts[targetCtxName] if exists { - fmt.Printf("context %v already exist. Overwriting it\n", toContext) + fmt.Printf("context %v already exist. Overwriting it\n", targetCtxName) } else { - toStartingConfig.Contexts[toContext] = clientcmdapi.NewContext() + toStartingConfig.Contexts[targetCtxName] = clientcmdapi.NewContext() } - toStartingConfig.Clusters[toContext] = fromStartingConfig.Clusters[fromContext] - toStartingConfig.Clusters[toContext].LocationOfOrigin = toConfigAccess.GetDefaultFilename() - toStartingConfig.AuthInfos[toContext] = fromStartingConfig.AuthInfos[fromContext] - toStartingConfig.AuthInfos[toContext].LocationOfOrigin = toConfigAccess.GetDefaultFilename() - toStartingConfig.Contexts[toContext].Cluster = toContext - toStartingConfig.Contexts[toContext].AuthInfo = toContext - toStartingConfig.CurrentContext = toContext - - if err := clientcmd.ModifyConfig(toConfigAccess, *toStartingConfig, true); err != nil { + toStartingConfig.Clusters[targetCtxName] = fromStartingConfig.Clusters[srcCtxName] + toStartingConfig.Clusters[targetCtxName].LocationOfOrigin = k.configAccess.GetDefaultFilename() + toStartingConfig.AuthInfos[targetCtxName] = fromStartingConfig.AuthInfos[srcCtxName] + toStartingConfig.AuthInfos[targetCtxName].LocationOfOrigin = k.configAccess.GetDefaultFilename() + toStartingConfig.Contexts[targetCtxName].Cluster = targetCtxName + toStartingConfig.Contexts[targetCtxName].AuthInfo = targetCtxName + toStartingConfig.CurrentContext = targetCtxName + if err := clientcmd.ModifyConfig(k.configAccess, *toStartingConfig, true); err != nil { return err } - fmt.Printf("context modified for %q and switched over to it.\n", toContext) + fmt.Printf("context modified for %q and switched over to it.\n", targetCtxName) return nil } // RemoveKubeContext removes the contextToRemove from the kubeContext pointed to be fromConfigAccess -func RemoveKubeContext(fromConfigAccess clientcmd.ConfigAccess, contextToRemove string) error { - fromStartingConfig, err := fromConfigAccess.GetStartingConfig() +func (k *ContextManager) RemoveContext(ctxName string) error { + fromStartingConfig, err := k.configAccess.GetStartingConfig() if err != nil { return err } - _, exists := fromStartingConfig.Contexts[contextToRemove] + _, exists := fromStartingConfig.Contexts[ctxName] if !exists { - return fmt.Errorf("context %v doesn't exist", contextToRemove) + return fmt.Errorf("context %v doesn't exist", ctxName) } - delete(fromStartingConfig.Clusters, contextToRemove) - delete(fromStartingConfig.AuthInfos, contextToRemove) - delete(fromStartingConfig.Contexts, contextToRemove) + delete(fromStartingConfig.Clusters, ctxName) + delete(fromStartingConfig.AuthInfos, ctxName) + delete(fromStartingConfig.Contexts, ctxName) fromStartingConfig.CurrentContext = "" - if err := clientcmd.ModifyConfig(fromConfigAccess, *fromStartingConfig, true); err != nil { + if err := clientcmd.ModifyConfig(k.configAccess, *fromStartingConfig, true); err != nil { return err } - fmt.Printf("context removed for %q.\n", contextToRemove) + fmt.Printf("context removed for %q.\n", ctxName) return nil } diff --git a/flytectl/pkg/k8s/mocks/context_ops.go b/flytectl/pkg/k8s/mocks/context_ops.go new file mode 100644 index 0000000000..11b7003d0f --- /dev/null +++ b/flytectl/pkg/k8s/mocks/context_ops.go @@ -0,0 +1,78 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + clientcmd "k8s.io/client-go/tools/clientcmd" + + mock "github.com/stretchr/testify/mock" +) + +// ContextOps is an autogenerated mock type for the ContextOps type +type ContextOps struct { + mock.Mock +} + +type ContextOps_CopyContext struct { + *mock.Call +} + +func (_m ContextOps_CopyContext) Return(_a0 error) *ContextOps_CopyContext { + return &ContextOps_CopyContext{Call: _m.Call.Return(_a0)} +} + +func (_m *ContextOps) OnCopyContext(srcConfigAccess clientcmd.ConfigAccess, srcCtxName string, targetCtxName string) *ContextOps_CopyContext { + c := _m.On("CopyContext", srcConfigAccess, srcCtxName, targetCtxName) + return &ContextOps_CopyContext{Call: c} +} + +func (_m *ContextOps) OnCopyContextMatch(matchers ...interface{}) *ContextOps_CopyContext { + c := _m.On("CopyContext", matchers...) + return &ContextOps_CopyContext{Call: c} +} + +// CopyContext provides a mock function with given fields: srcConfigAccess, srcCtxName, targetCtxName +func (_m *ContextOps) CopyContext(srcConfigAccess clientcmd.ConfigAccess, srcCtxName string, targetCtxName string) error { + ret := _m.Called(srcConfigAccess, srcCtxName, targetCtxName) + + var r0 error + if rf, ok := ret.Get(0).(func(clientcmd.ConfigAccess, string, string) error); ok { + r0 = rf(srcConfigAccess, srcCtxName, targetCtxName) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type ContextOps_RemoveContext struct { + *mock.Call +} + +func (_m ContextOps_RemoveContext) Return(_a0 error) *ContextOps_RemoveContext { + return &ContextOps_RemoveContext{Call: _m.Call.Return(_a0)} +} + +func (_m *ContextOps) OnRemoveContext(ctxName string) *ContextOps_RemoveContext { + c := _m.On("RemoveContext", ctxName) + return &ContextOps_RemoveContext{Call: c} +} + +func (_m *ContextOps) OnRemoveContextMatch(matchers ...interface{}) *ContextOps_RemoveContext { + c := _m.On("RemoveContext", matchers...) + return &ContextOps_RemoveContext{Call: c} +} + +// RemoveContext provides a mock function with given fields: ctxName +func (_m *ContextOps) RemoveContext(ctxName string) error { + ret := _m.Called(ctxName) + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(ctxName) + } else { + r0 = ret.Error(0) + } + + return r0 +}