From bd1b5dfe25f344afb1866d3aacd7807381a3903d Mon Sep 17 00:00:00 2001 From: "yuhang.wang" Date: Fri, 14 Jun 2024 15:50:36 +0800 Subject: [PATCH 1/3] support hive create function: Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL#LanguageManualDDL-CreateFunction --- sql_metadata/keywords_lists.py | 1 + sql_metadata/parser.py | 17 ++++++++++++++++- test/test_query_type.py | 19 +++++++++++++++++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/sql_metadata/keywords_lists.py b/sql_metadata/keywords_lists.py index c7d31183..8521c296 100644 --- a/sql_metadata/keywords_lists.py +++ b/sql_metadata/keywords_lists.py @@ -108,6 +108,7 @@ class TokenType(str, Enum): "CREATETABLE": QueryType.CREATE, "ALTERTABLE": QueryType.ALTER, "DROPTABLE": QueryType.DROP, + "CREATEFUNCTION": QueryType.CREATE, } # all the keywords we care for - rest is ignored in assigning diff --git a/sql_metadata/parser.py b/sql_metadata/parser.py index 1b88748e..cf704878 100644 --- a/sql_metadata/parser.py +++ b/sql_metadata/parser.py @@ -85,6 +85,19 @@ def query(self) -> str: """ return self._query.replace("\n", " ").replace(" ", " ") + @staticmethod + def get_switch_by_create_query(tokens: List[SQLToken], index: int) -> str: + switch = tokens[index].normalized + tokens[index + 1].normalized + + # Hive CREATE FUNCTION + if any( + index + i < len(tokens) and tokens[index + i].normalized == "FUNCTION" + for i in (1, 2) + ): + switch = "CREATEFUNCTION" + + return switch + @property def query_type(self) -> str: """ @@ -114,7 +127,9 @@ def query_type(self) -> str: ) .position ) - if tokens[index].normalized in ["CREATE", "ALTER", "DROP"]: + if tokens[index].normalized == "CREATE": + switch = self.get_switch_by_create_query(tokens, index) + elif tokens[index].normalized in ("ALTER", "DROP"): switch = tokens[index].normalized + tokens[index + 1].normalized else: switch = tokens[index].normalized diff --git a/test/test_query_type.py b/test/test_query_type.py index 44b8f333..8e1a92e3 100644 --- a/test/test_query_type.py +++ b/test/test_query_type.py @@ -93,3 +93,22 @@ def test_multiple_redundant_parentheses_create(): """ parser = Parser(query) assert parser.query_type == QueryType.CREATE + + +def test_hive_create_function(): + query = """ + CREATE FUNCTION simple_udf AS 'com.example.hive.udf.SimpleUDF' + USING JAR 'hdfs:///user/hive/udfs/simple-udf.jar' + WITH SERDEPROPERTIES ( + "hive.udf.param1"="value1", + "hive.udf.param2"="value2" + ); + """ + parser = Parser(query) + assert parser.query_type == QueryType.CREATE + + query = """ + CREATE TEMPORARY FUNCTION myudf AS 'com.udf.myudf'; + """ + parser = Parser(query) + assert parser.query_type == QueryType.CREATE From 12c3628dbc6a78afef02eff3376649e519d0d8cc Mon Sep 17 00:00:00 2001 From: "yuhang.wang" Date: Fri, 14 Jun 2024 16:17:54 +0800 Subject: [PATCH 2/3] Pylint --- sql_metadata/parser.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/sql_metadata/parser.py b/sql_metadata/parser.py index cf704878..dd4457e8 100644 --- a/sql_metadata/parser.py +++ b/sql_metadata/parser.py @@ -85,19 +85,6 @@ def query(self) -> str: """ return self._query.replace("\n", " ").replace(" ", " ") - @staticmethod - def get_switch_by_create_query(tokens: List[SQLToken], index: int) -> str: - switch = tokens[index].normalized + tokens[index + 1].normalized - - # Hive CREATE FUNCTION - if any( - index + i < len(tokens) and tokens[index + i].normalized == "FUNCTION" - for i in (1, 2) - ): - switch = "CREATEFUNCTION" - - return switch - @property def query_type(self) -> str: """ @@ -1094,3 +1081,19 @@ def _flatten_sqlparse(self): yield tok else: yield token + + @staticmethod + def _get_switch_by_create_query(tokens: List[SQLToken], index: int) -> str: + """ + Return the switch that creates query type. + """ + switch = tokens[index].normalized + tokens[index + 1].normalized + + # Hive CREATE FUNCTION + if any( + index + i < len(tokens) and tokens[index + i].normalized == "FUNCTION" + for i in (1, 2) + ): + switch = "CREATEFUNCTION" + + return switch From 4b235487b4c07f2c907238277c14d48118f18d0b Mon Sep 17 00:00:00 2001 From: Yuhang <50909599+MiuNice@users.noreply.github.com> Date: Fri, 14 Jun 2024 16:23:20 +0800 Subject: [PATCH 3/3] func: get_switch_by_create_query set to private method --- sql_metadata/parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql_metadata/parser.py b/sql_metadata/parser.py index dd4457e8..64f40219 100644 --- a/sql_metadata/parser.py +++ b/sql_metadata/parser.py @@ -115,7 +115,7 @@ def query_type(self) -> str: .position ) if tokens[index].normalized == "CREATE": - switch = self.get_switch_by_create_query(tokens, index) + switch = self._get_switch_by_create_query(tokens, index) elif tokens[index].normalized in ("ALTER", "DROP"): switch = tokens[index].normalized + tokens[index + 1].normalized else: