Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(oidc settings): effective JWS algorithm setting #9712

Merged
merged 1 commit into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions datahub-frontend/app/auth/AuthUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ public class AuthUtils {
public static final String USE_NONCE = "useNonce";
public static final String READ_TIMEOUT = "readTimeout";
public static final String EXTRACT_JWT_ACCESS_TOKEN_CLAIMS = "extractJwtAccessTokenClaims";
// Retained for backwards compatibility
public static final String PREFERRED_JWS_ALGORITHM = "preferredJwsAlgorithm";
public static final String PREFERRED_JWS_ALGORITHM_2 = "preferredJwsAlgorithm2";

/**
* Determines whether the inbound request should be forward to downstream Metadata Service. Today,
Expand Down
4 changes: 2 additions & 2 deletions datahub-frontend/app/auth/sso/oidc/OidcConfigs.java
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ public Builder from(final com.typesafe.config.Config configs, final String ssoSe
extractJwtAccessTokenClaims =
Optional.of(jsonNode.get(EXTRACT_JWT_ACCESS_TOKEN_CLAIMS).asBoolean());
}
if (jsonNode.has(OIDC_PREFERRED_JWS_ALGORITHM)) {
preferredJwsAlgorithm = Optional.of(jsonNode.get(OIDC_PREFERRED_JWS_ALGORITHM).asText());
if (jsonNode.has(PREFERRED_JWS_ALGORITHM_2)) {
preferredJwsAlgorithm = Optional.of(jsonNode.get(PREFERRED_JWS_ALGORITHM_2).asText());
} else {
preferredJwsAlgorithm =
Optional.ofNullable(getOptional(configs, OIDC_PREFERRED_JWS_ALGORITHM, null));
Expand Down
3 changes: 3 additions & 0 deletions datahub-frontend/play.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ play {
test {
useJUnitPlatform()

testLogging.showStandardStreams = true
testLogging.exceptionFormat = 'full'

def playJava17CompatibleJvmArgs = [
"--add-opens=java.base/java.lang=ALL-UNNAMED",
//"--add-opens=java.base/java.lang.invoke=ALL-UNNAMED",
Expand Down
24 changes: 24 additions & 0 deletions datahub-frontend/test/security/OidcConfigurationTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package security;

import static auth.AuthUtils.*;
import static auth.sso.oidc.OidcConfigs.*;
import static org.junit.jupiter.api.Assertions.assertEquals;

Expand All @@ -24,6 +25,7 @@
import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.Test;
import org.pac4j.oidc.client.OidcClient;
import org.json.JSONObject;

public class OidcConfigurationTest {

Expand Down Expand Up @@ -317,4 +319,26 @@ public void readTimeoutPropagation() {
OidcProvider oidcProvider = new OidcProvider(oidcConfigs);
assertEquals(10000, ((OidcClient) oidcProvider.client()).getConfiguration().getReadTimeout());
}

@Test
public void readPreferredJwsAlgorithmPropagationFromConfig() {
final String SSO_SETTINGS_JSON_STR = new JSONObject().put(PREFERRED_JWS_ALGORITHM, "HS256").toString();
CONFIG.withValue(OIDC_PREFERRED_JWS_ALGORITHM, ConfigValueFactory.fromAnyRef("RS256"));
OidcConfigs.Builder oidcConfigsBuilder = new OidcConfigs.Builder();
oidcConfigsBuilder.from(CONFIG, SSO_SETTINGS_JSON_STR);
OidcConfigs oidcConfigs = new OidcConfigs(oidcConfigsBuilder);
OidcProvider oidcProvider = new OidcProvider(oidcConfigs);
assertEquals("RS256", ((OidcClient) oidcProvider.client()).getConfiguration().getPreferredJwsAlgorithm().toString());
}

@Test
public void readPreferredJwsAlgorithmPropagationFromJSON() {
final String SSO_SETTINGS_JSON_STR = new JSONObject().put(PREFERRED_JWS_ALGORITHM, "Unused").put(PREFERRED_JWS_ALGORITHM_2, "HS256").toString();
CONFIG.withValue(OIDC_PREFERRED_JWS_ALGORITHM, ConfigValueFactory.fromAnyRef("RS256"));
OidcConfigs.Builder oidcConfigsBuilder = new OidcConfigs.Builder();
oidcConfigsBuilder.from(CONFIG, SSO_SETTINGS_JSON_STR);
OidcConfigs oidcConfigs = new OidcConfigs(oidcConfigsBuilder);
OidcProvider oidcProvider = new OidcProvider(oidcConfigs);
assertEquals("HS256", ((OidcClient) oidcProvider.client()).getConfiguration().getPreferredJwsAlgorithm().toString());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,12 @@ record OidcSettings {
extractJwtAccessTokenClaims: optional boolean

/**
* ADVANCED. Which jws algorithm to use.
* ADVANCED. Which jws algorithm to use. Unused.
*/
preferredJwsAlgorithm: optional string
}

/**
* ADVANCED. Which jws algorithm to use.
*/
preferredJwsAlgorithm2: optional string
}
8 changes: 8 additions & 0 deletions metadata-service/auth-servlet-impl/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,12 @@ dependencies {
compileOnly externalDependency.lombok

annotationProcessor externalDependency.lombok

testImplementation externalDependency.testng
testImplementation externalDependency.springBootTest
}

test {
testLogging.showStandardStreams = true
testLogging.exceptionFormat = 'full'
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ public class AuthServiceController {
private static final String USE_NONCE = "useNonce";
private static final String READ_TIMEOUT = "readTimeout";
private static final String EXTRACT_JWT_ACCESS_TOKEN_CLAIMS = "extractJwtAccessTokenClaims";
// Retained for backwards compatibility
private static final String PREFERRED_JWS_ALGORITHM = "preferredJwsAlgorithm";
private static final String PREFERRED_JWS_ALGORITHM_2 = "preferredJwsAlgorithm2";

@Inject StatelessTokenService _statelessTokenService;

Expand Down Expand Up @@ -514,8 +516,8 @@ private void buildOidcSettingsResponse(JSONObject json, final OidcSettings oidcS
if (oidcSettings.hasExtractJwtAccessTokenClaims()) {
json.put(EXTRACT_JWT_ACCESS_TOKEN_CLAIMS, oidcSettings.isExtractJwtAccessTokenClaims());
}
if (oidcSettings.hasPreferredJwsAlgorithm()) {
json.put(PREFERRED_JWS_ALGORITHM, oidcSettings.getPreferredJwsAlgorithm());
if (oidcSettings.hasPreferredJwsAlgorithm2()) {
json.put(PREFERRED_JWS_ALGORITHM, oidcSettings.getPreferredJwsAlgorithm2());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package com.datahub.auth.authentication;

import static com.linkedin.metadata.Constants.GLOBAL_SETTINGS_INFO_ASPECT_NAME;
import static com.linkedin.metadata.Constants.GLOBAL_SETTINGS_URN;
import static org.mockito.Mockito.when;
import static org.testng.Assert.*;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.linkedin.data.schema.annotation.PathSpecBasedSchemaAnnotationVisitor;
import com.linkedin.metadata.entity.EntityService;
import com.linkedin.settings.global.GlobalSettingsInfo;
import com.linkedin.settings.global.OidcSettings;
import com.linkedin.settings.global.SsoSettings;
import java.io.IOException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.ComponentScan;
import org.springframework.context.annotation.Import;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.test.context.testng.AbstractTestNGSpringContextTests;
import org.springframework.web.servlet.DispatcherServlet;
import org.testng.annotations.BeforeTest;
import org.testng.annotations.Test;

@SpringBootTest(classes = {DispatcherServlet.class})
@ComponentScan(basePackages = {"com.datahub.auth.authentication"})
@Import({AuthServiceTestConfiguration.class})
public class AuthServiceControllerTest extends AbstractTestNGSpringContextTests {
@BeforeTest
public void disableAssert() {
PathSpecBasedSchemaAnnotationVisitor.class
.getClassLoader()
.setClassAssertionStatus(PathSpecBasedSchemaAnnotationVisitor.class.getName(), false);
}

@Autowired private AuthServiceController authServiceController;
@Autowired private EntityService mockEntityService;

private final String PREFERRED_JWS_ALGORITHM = "preferredJwsAlgorithm";

@Test
public void initTest() {
assertNotNull(authServiceController);
assertNotNull(mockEntityService);
}

@Test
public void oldPreferredJwsAlgorithmIsNotReturned() throws IOException {
OidcSettings mockOidcSettings =
new OidcSettings()
.setEnabled(true)
.setClientId("1")
.setClientSecret("2")
.setDiscoveryUri("http://localhost")
.setPreferredJwsAlgorithm("test");
SsoSettings mockSsoSettings =
new SsoSettings().setBaseUrl("http://localhost").setOidcSettings(mockOidcSettings);
GlobalSettingsInfo mockGlobalSettingsInfo = new GlobalSettingsInfo().setSso(mockSsoSettings);

when(mockEntityService.getLatestAspect(GLOBAL_SETTINGS_URN, GLOBAL_SETTINGS_INFO_ASPECT_NAME))
.thenReturn(mockGlobalSettingsInfo);

ResponseEntity<String> httpResponse = authServiceController.getSsoSettings(null).join();
assertEquals(httpResponse.getStatusCode(), HttpStatus.OK);

JsonNode jsonNode = new ObjectMapper().readTree(httpResponse.getBody());
assertFalse(jsonNode.has(PREFERRED_JWS_ALGORITHM));
}

@Test
public void newPreferredJwsAlgorithmIsReturned() throws IOException {
OidcSettings mockOidcSettings =
new OidcSettings()
.setEnabled(true)
.setClientId("1")
.setClientSecret("2")
.setDiscoveryUri("http://localhost")
.setPreferredJwsAlgorithm("jws1")
.setPreferredJwsAlgorithm2("jws2");
SsoSettings mockSsoSettings =
new SsoSettings().setBaseUrl("http://localhost").setOidcSettings(mockOidcSettings);
GlobalSettingsInfo mockGlobalSettingsInfo = new GlobalSettingsInfo().setSso(mockSsoSettings);

when(mockEntityService.getLatestAspect(GLOBAL_SETTINGS_URN, GLOBAL_SETTINGS_INFO_ASPECT_NAME))
.thenReturn(mockGlobalSettingsInfo);

ResponseEntity<String> httpResponse = authServiceController.getSsoSettings(null).join();
assertEquals(httpResponse.getStatusCode(), HttpStatus.OK);

JsonNode jsonNode = new ObjectMapper().readTree(httpResponse.getBody());
assertTrue(jsonNode.has(PREFERRED_JWS_ALGORITHM));
assertEquals(jsonNode.get(PREFERRED_JWS_ALGORITHM).asText(), "jws2");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package com.datahub.auth.authentication;

import com.datahub.authentication.Authentication;
import com.datahub.authentication.invite.InviteTokenService;
import com.datahub.authentication.token.StatelessTokenService;
import com.datahub.authentication.user.NativeUserService;
import com.datahub.telemetry.TrackingService;
import com.linkedin.gms.factory.config.ConfigurationProvider;
import com.linkedin.metadata.entity.EntityService;
import com.linkedin.metadata.secret.SecretService;
import org.springframework.boot.test.context.TestConfiguration;
import org.springframework.boot.test.mock.mockito.MockBean;

@TestConfiguration
public class AuthServiceTestConfiguration {
@MockBean StatelessTokenService _statelessTokenService;

@MockBean Authentication _systemAuthentication;

@MockBean(name = "configurationProvider")
ConfigurationProvider _configProvider;

@MockBean NativeUserService _nativeUserService;

@MockBean EntityService _entityService;

@MockBean SecretService _secretService;

@MockBean InviteTokenService _inviteTokenService;

@MockBean TrackingService _trackingService;
}
Loading