Skip to content

Commit

Permalink
Fix: rawstring backslashes for bigquery
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Jun 16, 2023
1 parent d27e8f8 commit 58fe190
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
17 changes: 11 additions & 6 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,9 +778,11 @@ def bytestring_sql(self, expression: exp.ByteString) -> str:
return this

def rawstring_sql(self, expression: exp.RawString) -> str:
string = expression.this
if self.RAW_START:
return f"{self.RAW_START}{expression.name}{self.RAW_END}"
return self.sql(exp.Literal.string(expression.name.replace("\\", "\\\\")))
return f"{self.RAW_START}{self.escape_str(expression.this)}{self.RAW_END}"
string = self.escape_str(string.replace("\\", "\\\\"))
return f"{self.QUOTE_START}{string}{self.QUOTE_END}"

def datatypesize_sql(self, expression: exp.DataTypeSize) -> str:
this = self.sql(expression, "this")
Expand Down Expand Up @@ -1420,10 +1422,13 @@ def lock_sql(self, expression: exp.Lock) -> str:
def literal_sql(self, expression: exp.Literal) -> str:
text = expression.this or ""
if expression.is_string:
text = text.replace(self.QUOTE_END, self._escaped_quote_end)
if self.pretty:
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
text = f"{self.QUOTE_START}{text}{self.QUOTE_END}"
text = f"{self.QUOTE_START}{self.escape_str(text)}{self.QUOTE_END}"
return text

def escape_str(self, text: str) -> str:
text = text.replace(self.QUOTE_END, self._escaped_quote_end)
if self.pretty:
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
return text

def loaddata_sql(self, expression: exp.LoadData) -> str:
Expand Down
15 changes: 15 additions & 0 deletions tests/dialects/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,21 @@ def test_bigquery(self):
with self.assertRaises(ValueError):
transpile("'\\'", read="bigquery")

self.validate_all(
"r'x\\''",
write={
"bigquery": "r'x\\''",
"hive": "'x\\''",
},
)

self.validate_all(
"r'x\\y'",
write={
"bigquery": "r'x\\y'",
"hive": "'x\\\\y'",
},
)
self.validate_all(
"'\\\\'",
write={
Expand Down

0 comments on commit 58fe190

Please sign in to comment.