From 05b7dbf31d0694cb40f3902bb7b0e708e6264787 Mon Sep 17 00:00:00 2001 From: taylanisikdemir Date: Wed, 13 Dec 2023 11:40:58 -0800 Subject: [PATCH] Address map access data race in matching engine (#5477) --- service/matching/matchingEngine.go | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/service/matching/matchingEngine.go b/service/matching/matchingEngine.go index cda6b7dff80..d38c371f336 100644 --- a/service/matching/matchingEngine.go +++ b/service/matching/matchingEngine.go @@ -164,10 +164,10 @@ func (e *matchingEngineImpl) Stop() { } } -func (e *matchingEngineImpl) getTaskLists(maxCount int) (lists []taskListManager) { +func (e *matchingEngineImpl) getTaskLists(maxCount int) []taskListManager { e.taskListsLock.RLock() defer e.taskListsLock.RUnlock() - lists = make([]taskListManager, 0, len(e.taskLists)) + lists := make([]taskListManager, 0, len(e.taskLists)) count := 0 for _, tlMgr := range e.taskLists { lists = append(lists, tlMgr) @@ -176,7 +176,7 @@ func (e *matchingEngineImpl) getTaskLists(maxCount int) (lists []taskListManager break } } - return + return lists } func (e *matchingEngineImpl) String() string { @@ -190,11 +190,7 @@ func (e *matchingEngineImpl) String() string { // Returns taskListManager for a task list. If not already cached gets new range from DB and // if successful creates one. -func (e *matchingEngineImpl) getTaskListManager( - taskList *taskListID, - taskListKind *types.TaskListKind, -) (taskListManager, error) { - +func (e *matchingEngineImpl) getTaskListManager(taskList *taskListID, taskListKind *types.TaskListKind) (taskListManager, error) { // The first check is an optimization so almost all requests will have a task list manager // and return avoiding the write lock e.taskListsLock.RLock() @@ -203,6 +199,7 @@ func (e *matchingEngineImpl) getTaskListManager( return result, nil } e.taskListsLock.RUnlock() + // If it gets here, write lock and check again in case a task list is created between the two locks e.taskListsLock.Lock() if result, ok := e.taskLists[*taskList]; ok { @@ -240,9 +237,7 @@ func (e *matchingEngineImpl) getTaskListManager( return mgr, nil } -func (e *matchingEngineImpl) getTaskListByDomainLocked( - domainID string, -) *types.GetTaskListsByDomainResponse { +func (e *matchingEngineImpl) getTaskListByDomainLocked(domainID string) *types.GetTaskListsByDomainResponse { decisionTaskListMap := make(map[string]*types.DescribeTaskListResponse) activityTaskListMap := make(map[string]*types.DescribeTaskListResponse) for tl, tlm := range e.taskLists { @@ -270,11 +265,12 @@ func (e *matchingEngineImpl) updateTaskList(taskList *taskListID, mgr taskListMa func (e *matchingEngineImpl) removeTaskListManager(tlMgr taskListManager) { id := tlMgr.TaskListID() e.taskListsLock.Lock() + defer e.taskListsLock.Unlock() currentTlMgr, ok := e.taskLists[*id] if ok && tlMgr == currentTlMgr { delete(e.taskLists, *id) } - e.taskListsLock.Unlock() + e.metricsClient.Scope(metrics.MatchingTaskListMgrScope).UpdateGauge( metrics.TaskListManagersGauge, float64(len(e.taskLists)), @@ -886,9 +882,7 @@ func (e *matchingEngineImpl) getAllPartitions( } // Loads a task from persistence and wraps it in a task context -func (e *matchingEngineImpl) getTask( - ctx context.Context, taskList *taskListID, maxDispatchPerSecond *float64, taskListKind *types.TaskListKind, -) (*InternalTask, error) { +func (e *matchingEngineImpl) getTask(ctx context.Context, taskList *taskListID, maxDispatchPerSecond *float64, taskListKind *types.TaskListKind) (*InternalTask, error) { tlMgr, err := e.getTaskListManager(taskList, taskListKind) if err != nil { return nil, err