diff --git a/clickhouse_std.go b/clickhouse_std.go index 7b18480a2b..466d345fb2 100644 --- a/clickhouse_std.go +++ b/clickhouse_std.go @@ -239,12 +239,32 @@ func (std *stdDriver) ResetSession(ctx context.Context) error { var _ driver.SessionResetter = (*stdDriver)(nil) -func (std *stdDriver) Ping(ctx context.Context) error { return std.conn.ping(ctx) } +func (std *stdDriver) Ping(ctx context.Context) error { + if std.conn.isBad() { + std.debugf("Ping: connection is bad") + return driver.ErrBadConn + } + + return std.conn.ping(ctx) +} var _ driver.Pinger = (*stdDriver)(nil) -func (std *stdDriver) Begin() (driver.Tx, error) { return std, nil } +func (std *stdDriver) Begin() (driver.Tx, error) { + if std.conn.isBad() { + std.debugf("Begin: connection is bad") + return nil, driver.ErrBadConn + } + + return std, nil +} + func (std *stdDriver) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + if std.conn.isBad() { + std.debugf("BeginTx: connection is bad") + return nil, driver.ErrBadConn + } + return std, nil } @@ -280,6 +300,11 @@ func (std *stdDriver) CheckNamedValue(nv *driver.NamedValue) error { return nil var _ driver.NamedValueChecker = (*stdDriver)(nil) func (std *stdDriver) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + if std.conn.isBad() { + std.debugf("ExecContext: connection is bad") + return nil, driver.ErrBadConn + } + var err error if options := queryOptions(ctx); options.async.ok { err = std.conn.asyncInsert(ctx, query, options.async.wait, rebind(args)...) @@ -299,6 +324,11 @@ func (std *stdDriver) ExecContext(ctx context.Context, query string, args []driv } func (std *stdDriver) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + if std.conn.isBad() { + std.debugf("QueryContext: connection is bad") + return nil, driver.ErrBadConn + } + r, err := std.conn.query(ctx, func(*connect, error) {}, query, rebind(args)...) if isConnBrokenError(err) { std.debugf("QueryContext got a fatal error, resetting connection: %v\n", err) @@ -319,6 +349,11 @@ func (std *stdDriver) Prepare(query string) (driver.Stmt, error) { } func (std *stdDriver) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + if std.conn.isBad() { + std.debugf("PrepareContext: connection is bad") + return nil, driver.ErrBadConn + } + batch, err := std.conn.prepareBatch(ctx, query, ldriver.PrepareBatchOptions{}, func(*connect, error) {}, func(context.Context) (*connect, error) { return nil, nil }) if err != nil { if isConnBrokenError(err) { diff --git a/tests/issues/1395_test.go b/tests/issues/1395_test.go new file mode 100644 index 0000000000..7292733ced --- /dev/null +++ b/tests/issues/1395_test.go @@ -0,0 +1,100 @@ +// Licensed to ClickHouse, Inc. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. ClickHouse, Inc. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package issues + +import ( + "context" + "database/sql" + "database/sql/driver" + "testing" + + "github.com/ClickHouse/clickhouse-go/v2" + clickhouse_tests "github.com/ClickHouse/clickhouse-go/v2/tests" + "github.com/pkg/errors" + "github.com/stretchr/testify/require" +) + +func Test1395(t *testing.T) { + testEnv, err := clickhouse_tests.GetTestEnvironment("issues") + require.NoError(t, err) + opts := clickhouse_tests.ClientOptionsFromEnv(testEnv, clickhouse.Settings{}, false) + conn, err := sql.Open("clickhouse", clickhouse_tests.OptionsToDSN(&opts)) + require.NoError(t, err) + + ctx := context.Background() + + singleConn, err := conn.Conn(ctx) + if err != nil { + t.Fatalf("Get single conn from pool: %v", err) + } + + tx1 := func(c *sql.Conn) error { + tx, err := c.BeginTx(ctx, nil) + if err != nil { + return errors.Wrap(err, "begin tx") + } + defer tx.Rollback() + + _, err = tx.ExecContext(ctx, ` +CREATE TABLE IF NOT EXISTS test_table +ON CLUSTER my +(id UInt32, name String) +ENGINE = MergeTree() +ORDER BY id`) + if err != nil { + return errors.Wrap(err, "create table") + } + + err = tx.Commit() + if err != nil { + return errors.Wrap(err, "commit tx") + } + + return nil + } + + err = tx1(singleConn) + require.Error(t, err, "expected error due to cluster is not configured") + + tx2 := func(c *sql.Conn) error { + tx, err := c.BeginTx(ctx, nil) + if err != nil { + return errors.Wrap(err, "begin tx") + } + defer tx.Rollback() + + _, err = tx.ExecContext(ctx, "INSERT INTO test_table (id, name) VALUES (?, ?)", 1, "test_name") + if err != nil { + return errors.Wrap(err, "failed to insert record") + } + err = tx.Commit() + if err != nil { + return errors.Wrap(err, "commit tx") + } + + return nil + } + require.NotPanics( + t, + func() { + err := tx2(singleConn) + require.ErrorIs(t, err, driver.ErrBadConn) + }, + "must not panics", + ) +}