Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix multiple cluster credentials handling #276

Merged
merged 2 commits into from
May 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 75 additions & 10 deletions Client.UnitTests/CamundaCloudTokenProviderTest.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Net;
using System.Net.Http;
Expand All @@ -21,14 +22,23 @@ public class CamundaCloudTokenProviderTest
private static long ExpiresIn { get; set; }
private static string Token { get; set; }

private static string _requestUri;
private static string _clientId;
private static string _clientSecret;
private static string _audience;

[SetUp]
public void Init()
{
_requestUri = "https://local.de";
_clientId = "ID";
_clientSecret = "SECRET";
_audience = "AUDIENCE";
TokenProvider = new CamundaCloudTokenProviderBuilder()
.UseAuthServer("https://local.de")
.UseClientId("ID")
.UseClientSecret("SECRET")
.UseAudience("AUDIENCE")
.UseAuthServer(_requestUri)
.UseClientId(_clientId)
.UseClientSecret(_clientSecret)
.UseAudience(_audience)
.Build();

MessageHandlerStub = new HttpMessageHandlerStub();
Expand All @@ -50,16 +60,17 @@ private class HttpMessageHandlerStub : HttpMessageHandler
{
public int RequestCount { get; set; }
private bool _disposed = false;

protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request,
CancellationToken cancellationToken)
{
CheckDisposed();
Assert.AreEqual(request.RequestUri, "https://local.de");
Assert.AreEqual(request.RequestUri, _requestUri);
var content = await request.Content.ReadAsStringAsync();
var jsonObject = JObject.Parse(content);
Assert.AreEqual((string)jsonObject["client_id"], "ID");
Assert.AreEqual((string)jsonObject["client_secret"], "SECRET");
Assert.AreEqual((string)jsonObject["audience"], "AUDIENCE");
Assert.AreEqual((string)jsonObject["client_id"], _clientId);
Assert.AreEqual((string)jsonObject["client_secret"], _clientSecret);
Assert.AreEqual((string)jsonObject["audience"], _audience);

RequestCount++;
var responseMessage = new HttpResponseMessage(HttpStatusCode.OK)
Expand Down Expand Up @@ -117,8 +128,39 @@ public async Task ShouldStoreCredentials()
Assert.AreEqual(1, files.Length);
var tokenFile = files[0];
var content = File.ReadAllText(tokenFile);
var fileToken = JsonConvert.DeserializeObject<CamundaCloudTokenProvider.AccessToken>(content);
Assert.AreEqual(token, fileToken.Token);
var credentials = JsonConvert.DeserializeObject<Dictionary<string, CamundaCloudTokenProvider.AccessToken>>(content);
Assert.AreEqual(credentials["AUDIENCE"].Token, token);
}

[Test]
public async Task ShouldStoreMultipleCredentials()
{
// given
await TokenProvider.GetAccessTokenForRequestAsync();
var otherProvider = new CamundaCloudTokenProviderBuilder()
.UseAuthServer(_requestUri)
.UseClientId(_clientId = "OTHERID")
.UseClientSecret(_clientSecret = "OTHERSECRET")
.UseAudience(_audience = "OTHER_AUDIENCE")
.Build();
otherProvider.SetHttpMessageHandler(MessageHandlerStub);
otherProvider.TokenStoragePath = TokenStoragePath;
Token = "OTHER_TOKEN";

// when
var token = await otherProvider.GetAccessTokenForRequestAsync();

// then
Assert.AreEqual("OTHER_TOKEN", token);
var files = Directory.GetFiles(TokenStoragePath);
Assert.AreEqual(1, files.Length);
var tokenFile = files[0];
var content = File.ReadAllText(tokenFile);
var credentials = JsonConvert.DeserializeObject<Dictionary<string, CamundaCloudTokenProvider.AccessToken>>(content);

Assert.AreEqual(credentials.Count, 2);
Assert.AreEqual(token, credentials["OTHER_AUDIENCE"].Token);
Assert.AreEqual("REQUESTED_TOKEN", credentials["AUDIENCE"].Token);
}

[Test]
Expand Down Expand Up @@ -199,6 +241,29 @@ public async Task ShouldUseCachedFile()
Assert.AreEqual(0, MessageHandlerStub.RequestCount);
}

[Test]
public async Task ShouldNotUseCachedFileForOtherAudience()
{
// given
Token = "STORED_TOKEN";
await TokenProvider.GetAccessTokenForRequestAsync();
var otherProvider = new CamundaCloudTokenProviderBuilder()
.UseAuthServer(_requestUri)
.UseClientId(_clientId = "OTHERID")
.UseClientSecret(_clientSecret = "OTHERSECRET")
.UseAudience(_audience = "OTHER_AUDIENCE")
.Build();
otherProvider.SetHttpMessageHandler(MessageHandlerStub);
otherProvider.TokenStoragePath = TokenStoragePath;
Token = "OTHER_TOKEN";

// when
var token = await otherProvider.GetAccessTokenForRequestAsync();

// then
Assert.AreEqual("OTHER_TOKEN", token);
}

[Test]
public async Task ShouldRequestWhenCachedFileExpired()
{
Expand Down
32 changes: 22 additions & 10 deletions Client/Impl/Builder/CamundaCloudTokenProvider.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.IO;
using System.Net.Http;
using System.Text;
Expand All @@ -16,7 +18,7 @@ public class CamundaCloudTokenProvider : IAccessTokenSupplier, IDisposable
private const string JsonContent =
"{{\"client_id\":\"{0}\",\"client_secret\":\"{1}\",\"audience\":\"{2}\",\"grant_type\":\"client_credentials\"}}";

private const string ZeebeCloudTokenFileName = "cloud.token";
private const string ZeebeCloudTokenFileName = "credentials";

private static readonly string ZeebeRootPath =
Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.UserProfile), ".zeebe");
Expand Down Expand Up @@ -45,6 +47,7 @@ internal CamundaCloudTokenProvider(
// default client handler
httpClient = new HttpClient(new HttpClientHandler(), disposeHandler: false);
TokenStoragePath = ZeebeRootPath;
Credentials = new Dictionary<string, AccessToken>();
}

public static CamundaCloudTokenProviderBuilder Builder()
Expand All @@ -54,17 +57,18 @@ public static CamundaCloudTokenProviderBuilder Builder()

public string TokenStoragePath { get; set; }
private string TokenFileName => TokenStoragePath + Path.DirectorySeparatorChar + ZeebeCloudTokenFileName;
private AccessToken CurrentAccessToken { get; set; }
private Dictionary<string, AccessToken> Credentials { get; set; }

public Task<string> GetAccessTokenForRequestAsync(
string authUri = null,
CancellationToken cancellationToken = default(CancellationToken))
{
// check in memory
if (CurrentAccessToken != null)
AccessToken currentAccessToken;
if (Credentials.TryGetValue(audience, out currentAccessToken))
{
logger?.LogTrace("Use in memory access token.");
return GetValidToken(CurrentAccessToken);
return GetValidToken(currentAccessToken);
}

// check if token file exists
Expand All @@ -75,9 +79,12 @@ public Task<string> GetAccessTokenForRequestAsync(
logger?.LogTrace("Read cached access token from {tokenFileName}", tokenFileName);
// read token
var content = File.ReadAllText(tokenFileName);
var accessToken = JsonConvert.DeserializeObject<AccessToken>(content);
CurrentAccessToken = accessToken;
return GetValidToken(accessToken);
Credentials = JsonConvert.DeserializeObject<Dictionary<string, AccessToken>>(content);
if (Credentials.TryGetValue(audience, out currentAccessToken))
{
logger?.LogTrace("Found access token in credentials file.");
return GetValidToken(currentAccessToken);
}
}

// request token
Expand Down Expand Up @@ -139,14 +146,19 @@ private async Task<string> RequestAccessTokenAsync()

var result = await httpResponseMessage.Content.ReadAsStringAsync();
var token = ToAccessToken(result);
logger?.LogDebug("Received access token {token}, will backup at {path}.", token, tokenFileName);
File.WriteAllText(tokenFileName, JsonConvert.SerializeObject(token));
CurrentAccessToken = token;
logger?.LogDebug("Received access token for {audience}, will backup at {path}.", audience, tokenFileName);
Credentials[audience] = token;
WriteCredentials();

return token.Token;
}
}

private void WriteCredentials()
{
File.WriteAllText(TokenFileName, JsonConvert.SerializeObject(Credentials));
}

private static AccessToken ToAccessToken(string result)
{
var jsonResult = JObject.Parse(result);
Expand Down