Skip to content

Commit

Permalink
Multiple bug fixes and enhancements for std regex (#256)
Browse files Browse the repository at this point in the history
Implement namedGroups (which was missing from the first implementation)
- JS is tricky because you need to convert Python-style named groups
(the only ones supported by RE2) to perl ones
- You need to manually find them due to the lack of clear API in the JDK

Align our implementation with jrsonnet implementations:
- return null if no match
- make fullmatch just partialmatch with anchors - it also allows us to
share more code.

Fix the - problem with regexQuote in RE2 across the board
  • Loading branch information
stephenamar-db authored Jan 7, 2025
1 parent 7d75fd7 commit 9143c58
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 113 deletions.
19 changes: 18 additions & 1 deletion sjsonnet/src-js/sjsonnet/Platform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,27 @@ object Platform {
}

private val regexCache = new util.concurrent.ConcurrentHashMap[String, Pattern]
private val namedGroupPattern = Pattern.compile("\\(\\?<(.+?)>.*?\\)")
private val namedGroupPatternReplace = Pattern.compile("(\\(\\?P<)(.+?>.*?\\))")

// scala.js does not rely on re2. Per https://www.scala-js.org/doc/regular-expressions.html.
// Expect to see some differences in behavior.
def getPatternFromCache(pat: String) : Pattern = regexCache.computeIfAbsent(pat, _ => Pattern.compile(pat))
def getPatternFromCache(pat: String) : Pattern = {
val fixedPattern = namedGroupPatternReplace.matcher(pat).replaceAll("(?<$2")
regexCache.computeIfAbsent(pat, _ => Pattern.compile(fixedPattern))
}


def getNamedGroupsMap(pat: Pattern): Map[String, Int] = {
val namedGroups = Map.newBuilder[String, Int]
val matcher = namedGroupPattern.matcher(pat.pattern())
while (matcher.find()) {
for (i <- 1 to matcher.groupCount()) {
namedGroups += matcher.group(i) -> i
}
}
namedGroups.result()
}

def regexQuote(s: String): String = Pattern.quote(s)
}
15 changes: 14 additions & 1 deletion sjsonnet/src-jvm/sjsonnet/Platform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import org.tukaani.xz.XZOutputStream
import org.yaml.snakeyaml.{LoaderOptions, Yaml}
import org.yaml.snakeyaml.constructor.SafeConstructor

import scala.collection.compat._
import scala.jdk.CollectionConverters._

object Platform {
Expand Down Expand Up @@ -112,7 +113,19 @@ object Platform {
}

private val regexCache = new util.concurrent.ConcurrentHashMap[String, Pattern]
private val dashPattern = getPatternFromCache("-")

def getPatternFromCache(pat: String) : Pattern = regexCache.computeIfAbsent(pat, _ => Pattern.compile(pat))

def regexQuote(s: String): String = Pattern.quote(s)
def getNamedGroupsMap(pat: Pattern): Map[String, Int] = pat.namedGroups().asScala.view.mapValues(_.intValue()).toMap

def regexQuote(s: String): String = {
val quote = Pattern.quote(s)
val matcher = dashPattern.matcher(quote)
if (matcher.find()) {
matcher.replaceAll("\\\\-")
} else {
quote
}
}
}
19 changes: 16 additions & 3 deletions sjsonnet/src-native/sjsonnet/Platform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import java.util.Base64
import java.util.zip.GZIPOutputStream
import scala.scalanative.regex.Pattern

import scala.collection.compat._

object Platform {
def gzipBytes(b: Array[Byte]): String = {
val outputStream: ByteArrayOutputStream = new ByteArrayOutputStream(b.length)
Expand Down Expand Up @@ -54,9 +56,20 @@ object Platform {
}

private val regexCache = new util.concurrent.ConcurrentHashMap[String, Pattern]
// scala native is powered by RE2, per https://scala-native.org/en/latest/lib/javalib.html#regular-expressions-java-util-regexp
// It should perform similarly to the JVM implementation.
private val dashPattern = getPatternFromCache("-")

def getPatternFromCache(pat: String) : Pattern = regexCache.computeIfAbsent(pat, _ => Pattern.compile(pat))

def regexQuote(s: String): String = Pattern.quote(s)
def getNamedGroupsMap(pat: Pattern): Map[String, Int] = scala.jdk.javaapi.CollectionConverters.asScala(
pat.re2.namedGroups).view.mapValues(_.intValue()).toMap

def regexQuote(s: String): String = {
val quote = Pattern.quote(s)
val matcher = dashPattern.matcher(quote)
if (matcher.find()) {
matcher.replaceAll("\\\\-")
} else {
quote
}
}
}
21 changes: 4 additions & 17 deletions sjsonnet/src/sjsonnet/Std.scala
Original file line number Diff line number Diff line change
Expand Up @@ -483,26 +483,14 @@ class Std(private val additionalNativeFunctions: Map[String, Val.Builtin] = Map.
}

private object StripUtils {
private val dashPattern = Platform.getPatternFromCache("-")

private def cleanupPattern(chars: String): String = {
val matcher = dashPattern.matcher(chars)
if (matcher.find()) {
matcher.replaceAll("") + "-"
} else {
chars
}
}

private def getLeadingPattern(chars: String): String = "^[" + Platform.regexQuote(chars) + "]+"

private def getTrailingPattern(chars: String): String = "[" + Platform.regexQuote(chars) + "]+$"

def unspecializedStrip(str: String, chars: String, left: Boolean, right: Boolean): String = {
var s = str
val cleanedUpPattern = cleanupPattern(chars)
if (right) s = Platform.getPatternFromCache(getTrailingPattern(cleanedUpPattern)).matcher(s).replaceAll("")
if (left) s = Platform.getPatternFromCache(getLeadingPattern(cleanedUpPattern)).matcher(s).replaceAll("")
if (right) s = Platform.getPatternFromCache(getTrailingPattern(chars)).matcher(s).replaceAll("")
if (left) s = Platform.getPatternFromCache(getLeadingPattern(chars)).matcher(s).replaceAll("")
s
}

Expand All @@ -512,9 +500,8 @@ class Std(private val additionalNativeFunctions: Map[String, Val.Builtin] = Map.
right: Boolean,
functionName: String
) extends Val.Builtin1(functionName, "str") {
private[this] val cleanedUpPattern = cleanupPattern(chars)
private[this] val leftPattern = Platform.getPatternFromCache(getLeadingPattern(cleanedUpPattern))
private[this] val rightPattern = Platform.getPatternFromCache(getTrailingPattern(cleanedUpPattern))
private[this] val leftPattern = Platform.getPatternFromCache(getLeadingPattern(chars))
private[this] val rightPattern = Platform.getPatternFromCache(getTrailingPattern(chars))

def evalRhs(str: Val, ev: EvalScope, pos: Position): Val = {
var s = str.asString
Expand Down
78 changes: 25 additions & 53 deletions sjsonnet/src/sjsonnet/StdRegex.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,66 +4,38 @@ import sjsonnet.Expr.Member.Visibility
import sjsonnet.Val.Obj

object StdRegex {
private final def regexPartialMatch(pos: Position, pattern: String, str: String): Val = {
val compiledPattern = Platform.getPatternFromCache(pattern)
val matcher = compiledPattern.matcher(str)

if (matcher.find()) {
val captures = Range.Int.inclusive(1, matcher.groupCount(), 1)
.map(i => Val.Str(pos.noOffset, Option(matcher.group(i)).getOrElse("")))
.toArray
val namedCaptures = Platform.getNamedGroupsMap(compiledPattern).map {
case (k, v) =>
k -> new Obj.ConstMember(true, Visibility.Normal, captures(v - 1))
}.toSeq

Val.Obj.mk(pos.noOffset,
"string" -> new Obj.ConstMember(true, Visibility.Normal, Val.Str(pos.noOffset, str)),
"captures" -> new Obj.ConstMember(true, Visibility.Normal, new Val.Arr(pos.noOffset, captures)),
"namedCaptures" -> new Obj.ConstMember(true, Visibility.Normal, Val.Obj.mk(pos.noOffset, namedCaptures: _*))
)
} else {
Val.Null(pos.noOffset)
}
}

def functions: Map[String, Val.Builtin] = Map(
"regexPartialMatch" -> new Val.Builtin2("regexPartialMatch", "pattern", "str") {
override def evalRhs(pattern: Val, str: Val, ev: EvalScope, pos: Position): Val = {
val compiledPattern = Platform.getPatternFromCache(pattern.asString)
val matcher = compiledPattern.matcher(str.asString)
var returnStr: Val = null
val captures = Array.newBuilder[Val]
val groupCount = matcher.groupCount()
while (matcher.find()) {
if (returnStr == null) {
val m = matcher.group(0)
if (m != null) {
returnStr = Val.Str(pos.noOffset, matcher.group(0))
} else {
returnStr = Val.Null(pos.noOffset)
}
}
for (i <- 1 to groupCount) {
val m = matcher.group(i)
if (m == null) {
captures += Val.Null(pos.noOffset)
} else {
captures += Val.Str(pos.noOffset, m)
}
}
}
val result = captures.result()
Val.Obj.mk(pos.noOffset,
"string" -> new Obj.ConstMember(true, Visibility.Normal,
if (returnStr == null) Val.Null(pos.noOffset) else returnStr),
"captures" -> new Obj.ConstMember(true, Visibility.Normal, new Val.Arr(pos.noOffset, result))
)
regexPartialMatch(pos, pattern.asString, str.asString)
}
},
"regexFullMatch" -> new Val.Builtin2("regexFullMatch", "pattern", "str") {
override def evalRhs(pattern: Val, str: Val, ev: EvalScope, pos: Position): Val = {
val compiledPattern = Platform.getPatternFromCache(pattern.asString)
val matcher = compiledPattern.matcher(str.asString)
if (!matcher.matches()) {
Val.Obj.mk(pos.noOffset,
"string" -> new Obj.ConstMember(true, Visibility.Normal, Val.Null(pos.noOffset)),
"captures" -> new Obj.ConstMember(true, Visibility.Normal, new Val.Arr(pos.noOffset, Array.empty[Lazy]))
)
} else {
val captures = Array.newBuilder[Val]
val groupCount = matcher.groupCount()
for (i <- 0 to groupCount) {
val m = matcher.group(i)
if (m == null) {
captures += Val.Null(pos.noOffset)
} else {
captures += Val.Str(pos.noOffset, m)
}
}
val result = captures.result()
Val.Obj.mk(pos.noOffset,
"string" -> new Obj.ConstMember(true, Visibility.Normal, result.head),
"captures" -> new Obj.ConstMember(true, Visibility.Normal, new Val.Arr(pos.noOffset, result.drop(1)))
)
}
regexPartialMatch(pos, s"^${pattern.asString}$$", str.asString)
}
},
"regexGlobalReplace" -> new Val.Builtin3("regexGlobalReplace", "str", "pattern", "to") {
Expand Down
72 changes: 72 additions & 0 deletions sjsonnet/test/resources/test_suite/regex.jsonnet
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
std.assertEqual(std.native('regexFullMatch')(@'e', 'hello'), null) &&

std.assertEqual(
std.native('regexFullMatch')(@'h.*o', 'hello'),
{
string: 'hello',
captures: [],
namedCaptures: {},
}
) &&

std.assertEqual(
std.native('regexFullMatch')(@'h(.*)o', 'hello'),
{
string: 'hello',
captures: ['ell'],
namedCaptures: {},
}
) &&

std.assertEqual(
std.native('regexFullMatch')(@'h(?P<mid>.*)o', 'hello'),
{
string: 'hello',
captures: ['ell'],
namedCaptures: {
mid: 'ell',
},
}
) &&

std.assertEqual(std.native('regexPartialMatch')(@'world', 'hello'), null) &&

std.assertEqual(
std.native('regexPartialMatch')(@'e', 'hello'),
{
string: 'hello',
captures: [],
namedCaptures: {},
}
) &&

std.assertEqual(
std.native('regexPartialMatch')(@'e(.*)o', 'hello'),
{
string: 'hello',
captures: ['ll'],
namedCaptures: {},
}
) &&

std.assertEqual(
std.native('regexPartialMatch')(@'e(?P<mid>.*)o', 'hello'),
{
string: 'hello',
captures: ['ll'],
namedCaptures: {
mid: 'll',
},
}
) &&

std.assertEqual(std.native('regexQuoteMeta')(@'1.5-2.0?'), '1\\.5\\-2\\.0\\?') &&


std.assertEqual(std.native('regexReplace')('wishyfishyisishy', @'ish', 'and'), 'wandyfishyisishy') &&
std.assertEqual(std.native('regexReplace')('yabba dabba doo', @'b+', 'd'), 'yada dabba doo') &&

std.assertEqual(std.native('regexGlobalReplace')('wishyfishyisishy', @'ish', 'and'), 'wandyfandyisandy') &&
std.assertEqual(std.native('regexGlobalReplace')('yabba dabba doo', @'b+', 'd'), 'yada dada doo') &&

true
1 change: 1 addition & 0 deletions sjsonnet/test/src-jvm/sjsonnet/FileTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ object FileTests extends TestSuite{
// test("recursive_function") - check()
test("recursive_import_ok") - check()
test("recursive_object") - check()
test("regex") - check()
test("sanity") - checkGolden()
test("sanity2") - checkGolden()
test("shebang") - check()
Expand Down
38 changes: 0 additions & 38 deletions sjsonnet/test/src/sjsonnet/StdRegexTests.scala

This file was deleted.

0 comments on commit 9143c58

Please sign in to comment.