From 735b59c79395f38b4268c8ae4af5a302fbb86f48 Mon Sep 17 00:00:00 2001 From: pmahindrakar-oss Date: Fri, 11 Mar 2022 21:15:56 +0530 Subject: [PATCH] Create admin db if it doesn't exist (#367) --- cmd/entrypoints/migrate.go | 3 +- pkg/repositories/database.go | 81 ++++++++++++++++++++++++----- pkg/repositories/errors/postgres.go | 5 ++ 3 files changed, 75 insertions(+), 14 deletions(-) diff --git a/cmd/entrypoints/migrate.go b/cmd/entrypoints/migrate.go index fdad8c7fcc..df1393f6cc 100644 --- a/cmd/entrypoints/migrate.go +++ b/cmd/entrypoints/migrate.go @@ -4,10 +4,10 @@ import ( "context" "github.com/flyteorg/flyteadmin/pkg/repositories" - "github.com/flyteorg/flyteadmin/pkg/repositories/config" "github.com/flyteorg/flyteadmin/pkg/runtime" "github.com/flyteorg/flytestdlib/logger" + "github.com/go-gormigrate/gormigrate/v2" "github.com/spf13/cobra" _ "gorm.io/driver/postgres" // Required to import database driver. @@ -32,7 +32,6 @@ var migrateCmd = &cobra.Command{ if err != nil { logger.Fatal(ctx, err) } - sqlDB, err := db.DB() if err != nil { logger.Fatal(ctx, err) diff --git a/pkg/repositories/database.go b/pkg/repositories/database.go index e4dd13cc3f..61c1616966 100644 --- a/pkg/repositories/database.go +++ b/pkg/repositories/database.go @@ -5,15 +5,22 @@ import ( "fmt" "io/ioutil" "os" + "reflect" "strings" + repoErrors "github.com/flyteorg/flyteadmin/pkg/repositories/errors" runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" "github.com/flyteorg/flytestdlib/logger" + + "github.com/jackc/pgconn" "gorm.io/driver/postgres" "gorm.io/gorm" gormLogger "gorm.io/gorm/logger" ) +const pqInvalidDBCode = "3D000" +const defaultDB = "postgres" + // getGormLogLevel converts between the flytestdlib configured log level to the equivalent gorm log level. func getGormLogLevel(ctx context.Context, logConfig *logger.Config) gormLogger.LogLevel { if logConfig == nil { @@ -70,19 +77,27 @@ func getPostgresDsn(ctx context.Context, pgConfig runtimeInterfaces.PostgresConf pgConfig.Host, pgConfig.Port, pgConfig.DbName, pgConfig.User, password, pgConfig.ExtraOptions) } +// GetDB uses the dbConfig to create gorm DB object. If the db doesn't exist for the dbConfig then a new one is created +// using the default db for the provider. eg : postgres has default dbName as postgres func GetDB(ctx context.Context, dbConfig *runtimeInterfaces.DbConfig, logConfig *logger.Config) ( - *gorm.DB, error) { + gormDb *gorm.DB, err error) { if dbConfig == nil { panic("Cannot initialize database repository from empty db config") } - var dialector gorm.Dialector - logLevel := getGormLogLevel(ctx, logConfig) + gormConfig := &gorm.Config{ + Logger: gormLogger.Default.LogMode(getGormLogLevel(ctx, logConfig)), + DisableForeignKeyConstraintWhenMigrating: !dbConfig.EnableForeignKeyConstraintWhenMigrating, + } + // TODO: add other gorm-supported db type handling in further case blocks. switch { // TODO: Figure out a better proxy for a non-empty postgres config case len(dbConfig.PostgresConfig.Host) > 0 || len(dbConfig.PostgresConfig.User) > 0 || len(dbConfig.PostgresConfig.DbName) > 0: - dialector = postgres.Open(getPostgresDsn(ctx, dbConfig.PostgresConfig)) - // TODO: add other gorm-supported db type handling in further case blocks. + gormDb, err = createPostgresDbIfNotExists(ctx, gormConfig, dbConfig.PostgresConfig) + if err != nil { + return nil, err + } + case len(dbConfig.DeprecatedHost) > 0 || len(dbConfig.DeprecatedUser) > 0 || len(dbConfig.DeprecatedDbName) > 0: pgConfig := runtimeInterfaces.PostgresConfig{ Host: dbConfig.DeprecatedHost, @@ -94,21 +109,63 @@ func GetDB(ctx context.Context, dbConfig *runtimeInterfaces.DbConfig, logConfig ExtraOptions: dbConfig.DeprecatedExtraOptions, Debug: dbConfig.DeprecatedDebug, } - dialector = postgres.Open(getPostgresDsn(ctx, pgConfig)) + gormDb, err = createPostgresDbIfNotExists(ctx, gormConfig, pgConfig) + if err != nil { + return nil, err + } default: panic(fmt.Sprintf("Unrecognized database config %v", dbConfig)) } - gormDb, err := gorm.Open(dialector, &gorm.Config{ - Logger: gormLogger.Default.LogMode(logLevel), - DisableForeignKeyConstraintWhenMigrating: !dbConfig.EnableForeignKeyConstraintWhenMigrating, - }) + // Setup connection pool settings + return gormDb, setupDbConnectionPool(gormDb, dbConfig) +} + +// Creates DB if it doesn't exist for the passed in config +func createPostgresDbIfNotExists(ctx context.Context, gormConfig *gorm.Config, pgConfig runtimeInterfaces.PostgresConfig) (*gorm.DB, error) { + + dialector := postgres.Open(getPostgresDsn(ctx, pgConfig)) + gormDb, err := gorm.Open(dialector, gormConfig) + if err == nil { + return gormDb, nil + } + + // Check if its invalid db code error + cErr, ok := err.(repoErrors.ConnectError) + if !ok { + logger.Errorf(ctx, "Failed to cast error of type: %v, err: %v", reflect.TypeOf(err), + err) + return nil, err + } + pqError := cErr.Unwrap().(*pgconn.PgError) + if pqError.Code != pqInvalidDBCode { + return nil, err + } + + logger.Warningf(ctx, "Database [%v] does not exist", pgConfig.DbName) + + // Every postgres installation includes a 'postgres' database by default. We connect to that now in order to + // initialize the user-specified database. + defaultDbPgConfig := pgConfig + defaultDbPgConfig.DbName = defaultDB + defaultDBDialector := postgres.Open(getPostgresDsn(ctx, defaultDbPgConfig)) + gormDb, err = gorm.Open(defaultDBDialector, gormConfig) if err != nil { return nil, err } - // Setup connection pool settings - return gormDb, setupDbConnectionPool(gormDb, dbConfig) + // Because we asserted earlier that the db does not exist, we create it now. + logger.Infof(ctx, "Creating database %v", pgConfig.DbName) + + // NOTE: golang sql drivers do not support parameter injection for CREATE calls + createDBStatement := fmt.Sprintf("CREATE DATABASE %s", pgConfig.DbName) + result := gormDb.Exec(createDBStatement) + + if result.Error != nil { + return nil, result.Error + } + // Now try connecting to the db again + return gorm.Open(dialector, gormConfig) } func setupDbConnectionPool(gormDb *gorm.DB, dbConfig *runtimeInterfaces.DbConfig) error { diff --git a/pkg/repositories/errors/postgres.go b/pkg/repositories/errors/postgres.go index a61c613458..29f2efce7d 100644 --- a/pkg/repositories/errors/postgres.go +++ b/pkg/repositories/errors/postgres.go @@ -104,3 +104,8 @@ func NewPostgresErrorTransformer(scope promutils.Scope) ErrorTransformer { metrics: metrics, } } + +type ConnectError interface { + Unwrap() error + Error() string +}