From 04c6cba0403dad52c259971e3cf08bcaf6cb2da9 Mon Sep 17 00:00:00 2001 From: David Gamero Date: Wed, 20 Nov 2024 09:22:45 -0500 Subject: [PATCH] add selectors and azclient cleanup (#425) --- .github/workflows/integration-linux.yml | 2 - .gitignore | 1 + ...nerate-workflow.go => generateworkflow.go} | 2 +- cmd/setup-gh.go | 253 ------------- cmd/setup-gh_test.go | 28 -- cmd/setupgh.go | 352 ++++++++++++++++++ cmd/setupgh_test.go | 147 ++++++++ go.mod | 3 +- go.sum | 10 +- pkg/linguist/linguist_test.go | 2 +- pkg/prompts/prompts.go | 9 +- pkg/providers/az-client.go | 135 ++++++- pkg/providers/azcli.go | 16 +- pkg/providers/azure.go | 121 ++---- pkg/providers/azure_test.go | 139 ------- pkg/providers/commandrunner.go | 33 ++ pkg/providers/ghcli.go | 85 +++-- pkg/providers/ghcli_test.go | 8 +- 18 files changed, 775 insertions(+), 571 deletions(-) rename cmd/{generate-workflow.go => generateworkflow.go} (100%) delete mode 100644 cmd/setup-gh.go delete mode 100644 cmd/setup-gh_test.go create mode 100644 cmd/setupgh.go create mode 100644 cmd/setupgh_test.go delete mode 100644 pkg/providers/azure_test.go create mode 100644 pkg/providers/commandrunner.go diff --git a/.github/workflows/integration-linux.yml b/.github/workflows/integration-linux.yml index 23a3d5d0..16a968c5 100644 --- a/.github/workflows/integration-linux.yml +++ b/.github/workflows/integration-linux.yml @@ -444,8 +444,6 @@ jobs: npm install -g ajv-cli@5.0.0 ajv validate -s test/update_dry_run_schema.json -d test/temp/update_dry_run.json - run: ./draft -v update -d ./langtest/ -a webapp_routing --variable ingress-tls-cert-keyvault-uri=test.cert.keyvault.uri --variable ingress-use-osm-mtls=true --variable ingress-host=host1 - - name: print manifests - run: cat ./langtest/manifests/* - name: start minikube id: minikube uses: medyagh/setup-minikube@master diff --git a/.gitignore b/.gitignore index ef9c2602..e4a102d6 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ vendor/ /pkg/addons/addons/ .DS_Store draft +draft.exe langtest .vscode/ /pkg/languages/builders diff --git a/cmd/generate-workflow.go b/cmd/generateworkflow.go similarity index 100% rename from cmd/generate-workflow.go rename to cmd/generateworkflow.go index c47bb0b4..1aa7f1f6 100644 --- a/cmd/generate-workflow.go +++ b/cmd/generateworkflow.go @@ -8,10 +8,10 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "github.com/Azure/draft/pkg/cmdhelpers" "github.com/Azure/draft/pkg/handlers" "github.com/Azure/draft/pkg/prompts" "github.com/Azure/draft/pkg/templatewriter" - "github.com/Azure/draft/pkg/cmdhelpers" "github.com/Azure/draft/pkg/templatewriter/writers" ) diff --git a/cmd/setup-gh.go b/cmd/setup-gh.go deleted file mode 100644 index 8c0fdb44..00000000 --- a/cmd/setup-gh.go +++ /dev/null @@ -1,253 +0,0 @@ -package cmd - -import ( - "context" - "errors" - "fmt" - "strings" - - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v3" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/subscription/armsubscription" - "github.com/Azure/draft/pkg/cred" - "github.com/Azure/draft/pkg/prompts" - "github.com/manifoldco/promptui" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - - "github.com/Azure/draft/pkg/providers" - "github.com/Azure/draft/pkg/spinner" -) - -func newSetUpCmd() *cobra.Command { - sc := &providers.SetUpCmd{} - - // setup-ghCmd represents the setup-gh command - var cmd = &cobra.Command{ - Use: "setup-gh", - Short: "Automates the Github OIDC setup process", - Long: `This command will automate the Github OIDC setup process by creating an Azure Active Directory -application and service principle, and will configure that application to trust github.`, - RunE: func(cmd *cobra.Command, args []string) error { - ctx := cmd.Context() - - providers.EnsureAzCli() - providers.EnsureGhCli() - - azCred, err := cred.GetCred() - if err != nil { - return fmt.Errorf("getting credentials: %w", err) - } - - client, err := armsubscription.NewTenantsClient(azCred, nil) - if err != nil { - return fmt.Errorf("creating tenants client: %w", err) - } - - sc.AzClient.AzTenantClient = client - - roleAssignmentClient, err := armauthorization.NewRoleAssignmentsClient(sc.SubscriptionID, azCred, nil) - if err != nil { - return fmt.Errorf("getting role assignment client: %w", err) - } - - sc.AzClient.RoleAssignClient = roleAssignmentClient - - err = fillSetUpConfig(sc) - if err != nil { - return fmt.Errorf("filling setup config: %w", err) - } - - s := spinner.CreateSpinner("--> Setting up Github OIDC...") - s.Start() - err = runProviderSetUp(ctx, sc, s) - s.Stop() - if err != nil { - return err - } - - log.Info("Draft has successfully set up Github OIDC for your project 😃") - log.Info("Use 'draft generate-workflow' to generate a Github workflow to build and deploy an application on AKS.") - - return nil - }, - } - - f := cmd.Flags() - f.StringVarP(&sc.AppName, "app", "a", emptyDefaultFlagValue, "specify the Azure Active Directory application name") - f.StringVarP(&sc.SubscriptionID, "subscription-id", "s", emptyDefaultFlagValue, "specify the Azure subscription ID") - f.StringVarP(&sc.ResourceGroupName, "resource-group", "r", emptyDefaultFlagValue, "specify the Azure resource group name") - f.StringVarP(&sc.Repo, "gh-repo", "g", emptyDefaultFlagValue, "specify the github repository link") - sc.Provider = provider - return cmd -} - -func fillSetUpConfig(sc *providers.SetUpCmd) error { - if sc.AppName == "" { - sc.AppName = getAppName() - } - - if sc.SubscriptionID == "" { - if strings.ToLower(sc.Provider) == "azure" { - currentSub, err := providers.GetCurrentAzSubscriptionLabel() - if err != nil { - return fmt.Errorf("getting current subscription ID: %w", err) - } - - subLabels, err := providers.GetAzSubscriptionLabels() - if err != nil { - return fmt.Errorf("getting subscription labels: %w", err) - } - - sc.SubscriptionID, err = getAzSubscriptionId(subLabels, currentSub) - if err != nil { - return fmt.Errorf("getting subscription ID: %w", err) - } - } else { - sc.SubscriptionID = getSubscriptionID() - } - } - - if sc.ResourceGroupName == "" { - sc.ResourceGroupName = getResourceGroup() - } - - if sc.Repo == "" { - sc.Repo = getGhRepo() - } - - return nil -} - -func runProviderSetUp(ctx context.Context, sc *providers.SetUpCmd, s spinner.Spinner) error { - provider := strings.ToLower(sc.Provider) - if provider == "azure" { - // call azure provider logic - return providers.InitiateAzureOIDCFlow(ctx, sc, s) - - } else { - // call logic for user-submitted provider - fmt.Printf("The provider is %v\n", sc.Provider) - } - - return nil -} - -func getAppName() string { - validate := func(input string) error { - if input == "" { - return errors.New("Invalid app name") - } - return nil - } - - prompt := promptui.Prompt{ - Label: "Enter app registration name", - Validate: validate, - } - - result, err := prompt.Run() - - if err != nil { - return err.Error() - } - - return result -} - -func getSubscriptionID() string { - validate := func(input string) error { - if input == "" { - return errors.New("Invalid subscription id") - } - return nil - } - - prompt := promptui.Prompt{ - Label: "Enter subscription ID", - Validate: validate, - } - - result, err := prompt.Run() - - if err != nil { - return err.Error() - } - - return result -} - -func getResourceGroup() string { - validate := func(input string) error { - if input == "" { - return errors.New("Invalid resource group name") - } - return nil - } - - prompt := promptui.Prompt{ - Label: "Enter resource group name", - Validate: validate, - } - - result, err := prompt.Run() - - if err != nil { - return err.Error() - } - - return result -} - -func getGhRepo() string { - validate := func(input string) error { - if !strings.Contains(input, "/") { - return errors.New("Github repo cannot be empty") - } - - return nil - } - - repoPrompt := promptui.Prompt{ - Label: "Enter github organization and repo (organization/repoName)", - Validate: validate, - } - - repo, err := repoPrompt.Run() - if err != nil { - return err.Error() - } - - return repo -} - -func getCloudProvider() string { - selection := &promptui.Select{ - Label: "What cloud provider would you like to use?", - Items: []string{"azure"}, - } - - _, selectResponse, err := selection.Run() - if err != nil { - return err.Error() - } - - return selectResponse -} - -func getAzSubscriptionId(subLabels []providers.SubLabel, currentSub providers.SubLabel) (string, error) { - subLabel, err := prompts.Select("Please choose the subscription ID you would like to use", subLabels, &prompts.SelectOpt[providers.SubLabel]{ - Field: func(subLabel providers.SubLabel) string { - return subLabel.Name + " (" + subLabel.ID + ")" - }, - Default: ¤tSub, - }) - if err != nil { - return "", fmt.Errorf("selecting subscription ID: %w", err) - } - - return subLabel.ID, nil -} - -func init() { - rootCmd.AddCommand(newSetUpCmd()) -} diff --git a/cmd/setup-gh_test.go b/cmd/setup-gh_test.go deleted file mode 100644 index 9d2ac247..00000000 --- a/cmd/setup-gh_test.go +++ /dev/null @@ -1,28 +0,0 @@ -package cmd - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/Azure/draft/pkg/providers" - "github.com/Azure/draft/pkg/spinner" -) - -func TestSetUpConfig(t *testing.T) { - ctx := context.Background() - mockSetUpCmd := &providers.SetUpCmd{} - mockSetUpCmd.AppName = "testingSetUpCommand" - mockSetUpCmd.Provider = "Google" - mockSetUpCmd.Repo = "test/repo" - mockSetUpCmd.ResourceGroupName = "testResourceGroup" - mockSetUpCmd.SubscriptionID = "123456789" - s := spinner.CreateSpinner("--> Setting up Github OIDC...") - - fillSetUpConfig(mockSetUpCmd) - - err := runProviderSetUp(ctx, mockSetUpCmd, s) - - assert.True(t, err == nil) -} diff --git a/cmd/setupgh.go b/cmd/setupgh.go new file mode 100644 index 00000000..44f16ccd --- /dev/null +++ b/cmd/setupgh.go @@ -0,0 +1,352 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "unicode" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources" + "github.com/Azure/draft/pkg/cred" + "github.com/Azure/draft/pkg/prompts" + "github.com/manifoldco/promptui" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + "github.com/Azure/draft/pkg/providers" + "github.com/Azure/draft/pkg/spinner" + "k8s.io/apimachinery/pkg/util/validation" +) + +func newSetUpCmd() *cobra.Command { + sc := &providers.SetUpCmd{} + + // setup-ghCmd represents the setup-gh command + var cmd = &cobra.Command{ + Use: "setup-gh", + Short: "Automates the Github OIDC setup process", + Long: `This command will automate the Github OIDC setup process by creating an Azure Active Directory +application and service principle, and will configure that application to trust github.`, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + + gh := providers.NewGhClient() + providers.EnsureAzCli() + + azCred, err := cred.GetCred() + if err != nil { + return fmt.Errorf("getting credentials: %w", err) + } + az, err := providers.NewAzClient(azCred) + if err != nil { + return fmt.Errorf("creating azure client: %w", err) + } + sc.AzClient = az + + err = fillSetUpConfig(sc, gh) + if err != nil { + return fmt.Errorf("filling setup config: %w", err) + } + + s := spinner.CreateSpinner("--> Setting up Github OIDC...") + s.Start() + err = runProviderSetUp(ctx, sc, s, gh) + s.Stop() + if err != nil { + return err + } + + log.Info("Draft has successfully set up Github OIDC for your project 😃") + log.Info("Use 'draft generate-workflow' to generate a Github workflow to build and deploy an application on AKS.") + + return nil + }, + } + + f := cmd.Flags() + f.StringVarP(&sc.AppName, "app", "a", emptyDefaultFlagValue, "specify the Azure Active Directory application name") + f.StringVarP(&sc.SubscriptionID, "subscription-id", "s", emptyDefaultFlagValue, "specify the Azure subscription ID") + f.StringVarP(&sc.ResourceGroupName, "resource-group", "r", emptyDefaultFlagValue, "specify the Azure resource group name") + f.StringVarP(&sc.Repo, "gh-repo", "g", emptyDefaultFlagValue, "specify the github repository link") + sc.Provider = provider + return cmd +} + +func fillSetUpConfig(sc *providers.SetUpCmd, gh providers.GhClient) error { + if sc.TenantId == "" { + tenandId, err := providers.PromptTenantId(sc.AzClient, context.Background()) + if err != nil { + return fmt.Errorf("prompting tenant ID: %w", err) + } + sc.TenantId = tenandId + } + + if sc.AppName == "" { + // Set the application name; default is the current directory name plus "-workflow". + + // get the current directory name + currentDir, err := os.Getwd() + if err != nil { + return fmt.Errorf("getting current directory: %w", err) + } + defaultAppName := fmt.Sprintf("%s-workflow", filepath.Base(currentDir)) + defaultAppName, err = toValidAppName(defaultAppName) + if err != nil { + log.Debugf("unable to convert default app name %q to a valid name: %v", defaultAppName, err) + log.Debugf("using default app name %q", defaultAppName) + defaultAppName = "my-workflow" + } + + appName, err := PromptAppName(sc.AzClient, defaultAppName) + if err != nil { + return fmt.Errorf("prompting app name: %w", err) + } + sc.AppName = appName + } + + if sc.SubscriptionID == "" { + if strings.ToLower(sc.Provider) == "azure" { + currentSub, err := providers.GetCurrentAzSubscriptionLabel() + if err != nil { + return fmt.Errorf("getting current subscription ID: %w", err) + } + + subLabels, err := providers.GetAzSubscriptionLabels() + if err != nil { + return fmt.Errorf("getting subscription labels: %w", err) + } + + sc.SubscriptionID, err = getAzSubscriptionId(subLabels, currentSub) + if err != nil { + return fmt.Errorf("getting subscription ID: %w", err) + } + } else { + sc.SubscriptionID = getSubscriptionID() + } + } + + if sc.ResourceGroupName == "" { + rg, err := PromptResourceGroup(sc.AzClient, sc.SubscriptionID) + if err != nil { + return fmt.Errorf("getting resource group: %w", err) + } + sc.ResourceGroupName = *rg.Name + } + + if sc.Repo == "" { + repo, err := PromptGitHubRepoWithOwner(gh) + if err != nil { + return fmt.Errorf("failed to prompt for GitHub repository: %w", err) + } + if repo == "" { + return errors.New("github repo cannot be empty") + } + sc.Repo = repo + } + + return nil +} + +func toValidAppName(name string) (string, error) { + // replace all underscores with hyphens + cleanedName := strings.ReplaceAll(name, "_", "-") + // replace all spaces with hyphens + cleanedName = strings.ReplaceAll(cleanedName, " ", "-") + + // remove leading non-alphanumeric characters + for i, r := range cleanedName { + if unicode.IsLetter(r) || unicode.IsNumber(r) { + cleanedName = cleanedName[i:] + break + } + } + + // remove trailing non-alphanumeric characters + for i := len(cleanedName) - 1; i >= 0; i-- { + r := rune(cleanedName[i]) + if unicode.IsLetter(r) || unicode.IsNumber(r) { + cleanedName = cleanedName[:i+1] + break + } + } + + // remove all characters except alphanumeric, '-', '.' + var builder strings.Builder + for _, r := range cleanedName { + if unicode.IsLetter(r) || unicode.IsNumber(r) || r == '-' { + builder.WriteRune(r) + } + } + + // lowercase the name + cleanedName = strings.ToLower(builder.String()) + if err := ValidateAppName(cleanedName);err != nil { + return "", fmt.Errorf("app name '%s' could not be converted to a valid name: %w", name, err) + } + return cleanedName, nil +} + +func runProviderSetUp(ctx context.Context, sc *providers.SetUpCmd, s spinner.Spinner, gh providers.GhClient) error { + provider := strings.ToLower(sc.Provider) + if provider == "azure" { + // call azure provider logic + return providers.InitiateAzureOIDCFlow(ctx, sc, s, gh) + + } else { + // call logic for user-submitted provider + fmt.Printf("The provider is %v\n", sc.Provider) + } + + return nil +} + +func ValidateAppName(name string) error { + errors := validation.IsDNS1123Label(name) + if len(errors) > 0 { + return fmt.Errorf("invalid app name: %s", strings.Join(errors, ", ")) + } + return nil +} + +func PromptAppName(az providers.AzClientInterface, defaultAppName string) (string, error) { + appNamePrompt := &promptui.Prompt{ + Label: "Enter app registration name", + Validate: ValidateAppName, + Default: defaultAppName, + } + appName, err := appNamePrompt.Run() + + if err != nil { + return "", err + } + + if providers.AzAppExists(appName) { + confirmAppExistsPrompt := promptui.Prompt{ + Label: "An app with this name already exists. Would you like to use it?", + IsConfirm: true, + } + _, err := confirmAppExistsPrompt.Run() + if err != nil { + return PromptAppName(az, defaultAppName) + } + } else { + log.Debugf("App %q does not exist, will be created", appName) + } + + return appName, nil +} + +func getSubscriptionID() string { + validate := func(input string) error { + if input == "" { + return errors.New("invalid subscription id") + } + return nil + } + + prompt := promptui.Prompt{ + Label: "Enter subscription ID", + Validate: validate, + } + + result, err := prompt.Run() + + if err != nil { + return err.Error() + } + + return result +} + +func PromptResourceGroup(az providers.AzClientInterface, subscriptionID string) (armresources.ResourceGroup, error) { + var rg armresources.ResourceGroup + log.Println("Fetching resource groups...") + rgs, err := az.ListResourceGroups(context.Background(), subscriptionID) + if err != nil { + return rg, fmt.Errorf("listing resource groups: %w", err) + } + + rg, err = prompts.Select("Please choose the resource group you would like to use", rgs, &prompts.SelectOpt[armresources.ResourceGroup]{ + Field: func(rg armresources.ResourceGroup) string { + return *rg.Name + " (" + *rg.Location + ")" + }, + }) + if err != nil { + return rg, fmt.Errorf("selecting resource group: %w", err) + } + + return rg, nil +} + +func PromptGitHubRepoWithOwner(gh providers.GhClient) (string, error) { + defaultRepoNameWithOwner, err := gh.GetRepoNameWithOwner() + if err != nil { + return "", err + } + log.Println("Prompting for github repo with owner name...") + repoPrompt := promptui.Prompt{ + Label: "Enter github organization and repo organization and repoName", + Validate: func(input string) error { + if !strings.Contains(input, "/") { + return errors.New("github repo cannot be empty") + } + return nil + }, + Default: defaultRepoNameWithOwner, + } + + repo, err := repoPrompt.Run() + if err != nil { + return "", fmt.Errorf("running repo name with owner prompt: %w", err) + } + + log.Debug("Validating github repo...") + if err := gh.IsValidGhRepo(repo); err != nil { + confirmMissingRepoPrompt := promptui.Prompt{ + Label: "Unable to confirm this repo exists. Do you want to proceed anyway?", + IsConfirm: true, + } + _, err := confirmMissingRepoPrompt.Run() + if err != nil { + return PromptGitHubRepoWithOwner(gh) + } + } else { + log.Debugf("Github repo %q is valid", repo) + } + return repo, nil +} + +func getCloudProvider() string { + selection := &promptui.Select{ + Label: "What cloud provider would you like to use?", + Items: []string{"azure"}, + } + + _, selectResponse, err := selection.Run() + if err != nil { + return err.Error() + } + + return selectResponse +} + +func getAzSubscriptionId(subLabels []providers.SubLabel, currentSub providers.SubLabel) (string, error) { + subLabel, err := prompts.Select("Please choose the subscription ID you would like to use", subLabels, &prompts.SelectOpt[providers.SubLabel]{ + Field: func(subLabel providers.SubLabel) string { + return subLabel.Name + " (" + subLabel.ID + ")" + }, + Default: ¤tSub, + }) + if err != nil { + return "", fmt.Errorf("selecting subscription ID: %w", err) + } + + return subLabel.ID, nil +} + +func init() { + rootCmd.AddCommand(newSetUpCmd()) +} diff --git a/cmd/setupgh_test.go b/cmd/setupgh_test.go new file mode 100644 index 00000000..d049ae4e --- /dev/null +++ b/cmd/setupgh_test.go @@ -0,0 +1,147 @@ +package cmd + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/Azure/draft/pkg/providers" + "github.com/Azure/draft/pkg/spinner" +) + +func TestSetUpConfig(t *testing.T) { + ctx := context.Background() + mockSetUpCmd := &providers.SetUpCmd{} + mockSetUpCmd.AppName = "testingSetUpCommand" + mockSetUpCmd.Provider = "Google" + mockSetUpCmd.Repo = "test/repo" + mockSetUpCmd.ResourceGroupName = "testResourceGroup" + mockSetUpCmd.SubscriptionID = "123456789" + mockSetUpCmd.TenantId = "123456789" + s := spinner.CreateSpinner("--> Setting up Github OIDC...") + + gh := &providers.GhCliClient{} + fillSetUpConfig(mockSetUpCmd, gh) + + err := runProviderSetUp(ctx, mockSetUpCmd, s, gh) + + assert.True(t, err == nil) +} + +func TestToValidAppName(t *testing.T) { + testCases := []struct { + testCaseName string + nameInput string + expected string + shouldError bool + }{ + { + testCaseName: "valid name", + nameInput: "valid-name", + expected: "valid-name", + }, + { + testCaseName: "name with spaces", + nameInput: "name with spaces", + expected: "name-with-spaces", + }, + { + testCaseName: "name with special characters", + nameInput: "name!@#$%^&*()", + expected: "name", + }, + { + testCaseName: "cannot start with a period", + nameInput: ".name", + expected: "name", + }, + { + testCaseName: "cannot start or end with hyphen", + nameInput: "----name--", + expected: "name", + }, + { + testCaseName: "name that can't be made valid", + expected: ".**(__)", + shouldError: true, + }, + { + testCaseName: "hypens allowed in the middle", + nameInput: "name-name-name-name", + expected: "name-name-name-name", + }, { + testCaseName: "remove dots in the middle", + nameInput: "name.name-name-name", + expected: "namename-name-name", + }, + { + testCaseName: "no capital letters", + nameInput: "NaMe", + expected: "name", + }, + } + + for _, tc := range testCases { + t.Run(tc.testCaseName, func(t *testing.T) { + result, err := toValidAppName(tc.nameInput) + if tc.shouldError { + assert.Error(t, err) + } else { + assert.Equal(t, tc.expected, result) + if err != nil { + t.Errorf("expected no error, got %s", err) + } + } + }) + } + +} + +func TestValidateAppName(t *testing.T) { + cases := []struct { + name string + input string + expectedError bool + }{ + { + name: "valid name", + input: "valid-name", + }, + { + name: "name with spaces", + input: "name with spaces", + expectedError: true, + }, + { + name: "name with special characters", + input: "name!@#$%^&*()", + expectedError: true, + }, + { + name: "cannot start with a period", + input: ".name", + expectedError: true, + }, + { + name: "cannot start or end with hyphen", + input: "----name--", + expectedError: true, + }, + { + name: "cannot end with a period", + input: "name-1-a.", + expectedError: true, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := ValidateAppName(tc.input) + if tc.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/go.mod b/go.mod index 8d8493e0..455d05c9 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v3 v3.0.0-beta.2 + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources v1.2.0 github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/subscription/armsubscription v1.2.0 github.com/blang/semver/v4 v4.0.0 github.com/briandowns/spinner v1.23.1 @@ -24,7 +25,6 @@ require ( github.com/stretchr/testify v1.9.0 github.com/yannh/kubeconform v0.6.7 go.uber.org/mock v0.4.0 - golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f golang.org/x/mod v0.20.0 gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v3 v3.0.1 @@ -142,6 +142,7 @@ require ( go.opentelemetry.io/proto/otlp v1.0.0 // indirect go.starlark.net v0.0.0-20230525235612-a134d8f9ddca // indirect golang.org/x/crypto v0.28.0 // indirect + golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f // indirect golang.org/x/net v0.30.0 // indirect golang.org/x/oauth2 v0.19.0 // indirect golang.org/x/sync v0.8.0 // indirect diff --git a/go.sum b/go.sum index 2c32cd87..63dd432d 100644 --- a/go.sum +++ b/go.sum @@ -9,6 +9,12 @@ github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 h1:ywEEhmNahHBihViHepv3xP github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkYRNWPENUnqx6bJ2xnSDFI2tjwZNuY= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v3 v3.0.0-beta.2 h1:qiir/pptnHqp6hV8QwV+IExYIf6cPsXBfUDUXQ27t2Y= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v3 v3.0.0-beta.2/go.mod h1:jVRrRDLCOuif95HDYC23ADTMlvahB7tMdl519m9Iyjc= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v2 v2.0.0 h1:PTFGRSlMKCQelWwxUyYVEUqseBJVemLyqWJjvMyt0do= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v2 v2.0.0/go.mod h1:LRr2FzBTQlONPPa5HREE5+RjSCTXl7BwOvYOaWTqCaI= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/managementgroups/armmanagementgroups v1.0.0 h1:pPvTJ1dY0sA35JOeFq6TsY2xj6Z85Yo23Pj4wCCvu4o= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/managementgroups/armmanagementgroups v1.0.0/go.mod h1:mLfWfj8v3jfWKsL9G4eoBoXVcsqcIUTapmdKy7uGOp0= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources v1.2.0 h1:Dd+RhdJn0OTtVGaeDLZpcumkIVCtA/3/Fo42+eoYvVM= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources v1.2.0/go.mod h1:5kakwfW5CjC9KK+Q4wjXAg+ShuIm2mBMua0ZFj2C8PE= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/subscription/armsubscription v1.2.0 h1:UrGzkHueDwAWDdjQxC+QaXHd4tVCkISYE9j7fSSXF8k= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/subscription/armsubscription v1.2.0/go.mod h1:qskvSQeW+cxEE2bcKYyKimB1/KiQ9xpJ99bcHY0BX6c= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8= @@ -507,8 +513,6 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -544,8 +548,6 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= diff --git a/pkg/linguist/linguist_test.go b/pkg/linguist/linguist_test.go index 8c1e74f1..cecc8e12 100644 --- a/pkg/linguist/linguist_test.go +++ b/pkg/linguist/linguist_test.go @@ -60,7 +60,7 @@ func TestGitAttributes(t *testing.T) { } } -//TestDirectoryIsIgnored checks to see if directory paths such as 'docs/' are ignored from being classified by linguist when added to the "ignore" list. +// TestDirectoryIsIgnored checks to see if directory paths such as 'docs/' are ignored from being classified by linguist when added to the "ignore" list. func TestDirectoryIsIgnored(t *testing.T) { path := filepath.Join("testdirs", "app-documentation") // populate isIgnored diff --git a/pkg/prompts/prompts.go b/pkg/prompts/prompts.go index f164bf94..39f8385e 100644 --- a/pkg/prompts/prompts.go +++ b/pkg/prompts/prompts.go @@ -250,7 +250,14 @@ func Select[T any](label string, items []T, opt *SelectOpt[T]) (T, error) { selection := strings.ToLower(str) search = strings.ToLower(search) - return strings.Contains(selection, search) + searchWords := strings.Split(search, " ") + + for _, word := range searchWords { + if !strings.Contains(selection, word) { + return false + } + } + return true } // sort the default selection to top if exists diff --git a/pkg/providers/az-client.go b/pkg/providers/az-client.go index b20f9c28..4e14d219 100644 --- a/pkg/providers/az-client.go +++ b/pkg/providers/az-client.go @@ -1,32 +1,139 @@ package providers import ( + "context" + "errors" "fmt" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v3" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/subscription/armsubscription" - "github.com/Azure/draft/pkg/cred" + "github.com/google/uuid" + log "github.com/sirupsen/logrus" ) +//go:generate mockgen -source=./az-client.go -destination=./mock/az-client.go . +type AzClientInterface interface { + ListResourceGroups(ctx context.Context, subscriptionID string) ([]armresources.ResourceGroup, error) + ListTenants(ctx context.Context) ([]armsubscription.TenantIDDescription, error) + assignSpRole(ctx context.Context, subscriptionId, resourceGroup, servicePrincipalObjectID, roleId string) error +} + +// assert AzClient implements AzClientInterface +var _ AzClientInterface = &AzClient{} + +// AzClient is a struct that contains the Azure client and its dependencies +// It is used to interact with Azure resources +// Create a new AzClient with NewAzClient type AzClient struct { - AzTenantClient azTenantClient - RoleAssignClient *armauthorization.RoleAssignmentsClient + Credential *azidentity.DefaultAzureCredential + TenantClient *armsubscription.TenantsClient + RoleAssignClient *armauthorization.RoleAssignmentsClient + ResourceGroupClient *armresources.ResourceGroupsClient } -//go:generate mockgen -source=./az-client.go -destination=./mock/az-client.go . -type azTenantClient interface { - NewListPager(options *armsubscription.TenantsClientListOptions) *runtime.Pager[armsubscription.TenantsClientListResponse] +func NewAzClient(cred *azidentity.DefaultAzureCredential) (*AzClient, error) { + azClient := &AzClient{ + Credential: cred, + } + return azClient, nil } -func createRoleAssignmentClient(subscriptionId string) (*armauthorization.RoleAssignmentsClient, error) { - credential, err := cred.GetCred() - if err != nil { - return nil, fmt.Errorf("failed to get credentials: %w", err) +func (az *AzClient) ListResourceGroups(ctx context.Context, subscriptionID string) ([]armresources.ResourceGroup, error) { + log.Debug("listing Azure resource groups for subscription ", subscriptionID) + if az.ResourceGroupClient == nil { + c, err := armresources.NewResourceGroupsClient(subscriptionID, az.Credential, nil) + if err != nil { + return nil, fmt.Errorf("failed to create resource group client: %w", err) + } + az.ResourceGroupClient = c + } + + var rgs []armresources.ResourceGroup + pager := az.ResourceGroupClient.NewListPager(nil) + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, fmt.Errorf("listing resource groups page: %w", err) + } + + for _, rg := range page.Value { + if rg == nil { + return nil, errors.New("nil rg") + } + + rgs = append(rgs, *rg) + } + } + + return rgs, nil +} + +func (az *AzClient) ListTenants(ctx context.Context) ([]armsubscription.TenantIDDescription, error) { + log.Debug("Starting to list Azure Tenants") + + // Initialize the tenant slice to store the results. + tenants := make([]armsubscription.TenantIDDescription, 0) + + if az.TenantClient == nil { + c, err := armsubscription.NewTenantsClient(az.Credential, nil) + if err != nil { + return nil, fmt.Errorf("failed to create tenant client: %w", err) + } + az.TenantClient = c + } + pager := az.TenantClient.NewListPager(nil) + + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, fmt.Errorf("listing tenants page: %w", err) + } + + for _, t := range page.Value { + if t == nil { + return nil, errors.New("nil tenant") // this should never happen but it's good to check just in case + } + tenants = append(tenants, *t) + } } - client, err := armauthorization.NewRoleAssignmentsClient(subscriptionId, credential, nil) + log.Debugf("Successfully listed %d Azure tenants", len(tenants)) + return tenants, nil +} + +func (az *AzClient) assignSpRole(ctx context.Context, subscriptionId, resourceGroup, servicePrincipalObjectID, roleId string) error { + log.Debug("Assigning contributor role to service principal...") + if az.RoleAssignClient == nil { + c, err := armauthorization.NewRoleAssignmentsClient(subscriptionId, az.Credential, nil) + if err != nil { + return fmt.Errorf("failed to create role assignment client: %w", err) + } + az.RoleAssignClient = c + } + + scope := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s", subscriptionId, resourceGroup) + objectID := servicePrincipalObjectID + raUid := uuid.New().String() + + fullAssignmentId := fmt.Sprintf("/%s/providers/Microsoft.Authorization/roleAssignments/%s", scope, raUid) + fullDefinitionId := fmt.Sprintf("/providers/Microsoft.Authorization/roleDefinitions/%s", roleId) + + principalType := armauthorization.PrincipalTypeServicePrincipal + parameters := armauthorization.RoleAssignmentCreateParameters{ + Properties: &armauthorization.RoleAssignmentProperties{ + PrincipalID: &objectID, + RoleDefinitionID: &fullDefinitionId, + PrincipalType: &principalType, + }, + } + + _, err := az.RoleAssignClient.CreateByID(ctx, fullAssignmentId, parameters, nil) if err != nil { - return nil, fmt.Errorf("failed to create role assignment client: %w", err) + return fmt.Errorf("creating role assignment: %w", err) } - return client, nil + + log.Debug("Role assigned successfully!") + return nil } diff --git a/pkg/providers/azcli.go b/pkg/providers/azcli.go index 61393edb..c9851713 100644 --- a/pkg/providers/azcli.go +++ b/pkg/providers/azcli.go @@ -95,7 +95,6 @@ func IsLoggedInToAz() bool { } func EnsureAzCliLoggedIn() { - EnsureAzCliInstalled() if !IsLoggedInToAz() { if err := LogInToAz(); err != nil { log.Fatal("Error: unable to log in to Azure") @@ -170,6 +169,7 @@ func isValidResourceGroup( } func AzAppExists(appName string) bool { + log.Debugf("Checking if app %q exists...", appName) filter := fmt.Sprintf("displayName eq '%s'", appName) checkAppExistsCmd := exec.Command("az", "ad", "app", "list", "--only-show-errors", "--filter", filter, "--query", "[].appId") out, err := checkAppExistsCmd.CombinedOutput() @@ -228,13 +228,6 @@ func AzAksExists(aksName string, resourceGroup string) bool { } func GetCurrentAzSubscriptionLabel() (SubLabel, error) { - EnsureAzCliInstalled() - if !IsLoggedInToAz() { - if err := LogInToAz(); err != nil { - return SubLabel{}, fmt.Errorf("failed to log in to Azure CLI: %v", err) - } - } - getAccountCmd := exec.Command("az", "account", "show", "--query", "{id: id, name: name}") out, err := getAccountCmd.CombinedOutput() if err != nil { @@ -252,13 +245,6 @@ func GetCurrentAzSubscriptionLabel() (SubLabel, error) { } func GetAzSubscriptionLabels() ([]SubLabel, error) { - EnsureAzCliInstalled() - if !IsLoggedInToAz() { - if err := LogInToAz(); err != nil { - return nil, fmt.Errorf("failed to log in to Azure CLI: %v", err) - } - } - getAccountCmd := exec.Command("az", "account", "list", "--all", "--query", "[].{id: id, name: name}") out, err := getAccountCmd.CombinedOutput() diff --git a/pkg/providers/azure.go b/pkg/providers/azure.go index dfc2a8e4..9aa70235 100644 --- a/pkg/providers/azure.go +++ b/pkg/providers/azure.go @@ -8,10 +8,9 @@ import ( "os/exec" "time" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v3" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/subscription/armsubscription" - "github.com/google/uuid" + "github.com/Azure/draft/pkg/prompts" "github.com/Azure/draft/pkg/spinner" bo "github.com/cenkalti/backoff/v4" @@ -25,42 +24,39 @@ type SetUpCmd struct { Provider string Repo string appId string - tenantId string + TenantId string appObjectId string spObjectId string - AzClient AzClient + AzClient AzClientInterface } -func InitiateAzureOIDCFlow(ctx context.Context, sc *SetUpCmd, s spinner.Spinner) error { +const AKS_CLUSTER_ADMIN_ROLE_ID = "b1ff04bb-8a4e-4dc4-8eb5-8693973ce19b" + +func InitiateAzureOIDCFlow(ctx context.Context, sc *SetUpCmd, s spinner.Spinner, gh GhClient) error { log.Debug("Commencing github connection with azure...") - EnsureGhCliInstalled() - EnsureGhCliLoggedIn() s.Start() - if err := sc.ValidateSetUpConfig(); err != nil { + if err := sc.ValidateSetUpConfig(gh); err != nil { return err } - if AzAppExists(sc.AppName) { - return errors.New("app already exists") - } else if err := sc.createAzApp(); err != nil { - return err + if !AzAppExists(sc.AppName) { + err := sc.createAzApp() + if err != nil { + return err + } } if err := sc.CreateServicePrincipal(); err != nil { return err } - if err := sc.getTenantId(ctx); err != nil { - return err - } - if err := sc.getAppObjectId(); err != nil { return err } - if err := sc.assignSpRole(ctx); err != nil { + if err := sc.AzClient.assignSpRole(ctx, sc.SubscriptionID, sc.ResourceGroupName, sc.spObjectId, AKS_CLUSTER_ADMIN_ROLE_ID); err != nil { return err } @@ -129,7 +125,7 @@ func (sc *SetUpCmd) createAzApp() error { } func (sc *SetUpCmd) CreateServicePrincipal() error { - log.Debug("Creating Azure service principal...") + log.Debug("creating Azure service principal...") start := time.Now() log.Debug(start) @@ -141,7 +137,7 @@ func (sc *SetUpCmd) CreateServicePrincipal() error { return err } - log.Debug("Checking sp was created...") + log.Debug("checking sp was created...") if sc.ServicePrincipalExists() { log.Debug("Service principal created successfully!") end := time.Since(start) @@ -164,85 +160,36 @@ func (sc *SetUpCmd) CreateServicePrincipal() error { return nil } -func (sc *SetUpCmd) assignSpRole(ctx context.Context) error { - log.Debug("Assigning contributor role to service principal...") - - roleAssignClient, err := createRoleAssignmentClient(sc.SubscriptionID) - if err != nil { - return fmt.Errorf("creating role assignment client: %w", err) - } - - scope := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s", sc.SubscriptionID, sc.ResourceGroupName) - objectID := sc.spObjectId - roleId := "b24988ac-6180-42a0-ab88-20f7382dd24c" // Contributor role ID - raUid := uuid.New().String() - - fullAssignmentId := fmt.Sprintf("/%s/providers/Microsoft.Authorization/roleAssignments/%s", scope, raUid) - fullDefinitionId := fmt.Sprintf("/providers/Microsoft.Authorization/roleDefinitions/%s", roleId) - - principalType := armauthorization.PrincipalTypeServicePrincipal - parameters := armauthorization.RoleAssignmentCreateParameters{ - Properties: &armauthorization.RoleAssignmentProperties{ - PrincipalID: &objectID, - RoleDefinitionID: &fullDefinitionId, - PrincipalType: &principalType, - }, - } - - _, err = roleAssignClient.CreateByID(ctx, fullAssignmentId, parameters, nil) - if err != nil { - return fmt.Errorf("creating role assignment: %w", err) - } - - log.Debug("Role assigned successfully!") - return nil -} - -func (sc *SetUpCmd) getTenantId(ctx context.Context) error { +// Prompt the user to select a tenant ID if there are multiple tenants, or return the only tenant ID if there is only one +func PromptTenantId(azc AzClientInterface, ctx context.Context) (string, error) { log.Debug("getting Azure tenant ID") - tenants, err := sc.listTenants(ctx) + selectedTenant := "" + tenants, err := azc.ListTenants(ctx) if err != nil { - return fmt.Errorf("listing tenants: %w", err) + return selectedTenant, fmt.Errorf("listing tenants: %w", err) } if len(tenants) == 0 { - return errors.New("no tenants found") + return selectedTenant, errors.New("no tenants found") } - if len(tenants) > 1 { - return errors.New("multiple tenants found") - } - sc.tenantId = *tenants[0].TenantID - - return nil -} -func (sc *SetUpCmd) listTenants(ctx context.Context) ([]armsubscription.TenantIDDescription, error) { - log.Debug("listing Azure subscriptions") - - var tenants []armsubscription.TenantIDDescription - - pager := sc.AzClient.AzTenantClient.NewListPager(nil) - - for pager.More() { - page, err := pager.NextPage(ctx) - if err != nil { - return nil, fmt.Errorf("listing tenants page: %w", err) - } - - for _, t := range page.Value { - if t == nil { - return nil, errors.New("nil tenant") // this should never happen but it's good to check just in case - } - tenants = append(tenants, *t) + if len(tenants) == 1 { + if tenants[0].TenantID == nil { + return selectedTenant, errors.New("nil tenant ID") } + selectedTenant = *tenants[0].TenantID + log.Debugf("Selected only tenant ID found: %s", selectedTenant) + return selectedTenant, nil + } + if len(tenants) > 1 { + prompts.Select[armsubscription.TenantIDDescription]("Select the tenant you want to use", tenants, &prompts.SelectOpt[armsubscription.TenantIDDescription]{}) } - log.Debug("finished listing Azure tenants") - return tenants, nil + return selectedTenant, nil } -func (sc *SetUpCmd) ValidateSetUpConfig() error { +func (sc *SetUpCmd) ValidateSetUpConfig(gh GhClient) error { log.Debug("Checking that provided information is valid...") if err := IsSubscriptionIdValid(sc.SubscriptionID); err != nil { @@ -257,7 +204,7 @@ func (sc *SetUpCmd) ValidateSetUpConfig() error { return errors.New("invalid app name") } - if err := isValidGhRepo(sc.Repo); err != nil { + if err := gh.IsValidGhRepo(sc.Repo); err != nil { return err } @@ -374,7 +321,7 @@ func (sc *SetUpCmd) setAzSubscriptionId() error { func (sc *SetUpCmd) setAzTenantId() error { log.Debug("Setting AZURE_TENANT_ID in github...") - setTenantIdCmd := exec.Command("gh", "secret", "set", "AZURE_TENANT_ID", "-b", sc.tenantId, "--repo", sc.Repo) + setTenantIdCmd := exec.Command("gh", "secret", "set", "AZURE_TENANT_ID", "-b", sc.TenantId, "--repo", sc.Repo) out, err := setTenantIdCmd.CombinedOutput() if err != nil { log.Printf("%s\n", out) diff --git a/pkg/providers/azure_test.go b/pkg/providers/azure_test.go deleted file mode 100644 index 082f38b2..00000000 --- a/pkg/providers/azure_test.go +++ /dev/null @@ -1,139 +0,0 @@ -package providers - -import ( - "context" - "errors" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/tracing" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/subscription/armsubscription" - mock_providers "github.com/Azure/draft/pkg/providers/mock" - "go.uber.org/mock/gomock" - "strings" - "testing" -) - -func setupMockClientAndPager(ctrl *gomock.Controller, responses []armsubscription.TenantsClientListResponse) *mock_providers.MockazTenantClient { - mockClient := mock_providers.NewMockazTenantClient(ctrl) - - // Define a minimal paging handler function that returns the provided responses - mockPagerHandler := runtime.PagingHandler[armsubscription.TenantsClientListResponse]{ - More: func(t armsubscription.TenantsClientListResponse) bool { return false }, - Fetcher: func(ctx context.Context, response *armsubscription.TenantsClientListResponse) (armsubscription.TenantsClientListResponse, error) { - if len(responses) == 0 { - return armsubscription.TenantsClientListResponse{}, nil - } - resp := responses[0] - responses = responses[1:] - return resp, nil - }, - Tracer: tracing.Tracer{}, - } - - // Create a mock pager with the paging handler - mockPager := runtime.NewPager[armsubscription.TenantsClientListResponse](mockPagerHandler) - - mockClient.EXPECT().NewListPager(gomock.Nil()).Return(mockPager).Times(1) - - return mockClient -} - -func TestGetTenantId(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - // Define test data - testId := "00000000-0000-0000-0000-000000000000" - testTenantId := "/tenants/00000000-0000-0000-0000-000000000000" - testNextLink := "https://pkg.go.dev/github.com" - testTenantDesc := armsubscription.TenantIDDescription{ID: &testId, TenantID: &testTenantId} - testTenantDescArray := []*armsubscription.TenantIDDescription{&testTenantDesc} - testTenantListResult := armsubscription.TenantListResult{NextLink: &testNextLink, Value: testTenantDescArray} - responses := []armsubscription.TenantsClientListResponse{{testTenantListResult}} - - mockClient := setupMockClientAndPager(ctrl, responses) - - sc := &SetUpCmd{ - AzClient: AzClient{ - AzTenantClient: mockClient, - }, - } - - err := sc.getTenantId(context.Background()) - - if err != nil { - t.Errorf("Unexpected error: %v", err) - } -} - -// Test case for the getTenantId function when listing tenants encounters an error -func TestGetTenantId_ListTenantsError(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - // Setup mock client and pager to return an error when listing tenants - mockClient := mock_providers.NewMockazTenantClient(ctrl) - mockPager := runtime.NewPager[armsubscription.TenantsClientListResponse](runtime.PagingHandler[armsubscription.TenantsClientListResponse]{ - More: func(t armsubscription.TenantsClientListResponse) bool { return false }, - Fetcher: func(ctx context.Context, response *armsubscription.TenantsClientListResponse) (armsubscription.TenantsClientListResponse, error) { - return armsubscription.TenantsClientListResponse{}, errors.New("error listing tenants") - }, - Tracer: tracing.Tracer{}, - }) - mockClient.EXPECT().NewListPager(gomock.Nil()).Return(mockPager).Times(1) - - sc := &SetUpCmd{ - AzClient: AzClient{ - AzTenantClient: mockClient, - }, - } - - err := sc.getTenantId(context.Background()) - - if err == nil || !strings.Contains(err.Error(), "error listing tenants") { - t.Errorf("Expected error listing tenants, got: %v", err) - } -} - -// Test case for the getTenantId function when tenant list is empty -func TestGetTenantId_EmptyTenantList(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - // Setup mock client and pager with no responses - mockClient := setupMockClientAndPager(ctrl, []armsubscription.TenantsClientListResponse{}) - - sc := &SetUpCmd{ - AzClient: AzClient{ - AzTenantClient: mockClient, - }, - } - - err := sc.getTenantId(context.Background()) - - if err == nil || !strings.Contains(err.Error(), "no tenants found") { - t.Errorf("Expected error no tenants found, got: %v", err) - } -} - -// Test case for the getTenantId function when a nil tenant is encountered in the list -func TestGetTenantId_NilTenantInList(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - // Define test data with a nil tenant in the list - testTenantDescArray := []*armsubscription.TenantIDDescription{nil} - - mockClient := setupMockClientAndPager(ctrl, []armsubscription.TenantsClientListResponse{{TenantListResult: armsubscription.TenantListResult{Value: testTenantDescArray}}}) - - sc := &SetUpCmd{ - AzClient: AzClient{ - AzTenantClient: mockClient, - }, - } - - err := sc.getTenantId(context.Background()) - - if err == nil || !strings.Contains(err.Error(), "nil tenant") { - t.Errorf("Expected error nil tenant, got: %v", err) - } -} diff --git a/pkg/providers/commandrunner.go b/pkg/providers/commandrunner.go new file mode 100644 index 00000000..543ed763 --- /dev/null +++ b/pkg/providers/commandrunner.go @@ -0,0 +1,33 @@ +package providers + +import ( + "errors" + "os/exec" +) + +// CommandRunner is an interface for executing commands and getting the output/error +type CommandRunner interface { + RunCommand(...string) (string, error) +} + +type DefaultCommandRunner struct{} +var _ CommandRunner = &DefaultCommandRunner{} + +func (d *DefaultCommandRunner) RunCommand(args ...string) (string, error) { + cmd := exec.Command(args[0], args[1:]...) + out, err := cmd.CombinedOutput() + return string(out), err +} + +type FakeCommandRunner struct { + Output string + ErrStr string +} +var _ CommandRunner = &FakeCommandRunner{} + +func (f *FakeCommandRunner) RunCommand(args ...string) (string, error) { + if f.ErrStr != "" { + return f.Output, errors.New(f.ErrStr) + } + return f.Output, nil +} diff --git a/pkg/providers/ghcli.go b/pkg/providers/ghcli.go index aa814a35..a303e5fd 100644 --- a/pkg/providers/ghcli.go +++ b/pkg/providers/ghcli.go @@ -2,8 +2,7 @@ package providers import ( "fmt" - "os" - "os/exec" + "strings" log "github.com/sirupsen/logrus" ) @@ -14,15 +13,41 @@ type SubLabel struct { } // EnsureGhCliInstalled ensures that the Github CLI is installed and the user is logged in -func EnsureGhCli() { - EnsureGhCliInstalled() +func (gh GhCliClient) EnsureGhCli() { + gh.EnsureGhCliInstalled() + gh.EnsureGhCliLoggedIn() +} + +type GhClient interface { + EnsureGhCli() EnsureGhCliLoggedIn() + IsLoggedInToGh() bool + LogInToGh() error + IsValidGhRepo(repo string) error + GetRepoNameWithOwner() (string, error) +} + +var _ GhClient = &GhCliClient{} + +type GhCliClient struct { + CommandRunner CommandRunner +} + +func NewGhClient() *GhCliClient { + gh := &GhCliClient{ + CommandRunner: &DefaultCommandRunner{}, + } + gh.EnsureGhCli() + return gh } -func EnsureGhCliInstalled() { +func (gh GhCliClient) exec(args ...string) (string, error) { + return gh.CommandRunner.RunCommand(args...) +} + +func (gh GhCliClient) EnsureGhCliInstalled() { log.Debug("Checking that github cli is installed...") - ghCmd := exec.Command("gh") - _, err := ghCmd.CombinedOutput() + _, err := gh.exec("gh") if err != nil { log.Fatal("Error: The github cli is required to complete this process. Find installation instructions at this link: https://github.com/cli/cli#installation") } @@ -30,19 +55,18 @@ func EnsureGhCliInstalled() { log.Debug("Github cli found!") } -func EnsureGhCliLoggedIn() { - EnsureGhCliInstalled() - if !IsLoggedInToGh() { - if err := LogInToGh(); err != nil { +func (gh GhCliClient) EnsureGhCliLoggedIn() { + gh.EnsureGhCliInstalled() + if !gh.IsLoggedInToGh() { + if err := gh.LogInToGh(); err != nil { log.Fatal("Error: unable to log in to github") } } } -func IsLoggedInToGh() bool { +func (gh GhCliClient) IsLoggedInToGh() bool { log.Debug("Checking that user is logged in to github...") - ghCmd := exec.Command("gh", "auth", "status") - out, err := ghCmd.CombinedOutput() + out, err := gh.exec("gh", "auth", "status") if err != nil { fmt.Printf(string(out)) return false @@ -53,13 +77,9 @@ func IsLoggedInToGh() bool { } -func LogInToGh() error { +func (gh GhCliClient) LogInToGh() error { log.Debug("Logging user in to github...") - ghCmd := exec.Command("gh", "auth", "login") - ghCmd.Stdin = os.Stdin - ghCmd.Stdout = os.Stdout - ghCmd.Stderr = os.Stderr - err := ghCmd.Run() + _, err := gh.exec("gh", "auth", "login") if err != nil { return err } @@ -67,12 +87,29 @@ func LogInToGh() error { return nil } -func isValidGhRepo(repo string) error { - listReposCmd := exec.Command("gh", "repo", "view", repo) - _, err := listReposCmd.CombinedOutput() +func (gh GhCliClient) IsValidGhRepo(repo string) error { + _, err := gh.exec("gh", "repo", "view", repo) if err != nil { - log.Fatal("Github repo not found") + log.Debug("Github repo " + repo + "not found") return err } return nil } + +func (gh GhCliClient) GetRepoNameWithOwner() (string, error) { + repoNameWithOwner := "" + out, err := gh.exec("gh", "repo", "view", "--json", "nameWithOwner", "-q", ".nameWithOwner") + if err != nil { + log.Fatal("getting github repo name with owner") + return repoNameWithOwner, err + } + if out == "" { + log.Fatal("github repo name empty from gh cli") + return repoNameWithOwner, fmt.Errorf("github repo name empty from gh cli") + } + + repoNameWithOwner = string(out) + repoNameWithOwner = strings.TrimSpace(repoNameWithOwner) + log.Debug("retrieved repoNameWithOwner from gh cli: ", repoNameWithOwner) + return repoNameWithOwner, nil +} diff --git a/pkg/providers/ghcli_test.go b/pkg/providers/ghcli_test.go index 64a95577..f0afdfe3 100644 --- a/pkg/providers/ghcli_test.go +++ b/pkg/providers/ghcli_test.go @@ -5,5 +5,11 @@ import ( ) func TestHasGhCli(t *testing.T) { - EnsureGhCliInstalled() + cr := &FakeCommandRunner{ + Output: "gh version 1.0.0", + } + gh := &GhCliClient{ + CommandRunner: cr, + } + gh.EnsureGhCliInstalled() }