Skip to content

Commit

Permalink
add account id to supproted identity providers
Browse files Browse the repository at this point in the history
  • Loading branch information
sbiscigl committed Feb 6, 2025
1 parent 1bd5bdf commit afaba87
Show file tree
Hide file tree
Showing 11 changed files with 108 additions and 15 deletions.
39 changes: 39 additions & 0 deletions src/aws-cpp-sdk-core/include/aws/core/auth/AWSCredentials.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,20 @@ namespace Aws
{
}

/**
* Initializes object with accessKeyId, secretKey, sessionToken and expiration date.
*/
AWSCredentials(const Aws::String& accessKeyId,
const Aws::String& secretKey,
const Aws::String& sessionToken,
Aws::Utils::DateTime expiration,
const Aws::String& accountId)
: m_accessKeyId(accessKeyId),
m_secretKey(secretKey),
m_sessionToken(sessionToken),
m_expiration(expiration),
m_accountId(accountId) {}

bool operator == (const AWSCredentials& other) const
{
return m_accessKeyId == other.m_accessKeyId
Expand Down Expand Up @@ -109,6 +123,14 @@ namespace Aws
return m_expiration;
}

/**
* Gets the underlying account id
*/
inline const Aws::String GetAccountId() const
{
return m_accountId;
}

/**
* Sets the underlying access key credential. Copies from parameter accessKeyId.
*/
Expand All @@ -133,6 +155,14 @@ namespace Aws
m_sessionToken = sessionToken;
}

/**
* Sets the underlying account id. Copies from parameter accountId
*/
inline void SetAccountId(const Aws::String& accountId)
{
m_accountId = accountId;
}


/**
* Sets the underlying access key credential. Copies from parameter accessKeyId.
Expand All @@ -158,6 +188,14 @@ namespace Aws
m_sessionToken = sessionToken;
}

/**
* Sets the underlying account id. Copies from parameter accountId
*/
inline void SetExpiration(const char* accountId)
{
m_accountId = accountId;
}

/**
* Sets the expiration date of the credential
*/
Expand All @@ -171,6 +209,7 @@ namespace Aws
Aws::String m_secretKey;
Aws::String m_sessionToken;
Aws::Utils::DateTime m_expiration;
Aws::String m_accountId;
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,28 @@ namespace smithy {
class AwsCredentialIdentity : public AwsCredentialIdentityBase {
public:
AwsCredentialIdentity(const Aws::String& accessKeyId,
const Aws::String& secretAccessKey,
const Aws::Crt::Optional<Aws::String>& sessionToken,
const Aws::Crt::Optional<AwsIdentity::DateTime>& expiration)
: m_accessKeyId(accessKeyId), m_secretAccessKey(secretAccessKey),
m_sessionToken(sessionToken), m_expiration(expiration) {}
const Aws::String& secretAccessKey,
const Aws::Crt::Optional<Aws::String>& sessionToken,
const Aws::Crt::Optional<AwsIdentity::DateTime>& expiration,
const Aws::Crt::Optional<Aws::String>& accountId)
: m_accessKeyId(accessKeyId),
m_secretAccessKey(secretAccessKey),
m_sessionToken(sessionToken),
m_expiration(expiration),
m_accountId({accountId}) {}

Aws::String accessKeyId() const override;
Aws::String secretAccessKey() const override;
Aws::Crt::Optional<Aws::String> sessionToken() const override;
Aws::Crt::Optional<AwsIdentity::DateTime> expiration() const override;
Aws::Crt::Optional<Aws::String> accountId() const override;

protected:
Aws::String m_accessKeyId;
Aws::String m_secretAccessKey;
Aws::Crt::Optional<Aws::String> m_sessionToken;
Aws::Crt::Optional<AwsIdentity::DateTime> m_expiration;
Aws::Crt::Optional<Aws::String> m_accountId;
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,9 @@ namespace smithy {
virtual Aws::Crt::Optional<DateTime> expiration() const {
return Aws::Crt::Optional<DateTime>();
};

virtual Aws::Crt::Optional<Aws::String> accountId() const {
return Aws::Crt::Optional<Aws::String>{};
}
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,8 @@ namespace smithy {
inline Aws::Crt::Optional<AwsIdentity::DateTime> AwsCredentialIdentity::expiration() const {
return m_expiration;
}

inline Aws::Crt::Optional<Aws::String> AwsCredentialIdentity::accountId() const {
return m_sessionToken;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,12 @@ namespace smithy

const auto fetchedCreds = m_credentialsProvider->GetAWSCredentials();

auto smithyCreds = Aws::MakeUnique<AwsCredentialIdentity>("DefaultAwsCredentialIdentityResolver",
fetchedCreds.GetAWSAccessKeyId(), fetchedCreds.GetAWSSecretKey(),
fetchedCreds.GetSessionToken(), fetchedCreds.GetExpiration());
auto smithyCreds = Aws::MakeUnique<AwsCredentialIdentity>("AwsCredentialsProviderIdentityResolver",
fetchedCreds.GetAWSAccessKeyId(),
fetchedCreds.GetAWSSecretKey(),
fetchedCreds.GetSessionToken(),
fetchedCreds.GetExpiration(),
fetchedCreds.GetAccountId());

return {std::move(smithyCreds)};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ namespace smithy {
legacyCreds.GetAWSAccessKeyId(),
legacyCreds.GetAWSSecretKey(),
legacyCreds.GetSessionToken().empty()? Aws::Crt::Optional<Aws::String>() : legacyCreds.GetSessionToken(),
legacyCreds.GetExpiration());
legacyCreds.GetExpiration(),
legacyCreds.GetAccountId().empty()? Aws::Crt::Optional<Aws::String>() : legacyCreds.GetSessionToken());

return ResolveIdentityFutureOutcome(std::move(smithyCreds));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,12 @@ namespace smithy
AWS_UNREFERENCED_PARAM(identityProperties);
AWS_UNREFERENCED_PARAM(additionalParameters);

auto smithyCreds = Aws::MakeUnique<AwsCredentialIdentity>("DefaultAwsCredentialIdentityResolver",
m_credentials.GetAWSAccessKeyId(), m_credentials.GetAWSSecretKey(),
m_credentials.GetSessionToken(), m_credentials.GetExpiration());
auto smithyCreds = Aws::MakeUnique<AwsCredentialIdentity>("SimpleAwsCredentialIdentityResolver",
m_credentials.GetAWSAccessKeyId(),
m_credentials.GetAWSSecretKey(),
m_credentials.GetSessionToken().empty()? Aws::Crt::Optional<Aws::String>() : m_credentials.GetSessionToken(),
m_credentials.GetExpiration(),
m_credentials.GetAccountId().empty()? Aws::Crt::Optional<Aws::String>() : m_credentials.GetAccountId());

return {std::move(smithyCreds)};
}
Expand Down
13 changes: 13 additions & 0 deletions src/aws-cpp-sdk-core/source/auth/AWSCredentialsProvider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ using Aws::Utils::Threading::WriterLockGuard;
static const char ACCESS_KEY_ENV_VAR[] = "AWS_ACCESS_KEY_ID";
static const char SECRET_KEY_ENV_VAR[] = "AWS_SECRET_ACCESS_KEY";
static const char SESSION_TOKEN_ENV_VAR[] = "AWS_SESSION_TOKEN";
static const char ACCOUNT_ID_ENV_VAR[] = "AWS_ACCOUNT_ID";
static const char DEFAULT_PROFILE[] = "default";
static const char AWS_PROFILE_ENV_VAR[] = "AWS_PROFILE";
static const char AWS_PROFILE_DEFAULT_ENV_VAR[] = "AWS_DEFAULT_PROFILE";
Expand Down Expand Up @@ -91,6 +92,14 @@ AWSCredentials EnvironmentAWSCredentialsProvider::GetAWSCredentials()
credentials.SetSessionToken(sessionToken);
AWS_LOGSTREAM_DEBUG(ENVIRONMENT_LOG_TAG, "Found sessionToken");
}

const auto accountId = Aws::Environment::GetEnv(ACCOUNT_ID_ENV_VAR);

if (!accountId.empty())
{
credentials.SetAccountId(accountId);
AWS_LOGSTREAM_DEBUG(ENVIRONMENT_LOG_TAG, "Found accountId");
}
}

return credentials;
Expand Down Expand Up @@ -404,6 +413,10 @@ AWSCredentials Aws::Auth::GetCredentialsFromProcess(const Aws::String& process)
credentials.SetExpiration(Aws::Utils::DateTime::Now());
}
}
if (credentialsView.KeyExists("AccountId"))
{
credentials.SetAccountId(credentialsView.GetString("AccountId"));
}
else
{
credentials.SetExpiration((std::chrono::time_point<std::chrono::system_clock>::max)());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,17 +232,19 @@ void GeneralHTTPCredentialsProvider::Reload()
return;
}

Aws::String accessKey, secretKey, token;
Aws::String accessKey, secretKey, token, accountId;
Utils::Json::JsonView credentialsView(credentialsDoc);
accessKey = credentialsView.GetString("AccessKeyId");
secretKey = credentialsView.GetString("SecretAccessKey");
token = credentialsView.GetString("Token");
accountId = credentialsView.GetString("AccountId");
AWS_LOGSTREAM_DEBUG(GEN_HTTP_LOG_TAG, "Successfully pulled credentials from metadata service with access key " << accessKey);

m_credentials.SetAWSAccessKeyId(accessKey);
m_credentials.SetAWSSecretKey(secretKey);
m_credentials.SetSessionToken(token);
m_credentials.SetExpiration(Aws::Utils::DateTime(credentialsView.GetString("Expiration"), Aws::Utils::DateFormat::ISO_8601));
m_credentials.SetAccountId(accountId);
AWSCredentialsProvider::Reload();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ namespace Aws
}

auto accessKeyIdIter = currentKeyValues.find(ACCESS_KEY_ID_KEY);
Aws::String accessKey, secretKey, sessionToken;
Aws::String accessKey, secretKey, sessionToken, accountId;
if (accessKeyIdIter != currentKeyValues.end())
{
accessKey = accessKeyIdIter->second;
Expand All @@ -467,7 +467,18 @@ namespace Aws
sessionToken = sessionTokenIter->second;
}

profile.SetCredentials(Aws::Auth::AWSCredentials(accessKey, secretKey, sessionToken));
const auto accountIdIter = currentKeyValues.find("aws_account_id");

if (accountIdIter != currentKeyValues.end())
{
accountId = accountIdIter->second;
}

profile.SetCredentials(Aws::Auth::AWSCredentials(accessKey,
secretKey,
sessionToken,
DateTime{std::chrono::time_point<std::chrono::system_clock>::max()},
accountId));
}

if (!profile.GetSsoStartUrl().empty() || !profile.GetSsoRegion().empty()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <aws/core/http/HttpClientFactory.h>
#include <aws/core/http/HttpResponse.h>
#include <aws/core/utils/logging/LogMacros.h>
#include <aws/core/utils/ARN.h>
#include <aws/core/utils/StringUtils.h>
#include <aws/core/utils/HashingUtils.h>
#include <aws/core/platform/Environment.h>
Expand Down Expand Up @@ -588,6 +589,11 @@ namespace Aws
{
result.creds.SetExpiration(DateTime(StringUtils::Trim(expirationNode.GetText().c_str()).c_str(), DateFormat::ISO_8601));
}
XmlNode assumeRoleUser = credentialsNode.FirstChild("AssumedRoleUser");
if (!assumeRoleUser.IsNull())
{
result.creds.SetAccountId(ARN{assumeRoleUser.GetText()}.GetAccountId());
}
}
}
return result;
Expand Down Expand Up @@ -670,6 +676,7 @@ namespace Aws
creds.SetAWSSecretKey(roleCredentials.GetString("secretAccessKey"));
creds.SetSessionToken(roleCredentials.GetString("sessionToken"));
creds.SetExpiration(roleCredentials.GetInt64("expiration"));
creds.SetAccountId(roleCredentials.GetString("accountId"));
SSOCredentialsClient::SSOGetRoleCredentialsResult result;
result.creds = creds;
return result;
Expand Down

0 comments on commit afaba87

Please sign in to comment.