diff --git a/modules/fbs-core/api/src/main/scala/de/thm/ii/fbs/services/checker/SqlCheckerRemoteCheckerService.scala b/modules/fbs-core/api/src/main/scala/de/thm/ii/fbs/services/checker/SqlCheckerRemoteCheckerService.scala index 230ff5ee5..57c5c6e96 100644 --- a/modules/fbs-core/api/src/main/scala/de/thm/ii/fbs/services/checker/SqlCheckerRemoteCheckerService.scala +++ b/modules/fbs-core/api/src/main/scala/de/thm/ii/fbs/services/checker/SqlCheckerRemoteCheckerService.scala @@ -83,13 +83,10 @@ class SqlCheckerRemoteCheckerService(@Value("${services.masterRunner.insecure}") resultText: String, extInfo: String): Unit = { SqlCheckerRemoteCheckerService.isCheckerRun.getOrDefault(submission.id, SqlCheckerState.Runner) match { case SqlCheckerState.Runner => - println("a") SqlCheckerRemoteCheckerService.isCheckerRun.put(submission.id, SqlCheckerState.Checker) this.notify(task.id, submission.id, checkerConfiguration, userService.find(submission.userID.get).get) if (exitCode == 2 && hintsEnabled(checkerConfiguration)) { - println("b") if (extInfo != null) { - println("c") SqlCheckerRemoteCheckerService.extInfo.put(submission.id, extInfo) } } else { @@ -97,12 +94,10 @@ class SqlCheckerRemoteCheckerService(@Value("${services.masterRunner.insecure}") super.handle(submission, checkerConfiguration, task, exitCode, resultText, extInfo) } case SqlCheckerState.Checker => - println("e") SqlCheckerRemoteCheckerService.isCheckerRun.remove(submission.id) val extInfo = SqlCheckerRemoteCheckerService.extInfo.remove(submission.id) this.handleSelf(submission, checkerConfiguration, task, exitCode, resultText, extInfo) case SqlCheckerState.Ignore => - println("f") SqlCheckerRemoteCheckerService.isCheckerRun.remove(submission.id) } } @@ -139,9 +134,14 @@ class SqlCheckerRemoteCheckerService(@Value("${services.masterRunner.insecure}") } } if (query.distance.isPresent) { - hints ++= "Distanz zur nächstens Musterlösung: " - hints ++= Math.round(query.distance.get / 50).toString - hints ++= "\n" + val steps = Math.round(query.distance.get / 50) + if (steps == 0) { + hints ++= "Du bist ganz nah an der Lösung, es sind nur noch kleine Änderung notwendig.\n" + } else { + hints ++= "Es sind " + hints ++= steps.toString + hints ++= " Änderungen erforderlich, um Deine Lösung an die nächstgelegene Musterlösung anzupassen.\n" + } } if (sci.showExtendedHints && sci.showExtendedHintsAt <= attempts) { //ToDo @@ -150,11 +150,9 @@ class SqlCheckerRemoteCheckerService(@Value("${services.masterRunner.insecure}") private def formatV2(hints: StringBuilder, query: SQLCheckerQuery): Unit = { for (error <- query.errors.asScala) { - hints ++= "In " + hints ++= "Mistake in " hints ++= error.trace.asScala.mkString(", ") - hints ++= " expected " - hints ++= error.expected - hints ++= " but got " + hints ++= " where " hints ++= error.got hints ++= "\n\n" } diff --git a/modules/fbs-sql-checker/api/comparator/sqlparse_comparator.py b/modules/fbs-sql-checker/api/comparator/sqlparse_comparator.py index 83e37a65e..1c20d3af9 100644 --- a/modules/fbs-sql-checker/api/comparator/sqlparse_comparator.py +++ b/modules/fbs-sql-checker/api/comparator/sqlparse_comparator.py @@ -53,6 +53,8 @@ def __init__(self): sqlparse.sql.Comparison: self.visit_comparison, sqlparse.sql.Function: self.visit_function, sqlparse.sql.Parenthesis: self.visit_parenthesis, + sqlparse.sql.Operation: self.visit_operation, + sqlparse.sql.TypedLiteral: self.visit_typed_literal, } def recursive_visit(self, token: sqlparse.sql.TokenList): @@ -81,9 +83,15 @@ def visit_function(self, token: sqlparse.sql.Function): def visit_parenthesis(self, token: sqlparse.sql.Parenthesis): self.recursive_visit(token) + def visit_operation(self, token: sqlparse.sql.Operation): + self.recursive_visit(token) + def visit_literal(self, token: sqlparse.tokens.Token): pass + def visit_typed_literal(self, token: sqlparse.sql.TypedLiteral): + self.recursive_visit(token) + def visit(self, tokens: list[sqlparse.sql.Token]): for token in tokens: if token.ttype is not None: @@ -94,7 +102,7 @@ def visit(self, tokens: list[sqlparse.sql.Token]): raise ValueError("unhandled token", token) def trace_to_str_list(self) -> list[str]: - return [token_to_str(entry) for entry in self.parent_stack] + return [self._token_to_str(entry) for entry in self.parent_stack] class SqlParserDfs(SqlParseVisitor): @@ -113,7 +121,7 @@ def visit_literal(self, token: sqlparse.tokens.Token): class SqlParserCoVisitor(SqlParseVisitor): - def __init__(self, solution, message_overrides=None): + def __init__(self, solution, message_overrides=None, token_names_overrides=None): super().__init__() self._solution = solution self._i = 0 @@ -121,6 +129,10 @@ def __init__(self, solution, message_overrides=None): "end_of_query": "End of query", "end_of_token": "End of token", } | (message_overrides or {}) + self._token_names_overrides = { + "IdentifierList": "Select Attributes", + "TypedLiteral": "Interval", + } | (token_names_overrides or {}) self._error_depth = None self.errors = [] @@ -130,9 +142,9 @@ def visit(self, tokens: list[sqlparse.sql.Token]): should, _ = self._solution[self._i] self.errors.append( Error( - token_to_str(should), + self._token_to_str(should), self._messages["end_of_query"], - [token_to_str(tokens[0])], + [self._token_to_str(tokens[0])], ) ) @@ -161,7 +173,7 @@ def _should_compare(self, token, comparator) -> bool: if end_of_token_error: self.errors.append( Error( - token_to_str(token), + self._token_to_str(token), self._messages["end_of_token"], self.trace_to_str_list(), ) @@ -171,7 +183,7 @@ def _should_compare(self, token, comparator) -> bool: self.errors.append( Error( self._messages["end_of_query"], - token_to_str(token), + self._token_to_str(token), self.trace_to_str_list(), ) ) @@ -179,7 +191,7 @@ def _should_compare(self, token, comparator) -> bool: self.errors.append( Error( self._messages["end_of_token"], - token_to_str(token), + self._token_to_str(token), self.trace_to_str_list(), ) ) @@ -187,7 +199,7 @@ def _should_compare(self, token, comparator) -> bool: elif not comparator(token, should): self.errors.append( Error( - token_to_str(should), token_to_str(token), self.trace_to_str_list() + self._token_to_str(should), self._token_to_str(token), self.trace_to_str_list() ) ) else: @@ -210,6 +222,8 @@ def visit_literal(self, token: sqlparse.tokens.Token): ) super().visit_literal(token) + def _map_token_name(self, token_name: str) -> str: + return self._token_names_overrides.get(token_name, token_name) -def token_to_str(token: sqlparse.tokens.Token) -> str: - return token.__class__.__name__ if token.ttype is None else repr(token.value) + def _token_to_str(self, token: sqlparse.tokens.Token) -> str: + return self._map_token_name(token.__class__.__name__) if token.ttype is None else repr(token.value) diff --git a/modules/fbs-sql-checker/api/comparator/sqlparse_comparator_test.py b/modules/fbs-sql-checker/api/comparator/sqlparse_comparator_test.py index 0a3061607..93ef7cbdd 100644 --- a/modules/fbs-sql-checker/api/comparator/sqlparse_comparator_test.py +++ b/modules/fbs-sql-checker/api/comparator/sqlparse_comparator_test.py @@ -87,6 +87,62 @@ def test_compare_with_aggregate(self): ) assert len(errors) == 1 + def test_compare_deep(self): + comparator = SqlparseComparator() + errors = comparator.compare( + "SELECT email FROM users WHERE username IN (SELECT username FROM users WHERE registration_date > (SELECT AVG(registraion_date) FROM users))", + "SELECT email FROM users WHERE username IN (SELECT username FROM users WHERE registration_date > (SELECT MIN(registraion_date) FROM users))", + ) + assert len(errors) == 1 + + def test_compare_much_error(self): + comparator = SqlparseComparator() + errors = comparator.compare( + "SELECT email FROM users WHERE username IN (SELECT username FROM users WHERE registration_date > (SELECT AVG(registraion_date) FROM users))", + "SELECT email FROM users WHERE username IN (SELECT email FROM users WHERE registration_date < (SELECT MAX(registraion_date) FROM users))", + ) + assert len(errors) == 2 + + def test_compare_identifier_list(self): + comparator = SqlparseComparator() + errors = comparator.compare( + "SELECT username, password FROM users", + "SELECT FROM users", + ) + assert len(errors) == 2 + assert errors[0].expected == "Select Attributes" + + def test_very_complex_query(self): + comparator = SqlparseComparator() + errors = comparator.compare( + "SELECT monat, AVG(tage_bis_erstes_gebot) AS durchschnittliche_tage FROM ( SELECT EXTRACT(MONTH FROM Registriert_am) AS monat, EXTRACT(DAY FROM (MIN(Geboten_am) - Registriert_am)) AS tage_bis_erstes_gebot FROM Gebot g JOIN Kunde k ON g.Bieter = k.KNr GROUP BY Bieter, Registriert_am ) AS tage GROUP BY monat ORDER BY monat;", + "SELECT monat, AVG(tage_bis_erstes_gebot) AS durchschnittliche_tage FROM ( SELECT EXTRACT(MONTH FROM Registriert_am) AS monat, EXTRACT(DAY FROM (MIN(Geboten_am) - Registriert_am)) AS tage_bis_erstes_gebot FROM Gebot g JOIN Kunde k ON g.Bieter = k.KNr GROUP BY Bieter, Registriert_am ) AS tage GROUP BY monat;" + ) + assert len(errors) == 2 + + def test_with_with(self): + comparator = SqlparseComparator() + errors = comparator.compare( + "WITH (SELECT username FROM blocked_users) AS bu SELECT * FROM bu", + "WITH (SELECT email FROM blocked_users) AS bu SELECT * FROM bu" + ) + assert len(errors) == 1 + + def test_very_very_complex_query(self): + comparator = SqlparseComparator() + errors = comparator.compare( + "SELECT g.Auktion, g.Bieter, g.Geboten_am, g.Gebotspreis FROM Gebot g JOIN Auktion a ON g.Auktion = a.ANr WHERE g.Geboten_am >= a.Eingestellt_am AND g.Geboten_am <= a.Eingestellt_am + INTERVAL '7 days' AND g.Gebotspreis > COALESCE( ( SELECT MAX(g_prev.Gebotspreis) FROM Gebot g_prev WHERE g_prev.Auktion = g.Auktion AND g_prev.Geboten_am < g.Geboten_am ), a.Startpreis ) ORDER BY g.Auktion, g.Geboten_am;", + "SELECT g.Auktion, g.Bieter, g.Geboten_am, g.Gebotspreis FROM Gebot g JOIN Auktion a ON g.Auktion = a.ANr WHERE g.Geboten_am >= a.Eingestellt_am AND g.Geboten_am <= a.Eingestellt_am + INTERVAL '7 days' AND g.Gebotspreis > COALESCE( ( SELECT MAX(g_prev.Gebotspreis) FROM Gebot g_prev WHERE g_prev.Auktion = g.Auktion AND g_prev.Geboten_am < g.Geboten_am ), a.Startpreis ) ORDER BY g.Geboten_am, g.Auktion;" + ) + assert len(errors) == 2 + + def test_not_null(self): + comparator = SqlparseComparator() + errors = comparator.compare( + "SELECT username FROM user WHERE banned_at IS NULL", + "SELECT username FROM user WHERE banned_at IS NOT NULL" + ) + assert len(errors) == 1 if __name__ == "__main__": unittest.main()