Skip to content

Commit

Permalink
[CALCITE-1581] UDTF like in hive
Browse files Browse the repository at this point in the history
remove unused empty line

remove test code  & add comment

add comment

check for more than one table function in select

fix test sql
  • Loading branch information
pengzhiwei committed Mar 29, 2019
1 parent 0537f27 commit 635914d
Show file tree
Hide file tree
Showing 13 changed files with 407 additions and 6 deletions.
31 changes: 27 additions & 4 deletions core/src/main/codegen/templates/Parser.jj
Original file line number Diff line number Diff line change
Expand Up @@ -1560,15 +1560,38 @@ List<SqlNode> SelectList() :
SqlNode SelectItem() :
{
SqlNode e;
final SqlIdentifier id;
SqlIdentifier id;
final List<SqlNode> ids = new ArrayList();
final Span s = span();
}
{
e = SelectExpression()
[
[ <AS> ]
id = SimpleIdentifier() {
e = SqlStdOperatorTable.AS.createCall(span().end(e), e, id);
}
(
(
id = SimpleIdentifier()
{
e = SqlStdOperatorTable.AS.createCall(s.end(e), e, id);
}
)
|
(
<LPAREN>
id = SimpleIdentifier()
{
ids.add(id);
}
( <COMMA> id = SimpleIdentifier() { ids.add(id);} )*
<RPAREN>
{
if (!this.conformance.allowSelectTableFunction()) {
throw new ParseException(RESOURCE.notAllowTableFunctionInSelect().str());
}
e = SqlStdOperatorTable.AS.createCall(s.end(e), e, new SqlNodeList(ids, s.end(e)));
}
)
)
]
{
return e;
Expand Down
10 changes: 10 additions & 0 deletions core/src/main/java/org/apache/calcite/runtime/CalciteResource.java
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,16 @@ ExInst<CalciteException> invalidTypesForComparison(String clazzName0, String op,

@BaseMessage("Not a valid input for JSON_LENGTH: ''{0}''")
ExInst<CalciteException> invalidInputForJsonLength(String value);

@BaseMessage("Table Function is not allowed in select list in current SQL conformance level")
ExInst<SqlValidatorException> notAllowTableFunctionInSelect();

@BaseMessage("''{0}'' should be a table function")
ExInst<SqlValidatorException> exceptTableFunction(String name);

@BaseMessage("Only one table function allowed in select list")
ExInst<SqlValidatorException> onlyOneTableFunctionAllowedInSelect();

}

// End CalciteResource.java
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ public boolean shouldConvertRaggedUnionTypesToVarying() {
public boolean allowExtendedTrim() {
return SqlConformanceEnum.DEFAULT.allowExtendedTrim();
}

public boolean allowSelectTableFunction() {
return SqlConformanceEnum.DEFAULT.allowSelectTableFunction();
}
}

// End SqlAbstractConformance.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ public interface SqlConformance {
@Deprecated // to be removed before 2.0
SqlConformanceEnum PRAGMATIC_2003 = SqlConformanceEnum.PRAGMATIC_2003;

SqlConformanceEnum HIVE = SqlConformanceEnum.HIVE;
/**
* Whether this dialect supports features from a wide variety of
* dialects. This is enabled for the Babel parser, disabled otherwise.
Expand Down Expand Up @@ -393,6 +394,14 @@ public interface SqlConformance {
* false otherwise.
*/
boolean allowExtendedTrim();

/**
* Whether Select can contain a table function.
* <p>For example, consider the query
* <blockquote><pre> SELECT split(col) as (f0, f1) from a </pre> </blockquote>
* @return
*/
boolean allowSelectTableFunction();
}

// End SqlConformance.java
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,11 @@ public enum SqlConformanceEnum implements SqlConformance {

/** Conformance value that instructs Calcite to use SQL semantics
* consistent with Microsoft SQL Server version 2008. */
SQL_SERVER_2008;
SQL_SERVER_2008,

/** Conformance value that instructs Calcite to use SQL semantics
* consistent with Hive version. */
HIVE;

public boolean isLiberal() {
switch (this) {
Expand Down Expand Up @@ -304,6 +308,14 @@ public boolean allowExtendedTrim() {
}
}

public boolean allowSelectTableFunction() {
switch (this) {
case HIVE:
return true;
default:
return false;
}
}
}

// End SqlConformanceEnum.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ protected SqlDelegatingConformance(SqlConformance delegate) {
return delegate.allowNiladicParentheses();
}

@Override public boolean allowSelectTableFunction() {
return delegate.allowSelectTableFunction();
}
}

// End SqlDelegatingConformance.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.apache.calcite.sql.SqlAccessEnum;
import org.apache.calcite.sql.SqlAccessType;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlAsOperator;
import org.apache.calcite.sql.SqlBasicCall;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlCallBinding;
Expand All @@ -61,6 +62,7 @@
import org.apache.calcite.sql.SqlIntervalQualifier;
import org.apache.calcite.sql.SqlJoin;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlLateralOperator;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlMatchRecognize;
import org.apache.calcite.sql.SqlMerge;
Expand All @@ -81,6 +83,7 @@
import org.apache.calcite.sql.SqlWith;
import org.apache.calcite.sql.SqlWithItem;
import org.apache.calcite.sql.fun.SqlCase;
import org.apache.calcite.sql.fun.SqlCollectionTableOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.AssignableOperandTypeChecker;
Expand Down Expand Up @@ -243,6 +246,8 @@ public class SqlValidatorImpl implements SqlValidatorWithHints {
private int nextGeneratedId;
protected final RelDataTypeFactory typeFactory;

private int nextTableFunctionNameId;

/** The type of dynamic parameters until a type is imposed on them. */
protected final RelDataType unknownType;
private final RelDataType booleanType;
Expand Down Expand Up @@ -638,6 +643,9 @@ public SqlNode validate(SqlNode topNode) {

public List<SqlMoniker> lookupHints(SqlNode topNode, SqlParserPos pos) {
SqlValidatorScope scope = new EmptyScope(this);
if (conformance.allowSelectTableFunction()) {
topNode = performHiveUdtfRewrite(topNode);
}
SqlNode outermostNode = performUnconditionalRewrites(topNode, false);
cursorSet.add(outermostNode);
if (outermostNode.isA(SqlKind.TOP_LEVEL)) {
Expand Down Expand Up @@ -919,6 +927,9 @@ public SqlNode validateParameterizedExpression(
private SqlNode validateScopedExpression(
SqlNode topNode,
SqlValidatorScope scope) {
if (conformance.allowSelectTableFunction()) {
topNode = performHiveUdtfRewrite(topNode);
}
SqlNode outermostNode = performUnconditionalRewrites(topNode, false);
cursorSet.add(outermostNode);
top = outermostNode;
Expand Down Expand Up @@ -1320,6 +1331,214 @@ protected SqlNode performUnconditionalRewrites(
return node;
}

/**
* Rewrite Hive like udtf grammar to the LATERAL TABLE.
* <p> eg. rewrite the
* "select a.id, table_func(a.id) as (f0,f1) from a"
* to
* "select a.id, _table_function_0.f0,_table_function_0.f1 from a,
* lateral table(table_func(a.id)) as _table_function_0(f0,f1)"
*
* @param node the SqlNode to rewrite
* @return the SqlNode after rewrite
*/
private SqlNode performHiveUdtfRewrite(SqlNode node) {
// Mapping of SqlSelect and it's TableFunction Info
Map<SqlSelect, TableFunctionInfo> select2TableFunctionInfos = new HashMap<>();
return performHiveUdtfRewriteInternal(node, select2TableFunctionInfos);
}

private SqlNode performHiveUdtfRewriteInternal(SqlNode current,
Map<SqlSelect, TableFunctionInfo> select2TableFunctionInfos) {
// do the rewrite for SqlSelect
if (current instanceof SqlSelect) {
SqlSelect select = (SqlSelect) current;
// rewrite select items
SqlNodeList newSelectItem =
performRewriteForSelectItem(select,
select.getSelectList(), select2TableFunctionInfos);

TableFunctionInfo tableFunctionInfo = select2TableFunctionInfos.get(select);
// if the select items contain a table function,
// join the from node with the table function.
if (tableFunctionInfo != null && select.getFrom() != null) {
SqlBasicCall joinRight = createLateralTable(tableFunctionInfo);
SqlNode newFrom = new SqlJoin(
SqlParserPos.ZERO,
select.getFrom(),
SqlLiteral.createBoolean(false, SqlParserPos.ZERO),
SqlLiteral.createSymbol(JoinType.COMMA, SqlParserPos.ZERO),
joinRight,
SqlLiteral.createSymbol(JoinConditionType.NONE, SqlParserPos.ZERO), null);
select.setSelectList(newSelectItem);
select.setFrom(newFrom);
}
}
// recursive all sub-node of the node,ensure all
// SqlSelect can be rewrite.
if (current instanceof SqlCall) {
SqlCall call = (SqlCall) current;
List<SqlNode> newOperands = new ArrayList<>();
for (int i = 0; i < call.getOperandList().size(); i++) {
newOperands.add(performHiveUdtfRewriteInternal
(call.getOperandList().get(i), select2TableFunctionInfos));
}

for (int i = 0; i < newOperands.size(); i++) {
if (newOperands.get(i) != null) {
call.setOperand(i, newOperands.get(i));
}
}
} else if (current instanceof SqlNodeList) {
SqlNodeList nodeList = (SqlNodeList) current;
List<SqlNode> newNodes = new ArrayList<>();
for (int i = 0; i < nodeList.size(); i++) {
newNodes.add(
performHiveUdtfRewriteInternal(
nodeList.get(i), select2TableFunctionInfos));
}

for (int i = 0; i < newNodes.size(); i++) {
if (newNodes.get(i) != null) {
nodeList.set(i, newNodes.get(i));
}
}
}
return current;
}

/**
* Rewrite the "select a.id table_func(a.id) as (f0,f1)" to
* "select a.id,_table_function_0.f0, _table_function_0.f1"
* @param select SqlSelect Node
* @param selectItems select items to rewrite
* @param select2TableFunctionInfos Mapping of SqlSelect and it's TableFunction
* @return new SelectItems after rewrite
*/
private SqlNodeList performRewriteForSelectItem(SqlSelect select, SqlNodeList selectItems,
Map<SqlSelect, TableFunctionInfo> select2TableFunctionInfos) {
// step1. find the table function in the select items.
for (int i = 0; i < selectItems.size(); i++) {
SqlNode selectItem = selectItems.get(i);
if (selectItem.getKind() == SqlKind.AS) {
SqlNode udtfNode = ((SqlBasicCall) selectItem).getOperands()[0];
SqlNode aliasNode = ((SqlBasicCall) selectItem).getOperands()[1];

// test if this is a "table_func() as (f0,f1)" select item.
if (udtfNode instanceof SqlBasicCall
&& ((SqlBasicCall) udtfNode).getOperator() instanceof SqlFunction
&& aliasNode instanceof SqlNodeList) {

SqlFunction function = (SqlFunction)
((SqlBasicCall) udtfNode).getOperator();
List<SqlOperator> overloads = new ArrayList<>();
opTab.lookupOperatorOverloads(function.getNameAsId(),
SqlFunctionCategory.USER_DEFINED_TABLE_FUNCTION, SqlSyntax.FUNCTION, overloads);

if (overloads.size() == 0) {
throw newValidationError(udtfNode,
RESOURCE.exceptTableFunction(function.getName()));
}
// this is a table function
if (overloads.size() == 1 && overloads.get(0)
instanceof SqlUserDefinedTableFunction) {
//Only one table function allowed in select
if (select2TableFunctionInfos.containsKey(select)) {
throw newValidationError(udtfNode, RESOURCE.onlyOneTableFunctionAllowedInSelect());
}
TableFunctionInfo tableFunctionInfo = new TableFunctionInfo();
tableFunctionInfo.node = (SqlBasicCall) udtfNode;
tableFunctionInfo.selectIndex = i;
tableFunctionInfo.fieldNames = (SqlNodeList) aliasNode;
tableFunctionInfo.tableName = "_table_function_" + nextTableFunctionNameId++;

select2TableFunctionInfos.put(select, tableFunctionInfo);
}
}
}
}
// step2. rewrite the select items
TableFunctionInfo tableFunctionInfo = select2TableFunctionInfos.get(select);
if (tableFunctionInfo != null) {
SqlNodeList newSelectItems = new SqlNodeList(SqlParserPos.ZERO);

for (int k = 0; k < tableFunctionInfo.selectIndex; k++) {
newSelectItems.add(selectItems.get(k));
}
// add "(f0,f1)" to the select list
for (int k = 0; k < tableFunctionInfo.fieldNames.size(); k++) {
SqlIdentifier field = new SqlIdentifier(
Lists.newArrayList(tableFunctionInfo.tableName,
tableFunctionInfo.fieldNames.get(k).toString()),
SqlParserPos.ZERO);
newSelectItems.add(field);
}

for (int k = tableFunctionInfo.selectIndex + 1; k < selectItems.size(); k++) {
newSelectItems.add(selectItems.get(k));
}
return newSelectItems;
}
return selectItems;
}

/**
* Create LateralTableAs node for table function
* @param info information for table function in SqlSelect
* @return
*/
private SqlBasicCall createLateralTable(TableFunctionInfo info) {
SqlCollectionTableOperator to =
new SqlCollectionTableOperator("LATERAL TABLE", SqlModality.RELATION);

SqlBasicCall tableFunctionNode = info.node;
// change the function category to USER_DEFINED_TABLE_FUNCTION
if (info.node.getOperator() instanceof SqlUnresolvedFunction) {
SqlUnresolvedFunction function = (SqlUnresolvedFunction) info.node.getOperator();
if (!function.getFunctionType().isTableFunction()) {
tableFunctionNode = new SqlBasicCall(
new SqlUnresolvedFunction(function.getNameAsId(),
function.getReturnTypeInference(),
function.getOperandTypeInference(),
function.getOperandTypeChecker(),
function.getParamTypes(),
SqlFunctionCategory.USER_DEFINED_TABLE_FUNCTION),
info.node.getOperands(),
info.node.getParserPosition());
}
}
SqlBasicCall tableCall = new SqlBasicCall(to,
new SqlNode[]{ tableFunctionNode }, SqlParserPos.ZERO);
SqlLateralOperator lateralOp = new SqlLateralOperator(SqlKind.LATERAL);
SqlBasicCall lateralCall = new SqlBasicCall(lateralOp,
new SqlNode[]{ tableCall }, SqlParserPos.ZERO);

SqlAsOperator asOp = SqlStdOperatorTable.AS;
SqlNode[] operands = new SqlNode[2 + info.fieldNames.size()];
SqlIdentifier tableName = new SqlIdentifier(info.tableName, SqlParserPos.ZERO);
operands[0] = lateralCall;
operands[1] = tableName;
for (int i = 0; i < info.fieldNames.size(); i++) {
operands[2 + i] = info.fieldNames.get(i);
}
SqlBasicCall lateralTableAs = new SqlBasicCall(asOp, operands, SqlParserPos.ZERO);
return lateralTableAs;
}

/**
* table function info
*/
private static class TableFunctionInfo {
// table function index in the select list
public int selectIndex;
// table function function node
public SqlBasicCall node;
// table function return field names
public SqlNodeList fieldNames;
// table function table name
public String tableName;
}

private SqlSelect getInnerSelect(SqlNode node) {
for (;;) {
if (node instanceof SqlSelect) {
Expand Down
Loading

0 comments on commit 635914d

Please sign in to comment.