Skip to content

Commit

Permalink
perf(search): Improve search default performance (datahub-project#5311)
Browse files Browse the repository at this point in the history
  • Loading branch information
jjoyce0510 authored and maggiehays committed Aug 1, 2022
1 parent cb8dcd0 commit b963006
Show file tree
Hide file tree
Showing 16 changed files with 491 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public CompletableFuture<AutoCompleteMultipleResults> get(DataFetchingEnvironmen
environment);
}

// By default, autocomplete only against the set of Searchable Entity Types.
// By default, autocomplete only against the Default Set of Autocomplete entities
return AutocompleteUtils.batchGetAutocompleteResults(
AUTO_COMPLETE_ENTITY_TYPES.stream().map(_typeToEntity::get).collect(Collectors.toList()),
sanitizedQuery,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
package com.linkedin.metadata.search;

import com.linkedin.metadata.models.registry.EntityRegistry;
import com.linkedin.metadata.query.SearchFlags;
import com.linkedin.metadata.query.filter.Filter;
import com.linkedin.metadata.query.filter.SortCriterion;
import com.linkedin.metadata.search.aggregator.AllEntitiesSearchAggregator;
import com.linkedin.metadata.search.cache.AllEntitiesSearchAggregatorCache;
import com.linkedin.metadata.search.cache.CachingAllEntitiesSearchAggregator;
import com.linkedin.metadata.search.cache.EntityDocCountCache;
import com.linkedin.metadata.search.cache.EntitySearchServiceCache;
import com.linkedin.metadata.search.client.CachingEntitySearchService;
import com.linkedin.metadata.search.ranker.SearchRanker;
import java.util.List;
import java.util.Map;
Expand All @@ -16,30 +14,24 @@
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import lombok.extern.slf4j.Slf4j;
import org.springframework.cache.CacheManager;


@Slf4j
public class SearchService {
private final EntitySearchService _entitySearchService;
private final AllEntitiesSearchAggregator _aggregator;
private final SearchRanker _searchRanker;

private final CachingEntitySearchService _cachingEntitySearchService;
private final CachingAllEntitiesSearchAggregator _cachingAllEntitiesSearchAggregator;
private final EntityDocCountCache _entityDocCountCache;
private final EntitySearchServiceCache _entitySearchServiceCache;
private final AllEntitiesSearchAggregatorCache _allEntitiesSearchAggregatorCache;
private final SearchRanker _searchRanker;

public SearchService(EntityRegistry entityRegistry, EntitySearchService entitySearchService,
SearchRanker searchRanker, CacheManager cacheManager, int batchSize, boolean enableCache) {
_entitySearchService = entitySearchService;
public SearchService(
EntityDocCountCache entityDocCountCache,
CachingEntitySearchService cachingEntitySearchService,
CachingAllEntitiesSearchAggregator cachingEntitySearchAggregator,
SearchRanker searchRanker) {
_cachingEntitySearchService = cachingEntitySearchService;
_cachingAllEntitiesSearchAggregator = cachingEntitySearchAggregator;
_searchRanker = searchRanker;
_aggregator =
new AllEntitiesSearchAggregator(entityRegistry, entitySearchService, searchRanker, cacheManager, batchSize,
enableCache);
_entityDocCountCache = new EntityDocCountCache(entityRegistry, entitySearchService);
_entitySearchServiceCache = new EntitySearchServiceCache(cacheManager, entitySearchService, batchSize, enableCache);
_allEntitiesSearchAggregatorCache =
new AllEntitiesSearchAggregatorCache(cacheManager, _aggregator, batchSize, enableCache);
_entityDocCountCache = entityDocCountCache;
}

public Map<String, Long> docCountPerEntity(@Nonnull List<String> entityNames) {
Expand All @@ -65,8 +57,8 @@ public Map<String, Long> docCountPerEntity(@Nonnull List<String> entityNames) {
public SearchResult search(@Nonnull String entityName, @Nonnull String input, @Nullable Filter postFilters,
@Nullable SortCriterion sortCriterion, int from, int size, @Nullable SearchFlags searchFlags) {
SearchResult result =
_entitySearchServiceCache.getSearcher(entityName, input, postFilters, sortCriterion, searchFlags)
.getSearchResults(from, size);
_cachingEntitySearchService.search(entityName, input, postFilters, sortCriterion, from, size, searchFlags);

try {
return result.copy().setEntities(new SearchEntityArray(_searchRanker.rank(result.getEntities())));
} catch (Exception e) {
Expand Down Expand Up @@ -95,7 +87,6 @@ public SearchResult searchAcrossEntities(@Nonnull List<String> entities, @Nonnul
log.debug(String.format(
"Searching Search documents entities: %s, input: %s, postFilters: %s, sortCriterion: %s, from: %s, size: %s",
entities, input, postFilters, sortCriterion, from, size));
return _allEntitiesSearchAggregatorCache.getSearcher(entities, input, postFilters, sortCriterion, searchFlags)
.getSearchResults(from, size);
return _cachingAllEntitiesSearchAggregator.getSearchResults(entities, input, postFilters, sortCriterion, from, size, searchFlags);
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.linkedin.metadata.search.aggregator;

import com.codahale.metrics.Timer;
import com.linkedin.data.template.GetMode;
import com.linkedin.data.template.LongMap;
import com.linkedin.metadata.models.registry.EntityRegistry;
import com.linkedin.metadata.query.SearchFlags;
Expand All @@ -14,7 +15,7 @@
import com.linkedin.metadata.search.SearchEntityArray;
import com.linkedin.metadata.search.SearchResult;
import com.linkedin.metadata.search.SearchResultMetadata;
import com.linkedin.metadata.search.cache.EntitySearchServiceCache;
import com.linkedin.metadata.search.client.CachingEntitySearchService;
import com.linkedin.metadata.search.cache.EntityDocCountCache;
import com.linkedin.metadata.search.ranker.SearchRanker;
import com.linkedin.metadata.search.utils.SearchUtils;
Expand All @@ -27,29 +28,36 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import lombok.extern.slf4j.Slf4j;
import org.springframework.cache.CacheManager;

import static com.linkedin.metadata.search.utils.FilterUtils.rankFilterGroups;


@Slf4j
public class AllEntitiesSearchAggregator {

private static final int DEFAULT_MAX_AGGREGATION_VALUES = 20;

private final EntitySearchService _entitySearchService;
private final SearchRanker _searchRanker;
private final EntityDocCountCache _entityDocCountCache;

private final EntitySearchServiceCache _entitySearchServiceCache;

public AllEntitiesSearchAggregator(EntityRegistry entityRegistry, EntitySearchService entitySearchService,
SearchRanker searchRanker, CacheManager cacheManager, int batchSize, boolean enableCache) {
_entitySearchService = entitySearchService;
_searchRanker = searchRanker;
private final CachingEntitySearchService _cachingEntitySearchService;
private final int _maxAggregationValueCount;

public AllEntitiesSearchAggregator(
EntityRegistry entityRegistry,
EntitySearchService entitySearchService,
CachingEntitySearchService cachingEntitySearchService,
SearchRanker searchRanker) {
_entitySearchService = Objects.requireNonNull(entitySearchService);
_searchRanker = Objects.requireNonNull(searchRanker);
_cachingEntitySearchService = Objects.requireNonNull(cachingEntitySearchService);
_entityDocCountCache = new EntityDocCountCache(entityRegistry, entitySearchService);
_entitySearchServiceCache = new EntitySearchServiceCache(cacheManager, entitySearchService, batchSize, enableCache);
_maxAggregationValueCount = DEFAULT_MAX_AGGREGATION_VALUES; // TODO: Make this externally configurable
}

@Nonnull
Expand Down Expand Up @@ -95,10 +103,6 @@ public SearchResult search(@Nonnull List<String> entities, @Nonnull String input
Map<String, Long> numResultsPerEntity = searchResults.entrySet()
.stream()
.collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().getNumEntities().longValue()));
aggregations.put("entity", new AggregationMetadata().setName("entity")
.setDisplayName("Type")
.setAggregations(new LongMap(numResultsPerEntity))
.setFilterValues(new FilterValueArray(SearchUtil.convertToFilters(numResultsPerEntity))));

for (String entity : searchResults.keySet()) {
SearchResult result = searchResults.get(entity);
Expand All @@ -114,10 +118,19 @@ public SearchResult search(@Nonnull List<String> entities, @Nonnull String input
});
}

// Trim the aggregations / filters after merging.
Map<String, AggregationMetadata> finalAggregations = trimMergedAggregations(aggregations);

// Finally, Add a custom Entity aggregation (appears as the first filter) -- this should never be truncated
finalAggregations.put("entity", new AggregationMetadata().setName("entity")
.setDisplayName("Type")
.setAggregations(new LongMap(numResultsPerEntity))
.setFilterValues(new FilterValueArray(SearchUtil.convertToFilters(numResultsPerEntity))));

// 4. Rank results across entities
List<SearchEntity> rankedResult = _searchRanker.rank(matchedResults);
SearchResultMetadata finalMetadata =
new SearchResultMetadata().setAggregations(new AggregationMetadataArray(rankFilterGroups(aggregations)));
new SearchResultMetadata().setAggregations(new AggregationMetadataArray(rankFilterGroups(finalAggregations)));

postProcessTimer.stop();
return new SearchResult().setEntities(new SearchEntityArray(rankedResult))
Expand All @@ -143,12 +156,38 @@ private Map<String, SearchResult> getSearchResultsForEachEntity(@Nonnull List<St
// Query the entity search service for all entities asynchronously
try (Timer.Context ignored = MetricUtils.timer(this.getClass(), "searchEntities").time()) {
searchResults = ConcurrencyUtils.transformAndCollectAsync(entities, entity -> new Pair<>(entity,
_entitySearchServiceCache.getSearcher(entity, input, postFilters, sortCriterion, searchFlags)
.getSearchResults(queryFrom, querySize)))
_cachingEntitySearchService.search(entity, input, postFilters, sortCriterion, queryFrom, querySize, searchFlags)))
.stream()
.filter(pair -> pair.getValue().getNumEntities() > 0)
.collect(Collectors.toMap(Pair::getKey, Pair::getValue));
}
return searchResults;
}

/**
* Simply trims the total aggregation values that are returned to the client based on the SearchFlags which are set
*/
private Map<String, AggregationMetadata> trimMergedAggregations(Map<String, AggregationMetadata> aggregations) {
return aggregations.entrySet().stream().map(
entry -> Pair.of(entry.getKey(), new AggregationMetadata()
.setName(entry.getValue().getName())
.setDisplayName(entry.getValue().getDisplayName(GetMode.NULL))
.setAggregations(entry.getValue().getAggregations())
.setFilterValues(
trimFilterValues(entry.getValue().getFilterValues()))
)
).collect(Collectors.toMap(Pair::getFirst, Pair::getSecond));
}

/**
* Selects the top N filter values AFTER they've been fully merged.
*/
private FilterValueArray trimFilterValues(FilterValueArray original) {
if (original.size() > _maxAggregationValueCount) {
return new FilterValueArray(
original.subList(0, _maxAggregationValueCount)
);
}
return original;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.linkedin.metadata.query.SearchFlags;
import com.linkedin.metadata.query.filter.Filter;
import com.linkedin.metadata.query.filter.SortCriterion;
import com.linkedin.metadata.search.SearchResult;
import com.linkedin.metadata.search.aggregator.AllEntitiesSearchAggregator;
import java.util.List;
import javax.annotation.Nonnull;
Expand All @@ -13,19 +14,19 @@


@RequiredArgsConstructor
public class AllEntitiesSearchAggregatorCache {
public class CachingAllEntitiesSearchAggregator {
private static final String ALL_ENTITIES_SEARCH_AGGREGATOR_CACHE_NAME = "allEntitiesSearchAggregator";

private final CacheManager cacheManager;
private final AllEntitiesSearchAggregator aggregator;
private final int batchSize;
private final boolean enableCache;

public CacheableSearcher<?> getSearcher(List<String> entities, @Nonnull String input, @Nullable Filter postFilters,
@Nullable SortCriterion sortCriterion, @Nullable SearchFlags searchFlags) {
public SearchResult getSearchResults(List<String> entities, @Nonnull String input, @Nullable Filter postFilters,
@Nullable SortCriterion sortCriterion, int from, int size, @Nullable SearchFlags searchFlags) {
return new CacheableSearcher<>(cacheManager.getCache(ALL_ENTITIES_SEARCH_AGGREGATOR_CACHE_NAME), batchSize,
querySize -> aggregator.search(entities, input, postFilters, sortCriterion, querySize.getFrom(),
querySize.getSize(), searchFlags),
querySize -> Quintet.with(entities, input, postFilters, sortCriterion, querySize), searchFlags, enableCache);
querySize -> Quintet.with(entities, input, postFilters, sortCriterion, querySize), searchFlags, enableCache).getSearchResults(from, size);
}
}

This file was deleted.

Loading

0 comments on commit b963006

Please sign in to comment.