diff --git a/flyteadmin/pkg/manager/impl/execution_manager.go b/flyteadmin/pkg/manager/impl/execution_manager.go index f41185f9a4..528296cb84 100644 --- a/flyteadmin/pkg/manager/impl/execution_manager.go +++ b/flyteadmin/pkg/manager/impl/execution_manager.go @@ -6,6 +6,8 @@ import ( "strconv" "time" + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flyteadmin/pkg/auth" "github.com/golang/protobuf/ptypes" @@ -83,6 +85,11 @@ type ExecutionManager struct { urlData dataInterfaces.RemoteURLInterface } +func getExecutionContext(ctx context.Context, id *core.WorkflowExecutionIdentifier) context.Context { + ctx = contextutils.WithExecutionID(ctx, id.Name) + return contextutils.WithProjectDomain(ctx, id.Project, id.Domain) +} + // Returns the unique string which identifies the authenticated end user (if any). func getUser(ctx context.Context) string { principalContextUser := ctx.Value(auth.PrincipalContextKey) @@ -182,21 +189,22 @@ func (m *ExecutionManager) offloadInputs(ctx context.Context, literalMap *core.L } func (m *ExecutionManager) launchExecutionAndPrepareModel( - ctx context.Context, request admin.ExecutionCreateRequest, requestedAt time.Time) (*models.Execution, error) { + ctx context.Context, request admin.ExecutionCreateRequest, requestedAt time.Time) ( + context.Context, *models.Execution, error) { err := validation.ValidateExecutionRequest(ctx, request, m.db, m.config.ApplicationConfiguration()) if err != nil { logger.Debugf(ctx, "Failed to validate ExecutionCreateRequest %+v with err %v", request, err) - return nil, err + return nil, nil, err } launchPlanModel, err := util.GetLaunchPlanModel(ctx, m.db, *request.Spec.LaunchPlan) if err != nil { logger.Debugf(ctx, "Failed to get launch plan model for ExecutionCreateRequest %+v with err %v", request, err) - return nil, err + return nil, nil, err } launchPlan, err := transformers.FromLaunchPlanModel(launchPlanModel) if err != nil { logger.Debugf(ctx, "Failed to transform launch plan model %+v with err %v", launchPlanModel, err) - return nil, err + return nil, nil, err } executionInputs, err := validation.CheckAndFetchInputsForExecution( request.Inputs, @@ -208,12 +216,12 @@ func (m *ExecutionManager) launchExecutionAndPrepareModel( logger.Debugf(ctx, "Failed to CheckAndFetchInputsForExecution with request.Inputs: %+v"+ "fixed inputs: %+v and expected inputs: %+v with err %v", request.Inputs, launchPlan.Spec.FixedInputs, launchPlan.Closure.ExpectedInputs, err) - return nil, err + return nil, nil, err } workflow, err := util.GetWorkflow(ctx, m.db, m.storageClient, *launchPlan.Spec.WorkflowId) if err != nil { logger.Debugf(ctx, "Failed to get workflow with id %+v with err %v", launchPlan.Spec.WorkflowId, err) - return nil, err + return nil, nil, err } name := util.GetExecutionName(request) workflowExecutionID := core.WorkflowExecutionIdentifier{ @@ -221,6 +229,7 @@ func (m *ExecutionManager) launchExecutionAndPrepareModel( Domain: request.Domain, Name: name, } + ctx = getExecutionContext(ctx, &workflowExecutionID) // Get the node execution (if any) that launched this execution var parentNodeExecutionID uint @@ -229,7 +238,7 @@ func (m *ExecutionManager) launchExecutionAndPrepareModel( if err != nil { logger.Errorf(ctx, "Failed to get node execution [%+v] that launched this execution [%+v] with error %v", request.Spec.Metadata.ParentNodeExecution, workflowExecutionID, err) - return nil, err + return nil, nil, err } parentNodeExecutionID = parentNodeExecutionModel.ID @@ -245,11 +254,11 @@ func (m *ExecutionManager) launchExecutionAndPrepareModel( inputsURI, err := m.offloadInputs(ctx, executionInputs, &workflowExecutionID, shared.Inputs) if err != nil { - return nil, err + return nil, nil, err } userInputsURI, err := m.offloadInputs(ctx, request.Inputs, &workflowExecutionID, shared.UserInputs) if err != nil { - return nil, err + return nil, nil, err } // TODO: Reduce CRD size and use offloaded input URI to blob store instead. @@ -262,7 +271,7 @@ func (m *ExecutionManager) launchExecutionAndPrepareModel( } err = m.addLabelsAndAnnotations(request.Spec, &executeWorkflowInputs) if err != nil { - return nil, err + return nil, nil, err } execInfo, err := m.workflowExecutor.ExecuteWorkflow(ctx, executeWorkflowInputs) @@ -270,7 +279,7 @@ func (m *ExecutionManager) launchExecutionAndPrepareModel( m.systemMetrics.PropellerFailures.Inc() logger.Infof(ctx, "Failed to execute workflow %+v with execution id %+v and inputs %+v with err %v", request, workflowExecutionID, executionInputs, err) - return nil, err + return nil, nil, err } executionCreatedAt := time.Now() acceptanceDelay := executionCreatedAt.Sub(requestedAt) @@ -308,9 +317,9 @@ func (m *ExecutionManager) launchExecutionAndPrepareModel( if err != nil { logger.Infof(ctx, "Failed to create execution model in transformer for id: [%+v] with err: %v", workflowExecutionID, err) - return nil, err + return nil, nil, err } - return executionModel, nil + return ctx, executionModel, nil } // Inserts an execution model into the database store and emits platform metrics. @@ -341,7 +350,9 @@ func (m *ExecutionManager) CreateExecution( if request.Inputs == nil || len(request.Inputs.Literals) == 0 { request.Inputs = request.GetSpec().GetInputs() } - executionModel, err := m.launchExecutionAndPrepareModel(ctx, request, requestedAt) + var executionModel *models.Execution + var err error + ctx, executionModel, err = m.launchExecutionAndPrepareModel(ctx, request, requestedAt) if err != nil { return nil, err } @@ -387,7 +398,8 @@ func (m *ExecutionManager) RelaunchExecution( inputs = spec.Inputs } executionSpec.Metadata.Mode = admin.ExecutionMetadata_RELAUNCH - executionModel, err := m.launchExecutionAndPrepareModel(ctx, admin.ExecutionCreateRequest{ + var executionModel *models.Execution + ctx, executionModel, err = m.launchExecutionAndPrepareModel(ctx, admin.ExecutionCreateRequest{ Project: request.Id.Project, Domain: request.Id.Domain, Name: request.Name, @@ -543,6 +555,7 @@ func (m *ExecutionManager) CreateWorkflowEvent(ctx context.Context, request admi logger.Debugf(ctx, "received invalid CreateWorkflowEventRequest [%s]: %v", request.RequestId, err) return nil, err } + ctx = getExecutionContext(ctx, request.Event.ExecutionId) logger.Debugf(ctx, "Received workflow execution event for [%+v] transitioning to phase [%v]", request.Event.ExecutionId, request.Event.Phase) @@ -617,6 +630,7 @@ func (m *ExecutionManager) GetExecution( logger.Debugf(ctx, "GetExecution request [%+v] failed validation with err: %v", request, err) return nil, err } + ctx = getExecutionContext(ctx, request.Id) executionModel, err := util.GetExecutionModel(ctx, m.db, *request.Id) if err != nil { logger.Debugf(ctx, "Failed to get execution model for request [%+v] with err: %v", request, err) @@ -667,6 +681,7 @@ func (m *ExecutionManager) GetExecution( func (m *ExecutionManager) GetExecutionData( ctx context.Context, request admin.WorkflowExecutionGetDataRequest) (*admin.WorkflowExecutionGetDataResponse, error) { + ctx = getExecutionContext(ctx, request.Id) executionModel, err := util.GetExecutionModel(ctx, m.db, *request.Id) if err != nil { logger.Debugf(ctx, "Failed to get execution model for request [%+v] with err: %v", request, err) @@ -718,6 +733,7 @@ func (m *ExecutionManager) ListExecutions( logger.Debugf(ctx, "ListExecutions request [%+v] failed validation with err: %v", request, err) return nil, err } + ctx = contextutils.WithProjectDomain(ctx, request.Id.Project, request.Id.Domain) filters, err := util.GetDbFilters(util.FilterSpec{ Project: request.Id.Project, Domain: request.Id.Domain, @@ -842,6 +858,7 @@ func (m *ExecutionManager) TerminateExecution( logger.Debugf(ctx, "received terminate execution request: %v with invalid identifier: %v", request, err) return nil, err } + ctx = getExecutionContext(ctx, request.Id) // Save the abort reason (best effort) executionModel, err := m.db.ExecutionRepo().Get(ctx, repositoryInterfaces.GetResourceInput{ Project: request.Id.Project, diff --git a/flyteadmin/pkg/manager/impl/launch_plan_manager.go b/flyteadmin/pkg/manager/impl/launch_plan_manager.go index 8752b20cb8..b6e5de18d3 100644 --- a/flyteadmin/pkg/manager/impl/launch_plan_manager.go +++ b/flyteadmin/pkg/manager/impl/launch_plan_manager.go @@ -5,6 +5,8 @@ import ( "context" "strconv" + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flyteadmin/pkg/async/schedule/aws" "github.com/lyft/flytestdlib/promutils" @@ -44,6 +46,16 @@ type LaunchPlanManager struct { metrics launchPlanMetrics } +func getLaunchPlanContext(ctx context.Context, identifier *core.Identifier) context.Context { + ctx = contextutils.WithProjectDomain(ctx, identifier.Project, identifier.Domain) + return contextutils.WithLaunchPlanID(ctx, identifier.Name) +} + +func (m *LaunchPlanManager) getNamedEntityContext(ctx context.Context, identifier *admin.NamedEntityIdentifier) context.Context { + ctx = contextutils.WithProjectDomain(ctx, identifier.Project, identifier.Domain) + return contextutils.WithLaunchPlanID(ctx, identifier.Name) +} + func (m *LaunchPlanManager) CreateLaunchPlan( ctx context.Context, request admin.LaunchPlanCreateRequest) (*admin.LaunchPlanCreateResponse, error) { @@ -71,6 +83,7 @@ func (m *LaunchPlanManager) CreateLaunchPlan( logger.Debugf(ctx, "could not create launch plan: %+v, request failed validation with err: %v", request.Id, err) return nil, err } + ctx = getLaunchPlanContext(ctx, request.Id) launchPlan := transformers.CreateLaunchPlan(request, workflowInterface.Outputs) launchPlanDigest, err := util.GetLaunchPlanDigest(ctx, &launchPlan) if err != nil { @@ -328,6 +341,7 @@ func (m *LaunchPlanManager) UpdateLaunchPlan(ctx context.Context, request admin. if err := validation.ValidateIdentifier(request.Id, common.LaunchPlan); err != nil { logger.Debugf(ctx, "can't update launch plan [%+v] state, invalid identifier: %v", request.Id, err) } + ctx = getLaunchPlanContext(ctx, request.Id) switch request.State { case admin.LaunchPlanState_INACTIVE: return m.disableLaunchPlan(ctx, request) @@ -346,6 +360,7 @@ func (m *LaunchPlanManager) GetLaunchPlan(ctx context.Context, request admin.Obj logger.Debugf(ctx, "can't get launch plan [%+v] with invalid identifier: %v", request.Id, err) return nil, err } + ctx = getLaunchPlanContext(ctx, request.Id) return util.GetLaunchPlan(ctx, m.db, *request.Id) } @@ -355,6 +370,7 @@ func (m *LaunchPlanManager) GetActiveLaunchPlan(ctx context.Context, request adm logger.Debugf(ctx, "can't get active launch plan [%+v] with invalid request: %v", request.Id, err) return nil, err } + ctx = m.getNamedEntityContext(ctx, request.Id) filters, err := util.GetActiveLaunchPlanVersionFilters(request.Id.Project, request.Id.Domain, request.Id.Name) if err != nil { @@ -388,6 +404,7 @@ func (m *LaunchPlanManager) ListLaunchPlans(ctx context.Context, request admin.R logger.Debugf(ctx, "") return nil, err } + ctx = m.getNamedEntityContext(ctx, request.Id) filters, err := util.GetDbFilters(util.FilterSpec{ Project: request.Id.Project, @@ -447,6 +464,7 @@ func (m *LaunchPlanManager) ListActiveLaunchPlans(ctx context.Context, request a logger.Debugf(ctx, "") return nil, err } + ctx = contextutils.WithProjectDomain(ctx, request.Project, request.Domain) filters, err := util.ListActiveLaunchPlanVersionsFilters(request.Project, request.Domain) if err != nil { @@ -496,7 +514,7 @@ func (m *LaunchPlanManager) ListActiveLaunchPlans(ctx context.Context, request a // At least project name and domain must be specified along with limit. func (m *LaunchPlanManager) ListLaunchPlanIds(ctx context.Context, request admin.NamedEntityIdentifierListRequest) ( *admin.NamedEntityIdentifierList, error) { - + ctx = contextutils.WithProjectDomain(ctx, request.Project, request.Domain) filters, err := util.GetDbFilters(util.FilterSpec{ Project: request.Project, Domain: request.Domain, diff --git a/flyteadmin/pkg/manager/impl/named_entity_manager.go b/flyteadmin/pkg/manager/impl/named_entity_manager.go index d060fef0bb..dfec0d59c9 100644 --- a/flyteadmin/pkg/manager/impl/named_entity_manager.go +++ b/flyteadmin/pkg/manager/impl/named_entity_manager.go @@ -4,6 +4,8 @@ import ( "context" "strconv" + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flyteadmin/pkg/common" "github.com/lyft/flyteadmin/pkg/errors" "google.golang.org/grpc/codes" @@ -36,6 +38,7 @@ func (m *NamedEntityManager) UpdateNamedEntity(ctx context.Context, request admi logger.Debugf(ctx, "invalid request [%+v]: %v", request, err) return nil, err } + ctx = contextutils.WithProjectDomain(ctx, request.Id.Project, request.Id.Domain) // Ensure entity exists before trying to update it _, err := util.GetNamedEntity(ctx, m.db, request.ResourceType, *request.Id) @@ -58,6 +61,7 @@ func (m *NamedEntityManager) GetNamedEntity(ctx context.Context, request admin.N logger.Debugf(ctx, "invalid request [%+v]: %v", request, err) return nil, err } + ctx = contextutils.WithProjectDomain(ctx, request.Id.Project, request.Id.Domain) return util.GetNamedEntity(ctx, m.db, request.ResourceType, *request.Id) } @@ -67,6 +71,7 @@ func (m *NamedEntityManager) ListNamedEntities(ctx context.Context, request admi logger.Debugf(ctx, "invalid request [%+v]: %v", request, err) return nil, err } + ctx = contextutils.WithProjectDomain(ctx, request.Project, request.Domain) filters, err := util.GetDbFilters(util.FilterSpec{ Project: request.Project, diff --git a/flyteadmin/pkg/manager/impl/node_execution_manager.go b/flyteadmin/pkg/manager/impl/node_execution_manager.go index 2836575588..a58ed44108 100644 --- a/flyteadmin/pkg/manager/impl/node_execution_manager.go +++ b/flyteadmin/pkg/manager/impl/node_execution_manager.go @@ -4,6 +4,8 @@ import ( "context" "strconv" + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flyteadmin/pkg/manager/impl/shared" "github.com/lyft/flytestdlib/promutils" "github.com/prometheus/client_golang/prometheus" @@ -58,6 +60,12 @@ var isParent = common.NewMapFilter(map[string]interface{}{ shared.ParentTaskExecutionID: nil, }) +func getNodeExecutionContext(ctx context.Context, identifier *core.NodeExecutionIdentifier) context.Context { + ctx = contextutils.WithProjectDomain(ctx, identifier.ExecutionId.Project, identifier.ExecutionId.Domain) + ctx = contextutils.WithExecutionID(ctx, identifier.ExecutionId.Name) + return contextutils.WithNodeID(ctx, identifier.NodeId) +} + func (m *NodeExecutionManager) createNodeExecutionWithEvent( ctx context.Context, request *admin.NodeExecutionEventRequest) error { @@ -149,6 +157,10 @@ func (m *NodeExecutionManager) updateNodeExecutionWithEvent( func (m *NodeExecutionManager) CreateNodeEvent(ctx context.Context, request admin.NodeExecutionEventRequest) ( *admin.NodeExecutionEventResponse, error) { + if err := validation.ValidateNodeExecutionIdentifier(request.Event.Id); err != nil { + logger.Debugf(ctx, "CreateNodeEvent called with invalid identifier [%+v]: %v", request.Event.Id, err) + } + ctx = getNodeExecutionContext(ctx, request.Event.Id) executionID := request.Event.Id.ExecutionId logger.Debugf(ctx, "Received node execution event for Node Exec Id [%+v] transitioning to phase [%v], w/ Metadata [%v]", request.Event.Id, request.Event.Phase, request.Event.ParentTaskMetadata) @@ -208,6 +220,7 @@ func (m *NodeExecutionManager) GetNodeExecution( if err := validation.ValidateNodeExecutionIdentifier(request.Id); err != nil { logger.Debugf(ctx, "get node execution called with invalid identifier [%+v]: %v", request.Id, err) } + ctx = getNodeExecutionContext(ctx, request.Id) nodeExecutionModel, err := util.GetNodeExecutionModel(ctx, m.db, request.Id) if err != nil { logger.Debugf(ctx, "Failed to get node execution with id [%+v] with err %v", @@ -282,6 +295,7 @@ func (m *NodeExecutionManager) ListNodeExecutions( if err := validation.ValidateNodeExecutionListRequest(request); err != nil { return nil, err } + ctx = getExecutionContext(ctx, request.WorkflowExecutionId) identifierFilters, err := util.GetWorkflowExecutionIdentifierFilters(ctx, *request.WorkflowExecutionId) if err != nil { @@ -299,6 +313,7 @@ func (m *NodeExecutionManager) ListNodeExecutionsForTask( if err := validation.ValidateNodeExecutionForTaskListRequest(request); err != nil { return nil, err } + ctx = getTaskExecutionContext(ctx, request.TaskExecutionId) identifierFilters, err := util.GetWorkflowExecutionIdentifierFilters( ctx, *request.TaskExecutionId.NodeExecutionId.ExecutionId) if err != nil { @@ -323,6 +338,7 @@ func (m *NodeExecutionManager) GetNodeExecutionData( if err := validation.ValidateNodeExecutionIdentifier(request.Id); err != nil { logger.Debugf(ctx, "can't get node execution data with invalid identifier [%+v]: %v", request.Id, err) } + ctx = getNodeExecutionContext(ctx, request.Id) nodeExecutionModel, err := util.GetNodeExecutionModel(ctx, m.db, request.Id) if err != nil { logger.Debugf(ctx, "Failed to get node execution with id [%+v] with err %v", diff --git a/flyteadmin/pkg/manager/impl/project_domain_manager.go b/flyteadmin/pkg/manager/impl/project_domain_manager.go index 21c1a5d9e2..f790edb92c 100644 --- a/flyteadmin/pkg/manager/impl/project_domain_manager.go +++ b/flyteadmin/pkg/manager/impl/project_domain_manager.go @@ -3,6 +3,8 @@ package impl import ( "context" + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flyteadmin/pkg/manager/impl/validation" "github.com/lyft/flyteadmin/pkg/repositories/transformers" @@ -23,6 +25,7 @@ func (m *ProjectDomainManager) UpdateProjectDomain( if err := validation.ValidateProjectDomainAttributesUpdateRequest(request); err != nil { return nil, err } + ctx = contextutils.WithProjectDomain(ctx, request.Attributes.Project, request.Attributes.Domain) model, err := transformers.ToProjectDomainModel(*request.Attributes) if err != nil { diff --git a/flyteadmin/pkg/manager/impl/task_execution_manager.go b/flyteadmin/pkg/manager/impl/task_execution_manager.go index df5a8bcb8f..d8d488f03a 100644 --- a/flyteadmin/pkg/manager/impl/task_execution_manager.go +++ b/flyteadmin/pkg/manager/impl/task_execution_manager.go @@ -5,6 +5,8 @@ import ( "fmt" "strconv" + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils" "github.com/prometheus/client_golang/prometheus" @@ -41,6 +43,11 @@ type TaskExecutionManager struct { urlData dataInterfaces.RemoteURLInterface } +func getTaskExecutionContext(ctx context.Context, identifier *core.TaskExecutionIdentifier) context.Context { + ctx = getNodeExecutionContext(ctx, identifier.NodeExecutionId) + return contextutils.WithTaskID(ctx, fmt.Sprintf("%s-%v", identifier.TaskId.Name, identifier.RetryAttempt)) +} + func (m *TaskExecutionManager) createTaskExecution( ctx context.Context, nodeExecutionModel *models.NodeExecution, request *admin.TaskExecutionEventRequest) ( models.TaskExecution, error) { @@ -94,6 +101,7 @@ func (m *TaskExecutionManager) CreateTaskExecutionEvent(ctx context.Context, req NodeExecutionId: nodeExecutionID, RetryAttempt: request.Event.RetryAttempt, } + ctx = getTaskExecutionContext(ctx, &taskExecutionID) logger.Debugf(ctx, "Received task execution event for [%+v] transitioning to phase [%v]", taskExecutionID, request.Event.Phase) nodeExecutionModel, err := util.GetNodeExecutionModel(ctx, m.db, nodeExecutionID) @@ -166,6 +174,12 @@ func (m *TaskExecutionManager) CreateTaskExecutionEvent(ctx context.Context, req func (m *TaskExecutionManager) GetTaskExecution( ctx context.Context, request admin.TaskExecutionGetRequest) (*admin.TaskExecution, error) { + err := validation.ValidateTaskExecutionIdentifier(request.Id) + if err != nil { + logger.Debugf(ctx, "Failed to validate GetTaskExecution [%+v] with err: %v", request.Id, err) + return nil, err + } + ctx = getTaskExecutionContext(ctx, request.Id) taskExecutionModel, err := util.GetTaskExecutionModel(ctx, m.db, request.Id) if err != nil { return nil, err @@ -184,6 +198,7 @@ func (m *TaskExecutionManager) ListTaskExecutions( logger.Debugf(ctx, "ListTaskExecutions request [%+v] is invalid: %v", request, err) return nil, err } + ctx = getNodeExecutionContext(ctx, request.NodeExecutionId) identifierFilters, err := util.GetNodeExecutionIdentifierFilters(ctx, *request.NodeExecutionId) if err != nil { @@ -240,6 +255,7 @@ func (m *TaskExecutionManager) GetTaskExecutionData( if err := validation.ValidateTaskExecutionIdentifier(request.Id); err != nil { logger.Debugf(ctx, "Invalid identifier [%+v]: %v", request.Id, err) } + ctx = getTaskExecutionContext(ctx, request.Id) taskExecution, err := m.GetTaskExecution(ctx, admin.TaskExecutionGetRequest{ Id: request.Id, }) diff --git a/flyteadmin/pkg/manager/impl/task_manager.go b/flyteadmin/pkg/manager/impl/task_manager.go index a9f339db32..9198d1258c 100644 --- a/flyteadmin/pkg/manager/impl/task_manager.go +++ b/flyteadmin/pkg/manager/impl/task_manager.go @@ -6,6 +6,9 @@ import ( "strconv" "time" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytestdlib/contextutils" + "github.com/prometheus/client_golang/prometheus" "github.com/lyft/flytestdlib/promutils" @@ -42,6 +45,11 @@ type TaskManager struct { metrics taskMetrics } +func getTaskContext(ctx context.Context, identifier *core.Identifier) context.Context { + ctx = contextutils.WithProjectDomain(ctx, identifier.Project, identifier.Domain) + return contextutils.WithTaskID(ctx, identifier.Name) +} + func setDefaults(request admin.TaskCreateRequest) (admin.TaskCreateRequest, error) { if request.Id == nil { return request, errors.NewFlyteAdminError(codes.InvalidArgument, @@ -60,6 +68,7 @@ func (t *TaskManager) CreateTask( logger.Debugf(ctx, "Task [%+v] failed validation with err: %v", request.Id, err) return nil, err } + ctx = getTaskContext(ctx, request.Id) finalizedRequest, err := setDefaults(request) if err != nil { return nil, err @@ -119,6 +128,7 @@ func (t *TaskManager) GetTask(ctx context.Context, request admin.ObjectGetReques if err := validation.ValidateIdentifier(request.Id, common.Task); err != nil { logger.Debugf(ctx, "invalid identifier [%+v]: %v", request.Id, err) } + ctx = getTaskContext(ctx, request.Id) task, err := util.GetTask(ctx, t.db, *request.Id) if err != nil { logger.Debugf(ctx, "Failed to get task with id [%+v] with err %v", err, request.Id) @@ -133,6 +143,8 @@ func (t *TaskManager) ListTasks(ctx context.Context, request admin.ResourceListR logger.Debugf(ctx, "Invalid request [%+v]: %v", request, err) return nil, err } + ctx = contextutils.WithProjectDomain(ctx, request.Id.Project, request.Id.Domain) + ctx = contextutils.WithTaskID(ctx, request.Id.Name) spec := util.FilterSpec{ Project: request.Id.Project, Domain: request.Id.Domain, @@ -193,6 +205,7 @@ func (t *TaskManager) ListUniqueTaskIdentifiers(ctx context.Context, request adm logger.Debugf(ctx, "invalid request [%+v]: %v", request, err) return nil, err } + ctx = contextutils.WithProjectDomain(ctx, request.Project, request.Domain) filters, err := util.GetDbFilters(util.FilterSpec{ Project: request.Project, Domain: request.Domain, diff --git a/flyteadmin/pkg/manager/impl/workflow_manager.go b/flyteadmin/pkg/manager/impl/workflow_manager.go index e6db431c53..177b0f080b 100644 --- a/flyteadmin/pkg/manager/impl/workflow_manager.go +++ b/flyteadmin/pkg/manager/impl/workflow_manager.go @@ -6,6 +6,8 @@ import ( "strconv" "time" + "github.com/lyft/flytestdlib/contextutils" + "github.com/golang/protobuf/ptypes" "github.com/lyft/flyteadmin/pkg/common" "github.com/lyft/flyteadmin/pkg/errors" @@ -46,6 +48,11 @@ type WorkflowManager struct { metrics workflowMetrics } +func getWorkflowContext(ctx context.Context, identifier *core.Identifier) context.Context { + ctx = contextutils.WithProjectDomain(ctx, identifier.Project, identifier.Domain) + return contextutils.WithWorkflowID(ctx, identifier.Name) +} + func (w *WorkflowManager) setDefaults(request admin.WorkflowCreateRequest) (admin.WorkflowCreateRequest, error) { // TODO: Also add environment and configuration defaults once those have been determined. if request.Id == nil { @@ -131,6 +138,7 @@ func (w *WorkflowManager) CreateWorkflow( if err := validation.ValidateWorkflow(ctx, request, w.db, w.config.ApplicationConfiguration()); err != nil { return nil, err } + ctx = getWorkflowContext(ctx, request.Id) finalizedRequest, err := w.setDefaults(request) if err != nil { logger.Debugf(ctx, "Failed to set defaults for workflow with id [%+v] with err %v", request.Id, err) @@ -210,6 +218,7 @@ func (w *WorkflowManager) GetWorkflow(ctx context.Context, request admin.ObjectG logger.Debugf(ctx, "invalid identifier [%+v]: %v", request.Id, err) return nil, err } + ctx = getWorkflowContext(ctx, request.Id) workflow, err := util.GetWorkflow(ctx, w.db, w.storageClient, *request.Id) if err != nil { logger.Infof(ctx, "Failed to get workflow with id [%+v] with err %v", request.Id, err) @@ -225,6 +234,8 @@ func (w *WorkflowManager) ListWorkflows( if err := validation.ValidateResourceListRequest(request); err != nil { return nil, err } + ctx = contextutils.WithProjectDomain(ctx, request.Id.Project, request.Id.Domain) + ctx = contextutils.WithWorkflowID(ctx, request.Id.Name) filters, err := util.GetDbFilters(util.FilterSpec{ Project: request.Id.Project, Domain: request.Id.Domain, @@ -279,6 +290,7 @@ func (w *WorkflowManager) ListWorkflowIdentifiers(ctx context.Context, request a logger.Debugf(ctx, "invalid request [%+v]: %v", request, err) return nil, err } + ctx = contextutils.WithProjectDomain(ctx, request.Project, request.Domain) filters, err := util.GetDbFilters(util.FilterSpec{ Project: request.Project,