diff --git a/sql/engines/mssql.py b/sql/engines/mssql.py index b7c16d5391..6d930b3f59 100644 --- a/sql/engines/mssql.py +++ b/sql/engines/mssql.py @@ -316,6 +316,10 @@ def filter_sql(self, sql="", limit_num=0): # 对查询sql增加limit限制 if re.match(r"^select", sql_lower): if sql_lower.find(" top ") == -1: + if sql_lower.find(" distinct ") > 0: + return sql_lower.replace( + "distinct", "distinct top {}".format(limit_num) + ) return sql_lower.replace("select", "select top {}".format(limit_num)) return sql.strip() diff --git a/sql/engines/tests.py b/sql/engines/tests.py index 50939b2d2a..88cea4d093 100644 --- a/sql/engines/tests.py +++ b/sql/engines/tests.py @@ -192,6 +192,13 @@ def test_filter_sql(self): check_result = new_engine.filter_sql(sql=banned_sql, limit_num=10) self.assertEqual(check_result, "select top 10 user from user_table") + def test_filter_sql_with_distinct(self): + new_engine = MssqlEngine(instance=self.ins1) + # 只抽查一个函数 + banned_sql = "select distinct * from user_table" + check_result = new_engine.filter_sql(sql=banned_sql, limit_num=10) + self.assertEqual(check_result, "select distinct top 10 * from user_table") + def test_execute_check(self): new_engine = MssqlEngine(instance=self.ins1) test_sql = (