Skip to content

Commit

Permalink
Passing security context only if its non empty (flyteorg#300)
Browse files Browse the repository at this point in the history
* Rebase and fixes tests

Signed-off-by: Prafulla Mahindrakar <[email protected]>

* updating to setup-go@v3

Signed-off-by: Prafulla Mahindrakar <[email protected]>

* Fixing the tests for go 1.18 by adding mocking for k8s copy context

Signed-off-by: Prafulla Mahindrakar <[email protected]>

* fixes

Signed-off-by: Prafulla Mahindrakar <[email protected]>

* test fixes

Signed-off-by: Prafulla Mahindrakar <[email protected]>

* Fixes

Signed-off-by: Prafulla Mahindrakar <[email protected]>

* Fixes

Signed-off-by: Prafulla Mahindrakar <[email protected]>
  • Loading branch information
pmahindrakar-oss authored Apr 5, 2022
1 parent 49a1122 commit 59340ae
Show file tree
Hide file tree
Showing 14 changed files with 317 additions and 141 deletions.
2 changes: 1 addition & 1 deletion flytectl/.github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
~/.cache/go-build
~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('go.sum') }}
- uses: actions/setup-go@v2
- uses: actions/setup-go@v3
with:
go-version: '1.17'
- name: Run GoReleaser dry run
Expand Down
8 changes: 4 additions & 4 deletions flytectl/cmd/create/execution.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,15 @@ func createExecutionCommand(ctx context.Context, args []string, cmdCtx cmdCore.C
var executionRequest *admin.ExecutionCreateRequest
switch execParams.execType {
case Relaunch:
return relaunchExecution(ctx, execParams.name, sourceProject, sourceDomain, cmdCtx)
return relaunchExecution(ctx, execParams.name, sourceProject, sourceDomain, cmdCtx, executionConfig)
case Recover:
return recoverExecution(ctx, execParams.name, sourceProject, sourceDomain, cmdCtx)
return recoverExecution(ctx, execParams.name, sourceProject, sourceDomain, cmdCtx, executionConfig)
case Task:
if executionRequest, err = createExecutionRequestForTask(ctx, execParams.name, sourceProject, sourceDomain, cmdCtx); err != nil {
if executionRequest, err = createExecutionRequestForTask(ctx, execParams.name, sourceProject, sourceDomain, cmdCtx, executionConfig); err != nil {
return err
}
case Workflow:
if executionRequest, err = createExecutionRequestForWorkflow(ctx, execParams.name, sourceProject, sourceDomain, cmdCtx); err != nil {
if executionRequest, err = createExecutionRequestForWorkflow(ctx, execParams.name, sourceProject, sourceDomain, cmdCtx, executionConfig); err != nil {
return err
}
default:
Expand Down
56 changes: 32 additions & 24 deletions flytectl/cmd/create/execution_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
)

func createExecutionRequestForWorkflow(ctx context.Context, workflowName, project, domain string,
cmdCtx cmdCore.CommandContext) (*admin.ExecutionCreateRequest, error) {
cmdCtx cmdCore.CommandContext, executionConfig *ExecutionConfig) (*admin.ExecutionCreateRequest, error) {
// Fetch the launch plan
lp, err := cmdCtx.AdminFetcherExt().FetchLPVersion(ctx, workflowName, executionConfig.Version, project, domain)
if err != nil {
Expand All @@ -35,23 +35,27 @@ func createExecutionRequestForWorkflow(ctx context.Context, workflowName, projec
}

// Set both deprecated field and new field for security identity passing
authRole := &admin.AuthRole{
KubernetesServiceAccount: executionConfig.KubeServiceAcct,
AssumableIamRole: executionConfig.IamRoleARN,
}

securityContext := &core.SecurityContext{
RunAs: &core.Identity{
K8SServiceAccount: executionConfig.KubeServiceAcct,
IamRole: executionConfig.IamRoleARN,
},
var securityContext *core.SecurityContext
var authRole *admin.AuthRole

if len(executionConfig.KubeServiceAcct) > 0 || len(executionConfig.IamRoleARN) > 0 {
authRole = &admin.AuthRole{
KubernetesServiceAccount: executionConfig.KubeServiceAcct,
AssumableIamRole: executionConfig.IamRoleARN,
}
securityContext = &core.SecurityContext{
RunAs: &core.Identity{
K8SServiceAccount: executionConfig.KubeServiceAcct,
IamRole: executionConfig.IamRoleARN,
},
}
}

return createExecutionRequest(lp.Id, inputs, securityContext, authRole), nil
}

func createExecutionRequestForTask(ctx context.Context, taskName string, project string, domain string,
cmdCtx cmdCore.CommandContext) (*admin.ExecutionCreateRequest, error) {
cmdCtx cmdCore.CommandContext, executionConfig *ExecutionConfig) (*admin.ExecutionCreateRequest, error) {
// Fetch the task
task, err := cmdCtx.AdminFetcherExt().FetchTaskVersion(ctx, taskName, executionConfig.Version, project, domain)
if err != nil {
Expand All @@ -69,16 +73,20 @@ func createExecutionRequestForTask(ctx context.Context, taskName string, project
}

// Set both deprecated field and new field for security identity passing
authRole := &admin.AuthRole{
KubernetesServiceAccount: executionConfig.KubeServiceAcct,
AssumableIamRole: executionConfig.IamRoleARN,
}

securityContext := &core.SecurityContext{
RunAs: &core.Identity{
K8SServiceAccount: executionConfig.KubeServiceAcct,
IamRole: executionConfig.IamRoleARN,
},
var securityContext *core.SecurityContext
var authRole *admin.AuthRole

if len(executionConfig.KubeServiceAcct) > 0 || len(executionConfig.IamRoleARN) > 0 {
authRole = &admin.AuthRole{
KubernetesServiceAccount: executionConfig.KubeServiceAcct,
AssumableIamRole: executionConfig.IamRoleARN,
}
securityContext = &core.SecurityContext{
RunAs: &core.Identity{
K8SServiceAccount: executionConfig.KubeServiceAcct,
IamRole: executionConfig.IamRoleARN,
},
}
}

id := &core.Identifier{
Expand All @@ -93,7 +101,7 @@ func createExecutionRequestForTask(ctx context.Context, taskName string, project
}

func relaunchExecution(ctx context.Context, executionName string, project string, domain string,
cmdCtx cmdCore.CommandContext) error {
cmdCtx cmdCore.CommandContext, executionConfig *ExecutionConfig) error {
if executionConfig.DryRun {
logger.Debugf(ctx, "skipping RelaunchExecution request (DryRun)")
return nil
Expand All @@ -113,7 +121,7 @@ func relaunchExecution(ctx context.Context, executionName string, project string
}

func recoverExecution(ctx context.Context, executionName string, project string, domain string,
cmdCtx cmdCore.CommandContext) error {
cmdCtx cmdCore.CommandContext, executionConfig *ExecutionConfig) error {
if executionConfig.DryRun {
logger.Debugf(ctx, "skipping RecoverExecution request (DryRun)")
return nil
Expand Down
125 changes: 121 additions & 4 deletions flytectl/cmd/create/execution_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ package create

import (
"errors"
"fmt"
"testing"

"github.com/flyteorg/flytectl/cmd/config"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

var (
Expand Down Expand Up @@ -40,21 +42,23 @@ func createExecutionUtilSetup() {
Domain: config.GetConfig().Domain,
},
}
executionConfig = &ExecutionConfig{}
}

func TestCreateExecutionForRelaunch(t *testing.T) {
s := setup()
createExecutionUtilSetup()
s.MockAdminClient.OnRelaunchExecutionMatch(s.Ctx, relaunchRequest).Return(executionCreateResponse, nil)
err := relaunchExecution(s.Ctx, "execName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx)
err := relaunchExecution(s.Ctx, "execName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx, executionConfig)
assert.Nil(t, err)
}

func TestCreateExecutionForRelaunchNotFound(t *testing.T) {
s := setup()
createExecutionUtilSetup()
s.MockAdminClient.OnRelaunchExecutionMatch(s.Ctx, relaunchRequest).Return(nil, errors.New("unknown execution"))
err := relaunchExecution(s.Ctx, "execName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx)
err := relaunchExecution(s.Ctx, "execName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx, executionConfig)

assert.NotNil(t, err)
assert.Equal(t, err, errors.New("unknown execution"))
}
Expand All @@ -63,15 +67,128 @@ func TestCreateExecutionForRecovery(t *testing.T) {
s := setup()
createExecutionUtilSetup()
s.MockAdminClient.OnRecoverExecutionMatch(s.Ctx, recoverRequest).Return(executionCreateResponse, nil)
err := recoverExecution(s.Ctx, "execName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx)
err := recoverExecution(s.Ctx, "execName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx, executionConfig)
assert.Nil(t, err)
}

func TestCreateExecutionForRecoveryNotFound(t *testing.T) {
s := setup()
createExecutionUtilSetup()
s.MockAdminClient.OnRecoverExecutionMatch(s.Ctx, recoverRequest).Return(nil, errors.New("unknown execution"))
err := recoverExecution(s.Ctx, "execName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx)
err := recoverExecution(s.Ctx, "execName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx, executionConfig)
assert.NotNil(t, err)
assert.Equal(t, err, errors.New("unknown execution"))
}

func TestCreateExecutionRequestForWorkflow(t *testing.T) {
t.Run("successful", func(t *testing.T) {
s := setup()
createExecutionUtilSetup()
launchPlan := &admin.LaunchPlan{}
s.FetcherExt.OnFetchLPVersionMatch(s.Ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(launchPlan, nil)
execCreateRequest, err := createExecutionRequestForWorkflow(s.Ctx, "wfName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx, executionConfig)
assert.Nil(t, err)
assert.NotNil(t, execCreateRequest)
})
t.Run("failed literal conversion", func(t *testing.T) {
s := setup()
createExecutionUtilSetup()
launchPlan := &admin.LaunchPlan{
Spec: &admin.LaunchPlanSpec{
DefaultInputs: &core.ParameterMap{
Parameters: map[string]*core.Parameter{"nilparam": nil},
},
},
}
s.FetcherExt.OnFetchLPVersionMatch(s.Ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(launchPlan, nil)
execCreateRequest, err := createExecutionRequestForWorkflow(s.Ctx, "wfName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx, executionConfig)
assert.NotNil(t, err)
assert.Nil(t, execCreateRequest)
assert.Equal(t, fmt.Errorf("parameter [nilparam] has nil Variable"), err)
})
t.Run("failed fetch", func(t *testing.T) {
s := setup()
createExecutionUtilSetup()
s.FetcherExt.OnFetchLPVersionMatch(s.Ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, fmt.Errorf("failed"))
execCreateRequest, err := createExecutionRequestForWorkflow(s.Ctx, "wfName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx, executionConfig)
assert.NotNil(t, err)
assert.Nil(t, execCreateRequest)
assert.Equal(t, err, errors.New("failed"))
})
t.Run("with security context", func(t *testing.T) {
s := setup()
createExecutionUtilSetup()
executionConfig.KubeServiceAcct = "default"
launchPlan := &admin.LaunchPlan{}
s.FetcherExt.OnFetchLPVersionMatch(s.Ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(launchPlan, nil)
s.MockAdminClient.OnGetLaunchPlanMatch(s.Ctx, mock.Anything).Return(launchPlan, nil)
execCreateRequest, err := createExecutionRequestForWorkflow(s.Ctx, "wfName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx, executionConfig)
assert.Nil(t, err)
assert.NotNil(t, execCreateRequest)
executionConfig.KubeServiceAcct = ""
})
}

func TestCreateExecutionRequestForTask(t *testing.T) {
t.Run("successful", func(t *testing.T) {
s := setup()
createExecutionUtilSetup()
task := &admin.Task{
Id: &core.Identifier{
Name: "taskName",
},
}
s.FetcherExt.OnFetchTaskVersionMatch(s.Ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(task, nil)
execCreateRequest, err := createExecutionRequestForTask(s.Ctx, "taskName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx, executionConfig)
assert.Nil(t, err)
assert.NotNil(t, execCreateRequest)
})
t.Run("failed literal conversion", func(t *testing.T) {
s := setup()
createExecutionUtilSetup()
task := &admin.Task{
Closure: &admin.TaskClosure{
CompiledTask: &core.CompiledTask{
Template: &core.TaskTemplate{
Interface: &core.TypedInterface{
Inputs: &core.VariableMap{
Variables: map[string]*core.Variable{
"nilvar": nil,
},
},
},
},
},
},
}
s.FetcherExt.OnFetchTaskVersionMatch(s.Ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(task, nil)
execCreateRequest, err := createExecutionRequestForTask(s.Ctx, "taskName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx, executionConfig)
assert.NotNil(t, err)
assert.Nil(t, execCreateRequest)
assert.Equal(t, fmt.Errorf("variable [nilvar] has nil type"), err)
})
t.Run("failed fetch", func(t *testing.T) {
s := setup()
createExecutionUtilSetup()
s.FetcherExt.OnFetchTaskVersionMatch(s.Ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, fmt.Errorf("failed"))
execCreateRequest, err := createExecutionRequestForTask(s.Ctx, "taskName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx, executionConfig)
assert.NotNil(t, err)
assert.Nil(t, execCreateRequest)
assert.Equal(t, err, errors.New("failed"))
})
t.Run("with security context", func(t *testing.T) {
s := setup()
createExecutionUtilSetup()
executionConfig.KubeServiceAcct = "default"
task := &admin.Task{
Id: &core.Identifier{
Name: "taskName",
},
}
s.FetcherExt.OnFetchTaskVersionMatch(s.Ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(task, nil)
execCreateRequest, err := createExecutionRequestForTask(s.Ctx, "taskName", config.GetConfig().Project, config.GetConfig().Domain, s.CmdCtx, executionConfig)
assert.Nil(t, err)
assert.NotNil(t, execCreateRequest)
executionConfig.KubeServiceAcct = ""
})
}
46 changes: 5 additions & 41 deletions flytectl/cmd/get/execution_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,6 @@ func getExecutionSetup() {
func TestListExecutionFunc(t *testing.T) {
getExecutionSetup()
s := setup()
ctx := s.Ctx
execListRequest := &admin.ResourceListRequest{
Limit: 100,
SortBy: &admin.Sort{
Key: "created_at",
Direction: admin.Sort_DESCENDING,
},
Id: &admin.NamedEntityIdentifier{
Project: projectValue,
Domain: domainValue,
},
}
executionResponse := &admin.Execution{
Id: &core.WorkflowExecutionIdentifier{
Project: projectValue,
Expand Down Expand Up @@ -72,26 +60,14 @@ func TestListExecutionFunc(t *testing.T) {
executionList := &admin.ExecutionList{
Executions: executions,
}
s.MockAdminClient.OnListExecutionsMatch(mock.Anything, execListRequest).Return(executionList, nil)
s.FetcherExt.OnListExecutionMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(executionList, nil)
err := getExecutionFunc(s.Ctx, []string{}, s.CmdCtx)
assert.Nil(t, err)
s.MockAdminClient.AssertCalled(t, "ListExecutions", ctx, execListRequest)
s.FetcherExt.AssertCalled(t, "ListExecution", s.Ctx, projectValue, domainValue, execution.DefaultConfig.Filter)
}

func TestListExecutionFuncWithError(t *testing.T) {
ctx := context.Background()
getExecutionSetup()
execListRequest := &admin.ResourceListRequest{
Limit: 100,
SortBy: &admin.Sort{
Key: "created_at",
},
Id: &admin.NamedEntityIdentifier{
Project: projectValue,
Domain: domainValue,
},
}

_ = &admin.Execution{
Id: &core.WorkflowExecutionIdentifier{
Project: projectValue,
Expand All @@ -118,23 +94,14 @@ func TestListExecutionFuncWithError(t *testing.T) {
}
s := setup()
s.FetcherExt.OnListExecutionMatch(s.Ctx, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("executions NotFound"))
s.MockAdminClient.OnListExecutionsMatch(mock.Anything, execListRequest).Return(nil, errors.New("executions NotFound"))
err := getExecutionFunc(s.Ctx, []string{}, s.CmdCtx)
assert.NotNil(t, err)
assert.Equal(t, err, errors.New("executions NotFound"))
s.MockAdminClient.AssertCalled(t, "ListExecutions", ctx, execListRequest)
s.FetcherExt.AssertCalled(t, "ListExecution", s.Ctx, projectValue, domainValue, execution.DefaultConfig.Filter)
}

func TestGetExecutionFunc(t *testing.T) {
ctx := context.Background()
getExecutionSetup()
execGetRequest := &admin.WorkflowExecutionGetRequest{
Id: &core.WorkflowExecutionIdentifier{
Project: projectValue,
Domain: domainValue,
Name: executionNameValue,
},
}
executionResponse := &admin.Execution{
Id: &core.WorkflowExecutionIdentifier{
Project: projectValue,
Expand All @@ -161,14 +128,11 @@ func TestGetExecutionFunc(t *testing.T) {
}
args := []string{executionNameValue}
s := setup()
//executionList := &admin.ExecutionList{
// Executions: []*admin.Execution{executionResponse},
//}
s.MockAdminClient.OnGetExecutionMatch(ctx, execGetRequest).Return(executionResponse, nil)

s.FetcherExt.OnFetchExecutionMatch(s.Ctx, mock.Anything, mock.Anything, mock.Anything).Return(executionResponse, nil)
err := getExecutionFunc(s.Ctx, args, s.CmdCtx)
assert.Nil(t, err)
s.MockAdminClient.AssertCalled(t, "GetExecution", ctx, execGetRequest)
s.FetcherExt.AssertCalled(t, "FetchExecution", s.Ctx, executionNameValue, projectValue, domainValue)
}

func TestGetExecutionFuncForDetails(t *testing.T) {
Expand Down
Loading

0 comments on commit 59340ae

Please sign in to comment.