Skip to content

Commit

Permalink
Record who created & terminated executions (flyteorg#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
katrogan authored Dec 12, 2019
1 parent d1b921d commit 3191386
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 27 deletions.
7 changes: 3 additions & 4 deletions flyteadmin/Gopkg.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions flyteadmin/pkg/auth/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
const (
LoginRedirectURLParameter = "redirect_url"
bearerTokenContextKey contextutils.Key = "bearer"
emailContextKey contextutils.Key = "email"
PrincipalContextKey contextutils.Key = "principal"
)

type HTTPRequestToMetadataAnnotator func(ctx context.Context, request *http.Request) metadata.MD
Expand Down Expand Up @@ -120,7 +120,7 @@ func GetCallbackHandler(ctx context.Context, authContext interfaces.Authenticati
func AuthenticationLoggingInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
// Invoke 'handler' to use your gRPC server implementation and get
// the response.
logger.Debugf(ctx, "gRPC server info in logging interceptor email %s method %s\n", ctx.Value(emailContextKey), info.FullMethod)
logger.Debugf(ctx, "gRPC server info in logging interceptor email %s method %s\n", ctx.Value(PrincipalContextKey), info.FullMethod)
return handler(ctx, req)
}

Expand Down Expand Up @@ -186,7 +186,7 @@ func GetAuthenticationInterceptor(authContext interfaces.AuthenticationContext)
}

func WithUserEmail(ctx context.Context, email string) context.Context {
return context.WithValue(ctx, emailContextKey, email)
return context.WithValue(ctx, PrincipalContextKey, email)
}

// This is effectively middleware for the grpc gateway, it allows us to modify the translation between HTTP request
Expand Down
2 changes: 1 addition & 1 deletion flyteadmin/pkg/auth/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (

func TestWithUserEmail(t *testing.T) {
ctx := WithUserEmail(context.Background(), "abc")
assert.Equal(t, "abc", ctx.Value(emailContextKey))
assert.Equal(t, "abc", ctx.Value(PrincipalContextKey))
}

func TestGetLoginHandler(t *testing.T) {
Expand Down
21 changes: 19 additions & 2 deletions flyteadmin/pkg/manager/impl/execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"strconv"
"time"

"github.com/lyft/flyteadmin/pkg/auth"

"github.com/golang/protobuf/ptypes"
"github.com/golang/protobuf/ptypes/timestamp"
dataInterfaces "github.com/lyft/flyteadmin/pkg/data/interfaces"
Expand Down Expand Up @@ -42,6 +44,7 @@ import (
const parentContainerQueueKey = "parent_queue"
const childContainerQueueKey = "child_queue"
const noSourceExecutionID = 0
const principalContextKeyFormat = "%v"

// Map of [project] -> map of [domain] -> stop watch
type projectDomainScopedStopWatchMap = map[string]map[string]*promutils.StopWatch
Expand Down Expand Up @@ -80,6 +83,15 @@ type ExecutionManager struct {
urlData dataInterfaces.RemoteURLInterface
}

// Returns the unique string which identifies the authenticated end user (if any).
func getUser(ctx context.Context) string {
principalContextUser := ctx.Value(auth.PrincipalContextKey)
if principalContextUser != nil {
return fmt.Sprintf(principalContextKeyFormat, principalContextUser)
}
return ""
}

func (m *ExecutionManager) populateExecutionQueue(
ctx context.Context, identifier core.Identifier, compiledWorkflow *core.CompiledWorkflowClosure) {
queueConfig := m.queueAllocator.GetQueue(ctx, identifier)
Expand Down Expand Up @@ -291,6 +303,7 @@ func (m *ExecutionManager) launchExecutionAndPrepareModel(
Cluster: execInfo.Cluster,
InputsURI: inputsURI,
UserInputsURI: userInputsURI,
Principal: getUser(ctx),
})
if err != nil {
logger.Infof(ctx, "Failed to create execution model in transformer for id: [%+v] with err: %v",
Expand Down Expand Up @@ -554,7 +567,7 @@ func (m *ExecutionManager) CreateWorkflowEvent(ctx context.Context, request admi
return nil, errors.NewAlreadyInTerminalStateError(ctx, errorMsg, curPhase)
}

err = transformers.UpdateExecutionModelState(executionModel, request, nil)
err = transformers.UpdateExecutionModelState(executionModel, request)
if err != nil {
logger.Debugf(ctx, "failed to transform updated workflow execution model [%+v] after receiving event with err: %v",
request.Event.ExecutionId, err)
Expand Down Expand Up @@ -848,7 +861,11 @@ func (m *ExecutionManager) TerminateExecution(
return nil, err
}

executionModel.AbortCause = request.Cause
err = transformers.SetExecutionAborted(&executionModel, request.Cause, getUser(ctx))
if err != nil {
logger.Debugf(ctx, "failed to add abort metadata for execution [%+v] with err: %v", request.Id, err)
return nil, err
}
err = m.db.ExecutionRepo().UpdateExecution(ctx, executionModel)
if err != nil {
logger.Debugf(ctx, "failed to save abort cause for terminated execution: %+v with err: %v", request.Id, err)
Expand Down
29 changes: 27 additions & 2 deletions flyteadmin/pkg/manager/impl/execution_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"errors"
"testing"

"github.com/lyft/flyteadmin/pkg/auth"

"github.com/lyft/flyteadmin/pkg/common"
commonMocks "github.com/lyft/flyteadmin/pkg/common/mocks"

Expand Down Expand Up @@ -192,6 +194,15 @@ func getMockRepositoryForExecTest() repositories.RepositoryInterface {

func TestCreateExecution(t *testing.T) {
repository := getMockRepositoryForExecTest()
principal := "principal"
repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetCreateCallback(
func(ctx context.Context, input models.Execution) error {
var spec admin.ExecutionSpec
err := proto.Unmarshal(input.Spec, &spec)
assert.NoError(t, err)
assert.Equal(t, principal, spec.Metadata.Principal)
return nil
})
setDefaultLpCallbackForExecTest(repository)
mockExecutor := workflowengineMocks.NewMockExecutor()
mockExecutor.(*workflowengineMocks.MockExecutor).SetExecuteWorkflowCallback(
Expand All @@ -212,7 +223,11 @@ func TestCreateExecution(t *testing.T) {
repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockExecutor,
mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL)
request := testutils.GetExecutionRequest()
response, err := execManager.CreateExecution(context.Background(), request, requestedAt)
request.Spec.Metadata = &admin.ExecutionMetadata{
Principal: "unused - populated from authenticated context",
}
ctx := context.WithValue(context.Background(), auth.PrincipalContextKey, principal)
response, err := execManager.CreateExecution(ctx, request, requestedAt)
assert.Nil(t, err)

expectedResponse := &admin.ExecutionCreateResponse{
Expand Down Expand Up @@ -1791,6 +1806,7 @@ func TestTerminateExecution(t *testing.T) {
repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback(executionGetFunc)

abortCause := "abort cause"
principal := "principal"
updateExecutionFunc := func(
context context.Context, execution models.Execution) error {
assert.Equal(t, "project", execution.Project)
Expand All @@ -1805,6 +1821,14 @@ func TestTerminateExecution(t *testing.T) {
"an abort call should not change ExecutionUpdatedAt until a corresponding execution event is received")
assert.Equal(t, abortCause, execution.AbortCause)
assert.Equal(t, testCluster, execution.Cluster)

var unmarshaledClosure admin.ExecutionClosure
err := proto.Unmarshal(execution.Closure, &unmarshaledClosure)
assert.NoError(t, err)
assert.True(t, proto.Equal(&admin.AbortMetadata{
Cause: abortCause,
Principal: principal,
}, unmarshaledClosure.GetAbortMetadata()))
return nil
}
repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetUpdateExecutionCallback(updateExecutionFunc)
Expand All @@ -1823,7 +1847,8 @@ func TestTerminateExecution(t *testing.T) {
repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockExecutor,
mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL)

resp, err := execManager.TerminateExecution(context.Background(), admin.ExecutionTerminateRequest{
ctx := context.WithValue(context.Background(), auth.PrincipalContextKey, principal)
resp, err := execManager.TerminateExecution(ctx, admin.ExecutionTerminateRequest{
Id: &core.WorkflowExecutionIdentifier{
Project: "project",
Domain: "domain",
Expand Down
46 changes: 39 additions & 7 deletions flyteadmin/pkg/repositories/transformers/execution.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,19 @@ type CreateExecutionModelInput struct {
Cluster string
InputsURI storage.DataReference
UserInputsURI storage.DataReference
Principal string
}

// Transforms a ExecutionCreateRequest to a Execution model
func CreateExecutionModel(input CreateExecutionModelInput) (*models.Execution, error) {
spec, err := proto.Marshal(input.RequestSpec)
requestSpec := input.RequestSpec
if len(input.Principal) > 0 {
if requestSpec.Metadata == nil {
requestSpec.Metadata = &admin.ExecutionMetadata{}
}
requestSpec.Metadata.Principal = input.Principal
}
spec, err := proto.Marshal(requestSpec)
if err != nil {
return nil, errors.NewFlyteAdminErrorf(codes.Internal, "Failed to serialize execution spec: %v", err)
}
Expand Down Expand Up @@ -86,7 +94,7 @@ func CreateExecutionModel(input CreateExecutionModelInput) (*models.Execution, e

// Updates an existing model given a WorkflowExecution event.
func UpdateExecutionModelState(
execution *models.Execution, request admin.WorkflowExecutionEventRequest, abortCause *string) error {
execution *models.Execution, request admin.WorkflowExecutionEventRequest) error {
var executionClosure admin.ExecutionClosure
err := proto.Unmarshal(execution.Closure, &executionClosure)
if err != nil {
Expand Down Expand Up @@ -136,9 +144,29 @@ func UpdateExecutionModelState(
return errors.NewFlyteAdminErrorf(codes.Internal, "Failed to marshal execution closure: %v", err)
}
execution.Closure = marshaledClosure
if abortCause != nil {
execution.AbortCause = *abortCause
return nil
}

// The execution abort metadata is recorded but the phase is not actually updated *until* the abort event is propagated
// by flytepropeller. The metadata is preemptively saved at the time of the abort.
func SetExecutionAborted(execution *models.Execution, cause, principal string) error {
var closure admin.ExecutionClosure
err := proto.Unmarshal(execution.Closure, &closure)
if err != nil {
return errors.NewFlyteAdminErrorf(codes.Internal, "Failed to unmarshal execution closure: %v", err)
}
closure.OutputResult = &admin.ExecutionClosure_AbortMetadata{
AbortMetadata: &admin.AbortMetadata{
Cause: cause,
Principal: principal,
},
}
marshaledClosure, err := proto.Marshal(&closure)
if err != nil {
return errors.NewFlyteAdminErrorf(codes.Internal, "Failed to marshal execution closure: %v", err)
}
execution.Closure = marshaledClosure
execution.AbortCause = cause
return nil
}

Expand All @@ -162,9 +190,13 @@ func FromExecutionModel(executionModel models.Execution) (*admin.Execution, erro
return nil, errors.NewFlyteAdminErrorf(codes.Internal, "failed to unmarshal closure")
}
id := GetExecutionIdentifier(&executionModel)
if executionModel.Phase == core.WorkflowExecution_ABORTED.String() {
closure.OutputResult = &admin.ExecutionClosure_AbortCause{
AbortCause: executionModel.AbortCause,
if executionModel.Phase == core.WorkflowExecution_ABORTED.String() && closure.GetAbortMetadata() == nil {
// In the case of data predating the AbortMetadata field we manually set it in the closure only
// if it does not yet exist.
closure.OutputResult = &admin.ExecutionClosure_AbortMetadata{
AbortMetadata: &admin.AbortMetadata{
Cause: executionModel.AbortCause,
},
}
}

Expand Down
Loading

0 comments on commit 3191386

Please sign in to comment.