From 0b61e1fad80abadea334af2fd6a19b4ed68253b9 Mon Sep 17 00:00:00 2001 From: James Hartig Date: Fri, 27 Dec 2024 05:12:08 +0000 Subject: [PATCH] CASSGO-42: don't panic if no applied column is returned I also added checks against MapScan failing which could also trigger this bug. Several integration tests were added to validate various edge cases. Patch by James Hartig for CASSGO-42 --- CHANGELOG.md | 1 + cassandra_test.go | 56 ++++++++++++++++++++++++++++++++++++++++++++--- helpers.go | 19 ++++++++++++---- session.go | 16 +++++++++----- 4 files changed, 80 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 67c88a141..c2f40556b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Retry policy now takes into account query idempotency (CASSGO-27) - Don't return error to caller with RetryType Ignore (CASSGO-28) +- Don't panic in MapExecuteBatchCAS if no `[applied]` column is returned (CASSGO-42) ## [1.7.0] - 2024-09-23 diff --git a/cassandra_test.go b/cassandra_test.go index ec6969190..bfe34a508 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -32,7 +32,6 @@ import ( "context" "errors" "fmt" - "github.com/stretchr/testify/require" "io" "math" "math/big" @@ -45,6 +44,8 @@ import ( "time" "unicode" + "github.com/stretchr/testify/require" + "gopkg.in/inf.v0" ) @@ -476,7 +477,7 @@ func TestCAS(t *testing.T) { if applied, _, err := session.ExecuteBatchCAS(successBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil { t.Fatal("insert:", err) } else if applied { - t.Fatalf("insert should have been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS) + t.Fatalf("insert should have not been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS) } insertBatch := session.Batch(LoggedBatch) @@ -492,7 +493,7 @@ func TestCAS(t *testing.T) { if applied, iter, err := session.ExecuteBatchCAS(failBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil { t.Fatal("insert:", err) } else if applied { - t.Fatalf("insert should have been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS) + t.Fatalf("insert should have not been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS) } else { if scan := iter.Scan(&applied, &titleCAS, &revidCAS, &modifiedCAS); scan && applied { t.Fatalf("insert should have been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS) @@ -503,6 +504,55 @@ func TestCAS(t *testing.T) { t.Fatal("scan:", err) } } + + casMap = make(map[string]interface{}) + if applied, err := session.Query(`SELECT revid FROM cas_table WHERE title = ?`, + title+"_foo").MapScanCAS(casMap); err != nil { + t.Fatal("select:", err) + } else if applied { + t.Fatal("select shouldn't have returned applied") + } + + if _, err := session.Query(`SELECT revid FROM cas_table WHERE title = ?`, + title+"_foo").ScanCAS(&revidCAS); err == nil { + t.Fatal("select: should have returned an error") + } + + notCASBatch := session.Batch(LoggedBatch) + notCASBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?)", title+"_baz", revid, modified) + casMap = make(map[string]interface{}) + if _, _, err := session.MapExecuteBatchCAS(notCASBatch, casMap); err != ErrNotFound { + t.Fatal("insert should have returned not found:", err) + } + + notCASBatch = session.Batch(LoggedBatch) + notCASBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?)", title+"_baz", revid, modified) + casMap = make(map[string]interface{}) + if _, _, err := session.ExecuteBatchCAS(notCASBatch, &revidCAS); err != ErrNotFound { + t.Fatal("insert should have returned not found:", err) + } + + failBatch = session.Batch(LoggedBatch) + failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified = ?", modified) + if _, _, err := session.ExecuteBatchCAS(failBatch, new(bool)); err == nil { + t.Fatal("update should have errored") + } + // make sure MapScanCAS does not panic when MapScan fails + casMap = make(map[string]interface{}) + casMap["last_modified"] = false + if _, err := session.Query(`UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified = ?`, + modified).MapScanCAS(casMap); err == nil { + t.Fatal("update should hvae errored", err) + } + + // make sure MapExecuteBatchCAS does not panic when MapScan fails + failBatch = session.Batch(LoggedBatch) + failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified = ?", modified) + casMap = make(map[string]interface{}) + casMap["last_modified"] = false + if _, _, err := session.MapExecuteBatchCAS(failBatch, casMap); err == nil { + t.Fatal("update should have errored") + } } func TestDurationType(t *testing.T) { diff --git a/helpers.go b/helpers.go index f2faee9e0..a4812c1d4 100644 --- a/helpers.go +++ b/helpers.go @@ -322,6 +322,7 @@ func TupleColumnName(c string, n int) string { return fmt.Sprintf("%s[%d]", c, n) } +// RowData returns the RowData for the iterator. func (iter *Iter) RowData() (RowData, error) { if iter.err != nil { return RowData{}, iter.err @@ -334,6 +335,7 @@ func (iter *Iter) RowData() (RowData, error) { if c, ok := column.TypeInfo.(TupleTypeInfo); !ok { val, err := column.TypeInfo.NewWithError() if err != nil { + iter.err = err return RowData{}, err } columns = append(columns, column.Name) @@ -343,6 +345,7 @@ func (iter *Iter) RowData() (RowData, error) { columns = append(columns, TupleColumnName(column.Name, i)) val, err := elem.NewWithError() if err != nil { + iter.err = err return RowData{}, err } values = append(values, val) @@ -364,7 +367,10 @@ func (iter *Iter) rowMap() (map[string]interface{}, error) { return nil, iter.err } - rowData, _ := iter.RowData() + rowData, err := iter.RowData() + if err != nil { + return nil, err + } iter.Scan(rowData.Values...) m := make(map[string]interface{}, len(rowData.Columns)) rowData.rowMap(m) @@ -379,7 +385,10 @@ func (iter *Iter) SliceMap() ([]map[string]interface{}, error) { } // Not checking for the error because we just did - rowData, _ := iter.RowData() + rowData, err := iter.RowData() + if err != nil { + return nil, err + } dataToReturn := make([]map[string]interface{}, 0) for iter.Scan(rowData.Values...) { m := make(map[string]interface{}, len(rowData.Columns)) @@ -435,8 +444,10 @@ func (iter *Iter) MapScan(m map[string]interface{}) bool { return false } - // Not checking for the error because we just did - rowData, _ := iter.RowData() + rowData, err := iter.RowData() + if err != nil { + return false + } for i, col := range rowData.Columns { if dest, ok := m[col]; ok { diff --git a/session.go b/session.go index d04a13672..a44ba45ed 100644 --- a/session.go +++ b/session.go @@ -785,7 +785,7 @@ func (s *Session) ExecuteBatchCAS(batch *Batch, dest ...interface{}) (applied bo iter.Scan(&applied) } - return applied, iter, nil + return applied, iter, iter.err } // MapExecuteBatchCAS executes a batch operation much like ExecuteBatchCAS, @@ -798,8 +798,11 @@ func (s *Session) MapExecuteBatchCAS(batch *Batch, dest map[string]interface{}) return false, nil, err } iter.MapScan(dest) - applied = dest["[applied]"].(bool) - delete(dest, "[applied]") + // check if [applied] was returned, otherwise it might not be CAS + if _, ok := dest["[applied]"]; ok { + applied = dest["[applied]"].(bool) + delete(dest, "[applied]") + } // we usually close here, but instead of closing, just returin an error // if MapScan failed. Although Close just returns err, using Close @@ -1387,8 +1390,11 @@ func (q *Query) MapScanCAS(dest map[string]interface{}) (applied bool, err error return false, err } iter.MapScan(dest) - applied = dest["[applied]"].(bool) - delete(dest, "[applied]") + // check if [applied] was returned, otherwise it might not be CAS + if _, ok := dest["[applied]"]; ok { + applied = dest["[applied]"].(bool) + delete(dest, "[applied]") + } return applied, iter.Close() }