Skip to content

Commit

Permalink
Make 'AWSClient.AccountID' a getter.
Browse files Browse the repository at this point in the history
  • Loading branch information
ewbankkit committed Nov 27, 2024
1 parent 398878c commit fa813c5
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 31 deletions.
40 changes: 20 additions & 20 deletions internal/acctest/acctest.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ func PreCheck(ctx context.Context, t *testing.T) {
}

// ProviderAccountID returns the account ID of an AWS provider
func ProviderAccountID(provider *schema.Provider) string {
func ProviderAccountID(ctx context.Context, provider *schema.Provider) string {
if provider == nil {
log.Print("[DEBUG] Unable to read account ID from test provider: empty provider")
return ""
Expand All @@ -335,7 +335,7 @@ func ProviderAccountID(provider *schema.Provider) string {
log.Print("[DEBUG] Unable to read account ID from test provider: non-AWS or unconfigured AWS provider")
return ""
}
return client.AccountID
return client.AccountID(ctx)
}

// CheckDestroyNoop is a TestCheckFunc to be used as a TestCase's CheckDestroy when no such check can be made.
Expand All @@ -355,17 +355,17 @@ func CheckSleep(t *testing.T, d time.Duration) resource.TestCheckFunc {
}

// CheckResourceAttrAccountID ensures the Terraform state exactly matches the account ID
func CheckResourceAttrAccountID(resourceName, attributeName string) resource.TestCheckFunc {
func CheckResourceAttrAccountID(ctx context.Context, resourceName, attributeName string) resource.TestCheckFunc {
return func(s *terraform.State) error {
return resource.TestCheckResourceAttr(resourceName, attributeName, AccountID())(s)
return resource.TestCheckResourceAttr(resourceName, attributeName, AccountID(ctx))(s)
}
}

// CheckResourceAttrRegionalARN ensures the Terraform state exactly matches a formatted ARN with region
func CheckResourceAttrRegionalARN(resourceName, attributeName, arnService, arnResource string) resource.TestCheckFunc {
func CheckResourceAttrRegionalARN(ctx context.Context, resourceName, attributeName, arnService, arnResource string) resource.TestCheckFunc {
return func(s *terraform.State) error {
attributeValue := arn.ARN{
AccountID: AccountID(),
AccountID: AccountID(ctx),
Partition: Partition(),
Region: Region(),
Resource: arnResource,
Expand Down Expand Up @@ -458,10 +458,10 @@ func MatchResourceAttrAccountID(resourceName, attributeName string) resource.Tes
}

// MatchResourceAttrRegionalARN ensures the Terraform state regexp matches a formatted ARN with region
func MatchResourceAttrRegionalARN(resourceName, attributeName, arnService string, arnResourceRegexp *regexp.Regexp) resource.TestCheckFunc {
func MatchResourceAttrRegionalARN(ctx context.Context, resourceName, attributeName, arnService string, arnResourceRegexp *regexp.Regexp) resource.TestCheckFunc {
return func(s *terraform.State) error {
arnRegexp := arn.ARN{
AccountID: AccountID(),
AccountID: AccountID(ctx),
Partition: Partition(),
Region: Region(),
Resource: arnResourceRegexp.String(),
Expand All @@ -479,10 +479,10 @@ func MatchResourceAttrRegionalARN(resourceName, attributeName, arnService string
}

// MatchResourceAttrRegionalARNRegion ensures the Terraform state regexp matches a formatted ARN with the specified region
func MatchResourceAttrRegionalARNRegion(resourceName, attributeName, arnService, region string, arnResourceRegexp *regexp.Regexp) resource.TestCheckFunc {
func MatchResourceAttrRegionalARNRegion(ctx context.Context, resourceName, attributeName, arnService, region string, arnResourceRegexp *regexp.Regexp) resource.TestCheckFunc {
return func(s *terraform.State) error {
arnRegexp := arn.ARN{
AccountID: AccountID(),
AccountID: AccountID(ctx),
Partition: Partition(),
Region: region,
Resource: arnResourceRegexp.String(),
Expand Down Expand Up @@ -570,19 +570,19 @@ func MatchResourceAttrGlobalHostname(resourceName, attributeName, serviceName st
}
}

func globalARNValue(arnService, arnResource string) string {
func globalARNValue(ctx context.Context, arnService, arnResource string) string {
return arn.ARN{
AccountID: AccountID(),
AccountID: AccountID(ctx),
Partition: Partition(),
Resource: arnResource,
Service: arnService,
}.String()
}

// CheckResourceAttrGlobalARN ensures the Terraform state exactly matches a formatted ARN without region
func CheckResourceAttrGlobalARN(resourceName, attributeName, arnService, arnResource string) resource.TestCheckFunc {
func CheckResourceAttrGlobalARN(ctx context.Context, resourceName, attributeName, arnService, arnResource string) resource.TestCheckFunc {
return func(s *terraform.State) error {
return resource.TestCheckResourceAttr(resourceName, attributeName, globalARNValue(arnService, arnResource))(s)
return resource.TestCheckResourceAttr(resourceName, attributeName, globalARNValue(ctx, arnService, arnResource))(s)
}
}

Expand Down Expand Up @@ -612,10 +612,10 @@ func CheckResourceAttrGlobalARNAccountID(resourceName, attributeName, accountID,
}

// MatchResourceAttrGlobalARN ensures the Terraform state regexp matches a formatted ARN without region
func MatchResourceAttrGlobalARN(resourceName, attributeName, arnService string, arnResourceRegexp *regexp.Regexp) resource.TestCheckFunc {
func MatchResourceAttrGlobalARN(ctx context.Context, resourceName, attributeName, arnService string, arnResourceRegexp *regexp.Regexp) resource.TestCheckFunc {
return func(s *terraform.State) error {
arnRegexp := arn.ARN{
AccountID: AccountID(),
AccountID: AccountID(ctx),
Partition: Partition(),
Resource: arnResourceRegexp.String(),
Service: arnService,
Expand Down Expand Up @@ -892,8 +892,8 @@ func PrimaryInstanceState(s *terraform.State, name string) (*terraform.InstanceS

// AccountID returns the account ID of Provider
// Must be used within a resource.TestCheckFunc
func AccountID() string {
return ProviderAccountID(Provider)
func AccountID(ctx context.Context) string {
return ProviderAccountID(ctx, Provider)
}

func Region() string {
Expand Down Expand Up @@ -2624,14 +2624,14 @@ func CheckVPCExists(ctx context.Context, n string, v *ec2types.Vpc) resource.Tes
}
}

func CheckCallerIdentityAccountID(n string) resource.TestCheckFunc {
func CheckCallerIdentityAccountID(ctx context.Context, n string) resource.TestCheckFunc {
return func(s *terraform.State) error {
rs, ok := s.RootModule().Resources[n]
if !ok {
return fmt.Errorf("can't find AccountID resource: %s", n)
}

expected := Provider.Meta().(*conns.AWSClient).AccountID
expected := Provider.Meta().(*conns.AWSClient).AccountID(ctx)
if rs.Primary.Attributes["account_id"] != expected {
return fmt.Errorf("incorrect Account ID: expected %q, got %q", expected, rs.Primary.Attributes["account_id"])
}
Expand Down
5 changes: 3 additions & 2 deletions internal/acctest/known_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package acctest

import (
"context"
"fmt"

"github.com/hashicorp/terraform-plugin-testing/knownvalue"
Expand All @@ -23,7 +24,7 @@ func (v globalARNCheck) CheckValue(other any) error {
return fmt.Errorf("expected string value for GlobalARN check, got: %T", other)
}

arnValue := globalARNValue(v.arnService, v.arnResource)
arnValue := globalARNValue(context.Background(), v.arnService, v.arnResource)

if otherVal != arnValue {
return fmt.Errorf("expected value %s for GlobalARN check, got: %s", arnValue, otherVal)
Expand All @@ -34,7 +35,7 @@ func (v globalARNCheck) CheckValue(other any) error {

// String returns the string representation of the value.
func (v globalARNCheck) String() string {
return globalARNValue(v.arnService, v.arnResource)
return globalARNValue(context.Background(), v.arnService, v.arnResource)
}

func GlobalARN(arnService, arnResource string) globalARNCheck {
Expand Down
19 changes: 12 additions & 7 deletions internal/conns/awsclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@ import (
)

type AWSClient struct {
AccountID string
defaultTagsConfig *tftags.DefaultConfig
ignoreTagsConfig *tftags.IgnoreConfig
Region string
ServicePackages map[string]ServicePackage
Region string
ServicePackages map[string]ServicePackage

accountID string
awsConfig *aws.Config
clients map[string]any
conns map[string]any
defaultTagsConfig *tftags.DefaultConfig
endpoints map[string]string // From provider configuration.
httpClient *http.Client
ignoreTagsConfig *tftags.IgnoreConfig
lock sync.Mutex
logger baselogging.Logger
partition endpoints.Partition
Expand Down Expand Up @@ -76,6 +76,11 @@ func (c *AWSClient) Endpoints(context.Context) map[string]string {
return maps.Clone(c.endpoints)
}

// AccountID returns the configured AWS account ID.
func (c *AWSClient) AccountID(context.Context) string {
return c.accountID
}

// Partition returns the ID of the configured AWS partition.
func (c *AWSClient) Partition(context.Context) string {
return c.partition.ID()
Expand All @@ -94,7 +99,7 @@ func (c *AWSClient) RegionalARN(ctx context.Context, service, resource string) s
Partition: c.Partition(ctx),
Service: service,
Region: c.Region,
AccountID: c.AccountID,
AccountID: c.AccountID(ctx),
Resource: resource,
}.String()
}
Expand Down Expand Up @@ -209,7 +214,7 @@ func (c *AWSClient) DefaultKMSKeyPolicy(ctx context.Context) string {
}
]
}
`, c.Partition(ctx), c.AccountID)
`, c.Partition(ctx), c.AccountID(ctx))
}

// GlobalAcceleratorHostedZoneID returns the Route 53 hosted zone ID
Expand Down
2 changes: 1 addition & 1 deletion internal/conns/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ func (c *Config) ConfigureProvider(ctx context.Context, client *AWSClient) (*AWS
}
}

client.AccountID = accountID
client.accountID = accountID
client.defaultTagsConfig = c.DefaultTagsConfig
client.ignoreTagsConfig = c.IgnoreTagsConfig
client.Region = c.Region
Expand Down
3 changes: 2 additions & 1 deletion internal/provider/provider_acc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,7 @@ func TestAccProvider_Region_stsRegion(t *testing.T) {
// For historical reasons, ignore a single empty `assume_role` block
func TestAccProvider_AssumeRole_empty(t *testing.T) {
ctx := acctest.Context(t)

resource.ParallelTest(t, resource.TestCase{
PreCheck: func() { acctest.PreCheck(ctx, t) },
ErrorCheck: acctest.ErrorCheck(t),
Expand All @@ -656,7 +657,7 @@ func TestAccProvider_AssumeRole_empty(t *testing.T) {
{
Config: testAccProviderConfig_assumeRoleEmpty,
Check: resource.ComposeTestCheckFunc(
acctest.CheckCallerIdentityAccountID("data.aws_caller_identity.current"),
acctest.CheckCallerIdentityAccountID(ctx, "data.aws_caller_identity.current"),
),
},
},
Expand Down

0 comments on commit fa813c5

Please sign in to comment.