diff --git a/flyteadmin/pkg/repositories/database.go b/flyteadmin/pkg/repositories/database.go index 7a6e2dcc3c..5676ca8836 100644 --- a/flyteadmin/pkg/repositories/database.go +++ b/flyteadmin/pkg/repositories/database.go @@ -2,15 +2,14 @@ package repositories import ( "context" + "errors" "fmt" "io/ioutil" "os" - "reflect" "strings" "github.com/flyteorg/flytestdlib/database" - repoErrors "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "gorm.io/driver/sqlite" "github.com/flyteorg/flytestdlib/logger" @@ -116,15 +115,7 @@ func createPostgresDbIfNotExists(ctx context.Context, gormConfig *gorm.Config, p return gormDb, nil } - // Check if its invalid db code error - cErr, ok := err.(repoErrors.ConnectError) - if !ok { - logger.Errorf(ctx, "Failed to cast error of type: %v, err: %v", reflect.TypeOf(err), - err) - return nil, err - } - pqError := cErr.Unwrap().(*pgconn.PgError) - if pqError.Code != pqInvalidDBCode { + if !isInvalidDBPgError(err) { return nil, err } @@ -154,6 +145,17 @@ func createPostgresDbIfNotExists(ctx context.Context, gormConfig *gorm.Config, p return gorm.Open(dialector, gormConfig) } +func isInvalidDBPgError(err error) bool { + pgErr := &pgconn.PgError{} + if !errors.As(err, &pgErr) { + // err chain does not contain a pgconn.PgError + return false + } + + // pgconn.PgError found in chain and set to pgErr + return pgErr.Code == pqInvalidDBCode +} + func setupDbConnectionPool(ctx context.Context, gormDb *gorm.DB, dbConfig *database.DbConfig) error { genericDb, err := gormDb.DB() if err != nil { diff --git a/flyteadmin/pkg/repositories/database_test.go b/flyteadmin/pkg/repositories/database_test.go index 59fdc4173d..0c81349f80 100644 --- a/flyteadmin/pkg/repositories/database_test.go +++ b/flyteadmin/pkg/repositories/database_test.go @@ -2,7 +2,9 @@ package repositories import ( "context" + "errors" "io/ioutil" + "net" "os" "path" "path/filepath" @@ -10,6 +12,7 @@ import ( "time" "github.com/flyteorg/flytestdlib/database" + "github.com/jackc/pgconn" "github.com/flyteorg/flytestdlib/config" "github.com/flyteorg/flytestdlib/logger" @@ -73,6 +76,57 @@ func TestGetPostgresDsn(t *testing.T) { }) } +type wrappedError struct { + err error +} + +func (e *wrappedError) Error() string { + return e.err.Error() +} + +func (e *wrappedError) Unwrap() error { + return e.err +} + +func TestIsInvalidDBPgError(t *testing.T) { + // wrap error with wrappedError when testing to ensure the function checks the whole error chain + + testCases := []struct { + Name string + Err error + ExpectedResult bool + }{ + { + Name: "nil error", + Err: nil, + ExpectedResult: false, + }, + { + Name: "not a PgError", + Err: &wrappedError{err: &net.OpError{Op: "connect", Err: errors.New("connection refused")}}, + ExpectedResult: false, + }, + { + Name: "PgError but not invalid DB", + Err: &wrappedError{&pgconn.PgError{Severity: "FATAL", Message: "out of memory", Code: "53200"}}, + ExpectedResult: false, + }, + { + Name: "PgError and is invalid DB", + Err: &wrappedError{&pgconn.PgError{Severity: "FATAL", Message: "database \"flyte\" does not exist", Code: "3D000"}}, + ExpectedResult: true, + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.Name, func(t *testing.T) { + assert.Equal(t, tc.ExpectedResult, isInvalidDBPgError(tc.Err)) + }) + } +} + func TestSetupDbConnectionPool(t *testing.T) { ctx := context.TODO() t.Run("successful", func(t *testing.T) { diff --git a/flyteadmin/pkg/repositories/errors/postgres.go b/flyteadmin/pkg/repositories/errors/postgres.go index 29f2efce7d..a61c613458 100644 --- a/flyteadmin/pkg/repositories/errors/postgres.go +++ b/flyteadmin/pkg/repositories/errors/postgres.go @@ -104,8 +104,3 @@ func NewPostgresErrorTransformer(scope promutils.Scope) ErrorTransformer { metrics: metrics, } } - -type ConnectError interface { - Unwrap() error - Error() string -}