Skip to content

Commit

Permalink
Fix for 3000 - improve error messages when credentials fail to load (#…
Browse files Browse the repository at this point in the history
…3003)

* Fix for 3000 - logging for loading credentials into MSAL and minor for MSI+FIC

* revert msal bump

* Address comment

* Fix a nullable warning

* Updating a few files used to test

---------

Co-authored-by: Jean-Marc Prieur <[email protected]>
  • Loading branch information
bgavrilMS and jmprieur authored Sep 24, 2024
1 parent d4572e7 commit aed3c16
Show file tree
Hide file tree
Showing 16 changed files with 552 additions and 43 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using Microsoft.Extensions.Logging;
using Microsoft.Identity.Abstractions;
using System;

namespace Microsoft.Identity.Web
{
// Log messages for DefaultCredentialsLoader
public partial class DefaultCredentialsLoader
{
/// <summary>
/// Logging infrastructure
/// </summary>
private static class Logger
{
private static readonly Action<ILogger, string, string, bool, Exception?> s_credentialLoadingFailure =
LoggerMessage.Define<string, string, bool>(
LogLevel.Information,
new EventId(
7,
nameof(CredentialLoadingFailure)),
"Failed to load credential {id} from source {sourceType}. Will it be skipped in the future ? {skip}.");

public static void CredentialLoadingFailure(ILogger logger, CredentialDescription cd, Exception? ex)
=> s_credentialLoadingFailure(logger, cd.Id, cd.SourceType.ToString(), cd.Skip, ex);
}
}
}
33 changes: 25 additions & 8 deletions src/Microsoft.Identity.Web.Certificate/DefaultCredentialsLoader.cs
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Identity.Abstractions;

namespace Microsoft.Identity.Web
{
/// <summary>
/// Default credentials loader.
/// </summary>
public class DefaultCredentialsLoader : ICredentialsLoader
public partial class DefaultCredentialsLoader : ICredentialsLoader
{
ILogger<DefaultCredentialsLoader>? _logger;
private readonly ILogger<DefaultCredentialsLoader> _logger;
private readonly ConcurrentDictionary<string, SemaphoreSlim> _loadingSemaphores = new ConcurrentDictionary<string, SemaphoreSlim>();

/// <summary>
Expand All @@ -24,15 +26,16 @@ public class DefaultCredentialsLoader : ICredentialsLoader
/// <param name="logger"></param>
public DefaultCredentialsLoader(ILogger<DefaultCredentialsLoader>? logger)
{
_logger = logger;
_logger = logger ?? new NullLogger<DefaultCredentialsLoader>();

CredentialSourceLoaders = new Dictionary<CredentialSource, ICredentialSourceLoader>
{
{ CredentialSource.KeyVault, new KeyVaultCertificateLoader() },
{ CredentialSource.Path, new FromPathCertificateLoader() },
{ CredentialSource.StoreWithThumbprint, new StoreWithThumbprintCertificateLoader() },
{ CredentialSource.StoreWithDistinguishedName, new StoreWithDistinguishedNameCertificateLoader() },
{ CredentialSource.Base64Encoded, new Base64EncodedCertificateLoader() },
{ CredentialSource.SignedAssertionFromManagedIdentity, new SignedAssertionFromManagedIdentityCredentialLoader() },
{ CredentialSource.SignedAssertionFromManagedIdentity, new SignedAssertionFromManagedIdentityCredentialLoader(_logger) },
{ CredentialSource.SignedAssertionFilePath, new SignedAssertionFilePathCredentialsLoader(_logger) }
};
}
Expand All @@ -51,7 +54,8 @@ public DefaultCredentialsLoader() : this(null)
public IDictionary<CredentialSource, ICredentialSourceLoader> CredentialSourceLoaders { get; }

/// <inheritdoc/>
/// Load the credentials from the description, if needed.
/// Load the credentials from the description, if needed.
/// Important: Ignores SKIP flag, propagates exceptions.
public async Task LoadCredentialsIfNeededAsync(CredentialDescription credentialDescription, CredentialSourceLoaderParameters? parameters = null)
{
_ = Throws.IfNull(credentialDescription);
Expand All @@ -69,7 +73,17 @@ public async Task LoadCredentialsIfNeededAsync(CredentialDescription credentialD
if (credentialDescription.CachedValue == null)
{
if (CredentialSourceLoaders.TryGetValue(credentialDescription.SourceType, out ICredentialSourceLoader? loader))
await loader.LoadIfNeededAsync(credentialDescription, parameters);
{
try
{
await loader.LoadIfNeededAsync(credentialDescription, parameters);
}
catch (Exception ex)
{
Logger.CredentialLoadingFailure(_logger, credentialDescription, ex);
throw;
}
}
}
}
finally
Expand All @@ -81,11 +95,15 @@ public async Task LoadCredentialsIfNeededAsync(CredentialDescription credentialD
}

/// <inheritdoc/>
public async Task<CredentialDescription?> LoadFirstValidCredentialsAsync(IEnumerable<CredentialDescription> credentialDescriptions, CredentialSourceLoaderParameters? parameters = null)
/// Loads first valid credential which is not marked as Skipped.
public async Task<CredentialDescription?> LoadFirstValidCredentialsAsync(
IEnumerable<CredentialDescription> credentialDescriptions,
CredentialSourceLoaderParameters? parameters = null)
{
foreach (var credentialDescription in credentialDescriptions)
{
await LoadCredentialsIfNeededAsync(credentialDescription, parameters);

if (!credentialDescription.Skip)
{
return credentialDescription;
Expand All @@ -107,6 +125,5 @@ public void ResetCredentials(IEnumerable<CredentialDescription> credentialDescri
}
}
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@ public async Task LoadIfNeededAsync(CredentialDescription credentialDescription,
{
// Given that managed identity can be not available locally, we need to try to get a
// signed assertion, and if it fails, move to the next credentials
_= await signedAssertion!.GetSignedAssertionAsync(null);
_ = await signedAssertion!.GetSignedAssertionAsync(null);
credentialDescription.CachedValue = signedAssertion;
}
catch (Exception)
{
credentialDescription.Skip = true;
}
throw;
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,21 @@
using System.Threading;
using System.Threading.Tasks;
using Azure.Identity;
using Microsoft.Extensions.Logging;
using Microsoft.Identity.Abstractions;
using Microsoft.Identity.Client;

namespace Microsoft.Identity.Web
{
internal class SignedAssertionFromManagedIdentityCredentialLoader : ICredentialSourceLoader
{
private readonly ILogger<DefaultCredentialsLoader> _logger;

public SignedAssertionFromManagedIdentityCredentialLoader(ILogger<DefaultCredentialsLoader> logger)
{
_logger = logger;
}

public CredentialSource CredentialSource => CredentialSource.SignedAssertionFromManagedIdentity;

public async Task LoadIfNeededAsync(CredentialDescription credentialDescription, CredentialSourceLoaderParameters? credentialSourceLoaderParameters)
Expand All @@ -19,16 +28,19 @@ public async Task LoadIfNeededAsync(CredentialDescription credentialDescription,
ManagedIdentityClientAssertion? managedIdentityClientAssertion = credentialDescription.CachedValue as ManagedIdentityClientAssertion;
if (credentialDescription.CachedValue == null)
{
managedIdentityClientAssertion = new ManagedIdentityClientAssertion(credentialDescription.ManagedIdentityClientId, credentialDescription.TokenExchangeUrl);
managedIdentityClientAssertion = new ManagedIdentityClientAssertion(
credentialDescription.ManagedIdentityClientId,
credentialDescription.TokenExchangeUrl,
_logger);
}
try
{
// Given that managed identity can be not available locally, we need to try to get a
// signed assertion, and if it fails, move to the next credentials
_= await managedIdentityClientAssertion!.GetSignedAssertionAsync(null);
_ = await managedIdentityClientAssertion!.GetSignedAssertionAsync(null);
credentialDescription.CachedValue = managedIdentityClientAssertion;
}
catch (AuthenticationFailedException)
catch (MsalServiceException)
{
credentialDescription.Skip = true;
throw;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Identity.Client;
using Microsoft.Identity.Client.AppConfig;
using Microsoft.Identity.Client.Extensibility;
using Microsoft.Identity.Web.Certificateless;

namespace Microsoft.Identity.Web
Expand All @@ -16,21 +19,16 @@ public class ManagedIdentityClientAssertion : ClientAssertionProviderBase
{
IManagedIdentityApplication _managedIdentityApplication;
private readonly string _tokenExchangeUrl;
private readonly ILogger? _logger;

/// <summary>
/// See https://aka.ms/ms-id-web/certificateless.
/// </summary>
/// <param name="managedIdentityClientId">Optional ClientId of the Managed Identity</param>
public ManagedIdentityClientAssertion(string? managedIdentityClientId)
public ManagedIdentityClientAssertion(string? managedIdentityClientId) :
this(managedIdentityClientId, tokenExchangeUrl: null, logger: null)
{
var id = ManagedIdentityId.SystemAssigned;
if (!string.IsNullOrEmpty(managedIdentityClientId))
{
id = ManagedIdentityId.WithUserAssignedClientId(managedIdentityClientId);
}

_managedIdentityApplication = ManagedIdentityApplicationBuilder.Create(id).Build();
_tokenExchangeUrl = CertificatelessConstants.DefaultTokenExchangeUrl;
}

/// <summary>
Expand All @@ -39,9 +37,38 @@ public ManagedIdentityClientAssertion(string? managedIdentityClientId)
/// <param name="managedIdentityClientId">Optional ClientId of the Managed Identity</param>
/// <param name="tokenExchangeUrl">Optional audience of the token to be requested from Managed Identity. Default value is "api://AzureADTokenExchange".
/// This value is different on clouds other than Azure Public</param>
public ManagedIdentityClientAssertion(string? managedIdentityClientId, string? tokenExchangeUrl) : this (managedIdentityClientId)
public ManagedIdentityClientAssertion(string? managedIdentityClientId, string? tokenExchangeUrl) :
this(managedIdentityClientId, tokenExchangeUrl, null)
{
}

/// <summary>
/// See https://aka.ms/ms-id-web/certificateless.
/// </summary>
/// <param name="managedIdentityClientId">Optional ClientId of the Managed Identity</param>
/// <param name="tokenExchangeUrl">Optional audience of the token to be requested from Managed Identity. Default value is "api://AzureADTokenExchange".
/// This value is different on clouds other than Azure Public</param>
/// <param name="logger">A logger</param>
public ManagedIdentityClientAssertion(string? managedIdentityClientId, string? tokenExchangeUrl, ILogger? logger)
{
_tokenExchangeUrl = tokenExchangeUrl ?? CertificatelessConstants.DefaultTokenExchangeUrl;
_logger = logger;

var id = ManagedIdentityId.SystemAssigned;
if (!string.IsNullOrEmpty(managedIdentityClientId))
{
id = ManagedIdentityId.WithUserAssignedClientId(managedIdentityClientId);
}

var builder = ManagedIdentityApplicationBuilder.Create(id);
if (_logger != null)
{
builder = builder.WithLogging(Log, ConvertMicrosoftExtensionsLogLevelToMsal(_logger), enablePiiLogging: false);
_logger.LogInformation($"ManagedIdentityClientAssertion with tokenExchangeUrl={_tokenExchangeUrl}");
}

_managedIdentityApplication = builder
.Build();
}

/// <summary>
Expand All @@ -58,5 +85,57 @@ protected override async Task<ClientAssertion> GetClientAssertionAsync(Assertion

return new ClientAssertion(result.AccessToken, result.ExpiresOn);
}

private void Log(
Client.LogLevel level,
string message,
bool containsPii)
{
switch (level)
{
case Client.LogLevel.Always:
_logger.LogInformation(message);
break;
case Client.LogLevel.Error:
_logger.LogError(message);
break;
case Client.LogLevel.Warning:
_logger.LogWarning(message);
break;
case Client.LogLevel.Info:
_logger.LogInformation(message);
break;
case Client.LogLevel.Verbose:
_logger.LogDebug(message);
break;
}
}

private Client.LogLevel? ConvertMicrosoftExtensionsLogLevelToMsal(ILogger logger)
{
if (logger.IsEnabled(Microsoft.Extensions.Logging.LogLevel.Debug)
|| logger.IsEnabled(Microsoft.Extensions.Logging.LogLevel.Trace))
{
return Client.LogLevel.Verbose;
}
else if (logger.IsEnabled(Microsoft.Extensions.Logging.LogLevel.Information))
{
return Client.LogLevel.Info;
}
else if (logger.IsEnabled(Microsoft.Extensions.Logging.LogLevel.Warning))
{
return Client.LogLevel.Warning;
}
else if (logger.IsEnabled(Microsoft.Extensions.Logging.LogLevel.Error)
|| logger.IsEnabled(Microsoft.Extensions.Logging.LogLevel.Critical))
{
return Client.LogLevel.Error;
}
else
{
return null;
}
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using Microsoft.Extensions.Logging;
using Microsoft.Identity.Abstractions;

namespace Microsoft.Identity.Web
{
Expand Down Expand Up @@ -40,6 +41,62 @@ internal static class Logger
LoggingEventId.UsingCertThumbprint,
"[MsIdWeb] Using certificate Thumbprint={certThumbprint} as client credentials. ");

private static readonly Action<ILogger, string, string, Exception?> s_credentialAttempt =
LoggerMessage.Define<string, string>(
LogLevel.Information,
LoggingEventId.CredentialLoadAttempt,
"[MsIdWeb] Attempting to load the credential from the CredentialDescription with Id={Id} and Skip={Skip} . ");

private static readonly Action<ILogger, string, string, Exception?> s_credentialAttemptFailed =
LoggerMessage.Define<string, string>(
LogLevel.Information,
LoggingEventId.CredentialLoadAttemptFailed,
"[MsIdWeb] Loading the credential from CredentialDescription Id={Id} failed. Will the credential be re-attempted? - {Skip}.");

/// <summary>
/// Logger for attempting to use a CredentialDescription with MSAL
/// </summary>
/// <param name="logger"></param>
/// <param name="certificateDescription"></param>
/// <param name="ex"></param>
public static void AttemptToLoadCredentialsFailed(
ILogger logger,
CredentialDescription certificateDescription,
Exception ex) =>
s_credentialAttemptFailed(
logger,
certificateDescription.Id,
certificateDescription.Skip.ToString(),
ex);

/// <summary>
/// Logger for attempting to use a CredentialDescription with MSAL
/// </summary>
/// <param name="logger"></param>
/// <param name="certificateDescription"></param>
public static void AttemptToLoadCredentials(
ILogger logger,
CredentialDescription certificateDescription) =>
s_credentialAttempt(
logger,
certificateDescription.Id,
certificateDescription.Skip.ToString(),
default!);

/// <summary>
/// Logger for attempting to use a CredentialDescription with MSAL
/// </summary>
/// <param name="logger"></param>
/// <param name="certificateDescription"></param>
public static void FailedToLoadCredentials(
ILogger logger,
CredentialDescription certificateDescription) =>
s_credentialAttemptFailed(
logger,
certificateDescription.Id,
certificateDescription.Skip.ToString(),
default!);

/// <summary>
/// Logger for handling information specific to ConfidentialClientApplicationBuilderExtension.
/// </summary>
Expand Down
Loading

0 comments on commit aed3c16

Please sign in to comment.