Skip to content

Commit

Permalink
fix: Concurrent access to trigger connection maps (#2814)
Browse files Browse the repository at this point in the history
Signed-off-by: gokulav137 <[email protected]>
Signed-off-by: gokulav137 <[email protected]>
  • Loading branch information
gokulav137 authored Sep 27, 2023
1 parent 5b8c754 commit 334097f
Show file tree
Hide file tree
Showing 13 changed files with 109 additions and 72 deletions.
35 changes: 35 additions & 0 deletions common/string_keyed_map.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package common

import "sync"

// Concurrent Safe String keyed map
type StringKeyedMap[T any] struct {
items map[string]T
lock *sync.RWMutex
}

func NewStringKeyedMap[T any]() StringKeyedMap[T] {
return StringKeyedMap[T]{
items: make(map[string]T, 0),
lock: &sync.RWMutex{},
}
}

func (sm *StringKeyedMap[T]) Store(key string, item T) {
sm.lock.Lock()
defer sm.lock.Unlock()
sm.items[key] = item
}

func (sm *StringKeyedMap[T]) Load(key string) (T, bool) {
sm.lock.RLock()
defer sm.lock.RUnlock()
ok, item := sm.items[key]
return ok, item
}

func (sm *StringKeyedMap[T]) Delete(key string) {
sm.lock.Lock()
defer sm.lock.Unlock()
delete(sm.items, key)
}
37 changes: 19 additions & 18 deletions sensors/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"k8s.io/client-go/dynamic"
"k8s.io/client-go/kubernetes"

"github.com/argoproj/argo-events/common"
sensormetrics "github.com/argoproj/argo-events/metrics"
eventbusv1alpha1 "github.com/argoproj/argo-events/pkg/apis/eventbus/v1alpha1"
"github.com/argoproj/argo-events/pkg/apis/sensor/v1alpha1"
Expand All @@ -51,25 +52,25 @@ type SensorContext struct {
hostname string

// httpClients holds the reference to HTTP clients for HTTP triggers.
httpClients map[string]*http.Client
httpClients common.StringKeyedMap[*http.Client]
// customTriggerClients holds the references to the gRPC clients for the custom trigger servers
customTriggerClients map[string]*grpc.ClientConn
customTriggerClients common.StringKeyedMap[*grpc.ClientConn]
// http client to send slack messages.
slackHTTPClient *http.Client
// kafkaProducers holds references to the active kafka producers
kafkaProducers map[string]sarama.AsyncProducer
kafkaProducers common.StringKeyedMap[sarama.AsyncProducer]
// pulsarProducers holds references to the active pulsar producers
pulsarProducers map[string]pulsar.Producer
pulsarProducers common.StringKeyedMap[pulsar.Producer]
// natsConnections holds the references to the active nats connections.
natsConnections map[string]*natslib.Conn
natsConnections common.StringKeyedMap[*natslib.Conn]
// awsLambdaClients holds the references to active AWS Lambda clients.
awsLambdaClients map[string]*lambda.Lambda
awsLambdaClients common.StringKeyedMap[*lambda.Lambda]
// openwhiskClients holds the references to active OpenWhisk clients.
openwhiskClients map[string]*whisk.Client
openwhiskClients common.StringKeyedMap[*whisk.Client]
// azureEventHubsClients holds the references to active Azure Event Hub clients.
azureEventHubsClients map[string]*eventhubs.Hub
azureEventHubsClients common.StringKeyedMap[*eventhubs.Hub]
// azureServiceBusClients holds the references to active Azure Service Bus clients.
azureServiceBusClients map[string]*servicebus.Sender
azureServiceBusClients common.StringKeyedMap[*servicebus.Sender]
metrics *sensormetrics.Metrics
}

Expand All @@ -82,18 +83,18 @@ func NewSensorContext(kubeClient kubernetes.Interface, dynamicClient dynamic.Int
eventBusConfig: eventBusConfig,
eventBusSubject: eventBusSubject,
hostname: hostname,
httpClients: make(map[string]*http.Client),
customTriggerClients: make(map[string]*grpc.ClientConn),
httpClients: common.NewStringKeyedMap[*http.Client](),
customTriggerClients: common.NewStringKeyedMap[*grpc.ClientConn](),
slackHTTPClient: &http.Client{
Timeout: time.Minute * 5,
},
kafkaProducers: make(map[string]sarama.AsyncProducer),
pulsarProducers: make(map[string]pulsar.Producer),
natsConnections: make(map[string]*natslib.Conn),
awsLambdaClients: make(map[string]*lambda.Lambda),
openwhiskClients: make(map[string]*whisk.Client),
azureEventHubsClients: make(map[string]*eventhubs.Hub),
azureServiceBusClients: make(map[string]*servicebus.Sender),
kafkaProducers: common.NewStringKeyedMap[sarama.AsyncProducer](),
pulsarProducers: common.NewStringKeyedMap[pulsar.Producer](),
natsConnections: common.NewStringKeyedMap[*natslib.Conn](),
awsLambdaClients: common.NewStringKeyedMap[*lambda.Lambda](),
openwhiskClients: common.NewStringKeyedMap[*whisk.Client](),
azureEventHubsClients: common.NewStringKeyedMap[*eventhubs.Hub](),
azureServiceBusClients: common.NewStringKeyedMap[*servicebus.Sender](),
metrics: metrics,
}
}
6 changes: 3 additions & 3 deletions sensors/triggers/apache-openwhisk/apache-openwhisk.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ type TriggerImpl struct {
}

// NewTriggerImpl returns a new TriggerImpl
func NewTriggerImpl(sensor *v1alpha1.Sensor, trigger *v1alpha1.Trigger, openWhiskClients map[string]*whisk.Client, logger *zap.SugaredLogger) (*TriggerImpl, error) {
func NewTriggerImpl(sensor *v1alpha1.Sensor, trigger *v1alpha1.Trigger, openWhiskClients common.StringKeyedMap[*whisk.Client], logger *zap.SugaredLogger) (*TriggerImpl, error) {
openwhisktrigger := trigger.Template.OpenWhisk

client, ok := openWhiskClients[trigger.Template.Name]
client, ok := openWhiskClients.Load(trigger.Template.Name)
if !ok {
logger.Debugw("OpenWhisk trigger value", zap.Any("name", trigger.Template.Name), zap.Any("trigger", *trigger.Template.OpenWhisk))
logger.Infow("instantiating OpenWhisk client", zap.Any("trigger-name", trigger.Template.Name))
Expand Down Expand Up @@ -82,7 +82,7 @@ func NewTriggerImpl(sensor *v1alpha1.Sensor, trigger *v1alpha1.Trigger, openWhis
return nil, fmt.Errorf("failed to instantiate OpenWhisk client, %w", err)
}

openWhiskClients[trigger.Template.Name] = client
openWhiskClients.Store(trigger.Template.Name, client)
}

return &TriggerImpl{
Expand Down
7 changes: 4 additions & 3 deletions sensors/triggers/aws-lambda/aws-lambda.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/aws/aws-sdk-go/service/lambda"
"go.uber.org/zap"

"github.com/argoproj/argo-events/common"
"github.com/argoproj/argo-events/common/logging"
commonaws "github.com/argoproj/argo-events/eventsources/common/aws"
apicommon "github.com/argoproj/argo-events/pkg/apis/common"
Expand All @@ -45,17 +46,17 @@ type AWSLambdaTrigger struct {
}

// NewAWSLambdaTrigger returns a new AWS Lambda context
func NewAWSLambdaTrigger(lambdaClients map[string]*lambda.Lambda, sensor *v1alpha1.Sensor, trigger *v1alpha1.Trigger, logger *zap.SugaredLogger) (*AWSLambdaTrigger, error) {
func NewAWSLambdaTrigger(lambdaClients common.StringKeyedMap[*lambda.Lambda], sensor *v1alpha1.Sensor, trigger *v1alpha1.Trigger, logger *zap.SugaredLogger) (*AWSLambdaTrigger, error) {
lambdatrigger := trigger.Template.AWSLambda

lambdaClient, ok := lambdaClients[trigger.Template.Name]
lambdaClient, ok := lambdaClients.Load(trigger.Template.Name)
if !ok {
awsSession, err := commonaws.CreateAWSSessionWithCredsInVolume(lambdatrigger.Region, lambdatrigger.RoleARN, lambdatrigger.AccessKey, lambdatrigger.SecretKey, nil)
if err != nil {
return nil, fmt.Errorf("failed to create a AWS session, %w", err)
}
lambdaClient = lambda.New(awsSession, &aws.Config{Region: &lambdatrigger.Region})
lambdaClients[trigger.Template.Name] = lambdaClient
lambdaClients.Store(trigger.Template.Name, lambdaClient)
}

return &AWSLambdaTrigger{
Expand Down
6 changes: 3 additions & 3 deletions sensors/triggers/azure-event-hubs/azure_event_hubs.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ type AzureEventHubsTrigger struct {
}

// NewAzureEventHubsTrigger returns a new azure event hubs context.
func NewAzureEventHubsTrigger(sensor *v1alpha1.Sensor, trigger *v1alpha1.Trigger, azureEventHubsClient map[string]*eventhub.Hub, logger *zap.SugaredLogger) (*AzureEventHubsTrigger, error) {
func NewAzureEventHubsTrigger(sensor *v1alpha1.Sensor, trigger *v1alpha1.Trigger, azureEventHubsClient common.StringKeyedMap[*eventhub.Hub], logger *zap.SugaredLogger) (*AzureEventHubsTrigger, error) {
azureEventHubsTrigger := trigger.Template.AzureEventHubs

hub, ok := azureEventHubsClient[trigger.Template.Name]
hub, ok := azureEventHubsClient.Load(trigger.Template.Name)

if !ok {
// form event hubs connection string in the ff format:
Expand All @@ -72,7 +72,7 @@ func NewAzureEventHubsTrigger(sensor *v1alpha1.Sensor, trigger *v1alpha1.Trigger
return nil, err
}

azureEventHubsClient[trigger.Template.Name] = hub
azureEventHubsClient.Store(trigger.Template.Name, hub)
}

return &AzureEventHubsTrigger{
Expand Down
6 changes: 3 additions & 3 deletions sensors/triggers/azure-service-bus/azure_service_bus.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ type AzureServiceBusTrigger struct {
Logger *zap.SugaredLogger
}

func NewAzureServiceBusTrigger(sensor *v1alpha1.Sensor, trigger *v1alpha1.Trigger, azureServiceBusClients map[string]*servicebus.Sender, logger *zap.SugaredLogger) (*AzureServiceBusTrigger, error) {
func NewAzureServiceBusTrigger(sensor *v1alpha1.Sensor, trigger *v1alpha1.Trigger, azureServiceBusClients common.StringKeyedMap[*servicebus.Sender], logger *zap.SugaredLogger) (*AzureServiceBusTrigger, error) {
triggerLogger := logger.With(logging.LabelTriggerType, apicommon.AzureServiceBusTrigger)
azureServiceBusTrigger := trigger.Template.AzureServiceBus

sender, ok := azureServiceBusClients[trigger.Template.Name]
sender, ok := azureServiceBusClients.Load(trigger.Template.Name)

if !ok {
connStr, err := common.GetSecretFromVolume(azureServiceBusTrigger.ConnectionString)
Expand Down Expand Up @@ -91,7 +91,7 @@ func NewAzureServiceBusTrigger(sensor *v1alpha1.Sensor, trigger *v1alpha1.Trigge
return nil, err
}

azureServiceBusClients[trigger.Template.Name] = sender
azureServiceBusClients.Store(trigger.Template.Name, sender)
}

return &AzureServiceBusTrigger{
Expand Down
8 changes: 4 additions & 4 deletions sensors/triggers/custom-trigger/custom-trigger.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ type CustomTrigger struct {
}

// NewCustomTrigger returns a new custom trigger
func NewCustomTrigger(sensor *v1alpha1.Sensor, trigger *v1alpha1.Trigger, logger *zap.SugaredLogger, customTriggerClients map[string]*grpc.ClientConn) (*CustomTrigger, error) {
func NewCustomTrigger(sensor *v1alpha1.Sensor, trigger *v1alpha1.Trigger, logger *zap.SugaredLogger, customTriggerClients common.StringKeyedMap[*grpc.ClientConn]) (*CustomTrigger, error) {
customTrigger := &CustomTrigger{
Sensor: sensor,
Trigger: trigger,
Expand All @@ -57,15 +57,15 @@ func NewCustomTrigger(sensor *v1alpha1.Sensor, trigger *v1alpha1.Trigger, logger

ct := trigger.Template.CustomTrigger

if conn, ok := customTriggerClients[trigger.Template.Name]; ok {
if conn, ok := customTriggerClients.Load(trigger.Template.Name); ok {
if conn.GetState() == connectivity.Ready {
logger.Info("trigger client connection is ready...")
customTrigger.triggerClient = triggers.NewTriggerClient(conn)
return customTrigger, nil
}

logger.Info("trigger client connection is closed, creating new one...")
delete(customTriggerClients, trigger.Template.Name)
customTriggerClients.Delete(trigger.Template.Name)
}

logger.Infow("instantiating trigger client...", zap.Any("server-url", ct.ServerURL))
Expand Down Expand Up @@ -117,7 +117,7 @@ func NewCustomTrigger(sensor *v1alpha1.Sensor, trigger *v1alpha1.Trigger, logger
}

customTrigger.triggerClient = triggers.NewTriggerClient(conn)
customTriggerClients[trigger.Template.Name] = conn
customTriggerClients.Store(trigger.Template.Name, conn)

logger.Info("successfully setup the trigger client...")
return customTrigger, nil
Expand Down
6 changes: 3 additions & 3 deletions sensors/triggers/http/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ type HTTPTrigger struct {
}

// NewHTTPTrigger returns a new HTTP trigger
func NewHTTPTrigger(httpClients map[string]*http.Client, sensor *v1alpha1.Sensor, trigger *v1alpha1.Trigger, logger *zap.SugaredLogger) (*HTTPTrigger, error) {
func NewHTTPTrigger(httpClients common.StringKeyedMap[*http.Client], sensor *v1alpha1.Sensor, trigger *v1alpha1.Trigger, logger *zap.SugaredLogger) (*HTTPTrigger, error) {
httptrigger := trigger.Template.HTTP

client, ok := httpClients[trigger.Template.Name]
client, ok := httpClients.Load(trigger.Template.Name)
if !ok {
client = &http.Client{}

Expand All @@ -69,7 +69,7 @@ func NewHTTPTrigger(httpClients map[string]*http.Client, sensor *v1alpha1.Sensor
}
client.Timeout = timeout

httpClients[trigger.Template.Name] = client
httpClients.Store(trigger.Template.Name, client)
}

return &HTTPTrigger{
Expand Down
6 changes: 3 additions & 3 deletions sensors/triggers/kafka/kafka.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ type KafkaTrigger struct {
}

// NewKafkaTrigger returns a new kafka trigger context.
func NewKafkaTrigger(sensor *v1alpha1.Sensor, trigger *v1alpha1.Trigger, kafkaProducers map[string]sarama.AsyncProducer, logger *zap.SugaredLogger) (*KafkaTrigger, error) {
func NewKafkaTrigger(sensor *v1alpha1.Sensor, trigger *v1alpha1.Trigger, kafkaProducers common.StringKeyedMap[sarama.AsyncProducer], logger *zap.SugaredLogger) (*KafkaTrigger, error) {
kafkatrigger := trigger.Template.Kafka
triggerLogger := logger.With(logging.LabelTriggerType, apicommon.KafkaTrigger)

producer, ok := kafkaProducers[trigger.Template.Name]
producer, ok := kafkaProducers.Load(trigger.Template.Name)
var schema *srclient.Schema

if !ok {
Expand Down Expand Up @@ -133,7 +133,7 @@ func NewKafkaTrigger(sensor *v1alpha1.Sensor, trigger *v1alpha1.Trigger, kafkaPr
}
}()

kafkaProducers[trigger.Template.Name] = producer
kafkaProducers.Store(trigger.Template.Name, producer)
}

if kafkatrigger.SchemaRegistry != nil {
Expand Down
26 changes: 13 additions & 13 deletions sensors/triggers/kafka/kafka_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/stretchr/testify/assert"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"github.com/argoproj/argo-events/common"
"github.com/argoproj/argo-events/common/logging"
apicommon "github.com/argoproj/argo-events/pkg/apis/common"
"github.com/argoproj/argo-events/pkg/apis/sensor/v1alpha1"
Expand Down Expand Up @@ -60,15 +61,14 @@ var sensorObj = &v1alpha1.Sensor{
},
}

func getFakeKafkaTrigger(producers map[string]sarama.AsyncProducer) (*KafkaTrigger, error) {
func getFakeKafkaTrigger(producers common.StringKeyedMap[sarama.AsyncProducer]) (*KafkaTrigger, error) {
return NewKafkaTrigger(sensorObj.DeepCopy(), sensorObj.Spec.Triggers[0].DeepCopy(), producers, logging.NewArgoEventsLogger())
}

func TestNewKafkaTrigger(t *testing.T) {
producer := mocks.NewAsyncProducer(t, nil)
producers := map[string]sarama.AsyncProducer{
"fake-trigger": producer,
}
producers := common.NewStringKeyedMap[sarama.AsyncProducer]()
producers.Store("fake-trigger", producer)
trigger, err := NewKafkaTrigger(sensorObj.DeepCopy(), sensorObj.Spec.Triggers[0].DeepCopy(), producers, logging.NewArgoEventsLogger())

assert.Nil(t, err)
Expand All @@ -78,9 +78,9 @@ func TestNewKafkaTrigger(t *testing.T) {

func TestKafkaTrigger_FetchResource(t *testing.T) {
producer := mocks.NewAsyncProducer(t, nil)
trigger, err := getFakeKafkaTrigger(map[string]sarama.AsyncProducer{
"fake-trigger": producer,
})
producers := common.NewStringKeyedMap[sarama.AsyncProducer]()
producers.Store("fake-trigger", producer)
trigger, err := getFakeKafkaTrigger(producers)
assert.Nil(t, err)
obj, err := trigger.FetchResource(context.TODO())
assert.Nil(t, err)
Expand All @@ -92,9 +92,9 @@ func TestKafkaTrigger_FetchResource(t *testing.T) {

func TestKafkaTrigger_ApplyResourceParameters(t *testing.T) {
producer := mocks.NewAsyncProducer(t, nil)
trigger, err := getFakeKafkaTrigger(map[string]sarama.AsyncProducer{
"fake-trigger": producer,
})
producers := common.NewStringKeyedMap[sarama.AsyncProducer]()
producers.Store("fake-trigger", producer)
trigger, err := getFakeKafkaTrigger(producers)
assert.Nil(t, err)

testEvents := map[string]*v1alpha1.Event{
Expand Down Expand Up @@ -136,9 +136,9 @@ func TestKafkaTrigger_ApplyResourceParameters(t *testing.T) {

func TestKafkaTrigger_Execute(t *testing.T) {
producer := mocks.NewAsyncProducer(t, nil)
trigger, err := getFakeKafkaTrigger(map[string]sarama.AsyncProducer{
"fake-trigger": producer,
})
producers := common.NewStringKeyedMap[sarama.AsyncProducer]()
producers.Store("fake-trigger", producer)
trigger, err := getFakeKafkaTrigger(producers)
assert.Nil(t, err)
testEvents := map[string]*v1alpha1.Event{
"fake-dependency": {
Expand Down
6 changes: 3 additions & 3 deletions sensors/triggers/nats/nats.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ type NATSTrigger struct {
}

// NewNATSTrigger returns new nats trigger.
func NewNATSTrigger(sensor *v1alpha1.Sensor, trigger *v1alpha1.Trigger, natsConnections map[string]*natslib.Conn, logger *zap.SugaredLogger) (*NATSTrigger, error) {
func NewNATSTrigger(sensor *v1alpha1.Sensor, trigger *v1alpha1.Trigger, natsConnections common.StringKeyedMap[*natslib.Conn], logger *zap.SugaredLogger) (*NATSTrigger, error) {
natstrigger := trigger.Template.NATS

conn, ok := natsConnections[trigger.Template.Name]
conn, ok := natsConnections.Load(trigger.Template.Name)
if !ok {
var err error
opts := natslib.GetDefaultOptions()
Expand All @@ -67,7 +67,7 @@ func NewNATSTrigger(sensor *v1alpha1.Sensor, trigger *v1alpha1.Trigger, natsConn
return nil, err
}

natsConnections[trigger.Template.Name] = conn
natsConnections.Store(trigger.Template.Name, conn)
}

return &NATSTrigger{
Expand Down
6 changes: 3 additions & 3 deletions sensors/triggers/pulsar/pulsar.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ type PulsarTrigger struct {
}

// NewPulsarTrigger returns a new Pulsar trigger context.
func NewPulsarTrigger(sensor *v1alpha1.Sensor, trigger *v1alpha1.Trigger, pulsarProducers map[string]pulsar.Producer, logger *zap.SugaredLogger) (*PulsarTrigger, error) {
func NewPulsarTrigger(sensor *v1alpha1.Sensor, trigger *v1alpha1.Trigger, pulsarProducers common.StringKeyedMap[pulsar.Producer], logger *zap.SugaredLogger) (*PulsarTrigger, error) {
pulsarTrigger := trigger.Template.Pulsar

producer, ok := pulsarProducers[trigger.Template.Name]
producer, ok := pulsarProducers.Load(trigger.Template.Name)
if !ok {
var err error
tlsTrustCertsFilePath := ""
Expand Down Expand Up @@ -124,7 +124,7 @@ func NewPulsarTrigger(sensor *v1alpha1.Sensor, trigger *v1alpha1.Trigger, pulsar
return nil, err
}

pulsarProducers[trigger.Template.Name] = producer
pulsarProducers.Store(trigger.Template.Name, producer)
}

return &PulsarTrigger{
Expand Down
Loading

0 comments on commit 334097f

Please sign in to comment.