Skip to content

Commit

Permalink
Read entirety of database config (flyteorg#421)
Browse files Browse the repository at this point in the history
  • Loading branch information
katrogan authored May 10, 2022
1 parent 072ce74 commit 411ac2f
Show file tree
Hide file tree
Showing 21 changed files with 63 additions and 64 deletions.
1 change: 1 addition & 0 deletions cmd/entrypoints/serve_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//go:build integration
// +build integration

package entrypoints
Expand Down
10 changes: 5 additions & 5 deletions pkg/async/cloudevent/implementations/cloudevent_publisher.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,20 @@ func (p *Publisher) Publish(ctx context.Context, notificationType string, msg pr
var phase string
var eventTime time.Time

switch msg.(type) {
switch msgType := msg.(type) {
case *admin.WorkflowExecutionEventRequest:
e := msg.(*admin.WorkflowExecutionEventRequest).Event
e := msgType.Event
executionID = e.ExecutionId.String()
phase = e.Phase.String()
eventTime = e.OccurredAt.AsTime()
case *admin.TaskExecutionEventRequest:
e := msg.(*admin.TaskExecutionEventRequest).Event
e := msgType.Event
executionID = e.TaskId.String()
phase = e.Phase.String()
eventTime = e.OccurredAt.AsTime()
case *admin.NodeExecutionEventRequest:
e := msg.(*admin.NodeExecutionEventRequest).Event
executionID = msg.(*admin.NodeExecutionEventRequest).Event.Id.String()
e := msgType.Event
executionID = msgType.Event.Id.String()
phase = e.Phase.String()
eventTime = e.OccurredAt.AsTime()
default:
Expand Down
5 changes: 2 additions & 3 deletions pkg/manager/impl/resources/resource_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package resources

import (
"context"
"fmt"
"testing"

"github.com/flyteorg/flyteadmin/pkg/errors"
Expand Down Expand Up @@ -130,7 +129,7 @@ func TestUpdateWorkflowAttributes_CreateOrMerge(t *testing.T) {
} else if override.TaskType == "hive" {
assert.EqualValues(t, []string{"plugin b"}, override.PluginId)
} else {
t.Error(fmt.Sprintf("Unexpected task type [%s] plugin override committed to db", override.TaskType))
t.Errorf("Unexpected task type [%s] plugin override committed to db", override.TaskType)
}
}
createOrUpdateCalled = true
Expand Down Expand Up @@ -301,7 +300,7 @@ func TestUpdateProjectDomainAttributes_CreateOrMerge(t *testing.T) {
} else if override.TaskType == "hive" {
assert.EqualValues(t, []string{"plugin b"}, override.PluginId)
} else {
t.Error(fmt.Sprintf("Unexpected task type [%s] plugin override committed to db", override.TaskType))
t.Errorf("Unexpected task type [%s] plugin override committed to db", override.TaskType)
}
}
createOrUpdateCalled = true
Expand Down
3 changes: 1 addition & 2 deletions pkg/repositories/config/migrations_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package config

import (
"fmt"
"testing"

mocket "github.com/Selvatico/go-mocket"
Expand Down Expand Up @@ -29,7 +28,7 @@ func GetDbForTest(t *testing.T) *gorm.DB {
mocket.Catcher.Register()
db, err := gorm.Open(postgres.New(postgres.Config{DriverName: mocket.DriverName}))
if err != nil {
t.Fatal(fmt.Sprintf("Failed to open mock db with err %v", err))
t.Fatalf("Failed to open mock db with err %v", err)
}
return db
}
24 changes: 12 additions & 12 deletions pkg/repositories/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
repoErrors "github.com/flyteorg/flyteadmin/pkg/repositories/errors"
"gorm.io/driver/sqlite"

runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces"
"github.com/flyteorg/flytestdlib/logger"
"github.com/jackc/pgconn"
"gorm.io/driver/postgres"
Expand Down Expand Up @@ -44,7 +43,7 @@ func resolvePassword(ctx context.Context, passwordVal, passwordPath string) stri
}

// Produces the DSN (data source name) for opening a postgres db connection.
func getPostgresDsn(ctx context.Context, pgConfig *runtimeInterfaces.PostgresConfig) string {
func getPostgresDsn(ctx context.Context, pgConfig database.PostgresConfig) string {
password := resolvePassword(ctx, pgConfig.Password, pgConfig.PasswordPath)
if len(password) == 0 {
// The password-less case is included for development environments.
Expand All @@ -57,7 +56,7 @@ func getPostgresDsn(ctx context.Context, pgConfig *runtimeInterfaces.PostgresCon

// GetDB uses the dbConfig to create gorm DB object. If the db doesn't exist for the dbConfig then a new one is created
// using the default db for the provider. eg : postgres has default dbName as postgres
func GetDB(ctx context.Context, dbConfig *runtimeInterfaces.DbConfig, logConfig *logger.Config) (
func GetDB(ctx context.Context, dbConfig *database.DbConfig, logConfig *logger.Config) (
*gorm.DB, error) {
if dbConfig == nil {
panic("Cannot initialize database repository from empty db config")
Expand All @@ -71,22 +70,22 @@ func GetDB(ctx context.Context, dbConfig *runtimeInterfaces.DbConfig, logConfig
var err error

switch {
case dbConfig.SQLiteConfig != nil:
if dbConfig.SQLiteConfig.File == "" {
case !(dbConfig.SQLite.IsEmpty()):
if dbConfig.SQLite.File == "" {
return nil, fmt.Errorf("illegal sqlite database configuration. `file` is a required parameter and should be a path")
}
gormDb, err = gorm.Open(sqlite.Open(dbConfig.SQLiteConfig.File), gormConfig)
gormDb, err = gorm.Open(sqlite.Open(dbConfig.SQLite.File), gormConfig)
if err != nil {
return nil, err
}
case dbConfig.PostgresConfig != nil && (len(dbConfig.PostgresConfig.Host) > 0 || len(dbConfig.PostgresConfig.User) > 0 || len(dbConfig.PostgresConfig.DbName) > 0):
gormDb, err = createPostgresDbIfNotExists(ctx, gormConfig, dbConfig.PostgresConfig)
case !(dbConfig.Postgres.IsEmpty()):
gormDb, err = createPostgresDbIfNotExists(ctx, gormConfig, dbConfig.Postgres)
if err != nil {
return nil, err
}

case len(dbConfig.DeprecatedHost) > 0 || len(dbConfig.DeprecatedUser) > 0 || len(dbConfig.DeprecatedDbName) > 0:
pgConfig := &runtimeInterfaces.PostgresConfig{
pgConfig := database.PostgresConfig{
Host: dbConfig.DeprecatedHost,
Port: dbConfig.DeprecatedPort,
DbName: dbConfig.DeprecatedDbName,
Expand All @@ -105,11 +104,11 @@ func GetDB(ctx context.Context, dbConfig *runtimeInterfaces.DbConfig, logConfig
}

// Setup connection pool settings
return gormDb, setupDbConnectionPool(gormDb, dbConfig)
return gormDb, setupDbConnectionPool(ctx, gormDb, dbConfig)
}

// Creates DB if it doesn't exist for the passed in config
func createPostgresDbIfNotExists(ctx context.Context, gormConfig *gorm.Config, pgConfig *runtimeInterfaces.PostgresConfig) (*gorm.DB, error) {
func createPostgresDbIfNotExists(ctx context.Context, gormConfig *gorm.Config, pgConfig database.PostgresConfig) (*gorm.DB, error) {

dialector := postgres.Open(getPostgresDsn(ctx, pgConfig))
gormDb, err := gorm.Open(dialector, gormConfig)
Expand Down Expand Up @@ -155,13 +154,14 @@ func createPostgresDbIfNotExists(ctx context.Context, gormConfig *gorm.Config, p
return gorm.Open(dialector, gormConfig)
}

func setupDbConnectionPool(gormDb *gorm.DB, dbConfig *runtimeInterfaces.DbConfig) error {
func setupDbConnectionPool(ctx context.Context, gormDb *gorm.DB, dbConfig *database.DbConfig) error {
genericDb, err := gormDb.DB()
if err != nil {
return err
}
genericDb.SetConnMaxLifetime(dbConfig.ConnMaxLifeTime.Duration)
genericDb.SetMaxIdleConns(dbConfig.MaxIdleConnections)
genericDb.SetMaxOpenConns(dbConfig.MaxOpenConnections)
logger.Infof(ctx, "Set connection pool values to [%+v]", genericDb.Stats())
return nil
}
24 changes: 13 additions & 11 deletions pkg/repositories/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ import (
"testing"
"time"

runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces"
"github.com/flyteorg/flytestdlib/database"

"github.com/flyteorg/flytestdlib/config"
"github.com/flyteorg/flytestdlib/logger"

Expand All @@ -33,7 +34,7 @@ func TestResolvePassword(t *testing.T) {
}

func TestGetPostgresDsn(t *testing.T) {
pgConfig := &runtimeInterfaces.PostgresConfig{
pgConfig := database.PostgresConfig{
Host: "localhost",
Port: 5432,
DbName: "postgres",
Expand Down Expand Up @@ -73,16 +74,17 @@ func TestGetPostgresDsn(t *testing.T) {
}

func TestSetupDbConnectionPool(t *testing.T) {
ctx := context.TODO()
t.Run("successful", func(t *testing.T) {
gormDb, err := gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{})
assert.Nil(t, err)
dbConfig := &runtimeInterfaces.DbConfig{
dbConfig := &database.DbConfig{
DeprecatedPort: 5432,
MaxIdleConnections: 10,
MaxOpenConnections: 1000,
ConnMaxLifeTime: config.Duration{Duration: time.Hour},
}
err = setupDbConnectionPool(gormDb, dbConfig)
err = setupDbConnectionPool(ctx, gormDb, dbConfig)
assert.Nil(t, err)
genericDb, err := gormDb.DB()
assert.Nil(t, err)
Expand All @@ -91,12 +93,12 @@ func TestSetupDbConnectionPool(t *testing.T) {
t.Run("unset duration", func(t *testing.T) {
gormDb, err := gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{})
assert.Nil(t, err)
dbConfig := &runtimeInterfaces.DbConfig{
dbConfig := &database.DbConfig{
DeprecatedPort: 5432,
MaxIdleConnections: 10,
MaxOpenConnections: 1000,
}
err = setupDbConnectionPool(gormDb, dbConfig)
err = setupDbConnectionPool(ctx, gormDb, dbConfig)
assert.Nil(t, err)
genericDb, err := gormDb.DB()
assert.Nil(t, err)
Expand All @@ -108,13 +110,13 @@ func TestSetupDbConnectionPool(t *testing.T) {
ConnPool: &gorm.PreparedStmtDB{},
},
}
dbConfig := &runtimeInterfaces.DbConfig{
dbConfig := &database.DbConfig{
DeprecatedPort: 5432,
MaxIdleConnections: 10,
MaxOpenConnections: 1000,
ConnMaxLifeTime: config.Duration{Duration: time.Hour},
}
err := setupDbConnectionPool(gormDb, dbConfig)
err := setupDbConnectionPool(ctx, gormDb, dbConfig)
assert.NotNil(t, err)
})
}
Expand All @@ -123,14 +125,14 @@ func TestGetDB(t *testing.T) {
ctx := context.TODO()

t.Run("missing DB Config", func(t *testing.T) {
_, err := GetDB(ctx, &runtimeInterfaces.DbConfig{}, &logger.Config{})
_, err := GetDB(ctx, &database.DbConfig{}, &logger.Config{})
assert.Error(t, err)
})

t.Run("sqlite config", func(t *testing.T) {
dbFile := path.Join(t.TempDir(), "admin.db")
db, err := GetDB(ctx, &runtimeInterfaces.DbConfig{
SQLiteConfig: &runtimeInterfaces.SQLiteConfig{
db, err := GetDB(ctx, &database.DbConfig{
SQLite: database.SQLiteConfig{
File: dbFile,
},
}, &logger.Config{})
Expand Down
3 changes: 1 addition & 2 deletions pkg/repositories/gormimpl/utils_for_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
package gormimpl

import (
"fmt"
"testing"

"github.com/flyteorg/flyteadmin/pkg/common"
Expand All @@ -24,7 +23,7 @@ func GetDbForTest(t *testing.T) *gorm.DB {
mocket.Catcher.Register()
db, err := gorm.Open(postgres.New(postgres.Config{DriverName: mocket.DriverName}))
if err != nil {
t.Fatal(fmt.Sprintf("Failed to open mock db with err %v", err))
t.Fatalf("Failed to open mock db with err %v", err)
}
return db
}
Expand Down
3 changes: 1 addition & 2 deletions pkg/repositories/transformers/execution_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package transformers

import (
"context"
"fmt"
"math"
"strings"
"testing"
Expand Down Expand Up @@ -458,7 +457,7 @@ func TestSetExecutionAborted(t *testing.T) {
var actualClosure admin.ExecutionClosure
err = proto.Unmarshal(existingModel.Closure, &actualClosure)
if err != nil {
t.Fatal(fmt.Sprintf("Failed to marshal execution closure: %v", err))
t.Fatalf("Failed to marshal execution closure: %v", err)
}
assert.True(t, proto.Equal(&admin.ExecutionClosure{
OutputResult: &admin.ExecutionClosure_AbortMetadata{
Expand Down
16 changes: 2 additions & 14 deletions pkg/runtime/application_config_provider.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package runtime

import (
"fmt"

"github.com/flyteorg/flyteadmin/pkg/common"
"github.com/flyteorg/flyteadmin/pkg/runtime/interfaces"
"github.com/flyteorg/flytestdlib/config"
Expand Down Expand Up @@ -83,18 +81,8 @@ var cloudEventsConfig = config.MustRegisterSection(cloudEvents, &interfaces.Clou
// Implementation of an interfaces.ApplicationConfiguration
type ApplicationConfigurationProvider struct{}

func (p *ApplicationConfigurationProvider) GetDbConfig() *interfaces.DbConfig {
databaseConfig := database.GetConfig()
switch {
case !databaseConfig.SQLite.IsEmpty():
sqliteConfig := interfaces.SQLiteConfig(databaseConfig.SQLite)
return &interfaces.DbConfig{SQLiteConfig: &sqliteConfig}
case !databaseConfig.Postgres.IsEmpty():
postgresConfig := interfaces.PostgresConfig(databaseConfig.Postgres)
return &interfaces.DbConfig{PostgresConfig: &postgresConfig}
default:
panic(fmt.Errorf("database config cannot be empty"))
}
func (p *ApplicationConfigurationProvider) GetDbConfig() *database.DbConfig {
return database.GetConfig()
}

func (p *ApplicationConfigurationProvider) GetTopLevelConfig() *interfaces.ApplicationConfig {
Expand Down
12 changes: 6 additions & 6 deletions pkg/runtime/config_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ func TestPostgresConfig(t *testing.T) {

configProvider := NewConfigurationProvider()
dbConfig := configProvider.ApplicationConfiguration().GetDbConfig()
assert.Equal(t, 5432, dbConfig.PostgresConfig.Port)
assert.Equal(t, "postgres", dbConfig.PostgresConfig.Host)
assert.Equal(t, "postgres", dbConfig.PostgresConfig.User)
assert.Equal(t, "postgres", dbConfig.PostgresConfig.DbName)
assert.Equal(t, "sslmode=disable", dbConfig.PostgresConfig.ExtraOptions)
assert.Equal(t, 5432, dbConfig.Postgres.Port)
assert.Equal(t, "postgres", dbConfig.Postgres.Host)
assert.Equal(t, "postgres", dbConfig.Postgres.User)
assert.Equal(t, "postgres", dbConfig.Postgres.DbName)
assert.Equal(t, "sslmode=disable", dbConfig.Postgres.ExtraOptions)
}

func TestSqliteConfig(t *testing.T) {
Expand All @@ -80,5 +80,5 @@ func TestSqliteConfig(t *testing.T) {

configProvider := NewConfigurationProvider()
dbConfig := configProvider.ApplicationConfiguration().GetDbConfig()
assert.Equal(t, "admin.db", dbConfig.SQLiteConfig.File)
assert.Equal(t, "admin.db", dbConfig.SQLite.File)
}
3 changes: 2 additions & 1 deletion pkg/runtime/interfaces/application_configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flytestdlib/config"
"github.com/flyteorg/flytestdlib/database"
"github.com/golang/protobuf/ptypes/wrappers"

"golang.org/x/time/rate"
Expand Down Expand Up @@ -500,7 +501,7 @@ type DomainsConfig = []Domain

// Defines the interface to return top-level config structs necessary to start up a flyteadmin application.
type ApplicationConfiguration interface {
GetDbConfig() *DbConfig
GetDbConfig() *database.DbConfig
GetTopLevelConfig() *ApplicationConfig
GetSchedulerConfig() *SchedulerConfig
GetRemoteDataConfig() *RemoteDataConfig
Expand Down
7 changes: 4 additions & 3 deletions pkg/runtime/mocks/mock_application_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package mocks

import (
"github.com/flyteorg/flyteadmin/pkg/runtime/interfaces"
"github.com/flyteorg/flytestdlib/database"
)

type MockApplicationProvider struct {
dbConfig interfaces.DbConfig
dbConfig database.DbConfig
topLevelConfig interfaces.ApplicationConfig
schedulerConfig interfaces.SchedulerConfig
remoteDataConfig interfaces.RemoteDataConfig
Expand All @@ -15,11 +16,11 @@ type MockApplicationProvider struct {
cloudEventConfig interfaces.CloudEventsConfig
}

func (p *MockApplicationProvider) GetDbConfig() *interfaces.DbConfig {
func (p *MockApplicationProvider) GetDbConfig() *database.DbConfig {
return &p.dbConfig
}

func (p *MockApplicationProvider) SetDbConfig(dbConfig interfaces.DbConfig) {
func (p *MockApplicationProvider) SetDbConfig(dbConfig database.DbConfig) {
p.dbConfig = dbConfig
}

Expand Down
Loading

0 comments on commit 411ac2f

Please sign in to comment.