Skip to content

Commit

Permalink
Validate field and fields parameters in relevance search functions (#…
Browse files Browse the repository at this point in the history
…1067)

Change relevance functions that query fields to throw a SemanticCheckException when a field is queried that does not exist.

Signed-off-by: forestmvey <[email protected]>
  • Loading branch information
forestmvey authored Dec 20, 2022
1 parent 94b6bec commit d03c176
Show file tree
Hide file tree
Showing 18 changed files with 259 additions and 153 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@

package org.opensearch.sql.expression.function;

import static org.opensearch.sql.data.type.ExprCoreType.STRING;
import static org.opensearch.sql.data.type.ExprCoreType.STRUCT;

import com.google.common.collect.ImmutableMap;
import java.util.List;
import java.util.stream.Collectors;
import lombok.experimental.UtilityClass;
Expand Down Expand Up @@ -48,46 +44,46 @@ public void register(BuiltinFunctionRepository repository) {

private static FunctionResolver match_bool_prefix() {
FunctionName name = BuiltinFunctionName.MATCH_BOOL_PREFIX.getName();
return new RelevanceFunctionResolver(name, STRING);
return new RelevanceFunctionResolver(name);
}

private static FunctionResolver match(BuiltinFunctionName match) {
FunctionName funcName = match.getName();
return new RelevanceFunctionResolver(funcName, STRING);
return new RelevanceFunctionResolver(funcName);
}

private static FunctionResolver match_phrase_prefix() {
FunctionName funcName = BuiltinFunctionName.MATCH_PHRASE_PREFIX.getName();
return new RelevanceFunctionResolver(funcName, STRING);
return new RelevanceFunctionResolver(funcName);
}

private static FunctionResolver match_phrase(BuiltinFunctionName matchPhrase) {
FunctionName funcName = matchPhrase.getName();
return new RelevanceFunctionResolver(funcName, STRING);
return new RelevanceFunctionResolver(funcName);
}

private static FunctionResolver multi_match(BuiltinFunctionName multiMatchName) {
return new RelevanceFunctionResolver(multiMatchName.getName(), STRUCT);
return new RelevanceFunctionResolver(multiMatchName.getName());
}

private static FunctionResolver simple_query_string() {
FunctionName funcName = BuiltinFunctionName.SIMPLE_QUERY_STRING.getName();
return new RelevanceFunctionResolver(funcName, STRUCT);
return new RelevanceFunctionResolver(funcName);
}

private static FunctionResolver query() {
FunctionName funcName = BuiltinFunctionName.QUERY.getName();
return new RelevanceFunctionResolver(funcName, STRING);
return new RelevanceFunctionResolver(funcName);
}

private static FunctionResolver query_string() {
FunctionName funcName = BuiltinFunctionName.QUERY_STRING.getName();
return new RelevanceFunctionResolver(funcName, STRUCT);
return new RelevanceFunctionResolver(funcName);
}

private static FunctionResolver wildcard_query(BuiltinFunctionName wildcardQuery) {
FunctionName funcName = wildcardQuery.getName();
return new RelevanceFunctionResolver(funcName, STRING);
return new RelevanceFunctionResolver(funcName);
}

public static class OpenSearchFunction extends FunctionExpression {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,13 @@ public class RelevanceFunctionResolver
@Getter
private final FunctionName functionName;

@Getter
private final ExprType declaredFirstParamType;

@Override
public Pair<FunctionSignature, FunctionBuilder> resolve(FunctionSignature unresolvedSignature) {
if (!unresolvedSignature.getFunctionName().equals(functionName)) {
throw new SemanticCheckException(String.format("Expected '%s' but got '%s'",
functionName.getFunctionName(), unresolvedSignature.getFunctionName().getFunctionName()));
}
List<ExprType> paramTypes = unresolvedSignature.getParamTypeList();
ExprType providedFirstParamType = paramTypes.get(0);

// Check if the first parameter is of the specified type.
if (!declaredFirstParamType.equals(providedFirstParamType)) {
throw new SemanticCheckException(
getWrongParameterErrorMessage(0, providedFirstParamType, declaredFirstParamType));
}

// Check if all but the first parameter are of type STRING.
for (int i = 1; i < paramTypes.size(); i++) {
ExprType paramType = paramTypes.get(i);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,10 +374,10 @@ public void named_non_parse_expression() {
void match_bool_prefix_expression() {
assertAnalyzeEqual(
DSL.match_bool_prefix(
DSL.namedArgument("field", DSL.literal("fieldA")),
DSL.namedArgument("field", DSL.literal("field_value1")),
DSL.namedArgument("query", DSL.literal("sample query"))),
AstDSL.function("match_bool_prefix",
AstDSL.unresolvedArg("field", stringLiteral("fieldA")),
AstDSL.unresolvedArg("field", stringLiteral("field_value1")),
AstDSL.unresolvedArg("query", stringLiteral("sample query"))));
}

Expand Down Expand Up @@ -418,11 +418,11 @@ void multi_match_expression() {
DSL.multi_match(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"field", ExprValueUtils.floatValue(1.F)))))),
"field_value1", ExprValueUtils.floatValue(1.F)))))),
DSL.namedArgument("query", DSL.literal("sample query"))),
AstDSL.function("multi_match",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of(
"field", 1.F))),
"field_value1", 1.F))),
AstDSL.unresolvedArg("query", stringLiteral("sample query"))));
}

Expand All @@ -432,12 +432,12 @@ void multi_match_expression_with_params() {
DSL.multi_match(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"field", ExprValueUtils.floatValue(1.F)))))),
"field_value1", ExprValueUtils.floatValue(1.F)))))),
DSL.namedArgument("query", DSL.literal("sample query")),
DSL.namedArgument("analyzer", DSL.literal("keyword"))),
AstDSL.function("multi_match",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of(
"field", 1.F))),
"field_value1", 1.F))),
AstDSL.unresolvedArg("query", stringLiteral("sample query")),
AstDSL.unresolvedArg("analyzer", stringLiteral("keyword"))));
}
Expand All @@ -448,12 +448,12 @@ void multi_match_expression_two_fields() {
DSL.multi_match(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"field1", ExprValueUtils.floatValue(1.F),
"field2", ExprValueUtils.floatValue(.3F)))))),
"field_value1", ExprValueUtils.floatValue(1.F),
"field_value2", ExprValueUtils.floatValue(.3F)))))),
DSL.namedArgument("query", DSL.literal("sample query"))),
AstDSL.function("multi_match",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of(
"field1", 1.F, "field2", .3F))),
"field_value1", 1.F, "field_value2", .3F))),
AstDSL.unresolvedArg("query", stringLiteral("sample query"))));
}

Expand All @@ -463,11 +463,11 @@ void simple_query_string_expression() {
DSL.simple_query_string(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"field", ExprValueUtils.floatValue(1.F)))))),
"field_value1", ExprValueUtils.floatValue(1.F)))))),
DSL.namedArgument("query", DSL.literal("sample query"))),
AstDSL.function("simple_query_string",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of(
"field", 1.F))),
"field_value1", 1.F))),
AstDSL.unresolvedArg("query", stringLiteral("sample query"))));
}

Expand All @@ -477,12 +477,12 @@ void simple_query_string_expression_with_params() {
DSL.simple_query_string(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"field", ExprValueUtils.floatValue(1.F)))))),
"field_value1", ExprValueUtils.floatValue(1.F)))))),
DSL.namedArgument("query", DSL.literal("sample query")),
DSL.namedArgument("analyzer", DSL.literal("keyword"))),
AstDSL.function("simple_query_string",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of(
"field", 1.F))),
"field_value1", 1.F))),
AstDSL.unresolvedArg("query", stringLiteral("sample query")),
AstDSL.unresolvedArg("analyzer", stringLiteral("keyword"))));
}
Expand All @@ -493,12 +493,12 @@ void simple_query_string_expression_two_fields() {
DSL.simple_query_string(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"field1", ExprValueUtils.floatValue(1.F),
"field2", ExprValueUtils.floatValue(.3F)))))),
"field_value1", ExprValueUtils.floatValue(1.F),
"field_value2", ExprValueUtils.floatValue(.3F)))))),
DSL.namedArgument("query", DSL.literal("sample query"))),
AstDSL.function("simple_query_string",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of(
"field1", 1.F, "field2", .3F))),
"field_value1", 1.F, "field_value2", .3F))),
AstDSL.unresolvedArg("query", stringLiteral("sample query"))));
}

Expand All @@ -517,11 +517,11 @@ void query_string_expression() {
DSL.query_string(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"field", ExprValueUtils.floatValue(1.F)))))),
"field_value1", ExprValueUtils.floatValue(1.F)))))),
DSL.namedArgument("query", DSL.literal("query_value"))),
AstDSL.function("query_string",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of(
"field", 1.F))),
"field_value1", 1.F))),
AstDSL.unresolvedArg("query", stringLiteral("query_value"))));
}

Expand All @@ -531,12 +531,12 @@ void query_string_expression_with_params() {
DSL.query_string(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"field", ExprValueUtils.floatValue(1.F)))))),
"field_value1", ExprValueUtils.floatValue(1.F)))))),
DSL.namedArgument("query", DSL.literal("query_value")),
DSL.namedArgument("escape", DSL.literal("false"))),
AstDSL.function("query_string",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of(
"field", 1.F))),
"field_value1", 1.F))),
AstDSL.unresolvedArg("query", stringLiteral("query_value")),
AstDSL.unresolvedArg("escape", stringLiteral("false"))));
}
Expand All @@ -547,12 +547,12 @@ void query_string_expression_two_fields() {
DSL.query_string(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"field1", ExprValueUtils.floatValue(1.F),
"field2", ExprValueUtils.floatValue(.3F)))))),
"field_value1", ExprValueUtils.floatValue(1.F),
"field_value2", ExprValueUtils.floatValue(.3F)))))),
DSL.namedArgument("query", DSL.literal("query_value"))),
AstDSL.function("query_string",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of(
"field1", 1.F, "field2", .3F))),
"field_value1", 1.F, "field_value2", .3F))),
AstDSL.unresolvedArg("query", stringLiteral("query_value"))));
}

Expand Down Expand Up @@ -588,7 +588,7 @@ void wildcard_query_expression_all_params() {
public void match_phrase_prefix_all_params() {
assertAnalyzeEqual(
DSL.match_phrase_prefix(
DSL.namedArgument("field", "test"),
DSL.namedArgument("field", "field_value1"),
DSL.namedArgument("query", "search query"),
DSL.namedArgument("slop", "3"),
DSL.namedArgument("boost", "1.5"),
Expand All @@ -597,7 +597,7 @@ public void match_phrase_prefix_all_params() {
DSL.namedArgument("zero_terms_query", "NONE")
),
AstDSL.function("match_phrase_prefix",
unresolvedArg("field", stringLiteral("test")),
unresolvedArg("field", stringLiteral("field_value1")),
unresolvedArg("query", stringLiteral("search query")),
unresolvedArg("slop", stringLiteral("3")),
unresolvedArg("boost", stringLiteral("1.5")),
Expand Down
2 changes: 2 additions & 0 deletions core/src/test/java/org/opensearch/sql/config/TestConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ public class TestConfig {
.put("struct_value", ExprCoreType.STRUCT)
.put("array_value", ExprCoreType.ARRAY)
.put("timestamp_value", ExprCoreType.TIMESTAMP)
.put("field_value1", ExprCoreType.STRING)
.put("field_value2", ExprCoreType.STRING)
.build();

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class RelevanceFunctionResolverTest {

@BeforeEach
void setUp() {
resolver = new RelevanceFunctionResolver(sampleFuncName, STRING);
resolver = new RelevanceFunctionResolver(sampleFuncName);
}

@Test
Expand All @@ -44,15 +44,6 @@ void resolve_invalid_name_test() {
exception.getMessage());
}

@Test
void resolve_invalid_first_param_type_test() {
var sig = new FunctionSignature(sampleFuncName, List.of(INTEGER));
Exception exception = assertThrows(SemanticCheckException.class,
() -> resolver.resolve(sig));
assertEquals("Expected type STRING instead of INTEGER for parameter #1",
exception.getMessage());
}

@Test
void resolve_invalid_third_param_type_test() {
var sig = new FunctionSignature(sampleFuncName, List.of(STRING, STRING, INTEGER, STRING));
Expand Down
37 changes: 37 additions & 0 deletions integ-test/src/test/java/org/opensearch/sql/sql/MatchIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.json.JSONObject;
import org.junit.Test;
import org.opensearch.sql.legacy.SQLIntegTestCase;
import org.opensearch.sql.legacy.utils.StringUtils;

public class MatchIT extends SQLIntegTestCase {
@Override
Expand All @@ -36,6 +37,42 @@ public void match_in_having() throws IOException {
verifyDataRows(result, rows("Bates"));
}

@Test
public void missing_field_test() {
String query = StringUtils.format("SELECT * FROM %s WHERE match(invalid, 'Bates')", TEST_INDEX_ACCOUNT);
final RuntimeException exception =
expectThrows(RuntimeException.class, () -> executeJdbcRequest(query));

assertTrue(exception.getMessage()
.contains("can't resolve Symbol(namespace=FIELD_NAME, name=invalid) in type env"));

assertTrue(exception.getMessage().contains("SemanticCheckException"));
}

@Test
public void missing_quoted_field_test() {
String query = StringUtils.format("SELECT * FROM %s WHERE match('invalid', 'Bates')", TEST_INDEX_ACCOUNT);
final RuntimeException exception =
expectThrows(RuntimeException.class, () -> executeJdbcRequest(query));

assertTrue(exception.getMessage()
.contains("can't resolve Symbol(namespace=FIELD_NAME, name=invalid) in type env"));

assertTrue(exception.getMessage().contains("SemanticCheckException"));
}

@Test
public void missing_backtick_field_test() {
String query = StringUtils.format("SELECT * FROM %s WHERE match(`invalid`, 'Bates')", TEST_INDEX_ACCOUNT);
final RuntimeException exception =
expectThrows(RuntimeException.class, () -> executeJdbcRequest(query));

assertTrue(exception.getMessage()
.contains("can't resolve Symbol(namespace=FIELD_NAME, name=invalid) in type env"));

assertTrue(exception.getMessage().contains("SemanticCheckException"));
}

@Test
public void matchquery_in_where() throws IOException {
JSONObject result = executeJdbcRequest("SELECT firstname FROM " + TEST_INDEX_ACCOUNT + " WHERE matchquery(lastname, 'Bates')");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.NamedArgumentExpression;
import org.opensearch.sql.expression.ReferenceExpression;

/**
* Base class to represent builder class for relevance queries like match_query, match_bool_prefix,
Expand All @@ -36,7 +37,7 @@ protected T createQueryBuilder(List<NamedArgumentExpression> arguments) {
.orElseThrow(() -> new SemanticCheckException("'query' parameter is missing"));

return createBuilder(
field.getValue().valueOf().stringValue(),
((ReferenceExpression)field.getValue()).getAttr(),
query.getValue().valueOf().stringValue());
}

Expand Down
Loading

0 comments on commit d03c176

Please sign in to comment.