From 6b525d43fb48f6620ac2796ec282c4f3680067ee Mon Sep 17 00:00:00 2001 From: Terry Chow <32403408+tkyc@users.noreply.github.com> Date: Tue, 21 May 2024 15:27:11 -0700 Subject: [PATCH] Credential caching (#2415) --- .../jdbc/SQLServerSecurityUtility.java | 120 +++++++++++++++--- 1 file changed, 104 insertions(+), 16 deletions(-) diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java index 3bb2eb32d..734bcf487 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java @@ -6,17 +6,22 @@ package com.microsoft.sqlserver.jdbc; import java.security.InvalidKeyException; +import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.text.MessageFormat; import java.util.Arrays; +import java.util.HashMap; import java.util.Optional; import java.util.Iterator; import java.util.List; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; import javax.crypto.Mac; import javax.crypto.spec.SecretKeySpec; import com.azure.core.credential.AccessToken; +import com.azure.core.credential.TokenCredential; import com.azure.core.credential.TokenRequestContext; import com.azure.identity.ManagedIdentityCredential; import com.azure.identity.ManagedIdentityCredentialBuilder; @@ -46,6 +51,11 @@ class SQLServerSecurityUtility { // Environment variable for additionally allowed tenants. The tenantIds are comma delimited private static final String ADDITIONALLY_ALLOWED_TENANTS = "ADDITIONALLY_ALLOWED_TENANTS"; + // Credential Cache for ManagedIdentityCredential and DefaultAzureCredential + private static final HashMap CREDENTIAL_CACHE = new HashMap<>(); + + private static final Lock CREDENTIAL_LOCK = new ReentrantLock(); + private SQLServerSecurityUtility() { throw new UnsupportedOperationException(SQLServerException.getErrString("R_notSupported")); } @@ -331,16 +341,35 @@ static void verifyColumnMasterKeyMetadata(SQLServerConnection connection, SQLSer */ static SqlAuthenticationToken getManagedIdentityCredAuthToken(String resource, String managedIdentityClientId) throws SQLServerException { - ManagedIdentityCredential mic = null; if (logger.isLoggable(java.util.logging.Level.FINEST)) { logger.finest("Getting Managed Identity authentication token for: " + managedIdentityClientId); } - if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) { - mic = new ManagedIdentityCredentialBuilder().clientId(managedIdentityClientId).build(); - } else { - mic = new ManagedIdentityCredentialBuilder().build(); + String key = getHashedSecret( + new String[] {managedIdentityClientId, ManagedIdentityCredential.class.getSimpleName()}); + ManagedIdentityCredential mic = (ManagedIdentityCredential) getCredentialFromCache(key); + + if (null == mic) { + CREDENTIAL_LOCK.lock(); + + try { + mic = (ManagedIdentityCredential) getCredentialFromCache(key); + if (null == mic) { + ManagedIdentityCredentialBuilder micBuilder = new ManagedIdentityCredentialBuilder(); + + if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) { + mic = micBuilder.clientId(managedIdentityClientId).build(); + } else { + mic = micBuilder.build(); + } + + Credential credential = new Credential(mic); + CREDENTIAL_CACHE.put(key, credential); + } + } finally { + CREDENTIAL_LOCK.unlock(); + } } TokenRequestContext tokenRequestContext = new TokenRequestContext(); @@ -383,22 +412,49 @@ static SqlAuthenticationToken getDefaultAzureCredAuthToken(String resource, String intellijKeepassPath = System.getenv(INTELLIJ_KEEPASS_PASS); String[] additionallyAllowedTenants = getAdditonallyAllowedTenants(); - DefaultAzureCredentialBuilder dacBuilder = new DefaultAzureCredentialBuilder(); - DefaultAzureCredential dac = null; + int secretsLength = null == additionallyAllowedTenants ? 3 : additionallyAllowedTenants.length + 3; + String[] secrets = new String[secretsLength]; - if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) { - dacBuilder.managedIdentityClientId(managedIdentityClientId); + if (null != additionallyAllowedTenants && additionallyAllowedTenants.length != 0) { + System.arraycopy(additionallyAllowedTenants, 0, secrets, 3, additionallyAllowedTenants.length); } - if (null != intellijKeepassPath && !intellijKeepassPath.isEmpty()) { - dacBuilder.intelliJKeePassDatabasePath(intellijKeepassPath); - } + secrets[0] = DefaultAzureCredential.class.getSimpleName(); + secrets[1] = managedIdentityClientId; + secrets[2] = intellijKeepassPath; - if (null != additionallyAllowedTenants && additionallyAllowedTenants.length != 0) { - dacBuilder.additionallyAllowedTenants(additionallyAllowedTenants); - } + String key = getHashedSecret(secrets); + DefaultAzureCredential dac = (DefaultAzureCredential) getCredentialFromCache(key); + + if (null == dac) { + CREDENTIAL_LOCK.lock(); + + try { + dac = (DefaultAzureCredential) getCredentialFromCache(key); + if (null == dac) { + DefaultAzureCredentialBuilder dacBuilder = new DefaultAzureCredentialBuilder(); + + if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) { + dacBuilder.managedIdentityClientId(managedIdentityClientId); + } + + if (null != intellijKeepassPath && !intellijKeepassPath.isEmpty()) { + dacBuilder.intelliJKeePassDatabasePath(intellijKeepassPath); + } + + if (null != additionallyAllowedTenants && additionallyAllowedTenants.length != 0) { + dacBuilder.additionallyAllowedTenants(additionallyAllowedTenants); + } + + dac = dacBuilder.build(); - dac = dacBuilder.build(); + Credential credential = new Credential(dac); + CREDENTIAL_CACHE.put(key, credential); + } + } finally { + CREDENTIAL_LOCK.unlock(); + } + } TokenRequestContext tokenRequestContext = new TokenRequestContext(); String scope = resource.endsWith(SQLServerMSAL4JUtils.SLASH_DEFAULT) ? resource : resource @@ -430,4 +486,36 @@ private static String[] getAdditonallyAllowedTenants() { return null; } + + private static TokenCredential getCredentialFromCache(String key) { + Credential credential = CREDENTIAL_CACHE.get(key); + + if (null != credential) { + return credential.tokenCredential; + } + + return null; + } + + private static class Credential { + TokenCredential tokenCredential; + + public Credential(TokenCredential tokenCredential) { + this.tokenCredential = tokenCredential; + } + } + + private static String getHashedSecret(String[] secrets) throws SQLServerException { + try { + MessageDigest md = MessageDigest.getInstance("SHA-256"); + for (String secret : secrets) { + if (null != secret) { + md.update(secret.getBytes(java.nio.charset.StandardCharsets.UTF_16LE)); + } + } + return new String(md.digest()); + } catch (NoSuchAlgorithmException e) { + throw new SQLServerException(SQLServerException.getErrString("R_NoSHA256Algorithm"), e); + } + } }