Skip to content

Commit

Permalink
Create admin db if it doesn't exist (flyteorg#367)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmahindrakar-oss authored Mar 11, 2022
1 parent ddac175 commit 735b59c
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 14 deletions.
3 changes: 1 addition & 2 deletions cmd/entrypoints/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ import (
"context"

"github.com/flyteorg/flyteadmin/pkg/repositories"

"github.com/flyteorg/flyteadmin/pkg/repositories/config"
"github.com/flyteorg/flyteadmin/pkg/runtime"
"github.com/flyteorg/flytestdlib/logger"

"github.com/go-gormigrate/gormigrate/v2"
"github.com/spf13/cobra"
_ "gorm.io/driver/postgres" // Required to import database driver.
Expand All @@ -32,7 +32,6 @@ var migrateCmd = &cobra.Command{
if err != nil {
logger.Fatal(ctx, err)
}

sqlDB, err := db.DB()
if err != nil {
logger.Fatal(ctx, err)
Expand Down
81 changes: 69 additions & 12 deletions pkg/repositories/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,22 @@ import (
"fmt"
"io/ioutil"
"os"
"reflect"
"strings"

repoErrors "github.com/flyteorg/flyteadmin/pkg/repositories/errors"
runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces"
"github.com/flyteorg/flytestdlib/logger"

"github.com/jackc/pgconn"
"gorm.io/driver/postgres"
"gorm.io/gorm"
gormLogger "gorm.io/gorm/logger"
)

const pqInvalidDBCode = "3D000"
const defaultDB = "postgres"

// getGormLogLevel converts between the flytestdlib configured log level to the equivalent gorm log level.
func getGormLogLevel(ctx context.Context, logConfig *logger.Config) gormLogger.LogLevel {
if logConfig == nil {
Expand Down Expand Up @@ -70,19 +77,27 @@ func getPostgresDsn(ctx context.Context, pgConfig runtimeInterfaces.PostgresConf
pgConfig.Host, pgConfig.Port, pgConfig.DbName, pgConfig.User, password, pgConfig.ExtraOptions)
}

// GetDB uses the dbConfig to create gorm DB object. If the db doesn't exist for the dbConfig then a new one is created
// using the default db for the provider. eg : postgres has default dbName as postgres
func GetDB(ctx context.Context, dbConfig *runtimeInterfaces.DbConfig, logConfig *logger.Config) (
*gorm.DB, error) {
gormDb *gorm.DB, err error) {
if dbConfig == nil {
panic("Cannot initialize database repository from empty db config")
}
var dialector gorm.Dialector
logLevel := getGormLogLevel(ctx, logConfig)
gormConfig := &gorm.Config{
Logger: gormLogger.Default.LogMode(getGormLogLevel(ctx, logConfig)),
DisableForeignKeyConstraintWhenMigrating: !dbConfig.EnableForeignKeyConstraintWhenMigrating,
}

// TODO: add other gorm-supported db type handling in further case blocks.
switch {
// TODO: Figure out a better proxy for a non-empty postgres config
case len(dbConfig.PostgresConfig.Host) > 0 || len(dbConfig.PostgresConfig.User) > 0 || len(dbConfig.PostgresConfig.DbName) > 0:
dialector = postgres.Open(getPostgresDsn(ctx, dbConfig.PostgresConfig))
// TODO: add other gorm-supported db type handling in further case blocks.
gormDb, err = createPostgresDbIfNotExists(ctx, gormConfig, dbConfig.PostgresConfig)
if err != nil {
return nil, err
}

case len(dbConfig.DeprecatedHost) > 0 || len(dbConfig.DeprecatedUser) > 0 || len(dbConfig.DeprecatedDbName) > 0:
pgConfig := runtimeInterfaces.PostgresConfig{
Host: dbConfig.DeprecatedHost,
Expand All @@ -94,21 +109,63 @@ func GetDB(ctx context.Context, dbConfig *runtimeInterfaces.DbConfig, logConfig
ExtraOptions: dbConfig.DeprecatedExtraOptions,
Debug: dbConfig.DeprecatedDebug,
}
dialector = postgres.Open(getPostgresDsn(ctx, pgConfig))
gormDb, err = createPostgresDbIfNotExists(ctx, gormConfig, pgConfig)
if err != nil {
return nil, err
}
default:
panic(fmt.Sprintf("Unrecognized database config %v", dbConfig))
}
gormDb, err := gorm.Open(dialector, &gorm.Config{
Logger: gormLogger.Default.LogMode(logLevel),
DisableForeignKeyConstraintWhenMigrating: !dbConfig.EnableForeignKeyConstraintWhenMigrating,
})

// Setup connection pool settings
return gormDb, setupDbConnectionPool(gormDb, dbConfig)
}

// Creates DB if it doesn't exist for the passed in config
func createPostgresDbIfNotExists(ctx context.Context, gormConfig *gorm.Config, pgConfig runtimeInterfaces.PostgresConfig) (*gorm.DB, error) {

dialector := postgres.Open(getPostgresDsn(ctx, pgConfig))
gormDb, err := gorm.Open(dialector, gormConfig)
if err == nil {
return gormDb, nil
}

// Check if its invalid db code error
cErr, ok := err.(repoErrors.ConnectError)
if !ok {
logger.Errorf(ctx, "Failed to cast error of type: %v, err: %v", reflect.TypeOf(err),
err)
return nil, err
}
pqError := cErr.Unwrap().(*pgconn.PgError)
if pqError.Code != pqInvalidDBCode {
return nil, err
}

logger.Warningf(ctx, "Database [%v] does not exist", pgConfig.DbName)

// Every postgres installation includes a 'postgres' database by default. We connect to that now in order to
// initialize the user-specified database.
defaultDbPgConfig := pgConfig
defaultDbPgConfig.DbName = defaultDB
defaultDBDialector := postgres.Open(getPostgresDsn(ctx, defaultDbPgConfig))
gormDb, err = gorm.Open(defaultDBDialector, gormConfig)
if err != nil {
return nil, err
}

// Setup connection pool settings
return gormDb, setupDbConnectionPool(gormDb, dbConfig)
// Because we asserted earlier that the db does not exist, we create it now.
logger.Infof(ctx, "Creating database %v", pgConfig.DbName)

// NOTE: golang sql drivers do not support parameter injection for CREATE calls
createDBStatement := fmt.Sprintf("CREATE DATABASE %s", pgConfig.DbName)
result := gormDb.Exec(createDBStatement)

if result.Error != nil {
return nil, result.Error
}
// Now try connecting to the db again
return gorm.Open(dialector, gormConfig)
}

func setupDbConnectionPool(gormDb *gorm.DB, dbConfig *runtimeInterfaces.DbConfig) error {
Expand Down
5 changes: 5 additions & 0 deletions pkg/repositories/errors/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,8 @@ func NewPostgresErrorTransformer(scope promutils.Scope) ErrorTransformer {
metrics: metrics,
}
}

type ConnectError interface {
Unwrap() error
Error() string
}

0 comments on commit 735b59c

Please sign in to comment.