From 2aca31d5986a9e1c65a92264736de9fdc3b9b4ca Mon Sep 17 00:00:00 2001 From: Divjot Arora Date: Mon, 29 Mar 2021 19:38:16 -0400 Subject: [PATCH] GODRIVER-1923 Error if BSON cstrings contain null bytes (#622) --- bson/bsonrw/value_writer.go | 15 +++++++++- bson/marshal_test.go | 33 ++++++++++++++++++++++ x/bsonx/bsoncore/bsoncore.go | 26 +++++++++++++---- x/bsonx/bsoncore/bsoncore_test.go | 46 +++++++++++++++++++++++++++++++ 4 files changed, 114 insertions(+), 6 deletions(-) diff --git a/bson/bsonrw/value_writer.go b/bson/bsonrw/value_writer.go index 3717198366..7b7d7ad3f2 100644 --- a/bson/bsonrw/value_writer.go +++ b/bson/bsonrw/value_writer.go @@ -12,6 +12,7 @@ import ( "io" "math" "strconv" + "strings" "sync" "go.mongodb.org/mongo-driver/bson/bsontype" @@ -247,7 +248,12 @@ func (vw *valueWriter) invalidTransitionError(destination mode, name string, mod func (vw *valueWriter) writeElementHeader(t bsontype.Type, destination mode, callerName string, addmodes ...mode) error { switch vw.stack[vw.frame].mode { case mElement: - vw.buf = bsoncore.AppendHeader(vw.buf, t, vw.stack[vw.frame].key) + key := vw.stack[vw.frame].key + if !isValidCString(key) { + return errors.New("BSON element key cannot contain null bytes") + } + + vw.buf = bsoncore.AppendHeader(vw.buf, t, key) case mValue: // TODO: Do this with a cache of the first 1000 or so array keys. vw.buf = bsoncore.AppendHeader(vw.buf, t, strconv.Itoa(vw.stack[vw.frame].arrkey)) @@ -430,6 +436,9 @@ func (vw *valueWriter) WriteObjectID(oid primitive.ObjectID) error { } func (vw *valueWriter) WriteRegex(pattern string, options string) error { + if !isValidCString(pattern) || !isValidCString(options) { + return errors.New("BSON regex values cannot contain null bytes") + } if err := vw.writeElementHeader(bsontype.Regex, mode(0), "WriteRegex"); err != nil { return err } @@ -602,3 +611,7 @@ func (vw *valueWriter) writeLength() error { vw.buf[start+3] = byte(length >> 24) return nil } + +func isValidCString(cs string) bool { + return !strings.ContainsRune(cs, '\x00') +} diff --git a/bson/marshal_test.go b/bson/marshal_test.go index 7e570676b9..319870522d 100644 --- a/bson/marshal_test.go +++ b/bson/marshal_test.go @@ -8,6 +8,7 @@ package bson import ( "bytes" + "errors" "fmt" "reflect" "testing" @@ -267,3 +268,35 @@ func TestCachingEncodersNotSharedAcrossRegistries(t *testing.T) { }) }) } + +func TestNullBytes(t *testing.T) { + t.Run("element keys", func(t *testing.T) { + doc := D{{"a\x00", "foobar"}} + res, err := Marshal(doc) + want := errors.New("BSON element key cannot contain null bytes") + assert.Equal(t, want, err, "expected Marshal error %v, got error %v with result %q", want, err, Raw(res)) + }) + + t.Run("regex values", func(t *testing.T) { + wantErr := errors.New("BSON regex values cannot contain null bytes") + + testCases := []struct { + name string + pattern string + options string + }{ + {"null bytes in pattern", "a\x00", "i"}, + {"null bytes in options", "pattern", "i\x00"}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + regex := primitive.Regex{ + Pattern: tc.pattern, + Options: tc.options, + } + res, err := Marshal(D{{"foo", regex}}) + assert.Equal(t, wantErr, err, "expected Marshal error %v, got error %v with result %q", wantErr, err, Raw(res)) + }) + } + }) +} diff --git a/x/bsonx/bsoncore/bsoncore.go b/x/bsonx/bsoncore/bsoncore.go index dde741c0a3..97ef1b85d5 100644 --- a/x/bsonx/bsoncore/bsoncore.go +++ b/x/bsonx/bsoncore/bsoncore.go @@ -30,17 +30,21 @@ import ( "fmt" "math" "strconv" + "strings" "time" "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/bson/primitive" ) -// EmptyDocumentLength is the length of a document that has been started/ended but has no elements. -const EmptyDocumentLength = 5 - -// nullTerminator is a string version of the 0 byte that is appended at the end of cstrings. -const nullTerminator = string(byte(0)) +const ( + // EmptyDocumentLength is the length of a document that has been started/ended but has no elements. + EmptyDocumentLength = 5 + // nullTerminator is a string version of the 0 byte that is appended at the end of cstrings. + nullTerminator = string(byte(0)) + invalidKeyPanicMsg = "BSON element keys cannot contain null bytes" + invalidRegexPanicMsg = "BSON regex values cannot contain null bytes" +) // AppendType will append t to dst and return the extended buffer. func AppendType(dst []byte, t bsontype.Type) []byte { return append(dst, byte(t)) } @@ -51,6 +55,10 @@ func AppendKey(dst []byte, key string) []byte { return append(dst, key+nullTermi // AppendHeader will append Type t and key to dst and return the extended // buffer. func AppendHeader(dst []byte, t bsontype.Type, key string) []byte { + if !isValidCString(key) { + panic(invalidKeyPanicMsg) + } + dst = AppendType(dst, t) dst = append(dst, key...) return append(dst, 0x00) @@ -430,6 +438,10 @@ func AppendNullElement(dst []byte, key string) []byte { return AppendHeader(dst, // AppendRegex will append pattern and options to dst and return the extended buffer. func AppendRegex(dst []byte, pattern, options string) []byte { + if !isValidCString(pattern) || !isValidCString(options) { + panic(invalidRegexPanicMsg) + } + return append(dst, pattern+nullTerminator+options+nullTerminator...) } @@ -844,3 +856,7 @@ func appendBinarySubtype2(dst []byte, subtype byte, b []byte) []byte { dst = appendLength(dst, int32(len(b))) return append(dst, b...) } + +func isValidCString(cs string) bool { + return !strings.ContainsRune(cs, '\x00') +} diff --git a/x/bsonx/bsoncore/bsoncore_test.go b/x/bsonx/bsoncore/bsoncore_test.go index 8a31ecd913..84b889955d 100644 --- a/x/bsonx/bsoncore/bsoncore_test.go +++ b/x/bsonx/bsoncore/bsoncore_test.go @@ -16,6 +16,7 @@ import ( "github.com/google/go-cmp/cmp" "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/internal/testutil/assert" ) func noerr(t *testing.T, err error) { @@ -899,6 +900,51 @@ func TestBuild(t *testing.T) { } } +func TestNullBytes(t *testing.T) { + // Helper function to execute the provided callback and assert that it panics with the expected message. The + // createBSONFn callback should create a BSON document/array/value and return the stringified version. + assertBSONCreationPanics := func(t *testing.T, createBSONFn func(), expected string) { + t.Helper() + + defer func() { + got := recover() + assert.Equal(t, expected, got, "expected panic with error %v, got error %v", expected, got) + }() + createBSONFn() + } + + t.Run("element keys", func(t *testing.T) { + createDocFn := func() { + NewDocumentBuilder().AppendString("a\x00", "foo") + } + assertBSONCreationPanics(t, createDocFn, invalidKeyPanicMsg) + }) + t.Run("regex values", func(t *testing.T) { + testCases := []struct { + name string + pattern string + options string + }{ + {"null bytes in pattern", "a\x00", "i"}, + {"null bytes in options", "pattern", "i\x00"}, + } + for _, tc := range testCases { + t.Run(tc.name+"-AppendRegexElement", func(t *testing.T) { + createDocFn := func() { + AppendRegexElement(nil, "foo", tc.pattern, tc.options) + } + assertBSONCreationPanics(t, createDocFn, invalidRegexPanicMsg) + }) + t.Run(tc.name+"-AppendRegex", func(t *testing.T) { + createValFn := func() { + AppendRegex(nil, tc.pattern, tc.options) + } + assertBSONCreationPanics(t, createValFn, invalidRegexPanicMsg) + }) + } + }) +} + func compareDecimal128(d1, d2 primitive.Decimal128) bool { d1H, d1L := d1.GetBytes() d2H, d2L := d2.GetBytes()