Skip to content

Commit

Permalink
feat(sql-checker): improve feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
Zitrone44 committed Dec 10, 2024
1 parent a1cbece commit 3957561
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,26 +83,21 @@ class SqlCheckerRemoteCheckerService(@Value("${services.masterRunner.insecure}")
resultText: String, extInfo: String): Unit = {
SqlCheckerRemoteCheckerService.isCheckerRun.getOrDefault(submission.id, SqlCheckerState.Runner) match {
case SqlCheckerState.Runner =>
println("a")
SqlCheckerRemoteCheckerService.isCheckerRun.put(submission.id, SqlCheckerState.Checker)
this.notify(task.id, submission.id, checkerConfiguration, userService.find(submission.userID.get).get)
if (exitCode == 2 && hintsEnabled(checkerConfiguration)) {
println("b")
if (extInfo != null) {
println("c")
SqlCheckerRemoteCheckerService.extInfo.put(submission.id, extInfo)
}
} else {
SqlCheckerRemoteCheckerService.isCheckerRun.put(submission.id, SqlCheckerState.Ignore)
super.handle(submission, checkerConfiguration, task, exitCode, resultText, extInfo)
}
case SqlCheckerState.Checker =>
println("e")
SqlCheckerRemoteCheckerService.isCheckerRun.remove(submission.id)
val extInfo = SqlCheckerRemoteCheckerService.extInfo.remove(submission.id)
this.handleSelf(submission, checkerConfiguration, task, exitCode, resultText, extInfo)
case SqlCheckerState.Ignore =>
println("f")
SqlCheckerRemoteCheckerService.isCheckerRun.remove(submission.id)
}
}
Expand Down Expand Up @@ -139,9 +134,14 @@ class SqlCheckerRemoteCheckerService(@Value("${services.masterRunner.insecure}")
}
}
if (query.distance.isPresent) {
hints ++= "Distanz zur nächstens Musterlösung: "
hints ++= Math.round(query.distance.get / 50).toString
hints ++= "\n"
val steps = Math.round(query.distance.get / 50)
if (steps == 0) {
hints ++= "Du bist ganz nah an der Lösung, es sind nur noch kleine Änderung notwendig.\n"
} else {
hints ++= "Es sind "
hints ++= steps.toString
hints ++= " Änderungen erforderlich, um Deine Lösung an die nächstgelegene Musterlösung anzupassen.\n"
}
}
if (sci.showExtendedHints && sci.showExtendedHintsAt <= attempts) {
//ToDo
Expand All @@ -150,11 +150,9 @@ class SqlCheckerRemoteCheckerService(@Value("${services.masterRunner.insecure}")

private def formatV2(hints: StringBuilder, query: SQLCheckerQuery): Unit = {
for (error <- query.errors.asScala) {
hints ++= "In "
hints ++= "Mistake in "
hints ++= error.trace.asScala.mkString(", ")
hints ++= " expected "
hints ++= error.expected
hints ++= " but got "
hints ++= " where "
hints ++= error.got
hints ++= "\n\n"
}
Expand Down
34 changes: 24 additions & 10 deletions modules/fbs-sql-checker/api/comparator/sqlparse_comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def __init__(self):
sqlparse.sql.Comparison: self.visit_comparison,
sqlparse.sql.Function: self.visit_function,
sqlparse.sql.Parenthesis: self.visit_parenthesis,
sqlparse.sql.Operation: self.visit_operation,
sqlparse.sql.TypedLiteral: self.visit_typed_literal,
}

def recursive_visit(self, token: sqlparse.sql.TokenList):
Expand Down Expand Up @@ -81,9 +83,15 @@ def visit_function(self, token: sqlparse.sql.Function):
def visit_parenthesis(self, token: sqlparse.sql.Parenthesis):
self.recursive_visit(token)

def visit_operation(self, token: sqlparse.sql.Operation):
self.recursive_visit(token)

def visit_literal(self, token: sqlparse.tokens.Token):
pass

def visit_typed_literal(self, token: sqlparse.sql.TypedLiteral):
self.recursive_visit(token)

def visit(self, tokens: list[sqlparse.sql.Token]):
for token in tokens:
if token.ttype is not None:
Expand All @@ -94,7 +102,7 @@ def visit(self, tokens: list[sqlparse.sql.Token]):
raise ValueError("unhandled token", token)

def trace_to_str_list(self) -> list[str]:
return [token_to_str(entry) for entry in self.parent_stack]
return [self._token_to_str(entry) for entry in self.parent_stack]


class SqlParserDfs(SqlParseVisitor):
Expand All @@ -113,14 +121,18 @@ def visit_literal(self, token: sqlparse.tokens.Token):


class SqlParserCoVisitor(SqlParseVisitor):
def __init__(self, solution, message_overrides=None):
def __init__(self, solution, message_overrides=None, token_names_overrides=None):
super().__init__()
self._solution = solution
self._i = 0
self._messages = {
"end_of_query": "End of query",
"end_of_token": "End of token",
} | (message_overrides or {})
self._token_names_overrides = {
"IdentifierList": "Select Attributes",
"TypedLiteral": "Interval",
} | (token_names_overrides or {})
self._error_depth = None
self.errors = []

Expand All @@ -130,9 +142,9 @@ def visit(self, tokens: list[sqlparse.sql.Token]):
should, _ = self._solution[self._i]
self.errors.append(
Error(
token_to_str(should),
self._token_to_str(should),
self._messages["end_of_query"],
[token_to_str(tokens[0])],
[self._token_to_str(tokens[0])],
)
)

Expand Down Expand Up @@ -161,7 +173,7 @@ def _should_compare(self, token, comparator) -> bool:
if end_of_token_error:
self.errors.append(
Error(
token_to_str(token),
self._token_to_str(token),
self._messages["end_of_token"],
self.trace_to_str_list(),
)
Expand All @@ -171,23 +183,23 @@ def _should_compare(self, token, comparator) -> bool:
self.errors.append(
Error(
self._messages["end_of_query"],
token_to_str(token),
self._token_to_str(token),
self.trace_to_str_list(),
)
)
elif should_depth < len(self.parent_stack):
self.errors.append(
Error(
self._messages["end_of_token"],
token_to_str(token),
self._token_to_str(token),
self.trace_to_str_list(),
)
)
self._i -= 1
elif not comparator(token, should):
self.errors.append(
Error(
token_to_str(should), token_to_str(token), self.trace_to_str_list()
self._token_to_str(should), self._token_to_str(token), self.trace_to_str_list()
)
)
else:
Expand All @@ -210,6 +222,8 @@ def visit_literal(self, token: sqlparse.tokens.Token):
)
super().visit_literal(token)

def _map_token_name(self, token_name: str) -> str:
return self._token_names_overrides.get(token_name, token_name)

def token_to_str(token: sqlparse.tokens.Token) -> str:
return token.__class__.__name__ if token.ttype is None else repr(token.value)
def _token_to_str(self, token: sqlparse.tokens.Token) -> str:
return self._map_token_name(token.__class__.__name__) if token.ttype is None else repr(token.value)
56 changes: 56 additions & 0 deletions modules/fbs-sql-checker/api/comparator/sqlparse_comparator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,62 @@ def test_compare_with_aggregate(self):
)
assert len(errors) == 1

def test_compare_deep(self):
comparator = SqlparseComparator()
errors = comparator.compare(
"SELECT email FROM users WHERE username IN (SELECT username FROM users WHERE registration_date > (SELECT AVG(registraion_date) FROM users))",
"SELECT email FROM users WHERE username IN (SELECT username FROM users WHERE registration_date > (SELECT MIN(registraion_date) FROM users))",
)
assert len(errors) == 1

def test_compare_much_error(self):
comparator = SqlparseComparator()
errors = comparator.compare(
"SELECT email FROM users WHERE username IN (SELECT username FROM users WHERE registration_date > (SELECT AVG(registraion_date) FROM users))",
"SELECT email FROM users WHERE username IN (SELECT email FROM users WHERE registration_date < (SELECT MAX(registraion_date) FROM users))",
)
assert len(errors) == 2

def test_compare_identifier_list(self):
comparator = SqlparseComparator()
errors = comparator.compare(
"SELECT username, password FROM users",
"SELECT FROM users",
)
assert len(errors) == 2
assert errors[0].expected == "Select Attributes"

def test_very_complex_query(self):
comparator = SqlparseComparator()
errors = comparator.compare(
"SELECT monat, AVG(tage_bis_erstes_gebot) AS durchschnittliche_tage FROM ( SELECT EXTRACT(MONTH FROM Registriert_am) AS monat, EXTRACT(DAY FROM (MIN(Geboten_am) - Registriert_am)) AS tage_bis_erstes_gebot FROM Gebot g JOIN Kunde k ON g.Bieter = k.KNr GROUP BY Bieter, Registriert_am ) AS tage GROUP BY monat ORDER BY monat;",
"SELECT monat, AVG(tage_bis_erstes_gebot) AS durchschnittliche_tage FROM ( SELECT EXTRACT(MONTH FROM Registriert_am) AS monat, EXTRACT(DAY FROM (MIN(Geboten_am) - Registriert_am)) AS tage_bis_erstes_gebot FROM Gebot g JOIN Kunde k ON g.Bieter = k.KNr GROUP BY Bieter, Registriert_am ) AS tage GROUP BY monat;"
)
assert len(errors) == 2

def test_with_with(self):
comparator = SqlparseComparator()
errors = comparator.compare(
"WITH (SELECT username FROM blocked_users) AS bu SELECT * FROM bu",
"WITH (SELECT email FROM blocked_users) AS bu SELECT * FROM bu"
)
assert len(errors) == 1

def test_very_very_complex_query(self):
comparator = SqlparseComparator()
errors = comparator.compare(
"SELECT g.Auktion, g.Bieter, g.Geboten_am, g.Gebotspreis FROM Gebot g JOIN Auktion a ON g.Auktion = a.ANr WHERE g.Geboten_am >= a.Eingestellt_am AND g.Geboten_am <= a.Eingestellt_am + INTERVAL '7 days' AND g.Gebotspreis > COALESCE( ( SELECT MAX(g_prev.Gebotspreis) FROM Gebot g_prev WHERE g_prev.Auktion = g.Auktion AND g_prev.Geboten_am < g.Geboten_am ), a.Startpreis ) ORDER BY g.Auktion, g.Geboten_am;",
"SELECT g.Auktion, g.Bieter, g.Geboten_am, g.Gebotspreis FROM Gebot g JOIN Auktion a ON g.Auktion = a.ANr WHERE g.Geboten_am >= a.Eingestellt_am AND g.Geboten_am <= a.Eingestellt_am + INTERVAL '7 days' AND g.Gebotspreis > COALESCE( ( SELECT MAX(g_prev.Gebotspreis) FROM Gebot g_prev WHERE g_prev.Auktion = g.Auktion AND g_prev.Geboten_am < g.Geboten_am ), a.Startpreis ) ORDER BY g.Geboten_am, g.Auktion;"
)
assert len(errors) == 2

def test_not_null(self):
comparator = SqlparseComparator()
errors = comparator.compare(
"SELECT username FROM user WHERE banned_at IS NULL",
"SELECT username FROM user WHERE banned_at IS NOT NULL"
)
assert len(errors) == 1

if __name__ == "__main__":
unittest.main()

0 comments on commit 3957561

Please sign in to comment.