Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(search): Improve search default performance #5311

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public CompletableFuture<AutoCompleteMultipleResults> get(DataFetchingEnvironmen
environment);
}

// By default, autocomplete only against the set of Searchable Entity Types.
// By default, autocomplete only against the Default Set of Autocomplete entities
return AutocompleteUtils.batchGetAutocompleteResults(
AUTO_COMPLETE_ENTITY_TYPES.stream().map(_typeToEntity::get).collect(Collectors.toList()),
sanitizedQuery,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
package com.linkedin.metadata.search;

import com.linkedin.metadata.models.registry.EntityRegistry;
import com.linkedin.metadata.query.SearchFlags;
import com.linkedin.metadata.query.filter.Filter;
import com.linkedin.metadata.query.filter.SortCriterion;
import com.linkedin.metadata.search.aggregator.AllEntitiesSearchAggregator;
import com.linkedin.metadata.search.cache.AllEntitiesSearchAggregatorCache;
import com.linkedin.metadata.search.cache.CachingAllEntitiesSearchAggregator;
import com.linkedin.metadata.search.cache.EntityDocCountCache;
import com.linkedin.metadata.search.cache.EntitySearchServiceCache;
import com.linkedin.metadata.search.client.CachingEntitySearchService;
import com.linkedin.metadata.search.ranker.SearchRanker;
import java.util.List;
import java.util.Map;
Expand All @@ -16,30 +14,24 @@
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import lombok.extern.slf4j.Slf4j;
import org.springframework.cache.CacheManager;


@Slf4j
public class SearchService {
private final EntitySearchService _entitySearchService;
private final AllEntitiesSearchAggregator _aggregator;
private final SearchRanker _searchRanker;

private final CachingEntitySearchService _cachingEntitySearchService;
private final CachingAllEntitiesSearchAggregator _cachingAllEntitiesSearchAggregator;
private final EntityDocCountCache _entityDocCountCache;
private final EntitySearchServiceCache _entitySearchServiceCache;
private final AllEntitiesSearchAggregatorCache _allEntitiesSearchAggregatorCache;
private final SearchRanker _searchRanker;

public SearchService(EntityRegistry entityRegistry, EntitySearchService entitySearchService,
SearchRanker searchRanker, CacheManager cacheManager, int batchSize, boolean enableCache) {
_entitySearchService = entitySearchService;
public SearchService(
EntityDocCountCache entityDocCountCache,
CachingEntitySearchService cachingEntitySearchService,
CachingAllEntitiesSearchAggregator cachingEntitySearchAggregator,
SearchRanker searchRanker) {
_cachingEntitySearchService = cachingEntitySearchService;
_cachingAllEntitiesSearchAggregator = cachingEntitySearchAggregator;
_searchRanker = searchRanker;
_aggregator =
new AllEntitiesSearchAggregator(entityRegistry, entitySearchService, searchRanker, cacheManager, batchSize,
enableCache);
_entityDocCountCache = new EntityDocCountCache(entityRegistry, entitySearchService);
_entitySearchServiceCache = new EntitySearchServiceCache(cacheManager, entitySearchService, batchSize, enableCache);
_allEntitiesSearchAggregatorCache =
new AllEntitiesSearchAggregatorCache(cacheManager, _aggregator, batchSize, enableCache);
_entityDocCountCache = entityDocCountCache;
}

public Map<String, Long> docCountPerEntity(@Nonnull List<String> entityNames) {
Expand All @@ -65,8 +57,8 @@ public Map<String, Long> docCountPerEntity(@Nonnull List<String> entityNames) {
public SearchResult search(@Nonnull String entityName, @Nonnull String input, @Nullable Filter postFilters,
@Nullable SortCriterion sortCriterion, int from, int size, @Nullable SearchFlags searchFlags) {
SearchResult result =
_entitySearchServiceCache.getSearcher(entityName, input, postFilters, sortCriterion, searchFlags)
.getSearchResults(from, size);
_cachingEntitySearchService.search(entityName, input, postFilters, sortCriterion, from, size, searchFlags);

try {
return result.copy().setEntities(new SearchEntityArray(_searchRanker.rank(result.getEntities())));
} catch (Exception e) {
Expand Down Expand Up @@ -95,7 +87,6 @@ public SearchResult searchAcrossEntities(@Nonnull List<String> entities, @Nonnul
log.debug(String.format(
"Searching Search documents entities: %s, input: %s, postFilters: %s, sortCriterion: %s, from: %s, size: %s",
entities, input, postFilters, sortCriterion, from, size));
return _allEntitiesSearchAggregatorCache.getSearcher(entities, input, postFilters, sortCriterion, searchFlags)
.getSearchResults(from, size);
return _cachingAllEntitiesSearchAggregator.getSearchResults(entities, input, postFilters, sortCriterion, from, size, searchFlags);
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.linkedin.metadata.search.aggregator;

import com.codahale.metrics.Timer;
import com.linkedin.data.template.GetMode;
import com.linkedin.data.template.LongMap;
import com.linkedin.metadata.models.registry.EntityRegistry;
import com.linkedin.metadata.query.SearchFlags;
Expand All @@ -14,7 +15,7 @@
import com.linkedin.metadata.search.SearchEntityArray;
import com.linkedin.metadata.search.SearchResult;
import com.linkedin.metadata.search.SearchResultMetadata;
import com.linkedin.metadata.search.cache.EntitySearchServiceCache;
import com.linkedin.metadata.search.client.CachingEntitySearchService;
import com.linkedin.metadata.search.cache.EntityDocCountCache;
import com.linkedin.metadata.search.ranker.SearchRanker;
import com.linkedin.metadata.search.utils.SearchUtils;
Expand All @@ -27,29 +28,36 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import lombok.extern.slf4j.Slf4j;
import org.springframework.cache.CacheManager;

import static com.linkedin.metadata.search.utils.FilterUtils.rankFilterGroups;


@Slf4j
public class AllEntitiesSearchAggregator {

private static final int DEFAULT_MAX_AGGREGATION_VALUES = 20;

private final EntitySearchService _entitySearchService;
private final SearchRanker _searchRanker;
private final EntityDocCountCache _entityDocCountCache;

private final EntitySearchServiceCache _entitySearchServiceCache;

public AllEntitiesSearchAggregator(EntityRegistry entityRegistry, EntitySearchService entitySearchService,
SearchRanker searchRanker, CacheManager cacheManager, int batchSize, boolean enableCache) {
_entitySearchService = entitySearchService;
_searchRanker = searchRanker;
private final CachingEntitySearchService _cachingEntitySearchService;
private final int _maxAggregationValueCount;

public AllEntitiesSearchAggregator(
EntityRegistry entityRegistry,
EntitySearchService entitySearchService,
CachingEntitySearchService cachingEntitySearchService,
SearchRanker searchRanker) {
_entitySearchService = Objects.requireNonNull(entitySearchService);
_searchRanker = Objects.requireNonNull(searchRanker);
_cachingEntitySearchService = Objects.requireNonNull(cachingEntitySearchService);
_entityDocCountCache = new EntityDocCountCache(entityRegistry, entitySearchService);
_entitySearchServiceCache = new EntitySearchServiceCache(cacheManager, entitySearchService, batchSize, enableCache);
_maxAggregationValueCount = DEFAULT_MAX_AGGREGATION_VALUES; // TODO: Make this externally configurable
}

@Nonnull
Expand Down Expand Up @@ -95,10 +103,6 @@ public SearchResult search(@Nonnull List<String> entities, @Nonnull String input
Map<String, Long> numResultsPerEntity = searchResults.entrySet()
.stream()
.collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().getNumEntities().longValue()));
aggregations.put("entity", new AggregationMetadata().setName("entity")
.setDisplayName("Type")
.setAggregations(new LongMap(numResultsPerEntity))
.setFilterValues(new FilterValueArray(SearchUtil.convertToFilters(numResultsPerEntity))));

for (String entity : searchResults.keySet()) {
SearchResult result = searchResults.get(entity);
Expand All @@ -114,10 +118,19 @@ public SearchResult search(@Nonnull List<String> entities, @Nonnull String input
});
}

// Trim the aggregations / filters after merging.
Map<String, AggregationMetadata> finalAggregations = trimMergedAggregations(aggregations);

// Finally, Add a custom Entity aggregation (appears as the first filter) -- this should never be truncated
finalAggregations.put("entity", new AggregationMetadata().setName("entity")
.setDisplayName("Type")
.setAggregations(new LongMap(numResultsPerEntity))
.setFilterValues(new FilterValueArray(SearchUtil.convertToFilters(numResultsPerEntity))));

// 4. Rank results across entities
List<SearchEntity> rankedResult = _searchRanker.rank(matchedResults);
SearchResultMetadata finalMetadata =
new SearchResultMetadata().setAggregations(new AggregationMetadataArray(rankFilterGroups(aggregations)));
new SearchResultMetadata().setAggregations(new AggregationMetadataArray(rankFilterGroups(finalAggregations)));

postProcessTimer.stop();
return new SearchResult().setEntities(new SearchEntityArray(rankedResult))
Expand All @@ -143,12 +156,38 @@ private Map<String, SearchResult> getSearchResultsForEachEntity(@Nonnull List<St
// Query the entity search service for all entities asynchronously
try (Timer.Context ignored = MetricUtils.timer(this.getClass(), "searchEntities").time()) {
searchResults = ConcurrencyUtils.transformAndCollectAsync(entities, entity -> new Pair<>(entity,
_entitySearchServiceCache.getSearcher(entity, input, postFilters, sortCriterion, searchFlags)
.getSearchResults(queryFrom, querySize)))
_cachingEntitySearchService.search(entity, input, postFilters, sortCriterion, queryFrom, querySize, searchFlags)))
.stream()
.filter(pair -> pair.getValue().getNumEntities() > 0)
.collect(Collectors.toMap(Pair::getKey, Pair::getValue));
}
return searchResults;
}

/**
* Simply trims the total aggregation values that are returned to the client based on the SearchFlags which are set
*/
private Map<String, AggregationMetadata> trimMergedAggregations(Map<String, AggregationMetadata> aggregations) {
return aggregations.entrySet().stream().map(
entry -> Pair.of(entry.getKey(), new AggregationMetadata()
.setName(entry.getValue().getName())
.setDisplayName(entry.getValue().getDisplayName(GetMode.NULL))
.setAggregations(entry.getValue().getAggregations())
.setFilterValues(
trimFilterValues(entry.getValue().getFilterValues()))
)
).collect(Collectors.toMap(Pair::getFirst, Pair::getSecond));
}

/**
* Selects the top N filter values AFTER they've been fully merged.
*/
private FilterValueArray trimFilterValues(FilterValueArray original) {
if (original.size() > _maxAggregationValueCount) {
return new FilterValueArray(
original.subList(0, _maxAggregationValueCount)
);
}
return original;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.linkedin.metadata.query.SearchFlags;
import com.linkedin.metadata.query.filter.Filter;
import com.linkedin.metadata.query.filter.SortCriterion;
import com.linkedin.metadata.search.SearchResult;
import com.linkedin.metadata.search.aggregator.AllEntitiesSearchAggregator;
import java.util.List;
import javax.annotation.Nonnull;
Expand All @@ -13,19 +14,19 @@


@RequiredArgsConstructor
public class AllEntitiesSearchAggregatorCache {
public class CachingAllEntitiesSearchAggregator {
private static final String ALL_ENTITIES_SEARCH_AGGREGATOR_CACHE_NAME = "allEntitiesSearchAggregator";

private final CacheManager cacheManager;
private final AllEntitiesSearchAggregator aggregator;
private final int batchSize;
private final boolean enableCache;

public CacheableSearcher<?> getSearcher(List<String> entities, @Nonnull String input, @Nullable Filter postFilters,
@Nullable SortCriterion sortCriterion, @Nullable SearchFlags searchFlags) {
public SearchResult getSearchResults(List<String> entities, @Nonnull String input, @Nullable Filter postFilters,
@Nullable SortCriterion sortCriterion, int from, int size, @Nullable SearchFlags searchFlags) {
return new CacheableSearcher<>(cacheManager.getCache(ALL_ENTITIES_SEARCH_AGGREGATOR_CACHE_NAME), batchSize,
querySize -> aggregator.search(entities, input, postFilters, sortCriterion, querySize.getFrom(),
querySize.getSize(), searchFlags),
querySize -> Quintet.with(entities, input, postFilters, sortCriterion, querySize), searchFlags, enableCache);
querySize -> Quintet.with(entities, input, postFilters, sortCriterion, querySize), searchFlags, enableCache).getSearchResults(from, size);
}
}

This file was deleted.

Loading