diff --git a/service/history/historyEngine.go b/service/history/historyEngine.go index d0e4fbba3cc..d13716df4cb 100644 --- a/service/history/historyEngine.go +++ b/service/history/historyEngine.go @@ -249,40 +249,44 @@ func NewEngineWithShardContext( ) } historyEngImpl.decisionHandler = newDecisionHandler(historyEngImpl) - - nDCHistoryResender := xdc.NewNDCHistoryResender( - shard.GetDomainCache(), - shard.GetService().GetClientBean().GetRemoteAdminClient(currentClusterName), - func(ctx context.Context, request *h.ReplicateEventsV2Request) error { - return shard.GetService().GetHistoryClient().ReplicateEventsV2(ctx, request) - }, - shard.GetService().GetPayloadSerializer(), - nil, - shard.GetLogger(), - ) - historyRereplicator := xdc.NewHistoryRereplicator( - currentClusterName, - shard.GetDomainCache(), - shard.GetService().GetClientBean().GetRemoteAdminClient(currentClusterName), - func(ctx context.Context, request *h.ReplicateRawEventsRequest) error { - return shard.GetService().GetHistoryClient().ReplicateRawEvents(ctx, request) - }, - shard.GetService().GetPayloadSerializer(), - replicationTimeout, - nil, - shard.GetLogger(), - ) - replicationTaskExecutor := replication.NewTaskExecutor( - currentClusterName, - shard.GetDomainCache(), - nDCHistoryResender, - historyRereplicator, - historyEngImpl, - shard.GetMetricsClient(), - shard.GetLogger(), - ) var replicationTaskProcessors []replication.TaskProcessor + replicationTaskExecutors := make(map[string]replication.TaskExecutor) for _, replicationTaskFetcher := range replicationTaskFetchers.GetFetchers() { + sourceCluster := replicationTaskFetcher.GetSourceCluster() + nDCHistoryResender := xdc.NewNDCHistoryResender( + shard.GetDomainCache(), + shard.GetService().GetClientBean().GetRemoteAdminClient(sourceCluster), + func(ctx context.Context, request *h.ReplicateEventsV2Request) error { + return shard.GetService().GetHistoryClient().ReplicateEventsV2(ctx, request) + }, + shard.GetService().GetPayloadSerializer(), + nil, + shard.GetLogger(), + ) + historyRereplicator := xdc.NewHistoryRereplicator( + currentClusterName, + shard.GetDomainCache(), + shard.GetService().GetClientBean().GetRemoteAdminClient(sourceCluster), + func(ctx context.Context, request *h.ReplicateRawEventsRequest) error { + return shard.GetService().GetHistoryClient().ReplicateRawEvents(ctx, request) + }, + shard.GetService().GetPayloadSerializer(), + replicationTimeout, + nil, + shard.GetLogger(), + ) + replicationTaskExecutor := replication.NewTaskExecutor( + sourceCluster, + shard, + shard.GetDomainCache(), + nDCHistoryResender, + historyRereplicator, + historyEngImpl, + shard.GetMetricsClient(), + shard.GetLogger(), + ) + replicationTaskExecutors[sourceCluster] = replicationTaskExecutor + replicationTaskProcessor := replication.NewTaskProcessor( shard, historyEngImpl, @@ -294,7 +298,7 @@ func NewEngineWithShardContext( replicationTaskProcessors = append(replicationTaskProcessors, replicationTaskProcessor) } historyEngImpl.replicationTaskProcessors = replicationTaskProcessors - replicationMessageHandler := replication.NewDLQHandler(shard, replicationTaskExecutor) + replicationMessageHandler := replication.NewDLQHandler(shard, replicationTaskExecutors) historyEngImpl.replicationDLQHandler = replicationMessageHandler shard.SetEngine(historyEngImpl) diff --git a/service/history/replication/dlq_handler.go b/service/history/replication/dlq_handler.go index c5f7ca8fa23..04090fbb404 100644 --- a/service/history/replication/dlq_handler.go +++ b/service/history/replication/dlq_handler.go @@ -26,6 +26,7 @@ import ( "context" "github.com/uber/cadence/.gen/go/replicator" + workflow "github.com/uber/cadence/.gen/go/shared" "github.com/uber/cadence/common" "github.com/uber/cadence/common/log" "github.com/uber/cadence/common/log/tag" @@ -33,6 +34,10 @@ import ( "github.com/uber/cadence/service/history/shard" ) +var ( + errInvalidCluster = &workflow.BadRequestError{Message: "Invalid target cluster name."} +) + type ( // DLQHandler is the interface handles replication DLQ messages DLQHandler interface { @@ -57,9 +62,9 @@ type ( } dlqHandlerImpl struct { - taskExecutor TaskExecutor - shard shard.Context - logger log.Logger + taskExecutors map[string]TaskExecutor + shard shard.Context + logger log.Logger } ) @@ -68,13 +73,17 @@ var _ DLQHandler = (*dlqHandlerImpl)(nil) // NewDLQHandler initialize the replication message DLQ handler func NewDLQHandler( shard shard.Context, - taskExecutor TaskExecutor, + taskExecutors map[string]TaskExecutor, ) DLQHandler { + if taskExecutors == nil { + panic("Failed to initialize replication DLQ handler due to nil task executors") + } + return &dlqHandlerImpl{ - shard: shard, - taskExecutor: taskExecutor, - logger: shard.GetLogger(), + shard: shard, + taskExecutors: taskExecutors, + logger: shard.GetLogger(), } } @@ -184,6 +193,10 @@ func (r *dlqHandlerImpl) MergeMessages( pageToken []byte, ) ([]byte, error) { + if _, ok := r.taskExecutors[sourceCluster]; !ok { + return nil, errInvalidCluster + } + tasks, ackLevel, token, err := r.readMessagesWithAckLevel( ctx, sourceCluster, @@ -193,8 +206,7 @@ func (r *dlqHandlerImpl) MergeMessages( ) for _, task := range tasks { - if _, err := r.taskExecutor.execute( - sourceCluster, + if _, err := r.taskExecutors[sourceCluster].execute( task, true, ); err != nil { diff --git a/service/history/replication/dlq_handler_test.go b/service/history/replication/dlq_handler_test.go index 4406a2b75b4..b5806eefa5e 100644 --- a/service/history/replication/dlq_handler_test.go +++ b/service/history/replication/dlq_handler_test.go @@ -55,6 +55,8 @@ type ( executionManager *mocks.ExecutionManager shardManager *mocks.ShardManager taskExecutor *MockTaskExecutor + taskExecutors map[string]TaskExecutor + sourceCluster string messageHandler *dlqHandlerImpl } @@ -96,11 +98,14 @@ func (s *dlqHandlerSuite) SetupTest() { s.shardManager = s.mockShard.Resource.ShardMgr s.clusterMetadata.EXPECT().GetCurrentClusterName().Return("active").AnyTimes() + s.taskExecutors = make(map[string]TaskExecutor) s.taskExecutor = NewMockTaskExecutor(s.controller) + s.sourceCluster = "test" + s.taskExecutors[s.sourceCluster] = s.taskExecutor s.messageHandler = NewDLQHandler( s.mockShard, - s.taskExecutor, + s.taskExecutors, ).(*dlqHandlerImpl) } @@ -111,7 +116,6 @@ func (s *dlqHandlerSuite) TearDownTest() { func (s *dlqHandlerSuite) TestReadMessages_OK() { ctx := context.Background() - sourceCluster := "test" lastMessageID := int64(1) pageSize := 1 pageToken := []byte{} @@ -128,7 +132,7 @@ func (s *dlqHandlerSuite) TestReadMessages_OK() { }, } s.executionManager.On("GetReplicationTasksFromDLQ", &persistence.GetReplicationTasksFromDLQRequest{ - SourceClusterName: sourceCluster, + SourceClusterName: s.sourceCluster, GetReplicationTasksRequest: persistence.GetReplicationTasksRequest{ ReadLevel: -1, MaxReadLevel: lastMessageID, @@ -137,11 +141,11 @@ func (s *dlqHandlerSuite) TestReadMessages_OK() { }, }).Return(resp, nil).Times(1) - s.mockClientBean.EXPECT().GetRemoteAdminClient(sourceCluster).Return(s.adminClient).AnyTimes() + s.mockClientBean.EXPECT().GetRemoteAdminClient(s.sourceCluster).Return(s.adminClient).AnyTimes() s.adminClient.EXPECT(). GetDLQReplicationMessages(ctx, gomock.Any()). Return(&replicator.GetDLQReplicationMessagesResponse{}, nil) - tasks, token, err := s.messageHandler.ReadMessages(ctx, sourceCluster, lastMessageID, pageSize, pageToken) + tasks, token, err := s.messageHandler.ReadMessages(ctx, s.sourceCluster, lastMessageID, pageSize, pageToken) s.NoError(err) s.Nil(token) s.Nil(tasks) @@ -165,7 +169,6 @@ func (s *dlqHandlerSuite) TestPurgeMessages_OK() { func (s *dlqHandlerSuite) TestMergeMessages_OK() { ctx := context.Background() - sourceCluster := "test" lastMessageID := int64(1) pageSize := 1 pageToken := []byte{} @@ -182,7 +185,7 @@ func (s *dlqHandlerSuite) TestMergeMessages_OK() { }, } s.executionManager.On("GetReplicationTasksFromDLQ", &persistence.GetReplicationTasksFromDLQRequest{ - SourceClusterName: sourceCluster, + SourceClusterName: s.sourceCluster, GetReplicationTasksRequest: persistence.GetReplicationTasksRequest{ ReadLevel: -1, MaxReadLevel: lastMessageID, @@ -191,7 +194,7 @@ func (s *dlqHandlerSuite) TestMergeMessages_OK() { }, }).Return(resp, nil).Times(1) - s.mockClientBean.EXPECT().GetRemoteAdminClient(sourceCluster).Return(s.adminClient).AnyTimes() + s.mockClientBean.EXPECT().GetRemoteAdminClient(s.sourceCluster).Return(s.adminClient).AnyTimes() replicationTask := &replicator.ReplicationTask{ TaskType: replicator.ReplicationTaskTypeHistory.Ptr(), SourceTaskId: common.Int64Ptr(lastMessageID), @@ -204,17 +207,17 @@ func (s *dlqHandlerSuite) TestMergeMessages_OK() { replicationTask, }, }, nil) - s.taskExecutor.EXPECT().execute(sourceCluster, replicationTask, true).Return(0, nil).Times(1) + s.taskExecutor.EXPECT().execute(replicationTask, true).Return(0, nil).Times(1) s.executionManager.On("RangeDeleteReplicationTaskFromDLQ", &persistence.RangeDeleteReplicationTaskFromDLQRequest{ - SourceClusterName: sourceCluster, + SourceClusterName: s.sourceCluster, ExclusiveBeginTaskID: -1, InclusiveEndTaskID: lastMessageID, }).Return(nil).Times(1) s.shardManager.On("UpdateShard", mock.Anything).Return(nil) - token, err := s.messageHandler.MergeMessages(ctx, sourceCluster, lastMessageID, pageSize, pageToken) + token, err := s.messageHandler.MergeMessages(ctx, s.sourceCluster, lastMessageID, pageSize, pageToken) s.NoError(err) s.Nil(token) } diff --git a/service/history/replication/task_executor.go b/service/history/replication/task_executor.go index 0bfd805d5d6..0054ce093bf 100644 --- a/service/history/replication/task_executor.go +++ b/service/history/replication/task_executor.go @@ -35,16 +35,19 @@ import ( "github.com/uber/cadence/common/metrics" "github.com/uber/cadence/common/xdc" "github.com/uber/cadence/service/history/engine" + "github.com/uber/cadence/service/history/shard" ) type ( // TaskExecutor is the executor for replication task TaskExecutor interface { - execute(sourceCluster string, replicationTask *r.ReplicationTask, forceApply bool) (int, error) + execute(replicationTask *r.ReplicationTask, forceApply bool) (int, error) } taskExecutorImpl struct { currentCluster string + sourceCluster string + shard shard.Context domainCache cache.DomainCache nDCHistoryResender xdc.NDCHistoryResender historyRereplicator xdc.HistoryRereplicator @@ -60,7 +63,8 @@ var _ TaskExecutor = (*taskExecutorImpl)(nil) // NewTaskExecutor creates an replication task executor // The executor uses by 1) DLQ replication task handler 2) history replication task processor func NewTaskExecutor( - currentCluster string, + sourceCluster string, + shard shard.Context, domainCache cache.DomainCache, nDCHistoryResender xdc.NDCHistoryResender, historyRereplicator xdc.HistoryRereplicator, @@ -69,7 +73,9 @@ func NewTaskExecutor( logger log.Logger, ) TaskExecutor { return &taskExecutorImpl{ - currentCluster: currentCluster, + currentCluster: shard.GetClusterMetadata().GetCurrentClusterName(), + sourceCluster: sourceCluster, + shard: shard, domainCache: domainCache, nDCHistoryResender: nDCHistoryResender, historyRereplicator: historyRereplicator, @@ -80,7 +86,6 @@ func NewTaskExecutor( } func (e *taskExecutorImpl) execute( - sourceCluster string, replicationTask *r.ReplicationTask, forceApply bool, ) (int, error) { @@ -96,7 +101,7 @@ func (e *taskExecutorImpl) execute( err = e.handleActivityTask(replicationTask, forceApply) case r.ReplicationTaskTypeHistory: scope = metrics.HistoryReplicationTaskScope - err = e.handleHistoryReplicationTask(sourceCluster, replicationTask, forceApply) + err = e.handleHistoryReplicationTask(replicationTask, forceApply) case r.ReplicationTaskTypeHistoryMetadata: // Without kafka we should not have size limits so we don't necessary need this in the new replication scheme. scope = metrics.HistoryMetadataReplicationTaskScope @@ -195,7 +200,6 @@ func (e *taskExecutorImpl) handleActivityTask( //TODO: remove this part after 2DC deprecation func (e *taskExecutorImpl) handleHistoryReplicationTask( - sourceCluster string, task *r.ReplicationTask, forceApply bool, ) error { @@ -207,7 +211,7 @@ func (e *taskExecutorImpl) handleHistoryReplicationTask( } request := &history.ReplicateEventsRequest{ - SourceCluster: common.StringPtr(sourceCluster), + SourceCluster: common.StringPtr(e.sourceCluster), DomainUUID: attr.DomainId, WorkflowExecution: &shared.WorkflowExecution{ WorkflowId: attr.WorkflowId, diff --git a/service/history/replication/task_executor_mock.go b/service/history/replication/task_executor_mock.go index 4ad357e0fbc..6a8172bff1a 100644 --- a/service/history/replication/task_executor_mock.go +++ b/service/history/replication/task_executor_mock.go @@ -59,16 +59,16 @@ func (m *MockTaskExecutor) EXPECT() *MockTaskExecutorMockRecorder { } // execute mocks base method -func (m *MockTaskExecutor) execute(sourceCluster string, replicationTask *replicator.ReplicationTask, forceApply bool) (int, error) { +func (m *MockTaskExecutor) execute(replicationTask *replicator.ReplicationTask, forceApply bool) (int, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "execute", sourceCluster, replicationTask, forceApply) + ret := m.ctrl.Call(m, "execute", replicationTask, forceApply) ret0, _ := ret[0].(int) ret1, _ := ret[1].(error) return ret0, ret1 } // execute indicates an expected call of execute -func (mr *MockTaskExecutorMockRecorder) execute(sourceCluster, replicationTask, forceApply interface{}) *gomock.Call { +func (mr *MockTaskExecutorMockRecorder) execute(replicationTask, forceApply interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "execute", reflect.TypeOf((*MockTaskExecutor)(nil).execute), sourceCluster, replicationTask, forceApply) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "execute", reflect.TypeOf((*MockTaskExecutor)(nil).execute), replicationTask, forceApply) } diff --git a/service/history/replication/task_executor_test.go b/service/history/replication/task_executor_test.go index 36f7287b526..c39bdef2916 100644 --- a/service/history/replication/task_executor_test.go +++ b/service/history/replication/task_executor_test.go @@ -118,6 +118,7 @@ func (s *taskExecutorSuite) SetupTest() { s.taskHandler = NewTaskExecutor( s.currentCluster, + s.mockShard, s.mockDomainCache, s.nDCHistoryResender, s.historyRereplicator, @@ -166,7 +167,7 @@ func (s *taskExecutorSuite) TestFilterTask() { &persistence.DomainReplicationConfig{ Clusters: []*persistence.ClusterReplicationConfig{ { - ClusterName: "test", + ClusterName: "active", }, }}, 0, @@ -213,7 +214,7 @@ func (s *taskExecutorSuite) TestProcessTaskOnce_SyncActivityReplicationTask() { } s.mockEngine.EXPECT().SyncActivity(gomock.Any(), request).Return(nil).Times(1) - _, err := s.taskHandler.execute(s.currentCluster, task, true) + _, err := s.taskHandler.execute(task, true) s.NoError(err) } @@ -240,7 +241,7 @@ func (s *taskExecutorSuite) TestProcessTaskOnce_HistoryReplicationTask() { } s.mockEngine.EXPECT().ReplicateEvents(gomock.Any(), request).Return(nil).Times(1) - _, err := s.taskHandler.execute(s.currentCluster, task, true) + _, err := s.taskHandler.execute(task, true) s.NoError(err) } @@ -265,6 +266,6 @@ func (s *taskExecutorSuite) TestProcess_HistoryV2ReplicationTask() { } s.mockEngine.EXPECT().ReplicateEventsV2(gomock.Any(), request).Return(nil).Times(1) - _, err := s.taskHandler.execute(s.currentCluster, task, true) + _, err := s.taskHandler.execute(task, true) s.NoError(err) } diff --git a/service/history/replication/task_processor.go b/service/history/replication/task_processor.go index c515f4cd6bf..931d5d2c922 100644 --- a/service/history/replication/task_processor.go +++ b/service/history/replication/task_processor.go @@ -377,7 +377,6 @@ func (p *taskProcessorImpl) processSingleTask(replicationTask *r.ReplicationTask func (p *taskProcessorImpl) processTaskOnce(replicationTask *r.ReplicationTask) error { startTime := time.Now() scope, err := p.taskExecutor.execute( - p.sourceCluster, replicationTask, false)