Skip to content

Commit

Permalink
feat(entity-client): restli batchGetV2 batchSize fix and concurrency (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
david-leifker authored Jun 6, 2024
1 parent 3b8cda6 commit 3dd1c4c
Show file tree
Hide file tree
Showing 10 changed files with 150 additions and 88 deletions.
4 changes: 3 additions & 1 deletion datahub-frontend/app/auth/AuthModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ public class AuthModule extends AbstractModule {
private static final String ENTITY_CLIENT_RETRY_INTERVAL = "entityClient.retryInterval";
private static final String ENTITY_CLIENT_NUM_RETRIES = "entityClient.numRetries";
private static final String ENTITY_CLIENT_RESTLI_GET_BATCH_SIZE = "entityClient.restli.get.batchSize";
private static final String ENTITY_CLIENT_RESTLI_GET_BATCH_CONCURRENCY = "entityClient.restli.get.batchConcurrency";
private static final String GET_SSO_SETTINGS_ENDPOINT = "auth/getSsoSettings";

private final com.typesafe.config.Config _configs;
Expand Down Expand Up @@ -208,7 +209,8 @@ protected SystemEntityClient provideEntityClient(
new ExponentialBackoff(_configs.getInt(ENTITY_CLIENT_RETRY_INTERVAL)),
_configs.getInt(ENTITY_CLIENT_NUM_RETRIES),
configurationProvider.getCache().getClient().getEntityClient(),
Math.max(1, _configs.getInt(ENTITY_CLIENT_RESTLI_GET_BATCH_SIZE)));
Math.max(1, _configs.getInt(ENTITY_CLIENT_RESTLI_GET_BATCH_SIZE)),
Math.max(1, _configs.getInt(ENTITY_CLIENT_RESTLI_GET_BATCH_CONCURRENCY)));
}

@Provides
Expand Down
6 changes: 4 additions & 2 deletions datahub-frontend/conf/application.conf
Original file line number Diff line number Diff line change
Expand Up @@ -289,5 +289,7 @@ entityClient.retryInterval = 2
entityClient.retryInterval = ${?ENTITY_CLIENT_RETRY_INTERVAL}
entityClient.numRetries = 3
entityClient.numRetries = ${?ENTITY_CLIENT_NUM_RETRIES}
entityClient.restli.get.batchSize = 100
entityClient.restli.get.batchSize = ${?ENTITY_CLIENT_RESTLI_GET_BATCH_SIZE}
entityClient.restli.get.batchSize = 50
entityClient.restli.get.batchSize = ${?ENTITY_CLIENT_RESTLI_GET_BATCH_SIZE}
entityClient.restli.get.batchConcurrency = 2
entityClient.restli.get.batchConcurrency = ${?ENTITY_CLIENT_RESTLI_GET_BATCH_CONCURRENCY}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ public SystemEntityClient systemEntityClient(
new ExponentialBackoff(1),
1,
configurationProvider.getCache().getClient().getEntityClient(),
1);
1,
2);
}

@MockBean public Database ebeanServer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ entityClient:
restli:
get:
batchSize: ${ENTITY_CLIENT_RESTLI_GET_BATCH_SIZE:100} # limited to prevent exceeding restli URI size limit
batchConcurrency: ${ENTITY_CLIENT_RESTLI_GET_BATCH_CONCURRENCY:2} # parallel threads

usageClient:
retryInterval: ${USAGE_CLIENT_RETRY_INTERVAL:2}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ public EntityClient entityClient(
@Value("${datahub.gms.sslContext.protocol}") String gmsSslProtocol,
@Value("${entityClient.retryInterval:2}") int retryInterval,
@Value("${entityClient.numRetries:3}") int numRetries,
final @Value("${entityClient.restli.get.batchSize:150}") int batchGetV2Size) {
final @Value("${entityClient.restli.get.batchSize}") int batchGetV2Size,
final @Value("${entityClient.restli.get.batchConcurrency}") int batchGetV2Concurrency) {
final Client restClient;
if (gmsUri != null) {
restClient = DefaultRestliClientFactory.getRestLiClient(URI.create(gmsUri), gmsSslProtocol);
Expand All @@ -39,7 +40,11 @@ public EntityClient entityClient(
DefaultRestliClientFactory.getRestLiClient(gmsHost, gmsPort, gmsUseSSL, gmsSslProtocol);
}
return new RestliEntityClient(
restClient, new ExponentialBackoff(retryInterval), numRetries, batchGetV2Size);
restClient,
new ExponentialBackoff(retryInterval),
numRetries,
batchGetV2Size,
batchGetV2Concurrency);
}

@Bean("systemEntityClient")
Expand All @@ -53,7 +58,8 @@ public SystemEntityClient systemEntityClient(
@Value("${entityClient.retryInterval:2}") int retryInterval,
@Value("${entityClient.numRetries:3}") int numRetries,
final EntityClientCacheConfig entityClientCacheConfig,
final @Value("${entityClient.restli.get.batchSize:150}") int batchGetV2Size) {
final @Value("${entityClient.restli.get.batchSize}") int batchGetV2Size,
final @Value("${entityClient.restli.get.batchConcurrency}") int batchGetV2Concurrency) {

final Client restClient;
if (gmsUri != null) {
Expand All @@ -67,6 +73,7 @@ public SystemEntityClient systemEntityClient(
new ExponentialBackoff(retryInterval),
numRetries,
entityClientCacheConfig,
batchGetV2Size);
batchGetV2Size,
batchGetV2Concurrency);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,13 @@
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.mail.MethodNotSupportedException;
Expand All @@ -110,14 +115,17 @@ public class RestliEntityClient extends BaseClient implements EntityClient {
private static final RunsRequestBuilders RUNS_REQUEST_BUILDERS = new RunsRequestBuilders();

private final int batchGetV2Size;
private final int batchGetV2Concurrency;

public RestliEntityClient(
@Nonnull final Client restliClient,
@Nonnull final BackoffPolicy backoffPolicy,
int retryCount,
int batchGetV2Size) {
int batchGetV2Size,
int batchGetV2Concurrency) {
super(restliClient, backoffPolicy, retryCount);
this.batchGetV2Size = Math.max(1, batchGetV2Size);
this.batchGetV2Concurrency = batchGetV2Concurrency;
}

@Override
Expand Down Expand Up @@ -150,7 +158,6 @@ public Entity get(@Nonnull OperationContext opContext, @Nonnull final Urn urn)
* <p>Batch get a set of {@link Entity} objects by urn.
*
* @param urns the urns of the entities to batch get
* @param authentication the authentication to include in the request to the Metadata Service
* @throws RemoteInvocationException when unable to execute request
*/
@Override
Expand Down Expand Up @@ -216,40 +223,54 @@ public Map<Urn, EntityResponse> batchGetV2(
throws RemoteInvocationException, URISyntaxException {

Map<Urn, EntityResponse> responseMap = new HashMap<>();
ExecutorService executor = Executors.newFixedThreadPool(Math.max(1, batchGetV2Concurrency));

Iterators.partition(urns.iterator(), batchGetV2Size)
.forEachRemaining(
batch -> {
try {
final EntitiesV2BatchGetRequestBuilder requestBuilder =
ENTITIES_V2_REQUEST_BUILDERS
.batchGet()
.aspectsParam(aspectNames)
.ids(batch.stream().map(Urn::toString).collect(Collectors.toList()));

responseMap.putAll(
sendClientRequest(requestBuilder, opContext.getSessionAuthentication())
.getEntity()
.getResults()
.entrySet()
.stream()
.collect(
Collectors.toMap(
entry -> {
try {
return Urn.createFromString(entry.getKey());
} catch (URISyntaxException e) {
throw new RuntimeException(
String.format(
"Failed to bind urn string with value %s into urn",
entry.getKey()));
}
},
entry -> entry.getValue().getEntity())));
} catch (RemoteInvocationException e) {
throw new RuntimeException(e);
}
});
try {
Iterable<List<Urn>> iterable = () -> Iterators.partition(urns.iterator(), batchGetV2Size);
List<Future<Map<Urn, EntityResponse>>> futures =
StreamSupport.stream(iterable.spliterator(), false)
.map(
batch ->
executor.submit(
() -> {
try {
log.debug("Executing batchGetV2 with batch size: {}", batch.size());
final EntitiesV2BatchGetRequestBuilder requestBuilder =
ENTITIES_V2_REQUEST_BUILDERS
.batchGet()
.aspectsParam(aspectNames)
.ids(
batch.stream()
.map(Urn::toString)
.collect(Collectors.toList()));

return sendClientRequest(
requestBuilder, opContext.getSessionAuthentication())
.getEntity()
.getResults()
.entrySet()
.stream()
.collect(
Collectors.toMap(
entry -> UrnUtils.getUrn(entry.getKey()),
entry -> entry.getValue().getEntity()));
} catch (RemoteInvocationException e) {
throw new RuntimeException(e);
}
}))
.collect(Collectors.toList());

futures.forEach(
result -> {
try {
responseMap.putAll(result.get());
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
});
} finally {
executor.shutdown();
}

return responseMap;
}
Expand All @@ -260,7 +281,6 @@ public Map<Urn, EntityResponse> batchGetV2(
* @param entityName the entity type to fetch
* @param versionedUrns the urns of the entities to batch get
* @param aspectNames the aspect names to batch get
* @param authentication the authentication to include in the request to the Metadata Service
* @throws RemoteInvocationException when unable to execute request
*/
@Override
Expand All @@ -272,39 +292,62 @@ public Map<Urn, EntityResponse> batchGetVersionedV2(
@Nullable final Set<String> aspectNames) {

Map<Urn, EntityResponse> responseMap = new HashMap<>();
ExecutorService executor = Executors.newFixedThreadPool(Math.max(1, batchGetV2Concurrency));

Iterators.partition(versionedUrns.iterator(), batchGetV2Size)
.forEachRemaining(
batch -> {
final EntitiesVersionedV2BatchGetRequestBuilder requestBuilder =
ENTITIES_VERSIONED_V2_REQUEST_BUILDERS
.batchGet()
.aspectsParam(aspectNames)
.entityTypeParam(entityName)
.ids(
batch.stream()
.map(
versionedUrn ->
com.linkedin.common.urn.VersionedUrn.of(
versionedUrn.getUrn().toString(),
versionedUrn.getVersionStamp()))
.collect(Collectors.toSet()));

try {
responseMap.putAll(
sendClientRequest(requestBuilder, opContext.getSessionAuthentication())
.getEntity()
.getResults()
.entrySet()
.stream()
.collect(
Collectors.toMap(
entry -> UrnUtils.getUrn(entry.getKey().getUrn()),
entry -> entry.getValue().getEntity())));
} catch (RemoteInvocationException e) {
throw new RuntimeException(e);
}
});
try {
Iterable<List<VersionedUrn>> iterable =
() -> Iterators.partition(versionedUrns.iterator(), batchGetV2Size);
List<Future<Map<Urn, EntityResponse>>> futures =
StreamSupport.stream(iterable.spliterator(), false)
.map(
batch ->
executor.submit(
() -> {
try {
log.debug(
"Executing batchGetVersionedV2 with batch size: {}",
batch.size());
final EntitiesVersionedV2BatchGetRequestBuilder requestBuilder =
ENTITIES_VERSIONED_V2_REQUEST_BUILDERS
.batchGet()
.aspectsParam(aspectNames)
.entityTypeParam(entityName)
.ids(
batch.stream()
.map(
versionedUrn ->
com.linkedin.common.urn.VersionedUrn.of(
versionedUrn.getUrn().toString(),
versionedUrn.getVersionStamp()))
.collect(Collectors.toSet()));

return sendClientRequest(
requestBuilder, opContext.getSessionAuthentication())
.getEntity()
.getResults()
.entrySet()
.stream()
.collect(
Collectors.toMap(
entry -> UrnUtils.getUrn(entry.getKey().getUrn()),
entry -> entry.getValue().getEntity()));
} catch (RemoteInvocationException e) {
throw new RuntimeException(e);
}
}))
.collect(Collectors.toList());

futures.forEach(
result -> {
try {
responseMap.putAll(result.get());
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
});
} finally {
executor.shutdown();
}

return responseMap;
}
Expand Down Expand Up @@ -955,7 +998,6 @@ public VersionedAspect getAspectOrNull(
* @param startTimeMillis the earliest desired event time of the aspect value in milliseconds.
* @param endTimeMillis the latest desired event time of the aspect value in milliseconds.
* @param limit the maximum number of desired aspect values.
* @param authentication the actor associated with the request [internal]
* @return the list of EnvelopedAspect values satisfying the input parameters.
* @throws RemoteInvocationException on remote request error.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ public SystemRestliEntityClient(
@Nonnull final BackoffPolicy backoffPolicy,
int retryCount,
EntityClientCacheConfig cacheConfig,
int batchGetV2Size) {
super(restliClient, backoffPolicy, retryCount, batchGetV2Size);
int batchGetV2Size,
int batchGetV2Concurrency) {
super(restliClient, backoffPolicy, retryCount, batchGetV2Size, batchGetV2Concurrency);
this.operationContextMap = CacheBuilder.newBuilder().maximumSize(500).build();
this.entityClientCache = buildEntityClientCache(SystemRestliEntityClient.class, cacheConfig);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public void testZeroRetry() throws RemoteInvocationException {
when(mockRestliClient.sendRequest(any(ActionRequest.class))).thenReturn(mockFuture);

RestliEntityClient testClient =
new RestliEntityClient(mockRestliClient, new ExponentialBackoff(1), 0, 10);
new RestliEntityClient(mockRestliClient, new ExponentialBackoff(1), 0, 10, 2);
testClient.sendClientRequest(testRequestBuilder, AUTH);
// Expected 1 actual try and 0 retries
verify(mockRestliClient).sendRequest(any(ActionRequest.class));
Expand All @@ -56,7 +56,7 @@ public void testMultipleRetries() throws RemoteInvocationException {
.thenReturn(mockFuture);

RestliEntityClient testClient =
new RestliEntityClient(mockRestliClient, new ExponentialBackoff(1), 1, 10);
new RestliEntityClient(mockRestliClient, new ExponentialBackoff(1), 1, 10, 2);
testClient.sendClientRequest(testRequestBuilder, AUTH);
// Expected 1 actual try and 1 retries
verify(mockRestliClient, times(2)).sendRequest(any(ActionRequest.class));
Expand All @@ -73,7 +73,7 @@ public void testNonRetry() {
.thenThrow(new RuntimeException(new RequiredFieldNotPresentException("value")));

RestliEntityClient testClient =
new RestliEntityClient(mockRestliClient, new ExponentialBackoff(1), 1, 10);
new RestliEntityClient(mockRestliClient, new ExponentialBackoff(1), 1, 10, 2);
assertThrows(
RuntimeException.class, () -> testClient.sendClientRequest(testRequestBuilder, AUTH));
}
Expand Down
Loading

0 comments on commit 3dd1c4c

Please sign in to comment.