diff --git a/flytestdlib/contextutils/context.go b/flytestdlib/contextutils/context.go index 3b8224a79c..797d9a8089 100644 --- a/flytestdlib/contextutils/context.go +++ b/flytestdlib/contextutils/context.go @@ -1,10 +1,12 @@ -// Contains common flyte context utils. +// Package contextutils contains common flyte context utils. package contextutils import ( "context" "fmt" "runtime/pprof" + + "google.golang.org/grpc/metadata" ) type Key string @@ -25,6 +27,7 @@ const ( LaunchPlanIDKey Key = "lp" ResourceVersionKey Key = "res_ver" SignalIDKey Key = "signal" + RequestIDKey Key = "x-request-id" ) func (k Key) String() string { @@ -43,6 +46,7 @@ var logKeys = []Key{ RoutineLabelKey, LaunchPlanIDKey, ResourceVersionKey, + RequestIDKey, } // MetricKeysFromStrings is a convenience method to convert a slice of strings into a slice of Keys @@ -56,17 +60,17 @@ func MetricKeysFromStrings(keys []string) []Key { return res } -// Gets a new context with the resource version set. +// WithResourceVersion gets a new context with the resource version set. func WithResourceVersion(ctx context.Context, resourceVersion string) context.Context { return context.WithValue(ctx, ResourceVersionKey, resourceVersion) } -// Gets a new context with namespace set. +// WithNamespace gets a new context with namespace set. func WithNamespace(ctx context.Context, namespace string) context.Context { return context.WithValue(ctx, NamespaceKey, namespace) } -// Gets a new context with JobId set. If the existing context already has a job id, the new context will have +// WithJobID gets a new context with JobId set. If the existing context already has a job id, the new context will have // / set as the job id. func WithJobID(ctx context.Context, jobID string) context.Context { existingJobID := ctx.Value(JobIDKey) @@ -77,22 +81,22 @@ func WithJobID(ctx context.Context, jobID string) context.Context { return context.WithValue(ctx, JobIDKey, jobID) } -// Gets a new context with AppName set. +// WithAppName gets a new context with AppName set. func WithAppName(ctx context.Context, appName string) context.Context { return context.WithValue(ctx, AppNameKey, appName) } -// Gets a new context with Phase set. +// WithPhase gets a new context with Phase set. func WithPhase(ctx context.Context, phase string) context.Context { return context.WithValue(ctx, PhaseKey, phase) } -// Gets a new context with ExecutionID set. +// WithExecutionID gets a new context with ExecutionID set. func WithExecutionID(ctx context.Context, execID string) context.Context { return context.WithValue(ctx, ExecIDKey, execID) } -// Gets a new context with NodeID (nested) set. +// WithNodeID gets a new context with NodeID (nested) set. func WithNodeID(ctx context.Context, nodeID string) context.Context { existingNodeID := ctx.Value(NodeIDKey) if existingNodeID != nil { @@ -101,38 +105,44 @@ func WithNodeID(ctx context.Context, nodeID string) context.Context { return context.WithValue(ctx, NodeIDKey, nodeID) } -// Gets a new context with WorkflowName set. +// WithWorkflowID gets a new context with WorkflowName set. func WithWorkflowID(ctx context.Context, workflow string) context.Context { return context.WithValue(ctx, WorkflowIDKey, workflow) } -// Gets a new context with a launch plan ID set. +// WithLaunchPlanID gets a new context with a launch plan ID set. func WithLaunchPlanID(ctx context.Context, launchPlan string) context.Context { return context.WithValue(ctx, LaunchPlanIDKey, launchPlan) } -// Get new context with Project and Domain values set +// WithProjectDomain gets new context with Project and Domain values set func WithProjectDomain(ctx context.Context, project, domain string) context.Context { c := context.WithValue(ctx, ProjectKey, project) return context.WithValue(c, DomainKey, domain) } -// Gets a new context with WorkflowName set. +// WithTaskID gets a new context with WorkflowName set. func WithTaskID(ctx context.Context, taskID string) context.Context { return context.WithValue(ctx, TaskIDKey, taskID) } -// Gets a new context with TaskType set. +// WithTaskType gets a new context with TaskType set. func WithTaskType(ctx context.Context, taskType string) context.Context { return context.WithValue(ctx, TaskTypeKey, taskType) } -// Gets a new context with SignalID set. +// WithSignalID gets a new context with SignalID set. func WithSignalID(ctx context.Context, signalID string) context.Context { return context.WithValue(ctx, SignalIDKey, signalID) } -// Gets a new context with Go Routine label key set and a label assigned to the context using pprof.Labels. +// WithRequestID gets a new context with RequestID set. +func WithRequestID(ctx context.Context, requestID string) context.Context { + return metadata.AppendToOutgoingContext(context.WithValue(ctx, RequestIDKey, requestID), RequestIDKey.String(), requestID) +} + +// WithGoroutineLabel gets a new context with Go Routine label key set and a label assigned to the context using +// pprof.Labels. // You can then call pprof.SetGoroutineLabels(ctx) to annotate the current go-routine and have that show up in // pprof analysis. func WithGoroutineLabel(ctx context.Context, routineLabel string) context.Context { @@ -156,8 +166,8 @@ func addStringFieldWithDefaults(ctx context.Context, m map[string]string, fieldK m[fieldKey.String()] = val.(string) } -// Gets a map of all known logKeys set on the context. logKeys are special and should be used incase, context fields -// are to be added to the log lines. +// GetLogFields gets a map of all known logKeys set on the context. logKeys are special and should be used incase, +// context fields are to be added to the log lines. func GetLogFields(ctx context.Context) map[string]interface{} { res := map[string]interface{}{} for _, k := range logKeys { diff --git a/flytestdlib/contextutils/context_test.go b/flytestdlib/contextutils/context_test.go index f6cb842731..976a2184be 100644 --- a/flytestdlib/contextutils/context_test.go +++ b/flytestdlib/contextutils/context_test.go @@ -111,10 +111,11 @@ func TestWithSignalID(t *testing.T) { func TestGetFields(t *testing.T) { ctx := context.Background() - ctx = WithJobID(WithNamespace(ctx, "ns123"), "job123") + ctx = WithRequestID(WithJobID(WithNamespace(ctx, "ns123"), "job123"), "req123") m := GetLogFields(ctx) assert.Equal(t, "ns123", m[NamespaceKey.String()]) assert.Equal(t, "job123", m[JobIDKey.String()]) + assert.Equal(t, "req123", m[RequestIDKey.String()]) } func TestValues(t *testing.T) {