diff --git a/flyteplugins/go/tasks/plugins/presto/config/config.go b/flyteplugins/go/tasks/plugins/presto/config/config.go index 47fa6a8401..de86392dae 100644 --- a/flyteplugins/go/tasks/plugins/presto/config/config.go +++ b/flyteplugins/go/tasks/plugins/presto/config/config.go @@ -76,6 +76,7 @@ var ( Environment: URLMustParse(""), DefaultRoutingGroup: "adhoc", DefaultUser: "flyte-default-user", + UseNamespaceAsUser: true, RoutingGroupConfigs: []RoutingGroupConfig{{Name: "adhoc", Limit: 100}, {Name: "etl", Limit: 25}}, RefreshCacheConfig: RefreshCacheConfig{ Name: "presto", @@ -101,6 +102,7 @@ type Config struct { Environment config.URL `json:"environment" pflag:",Environment endpoint for Presto to use"` DefaultRoutingGroup string `json:"defaultRoutingGroup" pflag:",Default Presto routing group"` DefaultUser string `json:"defaultUser" pflag:",Default Presto user"` + UseNamespaceAsUser bool `json:"useNamespaceAsUser" pflag:",Use the K8s namespace as the user"` RoutingGroupConfigs []RoutingGroupConfig `json:"routingGroupConfigs" pflag:"-,A list of cluster configs. Each of the configs corresponds to a service cluster"` RefreshCacheConfig RefreshCacheConfig `json:"refreshCacheConfig" pflag:"Refresh cache config"` ReadRateLimiterConfig RateLimiterConfig `json:"readRateLimiterConfig" pflag:"Rate limiter config for read requests going to Presto"` diff --git a/flyteplugins/go/tasks/plugins/presto/config/config_flags.go b/flyteplugins/go/tasks/plugins/presto/config/config_flags.go index 36211bcfdb..40287064c0 100755 --- a/flyteplugins/go/tasks/plugins/presto/config/config_flags.go +++ b/flyteplugins/go/tasks/plugins/presto/config/config_flags.go @@ -44,6 +44,7 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.String(fmt.Sprintf("%v%v", prefix, "environment"), defaultConfig.Environment.String(), "Environment endpoint for Presto to use") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "defaultRoutingGroup"), defaultConfig.DefaultRoutingGroup, "Default Presto routing group") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "defaultUser"), defaultConfig.DefaultUser, "Default Presto user") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "useNamespaceAsUser"), defaultConfig.UseNamespaceAsUser, "Use the K8s namespace as the user") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "refreshCacheConfig.name"), defaultConfig.RefreshCacheConfig.Name, "The name of the rate limiter") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "refreshCacheConfig.syncPeriod"), defaultConfig.RefreshCacheConfig.SyncPeriod.String(), "The duration to wait before the cache is refreshed again") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "refreshCacheConfig.workers"), defaultConfig.RefreshCacheConfig.Workers, "Number of parallel workers to refresh the cache") diff --git a/flyteplugins/go/tasks/plugins/presto/config/config_flags_test.go b/flyteplugins/go/tasks/plugins/presto/config/config_flags_test.go index bd4652d235..48f57df7c8 100755 --- a/flyteplugins/go/tasks/plugins/presto/config/config_flags_test.go +++ b/flyteplugins/go/tasks/plugins/presto/config/config_flags_test.go @@ -165,6 +165,28 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_useNamespaceAsUser", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vBool, err := cmdFlags.GetBool("useNamespaceAsUser"); err == nil { + assert.Equal(t, bool(defaultConfig.UseNamespaceAsUser), vBool) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("useNamespaceAsUser", testValue) + if vBool, err := cmdFlags.GetBool("useNamespaceAsUser"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.UseNamespaceAsUser) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) t.Run("Test_refreshCacheConfig.name", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly diff --git a/flyteplugins/go/tasks/plugins/presto/execution_state.go b/flyteplugins/go/tasks/plugins/presto/execution_state.go index 7cfb647b84..5268cf4499 100644 --- a/flyteplugins/go/tasks/plugins/presto/execution_state.go +++ b/flyteplugins/go/tasks/plugins/presto/execution_state.go @@ -93,6 +93,8 @@ type Query struct { ExternalLocation string `json:"externalLocation"` } +const PrestoSource = "flyte" + // This is the main state iteration func HandleExecutionState( ctx context.Context, @@ -296,6 +298,11 @@ func GetNextQuery( if err != nil { return Query{}, err } + var user = getUser(ctx, prestoCfg.DefaultUser) + + if prestoCfg.UseNamespaceAsUser { + user = tCtx.TaskExecutionMetadata().GetNamespace() + } statement = fmt.Sprintf(`CREATE TABLE hive.flyte_temporary_tables."%s_temp" AS %s`, tempTableName, statement) @@ -305,8 +312,8 @@ func GetNextQuery( RoutingGroup: resolveRoutingGroup(ctx, routingGroup, prestoCfg), Catalog: catalog, Schema: schema, - Source: "flyte", - User: getUser(ctx, prestoCfg.DefaultUser), + Source: PrestoSource, + User: user, }, TempTableName: tempTableName + "_temp", ExternalTableName: tempTableName + "_external",