Skip to content

Commit

Permalink
Address map access data race in matching engine
Browse files Browse the repository at this point in the history
  • Loading branch information
taylanisikdemir committed Dec 13, 2023
1 parent d4f0fd7 commit 92e07d9
Showing 1 changed file with 9 additions and 15 deletions.
24 changes: 9 additions & 15 deletions service/matching/matchingEngine.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,10 @@ func (e *matchingEngineImpl) Stop() {
}
}

func (e *matchingEngineImpl) getTaskLists(maxCount int) (lists []taskListManager) {
func (e *matchingEngineImpl) getTaskLists(maxCount int) []taskListManager {
e.taskListsLock.RLock()
defer e.taskListsLock.RUnlock()
lists = make([]taskListManager, 0, len(e.taskLists))
lists := make([]taskListManager, 0, len(e.taskLists))
count := 0
for _, tlMgr := range e.taskLists {
lists = append(lists, tlMgr)
Expand All @@ -176,7 +176,7 @@ func (e *matchingEngineImpl) getTaskLists(maxCount int) (lists []taskListManager
break
}
}
return
return lists
}

func (e *matchingEngineImpl) String() string {
Expand All @@ -190,11 +190,7 @@ func (e *matchingEngineImpl) String() string {

// Returns taskListManager for a task list. If not already cached gets new range from DB and
// if successful creates one.
func (e *matchingEngineImpl) getTaskListManager(
taskList *taskListID,
taskListKind *types.TaskListKind,
) (taskListManager, error) {

func (e *matchingEngineImpl) getTaskListManager(taskList *taskListID, taskListKind *types.TaskListKind) (taskListManager, error) {
// The first check is an optimization so almost all requests will have a task list manager
// and return avoiding the write lock
e.taskListsLock.RLock()
Expand All @@ -203,6 +199,7 @@ func (e *matchingEngineImpl) getTaskListManager(
return result, nil
}
e.taskListsLock.RUnlock()

// If it gets here, write lock and check again in case a task list is created between the two locks
e.taskListsLock.Lock()
if result, ok := e.taskLists[*taskList]; ok {
Expand Down Expand Up @@ -240,9 +237,7 @@ func (e *matchingEngineImpl) getTaskListManager(
return mgr, nil
}

func (e *matchingEngineImpl) getTaskListByDomainLocked(
domainID string,
) *types.GetTaskListsByDomainResponse {
func (e *matchingEngineImpl) getTaskListByDomainLocked(domainID string) *types.GetTaskListsByDomainResponse {
decisionTaskListMap := make(map[string]*types.DescribeTaskListResponse)
activityTaskListMap := make(map[string]*types.DescribeTaskListResponse)
for tl, tlm := range e.taskLists {
Expand Down Expand Up @@ -270,11 +265,12 @@ func (e *matchingEngineImpl) updateTaskList(taskList *taskListID, mgr taskListMa
func (e *matchingEngineImpl) removeTaskListManager(tlMgr taskListManager) {
id := tlMgr.TaskListID()
e.taskListsLock.Lock()
defer e.taskListsLock.Unlock()
currentTlMgr, ok := e.taskLists[*id]
if ok && tlMgr == currentTlMgr {
delete(e.taskLists, *id)
}
e.taskListsLock.Unlock()

e.metricsClient.Scope(metrics.MatchingTaskListMgrScope).UpdateGauge(
metrics.TaskListManagersGauge,
float64(len(e.taskLists)),
Expand Down Expand Up @@ -886,9 +882,7 @@ func (e *matchingEngineImpl) getAllPartitions(
}

// Loads a task from persistence and wraps it in a task context
func (e *matchingEngineImpl) getTask(
ctx context.Context, taskList *taskListID, maxDispatchPerSecond *float64, taskListKind *types.TaskListKind,
) (*InternalTask, error) {
func (e *matchingEngineImpl) getTask(ctx context.Context, taskList *taskListID, maxDispatchPerSecond *float64, taskListKind *types.TaskListKind) (*InternalTask, error) {
tlMgr, err := e.getTaskListManager(taskList, taskListKind)
if err != nil {
return nil, err
Expand Down

0 comments on commit 92e07d9

Please sign in to comment.