Skip to content

Commit

Permalink
Merge remote-tracking branch 'gocql/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
avelanarius committed Jul 24, 2023
2 parents 8c8fbcc + 7a686db commit 61be561
Show file tree
Hide file tree
Showing 9 changed files with 174 additions and 27 deletions.
2 changes: 2 additions & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,5 @@ João Reis <[email protected]>
Lauro Ramos Venancio <[email protected]>
Dmitry Kropachev <[email protected]>
Oliver Boyle <[email protected]>
Jackson Fleming <[email protected]>
Sylwia Szunejko <[email protected]>
5 changes: 4 additions & 1 deletion conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ var (
"com.ericsson.bss.cassandra.ecaudit.auth.AuditAuthenticator",
"com.scylladb.auth.SaslauthdAuthenticator",
"com.scylladb.auth.TransitionalAuthenticator",
"com.instaclustr.cassandra.auth.InstaclustrPasswordAuthenticator",
}
)

Expand Down Expand Up @@ -1428,9 +1429,11 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter {
customPayload: qry.customPayload,
}

// Set "lwt" property in the query if it is present in preparedMetadata
// Set "lwt", keyspace", "table" property in the query if it is present in preparedMetadata
qry.routingInfo.mu.Lock()
qry.routingInfo.lwt = info.request.lwt
qry.routingInfo.keyspace = info.request.keyspace
qry.routingInfo.table = info.request.table
qry.routingInfo.mu.Unlock()
} else {
frame = &writeQueryFrame{
Expand Down
4 changes: 4 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ func TestApprove(t *testing.T) {
approve("com.datastax.bdp.cassandra.auth.DseAuthenticator", []string{}): true,
approve("io.aiven.cassandra.auth.AivenAuthenticator", []string{}): true,
approve("com.amazon.helenus.auth.HelenusAuthenticator", []string{}): true,
approve("com.ericsson.bss.cassandra.ecaudit.auth.AuditAuthenticator", []string{}): true,
approve("com.scylladb.auth.SaslauthdAuthenticator", []string{}): true,
approve("com.scylladb.auth.TransitionalAuthenticator", []string{}): true,
approve("com.instaclustr.cassandra.auth.InstaclustrPasswordAuthenticator", []string{}): true,
approve("com.apache.cassandra.auth.FakeAuthenticator", []string{}): false,
approve("com.apache.cassandra.auth.FakeAuthenticator", nil): false,
approve("com.apache.cassandra.auth.FakeAuthenticator", []string{"com.apache.cassandra.auth.FakeAuthenticator"}): true,
Expand Down
45 changes: 35 additions & 10 deletions control.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ func (c *controlConn) setupConn(conn *Conn) error {
}

if err := c.registerEvents(conn); err != nil {
return err
return fmt.Errorf("register events: %v", err)
}

ch := &connHost{
Expand Down Expand Up @@ -347,6 +347,20 @@ func (c *controlConn) reconnect() {
}
defer atomic.StoreInt32(&c.reconnecting, 0)

conn, err := c.attemptReconnect()

if conn == nil {
c.session.logger.Printf("gocql: unable to reconnect control connection: %v\n", err)
return
}

err = c.session.refreshRing()
if err != nil {
c.session.logger.Printf("gocql: unable to refresh ring: %v\n", err)
}
}

func (c *controlConn) attemptReconnect() (*Conn, error) {
hosts := c.session.ring.allHosts()
hosts = shuffleHosts(hosts)

Expand All @@ -363,6 +377,25 @@ func (c *controlConn) reconnect() {
ch.conn.Close()
}

conn, err := c.attemptReconnectToAnyOfHosts(hosts)

if conn != nil {
return conn, err
}

c.session.logger.Printf("gocql: unable to connect to any ring node: %v\n", err)
c.session.logger.Printf("gocql: control falling back to initial contact points.\n")
// Fallback to initial contact points, as it may be the case that all known initialHosts
// changed their IPs while keeping the same hostname(s).
initialHosts, resolvErr := addrsToHosts(c.session.cfg.Hosts, c.session.cfg.Port, c.session.logger)
if resolvErr != nil {
return nil, fmt.Errorf("resolve contact points' hostnames: %v", resolvErr)
}

return c.attemptReconnectToAnyOfHosts(initialHosts)
}

func (c *controlConn) attemptReconnectToAnyOfHosts(hosts []*HostInfo) (*Conn, error) {
var conn *Conn
var err error
for _, host := range hosts {
Expand All @@ -379,15 +412,7 @@ func (c *controlConn) reconnect() {
conn.Close()
conn = nil
}
if conn == nil {
c.session.logger.Printf("gocql: control unable to register events: %v\n", err)
return
}

err = c.session.refreshRing()
if err != nil {
c.session.logger.Printf("gocql: unable to refresh ring: %v\n", err)
}
return conn, err
}

func (c *controlConn) HandleError(conn *Conn, err error, closed bool) {
Expand Down
13 changes: 8 additions & 5 deletions frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,10 @@ type preparedMetadata struct {

// proto v4+
pkeyColumns []int

keyspace string

table string
}

func (r preparedMetadata) String() string {
Expand Down Expand Up @@ -981,26 +985,25 @@ func (f *framer) parsePreparedMetadata() preparedMetadata {
return meta
}

var keyspace, table string
globalSpec := meta.flags&flagGlobalTableSpec == flagGlobalTableSpec
if globalSpec {
keyspace = f.readString()
table = f.readString()
meta.keyspace = f.readString()
meta.table = f.readString()
}

var cols []ColumnInfo
if meta.colCount < 1000 {
// preallocate columninfo to avoid excess copying
cols = make([]ColumnInfo, meta.colCount)
for i := 0; i < meta.colCount; i++ {
f.readCol(&cols[i], &meta.resultMetadata, globalSpec, keyspace, table)
f.readCol(&cols[i], &meta.resultMetadata, globalSpec, meta.keyspace, meta.table)
}
} else {
// use append, huge number of columns usually indicates a corrupt frame or
// just a huge row.
for i := 0; i < meta.colCount; i++ {
var col ColumnInfo
f.readCol(&col, &meta.resultMetadata, globalSpec, keyspace, table)
f.readCol(&col, &meta.resultMetadata, globalSpec, meta.keyspace, meta.table)
cols = append(cols, col)
}
}
Expand Down
2 changes: 1 addition & 1 deletion host_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ func refreshRing(r *ringDescriber) error {
if !ok {
return fmt.Errorf("get existing host=%s from prevHosts: %w", h, ErrCannotFindHost)
}
if h.nodeToNodeAddress().Equal(existing.nodeToNodeAddress()) {
if h.connectAddress.Equal(existing.connectAddress) && h.nodeToNodeAddress().Equal(existing.nodeToNodeAddress()) {
// no host IP change
host.update(h)
} else {
Expand Down
81 changes: 81 additions & 0 deletions keyspace_table_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
//go:build all || integration
// +build all integration

package gocql

import (
"context"
"fmt"
"testing"
)

// Keyspace_table checks if Query.Keyspace() is updated based on prepared statement
func TestKeyspaceTable(t *testing.T) {
cluster := createCluster()

fallback := RoundRobinHostPolicy()
cluster.PoolConfig.HostSelectionPolicy = TokenAwareHostPolicy(fallback)

session, err := cluster.CreateSession()
if err != nil {
t.Fatal("createSession:", err)
}

cluster.Keyspace = "wrong_keyspace"

keyspace := "test1"
table := "table1"

err = createTable(session, `DROP KEYSPACE IF EXISTS `+keyspace)
if err != nil {
t.Fatal("unable to drop keyspace:", err)
}

err = createTable(session, fmt.Sprintf(`CREATE KEYSPACE %s
WITH replication = {
'class' : 'SimpleStrategy',
'replication_factor' : 1
}`, keyspace))

if err != nil {
t.Fatal("unable to create keyspace:", err)
}

if err := session.control.awaitSchemaAgreement(); err != nil {
t.Fatal(err)
}

err = createTable(session, fmt.Sprintf(`CREATE TABLE %s.%s (pk int, ck int, v int, PRIMARY KEY (pk, ck));
`, keyspace, table))

if err != nil {
t.Fatal("unable to create table:", err)
}

if err := session.control.awaitSchemaAgreement(); err != nil {
t.Fatal(err)
}

ctx := context.Background()

// insert a row
if err := session.Query(`INSERT INTO test1.table1(pk, ck, v) VALUES (?, ?, ?)`,
1, 2, 3).WithContext(ctx).Consistency(One).Exec(); err != nil {
t.Fatal(err)
}

var pk int

/* Search for a specific set of records whose 'pk' column matches
* the value of inserted row. */
qry := session.Query(`SELECT pk FROM test1.table1 WHERE pk = ? LIMIT 1`,
1).WithContext(ctx).Consistency(One)
if err := qry.Scan(&pk); err != nil {
t.Fatal(err)
}

// cluster.Keyspace was set to "wrong_keyspace", but during prepering statement
// Keyspace in Query should be changed to "test" and Table should be changed to table1
assertEqual(t, "qry.Keyspace()", "test1", qry.Keyspace())
assertEqual(t, "qry.Table()", "table1", qry.Table())
}
1 change: 1 addition & 0 deletions query_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type ExecutableQuery interface {
speculativeExecutionPolicy() SpeculativeExecutionPolicy
GetRoutingKey() ([]byte, error)
Keyspace() string
Table() string
IsIdempotent() bool
IsLWT() bool
GetCustomPartitioner() partitioner
Expand Down
48 changes: 38 additions & 10 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ var queryPool = &sync.Pool{

func addrsToHosts(addrs []string, defaultPort int, logger StdLogger) ([]*HostInfo, error) {
var hosts []*HostInfo
for _, hostport := range addrs {
resolvedHosts, err := hostInfo(hostport, defaultPort)
for _, hostaddr := range addrs {
resolvedHosts, err := hostInfo(hostaddr, defaultPort)
if err != nil {
// Try other hosts if unable to resolve DNS name
if _, ok := err.(*net.DNSError); ok {
Expand Down Expand Up @@ -643,8 +643,8 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI
return nil, nil
}

table := info.request.columns[0].Table
keyspace := info.request.columns[0].Keyspace
table := info.request.table
keyspace := info.request.keyspace

partitioner, err := scyllaGetTablePartitioner(s, keyspace, table)
if err != nil {
Expand All @@ -665,6 +665,8 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI
types: types,
lwt: info.request.lwt,
partitioner: partitioner,
keyspace: keyspace,
table: table,
}

inflight.value = routingKeyInfo
Expand Down Expand Up @@ -700,6 +702,8 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI
types: make([]TypeInfo, size),
lwt: info.request.lwt,
partitioner: partitioner,
keyspace: keyspace,
table: table,
}

for keyIndex, keyColumn := range partitionKey {
Expand Down Expand Up @@ -951,6 +955,10 @@ type queryRoutingInfo struct {

// If not nil, represents a custom partitioner for the table.
partitioner partitioner

keyspace string

table string
}

func (qri *queryRoutingInfo) isLWT() bool {
Expand Down Expand Up @@ -1158,6 +1166,10 @@ func (q *Query) Keyspace() string {
if q.getKeyspace != nil {
return q.getKeyspace()
}
if q.routingInfo.keyspace != "" {
return q.routingInfo.keyspace
}

if q.session == nil {
return ""
}
Expand All @@ -1166,6 +1178,11 @@ func (q *Query) Keyspace() string {
return q.session.cfg.Keyspace
}

// Table returns name of the table the query will be executed against.
func (q *Query) Table() string {
return q.routingInfo.table
}

// GetRoutingKey gets the routing key to use for routing this query. If
// a routing key has not been explicitly set, then the routing key will
// be constructed if possible using the keyspace's schema and the query
Expand All @@ -1187,10 +1204,13 @@ func (q *Query) GetRoutingKey() ([]byte, error) {
if err != nil {
return nil, err
}

if routingKeyInfo != nil {
q.routingInfo.mu.Lock()
q.routingInfo.lwt = routingKeyInfo.lwt
q.routingInfo.partitioner = routingKeyInfo.partitioner
q.routingInfo.keyspace = routingKeyInfo.keyspace
q.routingInfo.table = routingKeyInfo.table
q.routingInfo.mu.Unlock()
}
return createRoutingKey(routingKeyInfo, q.values)
Expand Down Expand Up @@ -1818,6 +1838,11 @@ func (b *Batch) Keyspace() string {
return b.keyspace
}

// Batch has no reasonable eqivalent of Query.Table().
func (b *Batch) Table() string {
return b.routingInfo.table
}

// Attempts returns the number of attempts made to execute the batch.
func (b *Batch) Attempts() int {
return b.metrics.attempts()
Expand Down Expand Up @@ -2106,8 +2131,10 @@ type routingKeyInfoLRU struct {
}

type routingKeyInfo struct {
indexes []int
types []TypeInfo
indexes []int
types []TypeInfo
keyspace string
table string
lwt bool
partitioner partitioner
}
Expand Down Expand Up @@ -2182,6 +2209,7 @@ func (t *traceWriter) Trace(traceId []byte) {
activity string
source string
elapsed int
thread string
)

t.mu.Lock()
Expand All @@ -2190,13 +2218,13 @@ func (t *traceWriter) Trace(traceId []byte) {
fmt.Fprintf(t.w, "Tracing session %016x (coordinator: %s, duration: %v):\n",
traceId, coordinator, time.Duration(duration)*time.Microsecond)

iter = t.session.control.query(`SELECT event_id, activity, source, source_elapsed
iter = t.session.control.query(`SELECT event_id, activity, source, source_elapsed, thread
FROM system_traces.events
WHERE session_id = ?`, traceId)

for iter.Scan(&timestamp, &activity, &source, &elapsed) {
fmt.Fprintf(t.w, "%s: %s (source: %s, elapsed: %d)\n",
timestamp.Format("2006/01/02 15:04:05.999999"), activity, source, elapsed)
for iter.Scan(&timestamp, &activity, &source, &elapsed, &thread) {
fmt.Fprintf(t.w, "%s: %s [%s] (source: %s, elapsed: %d)\n",
timestamp.Format("2006/01/02 15:04:05.999999"), activity, thread, source, elapsed)
}

if err := iter.Close(); err != nil {
Expand Down

0 comments on commit 61be561

Please sign in to comment.