Skip to content

Commit

Permalink
Adding highlight support in PPL with optional arguments and wildcard …
Browse files Browse the repository at this point in the history
…support in SQL and PPL.

Signed-off-by: forestmvey <[email protected]>
  • Loading branch information
forestmvey committed Oct 3, 2022
1 parent 057fa44 commit 28e98d3
Show file tree
Hide file tree
Showing 36 changed files with 711 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,6 @@ public LogicalPlan visitProject(Project node, AnalysisContext context) {
for (UnresolvedExpression expr : node.getProjectList()) {
HighlightAnalyzer highlightAnalyzer = new HighlightAnalyzer(expressionAnalyzer, child);
child = highlightAnalyzer.analyze(expr, context);

}

List<NamedExpression> namedExpressions =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.sql.analysis.symbol.Symbol;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.AggregateFunction;
import org.opensearch.sql.ast.expression.Alias;
import org.opensearch.sql.ast.expression.AllFields;
import org.opensearch.sql.ast.expression.And;
import org.opensearch.sql.ast.expression.Case;
Expand All @@ -44,7 +45,9 @@
import org.opensearch.sql.ast.expression.WindowFunction;
import org.opensearch.sql.ast.expression.Xor;
import org.opensearch.sql.common.antlr.SyntaxCheckException;
import org.opensearch.sql.common.utils.StringUtils;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.DSL;
Expand Down Expand Up @@ -91,6 +94,14 @@ public Expression analyze(UnresolvedExpression unresolved, AnalysisContext conte
return unresolved.accept(this, context);
}

@Override
public Expression visitAlias(Alias node, AnalysisContext context) {
// Only purpose for this override currently is to avoid null pointer exception when using
// '-' flag with a highlight call in a fields command.
throw new SemanticCheckException(String.format("can't resolve Symbol %s in type env",
node.getName()));
}

@Override
public Expression visitUnresolvedAttribute(UnresolvedAttribute node, AnalysisContext context) {
return visitIdentifier(node.getAttr(), context);
Expand Down Expand Up @@ -205,7 +216,7 @@ public Expression visitWindowFunction(WindowFunction node, AnalysisContext conte
}

@Override
public Expression visitHighlight(HighlightFunction node, AnalysisContext context) {
public Expression visitHighlightFunction(HighlightFunction node, AnalysisContext context) {
Expression expr = node.getHighlightField().accept(this, context);
return new HighlightExpression(expr);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ public LogicalPlan analyze(UnresolvedExpression projectItem, AnalysisContext con

@Override
public LogicalPlan visitAlias(Alias node, AnalysisContext context) {
if (!(node.getDelegated() instanceof HighlightFunction)) {
UnresolvedExpression delegated = node.getDelegated();
if (!(delegated instanceof HighlightFunction)) {
return null;
}

HighlightFunction unresolved = (HighlightFunction) node.getDelegated();
HighlightFunction unresolved = (HighlightFunction) delegated;
Expression field = expressionAnalyzer.analyze(unresolved.getHighlightField(), context);
return new LogicalHighlight(child, field);
return new LogicalHighlight(child, field, unresolved.getArguments());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ public T visitAD(AD node, C context) {
return visitChildren(node, context);
}

public T visitHighlight(HighlightFunction node, C context) {
public T visitHighlightFunction(HighlightFunction node, C context) {
return visitChildren(node, context);
}
}
5 changes: 3 additions & 2 deletions core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,9 @@ public When when(UnresolvedExpression condition, UnresolvedExpression result) {
return new When(condition, result);
}

public UnresolvedExpression highlight(UnresolvedExpression fieldName) {
return new HighlightFunction(fieldName);
public UnresolvedExpression highlight(UnresolvedExpression fieldName,
java.util.Map<String, Literal> arguments) {
return new HighlightFunction(fieldName, arguments);
}

public UnresolvedExpression window(UnresolvedExpression function,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.sql.ast.expression;

import java.util.List;
import java.util.Map;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
Expand All @@ -21,10 +22,11 @@
@ToString
public class HighlightFunction extends UnresolvedExpression {
private final UnresolvedExpression highlightField;
private final Map<String, Literal> arguments;

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
return nodeVisitor.visitHighlight(this, context);
return nodeVisitor.visitHighlightFunction(this, context);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,16 @@

package org.opensearch.sql.expression;

import java.util.LinkedHashMap;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import lombok.Getter;
import org.opensearch.sql.common.utils.StringUtils;
import org.opensearch.sql.data.model.ExprTupleValue;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.expression.env.Environment;
Expand All @@ -20,6 +26,7 @@
@Getter
public class HighlightExpression extends FunctionExpression {
private final Expression highlightField;
private final ExprType type;

/**
* HighlightExpression Constructor.
Expand All @@ -28,6 +35,8 @@ public class HighlightExpression extends FunctionExpression {
public HighlightExpression(Expression highlightField) {
super(BuiltinFunctionName.HIGHLIGHT.getName(), List.of(highlightField));
this.highlightField = highlightField;
this.type = this.highlightField.toString().contains("*")
? ExprCoreType.STRUCT : ExprCoreType.ARRAY;
}

/**
Expand All @@ -37,21 +46,57 @@ public HighlightExpression(Expression highlightField) {
*/
@Override
public ExprValue valueOf(Environment<Expression, ExprValue> valueEnv) {
String refName = "_highlight" + "." + StringUtils.unquoteText(getHighlightField().toString());
return valueEnv.resolve(DSL.ref(refName, ExprCoreType.STRING));
String refName = "_highlight";
// Not a wilcard expression
if (this.type == ExprCoreType.ARRAY) {
refName += "." + StringUtils.unquoteText(getHighlightField().toString());
}
ExprValue value = valueEnv.resolve(DSL.ref(refName, ExprCoreType.STRING));

// In the event of multiple returned highlights and wildcard being
// used in conjunction with other highlight calls, we need to ensure
// only wildcard regex matching is mapped to wildcard call.
if (this.type == ExprCoreType.STRUCT && value.type() == ExprCoreType.STRUCT) {
value = new ExprTupleValue(
new LinkedHashMap<String, ExprValue>(value.tupleValue()
.entrySet()
.stream()
.filter(s -> matchesHighlightRegex(s.getKey(),
StringUtils.unquoteText(highlightField.toString())))
.collect(Collectors.toMap(
e -> e.getKey(),
e -> e.getValue()))));
if (value.tupleValue().isEmpty()) {
value = ExprValueUtils.missingValue();
}
}

return value;
}

/**
* Get type for HighlightExpression.
* @return : String type.
* @return : Expression type.
*/
@Override
public ExprType type() {
return ExprCoreType.ARRAY;
return this.type;
}

@Override
public <T, C> T accept(ExpressionNodeVisitor<T, C> visitor, C context) {
return visitor.visitHighlight(this, context);
}

/**
* Check if field matches the wildcard pattern used in highlight query.
* @param field Highlight selected field for query
* @param pattern Wildcard regex to match field against
* @return True if field matches wildcard pattern
*/
private boolean matchesHighlightRegex(String field, String pattern) {
Pattern p = Pattern.compile(pattern.replace("*", ".*"));
Matcher matcher = p.matcher(field);
return matcher.matches();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.FunctionExpression;
import org.opensearch.sql.expression.HighlightExpression;
import org.opensearch.sql.expression.NamedArgumentExpression;
import org.opensearch.sql.expression.env.Environment;

Expand All @@ -37,15 +36,6 @@ public void register(BuiltinFunctionRepository repository) {
repository.register(match_phrase(BuiltinFunctionName.MATCH_PHRASE));
repository.register(match_phrase(BuiltinFunctionName.MATCHPHRASE));
repository.register(match_phrase_prefix());
repository.register(highlight());
}

private static FunctionResolver highlight() {
FunctionName functionName = BuiltinFunctionName.HIGHLIGHT.getName();
FunctionSignature functionSignature = new FunctionSignature(functionName, List.of(STRING));
FunctionBuilder functionBuilder = arguments -> new HighlightExpression(arguments.get(0));
return new DefaultFunctionResolver(functionName,
ImmutableMap.of(functionSignature, functionBuilder));
}

private static FunctionResolver match_bool_prefix() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,28 @@
package org.opensearch.sql.planner.logical;

import java.util.Collections;
import java.util.Map;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.expression.Expression;

@EqualsAndHashCode(callSuper = true)
@Getter
@ToString
public class LogicalHighlight extends LogicalPlan {
private final Expression highlightField;
private final Map<String, Literal> arguments;

public LogicalHighlight(LogicalPlan childPlan, Expression field) {
/**
* Constructor of LogicalHighlight.
*/
public LogicalHighlight(LogicalPlan childPlan, Expression highlightField,
Map<String, Literal> arguments) {
super(Collections.singletonList(childPlan));
highlightField = field;
this.highlightField = highlightField;
this.arguments = arguments;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ public LogicalPlan window(LogicalPlan input,
return new LogicalWindow(input, windowFunction, windowDefinition);
}

public LogicalPlan highlight(LogicalPlan input, Expression field) {
return new LogicalHighlight(input, field);
public LogicalPlan highlight(LogicalPlan input, Expression field,
Map<String, Literal> arguments) {
return new LogicalHighlight(input, field, arguments);
}

public static LogicalPlan remove(LogicalPlan input, ReferenceExpression... fields) {
Expand Down
33 changes: 30 additions & 3 deletions core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@
import static org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC;
import static org.opensearch.sql.ast.tree.Sort.SortOrder;
import static org.opensearch.sql.data.model.ExprValueUtils.integerValue;
import static org.opensearch.sql.data.type.ExprCoreType.ARRAY;
import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE;
import static org.opensearch.sql.data.type.ExprCoreType.INTEGER;
import static org.opensearch.sql.data.type.ExprCoreType.LONG;
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
import static org.opensearch.sql.data.type.ExprCoreType.STRUCT;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand Down Expand Up @@ -270,16 +272,41 @@ public void project_source() {

@Test
public void project_highlight() {
Map<String, Literal> args = new HashMap<>();
args.put("pre_tags", new Literal("<mark>", DataType.STRING));
args.put("post_tags", new Literal("</mark>", DataType.STRING));

assertAnalyzeEqual(
LogicalPlanDSL.project(
LogicalPlanDSL.highlight(LogicalPlanDSL.relation("schema", table),
DSL.literal("fieldA"), args),
DSL.named("highlight(fieldA, pre_tags='<mark>', post_tags='</mark>')",
new HighlightExpression(DSL.literal("fieldA")))
),
AstDSL.projectWithArg(
AstDSL.relation("schema"),
AstDSL.defaultFieldsArgs(),
AstDSL.alias("highlight(fieldA, pre_tags='<mark>', post_tags='</mark>')",
new HighlightFunction(AstDSL.stringLiteral("fieldA"), args))
)
);
}

@Test
public void project_highlight_wildcard() {
Map<String, Literal> args = new HashMap<>();
assertAnalyzeEqual(
LogicalPlanDSL.project(
LogicalPlanDSL.highlight(LogicalPlanDSL.relation("schema", table),
DSL.literal("fieldA")),
DSL.named("highlight(fieldA)", new HighlightExpression(DSL.literal("fieldA")))
DSL.literal("*"), args),
DSL.named("highlight(*)",
new HighlightExpression(DSL.literal("*")))
),
AstDSL.projectWithArg(
AstDSL.relation("schema"),
AstDSL.defaultFieldsArgs(),
AstDSL.alias("highlight(fieldA)", new HighlightFunction(AstDSL.stringLiteral("fieldA")))
AstDSL.alias("highlight(*)",
new HighlightFunction(AstDSL.stringLiteral("*"), args))
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import static java.util.Collections.emptyList;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.opensearch.sql.ast.dsl.AstDSL.field;
Expand All @@ -28,6 +27,7 @@

import com.google.common.collect.ImmutableMap;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -36,9 +36,11 @@
import org.opensearch.sql.analysis.symbol.Namespace;
import org.opensearch.sql.analysis.symbol.Symbol;
import org.opensearch.sql.ast.dsl.AstDSL;
import org.opensearch.sql.ast.expression.Alias;
import org.opensearch.sql.ast.expression.AllFields;
import org.opensearch.sql.ast.expression.DataType;
import org.opensearch.sql.ast.expression.HighlightFunction;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.RelevanceFieldList;
import org.opensearch.sql.ast.expression.SpanUnit;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
Expand All @@ -50,7 +52,6 @@
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.FunctionExpression;
import org.opensearch.sql.expression.HighlightExpression;
import org.opensearch.sql.expression.LiteralExpression;
import org.opensearch.sql.expression.config.ExpressionConfig;
import org.opensearch.sql.expression.window.aggregation.AggregateWindowFunction;
Expand Down Expand Up @@ -163,6 +164,17 @@ public void castAnalyzer() {
"boolean_value"), AstDSL.stringLiteral("INTERVAL"))));
}

@Test
public void highlight_throws_semantic_check_exception() {
Map<String, Literal> args = new HashMap<>();
HighlightFunction highlightFunction = new HighlightFunction(
AstDSL.stringLiteral("invalid_field"), args);
Alias alias = AstDSL.alias("highlight(invalid_field)",
highlightFunction);

assertThrows(SemanticCheckException.class, () -> analyze(alias));
}

@Test
public void case_with_default_result_type_different() {
UnresolvedExpression caseWhen = AstDSL.caseWhen(
Expand Down Expand Up @@ -592,12 +604,6 @@ public void now_as_a_function_not_cached() {
assertTrue(values.stream().noneMatch(v -> v.valueOf(null) == referenceValue));
}

@Test
void highlight() {
assertAnalyzeEqual(new HighlightExpression(DSL.literal("fieldA")),
new HighlightFunction(stringLiteral("fieldA")));
}

protected Expression analyze(UnresolvedExpression unresolvedExpression) {
return expressionAnalyzer.analyze(unresolvedExpression, analysisContext);
}
Expand Down
Loading

0 comments on commit 28e98d3

Please sign in to comment.