Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(graph-retriever): implement graph retriever #10241

Merged
merged 9 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
11 changes: 6 additions & 5 deletions datahub-frontend/app/auth/AuthModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import com.google.inject.name.Named;
import com.linkedin.entity.client.SystemEntityClient;
import com.linkedin.entity.client.SystemRestliEntityClient;
import com.linkedin.metadata.models.registry.EmptyEntityRegistry;
import com.linkedin.metadata.restli.DefaultRestliClientFactory;
import com.linkedin.parseq.retry.backoff.ExponentialBackoff;
import com.linkedin.util.Configuration;
Expand Down Expand Up @@ -112,7 +113,7 @@ protected void configure() {
.toConstructor(
SsoCallbackController.class.getConstructor(
SsoManager.class,
Authentication.class,
OperationContext.class,
SystemEntityClient.class,
AuthServiceClient.class,
org.pac4j.core.config.Config.class,
Expand Down Expand Up @@ -164,8 +165,9 @@ protected Authentication provideSystemAuthentication() {
@Provides
@Singleton
@Named("systemOperationContext")
protected OperationContext provideOperationContext(final Authentication systemAuthentication,
final ConfigurationProvider configurationProvider) {
protected OperationContext provideOperationContext(
final Authentication systemAuthentication,
final ConfigurationProvider configurationProvider) {
ActorContext systemActorContext =
ActorContext.builder()
.systemAuth(true)
Expand All @@ -180,7 +182,7 @@ protected OperationContext provideOperationContext(final Authentication systemAu
.operationContextConfig(systemConfig)
.systemActorContext(systemActorContext)
.searchContext(SearchContext.EMPTY)
.entityRegistryContext(EntityRegistryContext.EMPTY)
.entityRegistryContext(EntityRegistryContext.builder().build(EmptyEntityRegistry.EMPTY))
// Authorizer.EMPTY doesn't actually apply to system auth
.authorizerContext(AuthorizerContext.builder().authorizer(Authorizer.EMPTY).build())
.build(systemAuthentication);
Expand All @@ -200,7 +202,6 @@ protected SystemEntityClient provideEntityClient(
@Named("systemOperationContext") final OperationContext systemOperationContext,
final ConfigurationProvider configurationProvider) {
return new SystemRestliEntityClient(
systemOperationContext,
buildRestliClient(),
new ExponentialBackoff(_configs.getInt(ENTITY_CLIENT_RETRY_INTERVAL)),
_configs.getInt(ENTITY_CLIENT_NUM_RETRIES),
Expand Down
80 changes: 41 additions & 39 deletions datahub-frontend/app/auth/sso/oidc/OidcCallbackLogic.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import auth.CookieConfigs;
import auth.sso.SsoManager;
import client.AuthServiceClient;
import com.datahub.authentication.Authentication;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.linkedin.common.AuditStamp;
Expand Down Expand Up @@ -57,6 +56,8 @@
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import io.datahubproject.metadata.context.OperationContext;
import lombok.extern.slf4j.Slf4j;
import org.pac4j.core.config.Config;
import org.pac4j.core.context.Cookie;
Expand All @@ -69,6 +70,8 @@
import org.pac4j.play.PlayWebContext;
import play.mvc.Result;

import javax.annotation.Nonnull;

/**
* This class contains the logic that is executed when an OpenID Connect Identity Provider redirects
* back to D DataHub after an authentication attempt.
Expand All @@ -82,23 +85,23 @@
@Slf4j
public class OidcCallbackLogic extends DefaultCallbackLogic<Result, PlayWebContext> {

private final SsoManager _ssoManager;
private final SystemEntityClient _entityClient;
private final Authentication _systemAuthentication;
private final AuthServiceClient _authClient;
private final CookieConfigs _cookieConfigs;
private final SsoManager ssoManager;
private final SystemEntityClient systemEntityClient;
private final OperationContext systemOperationContext;
private final AuthServiceClient authClient;
private final CookieConfigs cookieConfigs;

public OidcCallbackLogic(
final SsoManager ssoManager,
final Authentication systemAuthentication,
final OperationContext systemOperationContext,
final SystemEntityClient entityClient,
final AuthServiceClient authClient,
final CookieConfigs cookieConfigs) {
_ssoManager = ssoManager;
_systemAuthentication = systemAuthentication;
_entityClient = entityClient;
_authClient = authClient;
_cookieConfigs = cookieConfigs;
this.ssoManager = ssoManager;
this.systemOperationContext = systemOperationContext;
systemEntityClient = entityClient;
this.authClient = authClient;
this.cookieConfigs = cookieConfigs;
}

@Override
Expand Down Expand Up @@ -131,8 +134,8 @@ public Result perform(
}

// By this point, we know that OIDC is the enabled provider.
final OidcConfigs oidcConfigs = (OidcConfigs) _ssoManager.getSsoProvider().configs();
return handleOidcCallback(oidcConfigs, result, context, getProfileManager(context));
final OidcConfigs oidcConfigs = (OidcConfigs) ssoManager.getSsoProvider().configs();
return handleOidcCallback(systemOperationContext, oidcConfigs, result, getProfileManager(context));
}

@SuppressWarnings("unchecked")
Expand All @@ -153,9 +156,9 @@ private void setContextRedirectUrl(PlayWebContext context) {
}

private Result handleOidcCallback(
final OperationContext opContext,
final OidcConfigs oidcConfigs,
final Result result,
final PlayWebContext context,
final ProfileManager<UserProfile> profileManager) {

log.debug("Beginning OIDC Callback Handling...");
Expand All @@ -177,23 +180,23 @@ private Result handleOidcCallback(
if (oidcConfigs.isJitProvisioningEnabled()) {
log.debug("Just-in-time provisioning is enabled. Beginning provisioning process...");
CorpUserSnapshot extractedUser = extractUser(corpUserUrn, profile);
tryProvisionUser(extractedUser);
tryProvisionUser(opContext, extractedUser);
if (oidcConfigs.isExtractGroupsEnabled()) {
// Extract groups & provision them.
List<CorpGroupSnapshot> extractedGroups = extractGroups(profile);
tryProvisionGroups(extractedGroups);
tryProvisionGroups(opContext, extractedGroups);
// Add users to groups on DataHub. Note that this clears existing group membership for a
// user if it already exists.
updateGroupMembership(corpUserUrn, createGroupMembership(extractedGroups));
updateGroupMembership(opContext, corpUserUrn, createGroupMembership(extractedGroups));
}
} else if (oidcConfigs.isPreProvisioningRequired()) {
// We should only allow logins for user accounts that have been pre-provisioned
log.debug("Pre Provisioning is required. Beginning validation of extracted user...");
verifyPreProvisionedUser(corpUserUrn);
verifyPreProvisionedUser(opContext, corpUserUrn);
}
// Update user status to active on login.
// If we want to prevent certain users from logging in, here's where we'll want to do it.
setUserStatus(
setUserStatus(opContext,
corpUserUrn,
new CorpUserStatus()
.setStatus(Constants.CORP_USER_STATUS_ACTIVE)
Expand All @@ -209,15 +212,15 @@ private Result handleOidcCallback(
}

// Successfully logged in - Generate GMS login token
final String accessToken = _authClient.generateSessionTokenForUser(corpUserUrn.getId());
final String accessToken = authClient.generateSessionTokenForUser(corpUserUrn.getId());
return result
.withSession(createSessionMap(corpUserUrn.toString(), accessToken))
.withCookies(
createActorCookie(
corpUserUrn.toString(),
_cookieConfigs.getTtlInHours(),
_cookieConfigs.getAuthCookieSameSite(),
_cookieConfigs.getAuthCookieSecure()));
cookieConfigs.getTtlInHours(),
cookieConfigs.getAuthCookieSameSite(),
cookieConfigs.getAuthCookieSecure()));
}
return internalServerError(
"Failed to authenticate current user. Cannot find valid identity provider profile in session.");
Expand Down Expand Up @@ -331,7 +334,7 @@ private List<CorpGroupSnapshot> extractGroups(CommonProfile profile) {
String.format(
"Attempting to extract groups from OIDC profile %s",
profile.getAttributes().toString()));
final OidcConfigs configs = (OidcConfigs) _ssoManager.getSsoProvider().configs();
final OidcConfigs configs = (OidcConfigs) ssoManager.getSsoProvider().configs();

// First, attempt to extract a list of groups from the profile, using the group name attribute
// config.
Expand Down Expand Up @@ -400,13 +403,13 @@ private GroupMembership createGroupMembership(final List<CorpGroupSnapshot> extr
return groupMembershipAspect;
}

private void tryProvisionUser(CorpUserSnapshot corpUserSnapshot) {
private void tryProvisionUser(@Nonnull OperationContext opContext, CorpUserSnapshot corpUserSnapshot) {

log.debug(String.format("Attempting to provision user with urn %s", corpUserSnapshot.getUrn()));

// 1. Check if this user already exists.
try {
final Entity corpUser = _entityClient.get(corpUserSnapshot.getUrn(), _systemAuthentication);
final Entity corpUser = systemEntityClient.get(opContext, corpUserSnapshot.getUrn());
final CorpUserSnapshot existingCorpUserSnapshot = corpUser.getValue().getCorpUserSnapshot();

log.debug(String.format("Fetched GMS user with urn %s", corpUserSnapshot.getUrn()));
Expand All @@ -420,7 +423,7 @@ private void tryProvisionUser(CorpUserSnapshot corpUserSnapshot) {
// 2. The user does not exist. Provision them.
final Entity newEntity = new Entity();
newEntity.setValue(Snapshot.create(corpUserSnapshot));
_entityClient.update(newEntity, _systemAuthentication);
systemEntityClient.update(opContext, newEntity);
log.debug(String.format("Successfully provisioned user %s", corpUserSnapshot.getUrn()));
}
log.debug(
Expand All @@ -434,7 +437,7 @@ private void tryProvisionUser(CorpUserSnapshot corpUserSnapshot) {
}
}

private void tryProvisionGroups(List<CorpGroupSnapshot> corpGroups) {
private void tryProvisionGroups(@Nonnull OperationContext opContext, List<CorpGroupSnapshot> corpGroups) {

log.debug(
String.format(
Expand All @@ -446,7 +449,7 @@ private void tryProvisionGroups(List<CorpGroupSnapshot> corpGroups) {
final Set<Urn> urnsToFetch =
corpGroups.stream().map(CorpGroupSnapshot::getUrn).collect(Collectors.toSet());
final Map<Urn, Entity> existingGroups =
_entityClient.batchGet(urnsToFetch, _systemAuthentication);
systemEntityClient.batchGet(opContext, urnsToFetch);

log.debug(String.format("Fetched GMS groups with urns %s", existingGroups.keySet()));

Expand Down Expand Up @@ -484,11 +487,10 @@ private void tryProvisionGroups(List<CorpGroupSnapshot> corpGroups) {
log.debug(String.format("Provisioning groups with urns %s", groupsToCreateUrns));

// Now batch create all entities identified to create.
_entityClient.batchUpdate(
systemEntityClient.batchUpdate(opContext,
groupsToCreate.stream()
.map(groupSnapshot -> new Entity().setValue(Snapshot.create(groupSnapshot)))
.collect(Collectors.toSet()),
_systemAuthentication);
.collect(Collectors.toSet()));

log.debug(String.format("Successfully provisioned groups with urns %s", groupsToCreateUrns));
} catch (RemoteInvocationException e) {
Expand All @@ -501,7 +503,7 @@ private void tryProvisionGroups(List<CorpGroupSnapshot> corpGroups) {
}
}

private void updateGroupMembership(Urn urn, GroupMembership groupMembership) {
private void updateGroupMembership(@Nonnull OperationContext opContext, Urn urn, GroupMembership groupMembership) {
log.debug(String.format("Updating group membership for user %s", urn));
final MetadataChangeProposal proposal = new MetadataChangeProposal();
proposal.setEntityUrn(urn);
Expand All @@ -510,18 +512,18 @@ private void updateGroupMembership(Urn urn, GroupMembership groupMembership) {
proposal.setAspect(GenericRecordUtils.serializeAspect(groupMembership));
proposal.setChangeType(ChangeType.UPSERT);
try {
_entityClient.ingestProposal(proposal, _systemAuthentication);
systemEntityClient.ingestProposal(opContext, proposal);
} catch (RemoteInvocationException e) {
throw new RuntimeException(
String.format("Failed to update group membership for user with urn %s", urn), e);
}
}

private void verifyPreProvisionedUser(CorpuserUrn urn) {
private void verifyPreProvisionedUser(@Nonnull OperationContext opContext, CorpuserUrn urn) {
// Validate that the user exists in the system (there is more than just a key aspect for them,
// as of today).
try {
final Entity corpUser = _entityClient.get(urn, _systemAuthentication);
final Entity corpUser = systemEntityClient.get(opContext, urn);

log.debug(String.format("Fetched GMS user with urn %s", urn));

Expand All @@ -543,15 +545,15 @@ private void verifyPreProvisionedUser(CorpuserUrn urn) {
}
}

private void setUserStatus(final Urn urn, final CorpUserStatus newStatus) throws Exception {
private void setUserStatus(@Nonnull OperationContext opContext, final Urn urn, final CorpUserStatus newStatus) throws Exception {
// Update status aspect to be active.
final MetadataChangeProposal proposal = new MetadataChangeProposal();
proposal.setEntityUrn(urn);
proposal.setEntityType(Constants.CORP_USER_ENTITY_NAME);
proposal.setAspectName(Constants.CORP_USER_STATUS_ASPECT_NAME);
proposal.setAspect(GenericRecordUtils.serializeAspect(newStatus));
proposal.setChangeType(ChangeType.UPSERT);
_entityClient.ingestProposal(proposal, _systemAuthentication);
systemEntityClient.ingestProposal(opContext, proposal);
}

private Optional<String> extractRegexGroup(final String patternStr, final String target) {
Expand Down
11 changes: 7 additions & 4 deletions datahub-frontend/app/controllers/SsoCallbackController.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
import java.util.concurrent.CompletionStage;
import javax.annotation.Nonnull;
import javax.inject.Inject;
import javax.inject.Named;

import io.datahubproject.metadata.context.OperationContext;
import lombok.extern.slf4j.Slf4j;
import org.pac4j.core.client.Client;
import org.pac4j.core.client.Clients;
Expand Down Expand Up @@ -43,7 +46,7 @@ public class SsoCallbackController extends CallbackController {
@Inject
public SsoCallbackController(
@Nonnull SsoManager ssoManager,
@Nonnull Authentication systemAuthentication,
@Named("systemOperationContext") @Nonnull OperationContext systemOperationContext,
@Nonnull SystemEntityClient entityClient,
@Nonnull AuthServiceClient authClient,
@Nonnull Config config,
Expand All @@ -55,7 +58,7 @@ public SsoCallbackController(
setCallbackLogic(
new SsoCallbackLogic(
ssoManager,
systemAuthentication,
systemOperationContext,
entityClient,
authClient,
new CookieConfigs(configs)));
Expand Down Expand Up @@ -96,13 +99,13 @@ public class SsoCallbackLogic implements CallbackLogic<Result, PlayWebContext> {

SsoCallbackLogic(
final SsoManager ssoManager,
final Authentication systemAuthentication,
final OperationContext systemOperationContext,
final SystemEntityClient entityClient,
final AuthServiceClient authClient,
final CookieConfigs cookieConfigs) {
_oidcCallbackLogic =
new OidcCallbackLogic(
ssoManager, systemAuthentication, entityClient, authClient, cookieConfigs);
ssoManager, systemOperationContext, entityClient, authClient, cookieConfigs);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@
import com.linkedin.datahub.graphql.types.view.DataHubViewType;
import com.linkedin.entity.client.EntityClient;
import com.linkedin.entity.client.SystemEntityClient;
import com.linkedin.metadata.client.UsageStatsJavaClient;
import com.linkedin.metadata.config.DataHubConfiguration;
import com.linkedin.metadata.config.IngestionConfiguration;
import com.linkedin.metadata.config.TestsConfiguration;
Expand All @@ -372,7 +373,6 @@
import com.linkedin.metadata.timeline.TimelineService;
import com.linkedin.metadata.timeseries.TimeseriesAspectService;
import com.linkedin.metadata.version.GitVersion;
import com.linkedin.usage.UsageClient;
import graphql.execution.DataFetcherResult;
import graphql.schema.DataFetcher;
import graphql.schema.DataFetchingEnvironment;
Expand Down Expand Up @@ -411,7 +411,7 @@ public class GmsGraphQLEngine {
private final EntityClient entityClient;
private final SystemEntityClient systemEntityClient;
private final GraphClient graphClient;
private final UsageClient usageClient;
private final UsageStatsJavaClient usageClient;
private final SiblingGraphService siblingGraphService;

private final EntityService entityService;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import com.linkedin.datahub.graphql.featureflags.FeatureFlags;
import com.linkedin.entity.client.EntityClient;
import com.linkedin.entity.client.SystemEntityClient;
import com.linkedin.metadata.client.UsageStatsJavaClient;
import com.linkedin.metadata.config.DataHubConfiguration;
import com.linkedin.metadata.config.IngestionConfiguration;
import com.linkedin.metadata.config.TestsConfiguration;
Expand All @@ -35,7 +36,6 @@
import com.linkedin.metadata.timeline.TimelineService;
import com.linkedin.metadata.timeseries.TimeseriesAspectService;
import com.linkedin.metadata.version.GitVersion;
import com.linkedin.usage.UsageClient;
import io.datahubproject.metadata.services.RestrictedService;
import io.datahubproject.metadata.services.SecretService;
import lombok.Data;
Expand All @@ -45,7 +45,7 @@ public class GmsGraphQLEngineArgs {
EntityClient entityClient;
SystemEntityClient systemEntityClient;
GraphClient graphClient;
UsageClient usageClient;
UsageStatsJavaClient usageClient;
AnalyticsService analyticsService;
EntityService entityService;
RecommendationsService recommendationsService;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ public CompletableFuture<SubTypes> get(DataFetchingEnvironment environment) thro
EntityResponse entityResponse =
_entityClient
.batchGetV2(
context.getOperationContext(),
urn.getEntityType(),
Collections.singleton(urn),
Collections.singleton(_aspectName),
context.getAuthentication())
Collections.singleton(_aspectName))
.get(urn);
if (entityResponse != null && entityResponse.getAspects().containsKey(_aspectName)) {
subType =
Expand Down
Loading
Loading