Skip to content

Commit

Permalink
Fix for first run cache miss for interactive credentials (#38449)
Browse files Browse the repository at this point in the history
  • Loading branch information
christothes authored Sep 8, 2023
1 parent 45a08cf commit fb990f2
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 62 deletions.
1 change: 1 addition & 0 deletions sdk/identity/Azure.Identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Bugs Fixed

- `ManagedIdentityCredential` will fall through to the next credential in the chain in the case that Docker Desktop returns a 403 response when attempting to access the IMDS endpoint. [#38218](https://github.com/Azure/azure-sdk-for-net/issues/38218)
- Fixed an issue where interactive credentials would still prompt on the first GetToken request even when the cache is populated and an AuthenticationRecord is provided. [#38431](https://github.com/Azure/azure-sdk-for-net/issues/38431)

## 1.10.0 (2023-08-14)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ public class AuthorizationCodeCredential : TokenCredential
private readonly string _clientId;
private readonly CredentialPipeline _pipeline;
private AuthenticationRecord _record;
private bool _isCaeEnabledRequestCached = false;
private bool _isCaeDisabledRequestCached = false;
internal MsalConfidentialClient Client { get; }
private readonly string _redirectUri;
private readonly string _tenantId;
Expand Down Expand Up @@ -138,33 +136,15 @@ private async ValueTask<AccessToken> GetTokenImplAsync(bool async, TokenRequestC
{
using CredentialDiagnosticScope scope = _pipeline.StartGetTokenScope($"{nameof(AuthorizationCodeCredential)}.{nameof(GetToken)}", requestContext);

AccessToken token;
string tenantId = null;
try
{
AccessToken token;
var tenantId = TenantIdResolver.Resolve(_tenantId, requestContext, AdditionallyAllowedTenantIds);
var isCachePopulated = _record switch
{
not null when requestContext.IsCaeEnabled && _isCaeEnabledRequestCached => true,
not null when !requestContext.IsCaeEnabled && _isCaeDisabledRequestCached => true,
_ => false
};
tenantId = TenantIdResolver.Resolve(_tenantId, requestContext, AdditionallyAllowedTenantIds);

if (!isCachePopulated)
if (_record is null)
{
AuthenticationResult result = await Client
.AcquireTokenByAuthorizationCodeAsync(requestContext.Scopes, _authCode, tenantId, _redirectUri, requestContext.IsCaeEnabled, async, cancellationToken)
.ConfigureAwait(false);
_record = new AuthenticationRecord(result, _clientId);
if (requestContext.IsCaeEnabled)
{
_isCaeEnabledRequestCached = true;
}
else
{
_isCaeDisabledRequestCached = true;
}

token = new AccessToken(result.AccessToken, result.ExpiresOn);
token = await AcquireTokenWithCode(async, requestContext, token, tenantId, cancellationToken).ConfigureAwait(false);
}
else
{
Expand All @@ -176,10 +156,35 @@ private async ValueTask<AccessToken> GetTokenImplAsync(bool async, TokenRequestC

return scope.Succeeded(token);
}
catch (MsalUiRequiredException)
{
// This occurs when we have an auth record but the cae or ncae cache entry is missing
// fall through to the acquire call below
}
catch (Exception e)
{
throw scope.FailWrapAndThrow(e);
}

try
{
token = await AcquireTokenWithCode(async, requestContext, token, tenantId, cancellationToken).ConfigureAwait(false);
return scope.Succeeded(token);
}
catch (Exception e)
{
throw scope.FailWrapAndThrow(e);
}
}

private async Task<AccessToken> AcquireTokenWithCode(bool async, TokenRequestContext requestContext, AccessToken token, string tenantId, CancellationToken cancellationToken)
{
AuthenticationResult result = await Client
.AcquireTokenByAuthorizationCodeAsync(requestContext.Scopes, _authCode, tenantId, _redirectUri, requestContext.IsCaeEnabled, async, cancellationToken)
.ConfigureAwait(false);
_record = new AuthenticationRecord(result, _clientId);
token = new AccessToken(result.AccessToken, result.ExpiresOn);
return token;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ public class DeviceCodeCredential : TokenCredential
internal string ClientId { get; }
internal bool DisableAutomaticAuthentication { get; }
internal AuthenticationRecord Record { get; private set; }
private bool _isCaeEnabledRequestCached = false;
private bool _isCaeDisabledRequestCached = false;
internal Func<DeviceCodeInfo, CancellationToken, Task> DeviceCodeCallback { get; }
internal CredentialPipeline Pipeline { get; }
internal string DefaultScope { get; }
Expand Down Expand Up @@ -211,16 +209,9 @@ private async ValueTask<AccessToken> GetTokenImplAsync(bool async, TokenRequestC
try
{
Exception inner = null;

var tenantId = TenantIdResolver.Resolve(_tenantId, requestContext, AdditionallyAllowedTenantIds);
var isCachePopulated = Record switch
{
not null when requestContext.IsCaeEnabled && _isCaeEnabledRequestCached => true,
not null when !requestContext.IsCaeEnabled && _isCaeDisabledRequestCached => true,
_ => false
};

if (isCachePopulated)
if (Record is not null)
{
try
{
Expand Down Expand Up @@ -255,15 +246,6 @@ private async Task<AccessToken> GetTokenViaDeviceCodeAsync(TokenRequestContext c
.ConfigureAwait(false);

Record = new AuthenticationRecord(result, ClientId);
if (context.IsCaeEnabled)
{
_isCaeEnabledRequestCached = true;
}
else
{
_isCaeDisabledRequestCached = true;
}

return new AccessToken(result.AccessToken, result.ExpiresOn);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ public class InteractiveBrowserCredential : TokenCredential
internal CredentialPipeline Pipeline { get; }
internal bool DisableAutomaticAuthentication { get; }
internal AuthenticationRecord Record { get; private set; }
internal bool _isCaeEnabledRequestCached = false;
internal bool _isCaeDisabledRequestCached = false;
internal string DefaultScope { get; }

private const string AuthenticationRequiredMessage = "Interactive authentication is needed to acquire token. Call Authenticate to interactively authenticate.";
Expand Down Expand Up @@ -197,13 +195,7 @@ private async ValueTask<AccessToken> GetTokenImplAsync(bool async, TokenRequestC
Exception inner = null;

var tenantId = TenantIdResolver.Resolve(TenantId ?? Record?.TenantId, requestContext, AdditionallyAllowedTenantIds);
var isCachePopulated = Record switch
{
not null when requestContext.IsCaeEnabled && _isCaeEnabledRequestCached => true,
not null when !requestContext.IsCaeEnabled && _isCaeDisabledRequestCached => true,
_ => false
};
if (isCachePopulated)
if (Record is not null)
{
try
{
Expand Down Expand Up @@ -246,14 +238,6 @@ private async Task<AccessToken> GetTokenViaBrowserLoginAsync(TokenRequestContext
.ConfigureAwait(false);

Record = new AuthenticationRecord(result, ClientId);
if (context.IsCaeEnabled)
{
_isCaeEnabledRequestCached = true;
}
else
{
_isCaeDisabledRequestCached = true;
}
return new AccessToken(result.AccessToken, result.ExpiresOn);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public override TokenCredential GetTokenCredential(CommonCredentialTestConfig co
IsUnsafeSupportLoggingEnabled = config.IsUnsafeSupportLoggingEnabled,
};
var pipeline = CredentialPipeline.GetInstance(options);
return InstrumentClient(new InteractiveBrowserCredential(config.TenantId, ClientId, options, pipeline, null) { _isCaeDisabledRequestCached = true, _isCaeEnabledRequestCached = true });
return InstrumentClient(new InteractiveBrowserCredential(config.TenantId, ClientId, options, pipeline, null));
}

[Test]
Expand Down

0 comments on commit fb990f2

Please sign in to comment.