Skip to content

Commit

Permalink
MQTT streaming: SourceQueue backpressure (#1577)
Browse files Browse the repository at this point in the history
* mqtt streaming - introduce a mechanism to use SourceQueue backpressure

When offering commands to the internal queue, a backpressure strategy is
now used to ensure that elements are not dropped. Instead, we wait for
consumption downstream before continuing.

We want WatchedActorTerminatedException to complete stream processing as before, but not with a failure. A session shutting down isn’t in itself a failure.
  • Loading branch information
longshorej authored and ennru committed Jul 1, 2019
1 parent 60d201a commit ed5770c
Show file tree
Hide file tree
Showing 8 changed files with 356 additions and 107 deletions.
3 changes: 3 additions & 0 deletions mqtt-streaming/src/main/mima-filters/1.0.2.backwards.excludes
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# PR #1577
# https://github.com/akka/alpakka/pull/1577
ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.alpakka.mqtt.streaming.MqttSessionSettings.this")
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ object MqttSessionSettings {
* Configuration settings for client and server usage.
*/
final class MqttSessionSettings private (val maxPacketSize: Int = 4096,
val clientSendBufferSize: Int = 10,
val clientTerminationWatcherBufferSize: Int = 100,
val commandParallelism: Int = 50,
val eventParallelism: Int = 10,
Expand All @@ -52,6 +53,13 @@ final class MqttSessionSettings private (val maxPacketSize: Int = 4096,

import akka.util.JavaDurationConverters._

/**
* Just for clients - the number of commands that can be buffered while connected to a server. Defaults
* to 10. Any commands received beyond this will apply backpressure.
*/
def withClientSendBufferSize(clientSendBufferSize: Int): MqttSessionSettings =
copy(clientSendBufferSize = clientSendBufferSize)

/**
* The maximum size of a packet that is allowed to be decoded. Defaults to 4k.
*/
Expand Down Expand Up @@ -224,12 +232,13 @@ final class MqttSessionSettings private (val maxPacketSize: Int = 4096,

/**
* Just for servers - the number of commands that can be buffered while connected to a client. Defaults
* to 100. Any commands received beyond this will be dropped.
* to 100. Any commands received beyond this will apply backpressure.
*/
def withServerSendBufferSize(serverSendBufferSize: Int): MqttSessionSettings =
copy(serverSendBufferSize = serverSendBufferSize)

private def copy(maxPacketSize: Int = maxPacketSize,
clientSendBufferSize: Int = clientSendBufferSize,
clientTerminationWatcherBufferSize: Int = clientTerminationWatcherBufferSize,
commandParallelism: Int = commandParallelism,
eventParallelism: Int = eventParallelism,
Expand All @@ -245,6 +254,7 @@ final class MqttSessionSettings private (val maxPacketSize: Int = 4096,
serverSendBufferSize: Int = serverSendBufferSize) =
new MqttSessionSettings(
maxPacketSize,
clientSendBufferSize,
clientTerminationWatcherBufferSize,
commandParallelism,
eventParallelism,
Expand All @@ -261,5 +271,21 @@ final class MqttSessionSettings private (val maxPacketSize: Int = 4096,
)

override def toString: String =
s"MqttSessionSettings(maxPacketSize=$maxPacketSize,clientTerminationWatcherBufferSize=$clientTerminationWatcherBufferSize,commandParallelism=$commandParallelism,eventParallelism=$eventParallelism,receiveConnectTimeout=$receiveConnectTimeout,receiveConnAckTimeout=$receiveConnAckTimeout,receivePubAckRecTimeout=$producerPubAckRecTimeout,receivePubCompTimeout=$producerPubCompTimeout,receivePubAckRecTimeout=$consumerPubAckRecTimeout,receivePubCompTimeout=$consumerPubCompTimeout,receivePubRelTimeout=$consumerPubRelTimeout,receiveSubAckTimeout=$receiveSubAckTimeout,receiveUnsubAckTimeout=$receiveUnsubAckTimeout,serverSendBufferSize=$serverSendBufferSize)"
"MqttSessionSettings(" +
s"maxPacketSize=$maxPacketSize," +
s"clientSendBufferSize=$clientSendBufferSize," +
s"clientTerminationWatcherBufferSize=$clientTerminationWatcherBufferSize," +
s"commandParallelism=$commandParallelism," +
s"eventParallelism=$eventParallelism," +
s"receiveConnectTimeout=${receiveConnectTimeout.toCoarsest}," +
s"receiveConnAckTimeout=${receiveConnAckTimeout.toCoarsest}," +
s"receivePubAckRecTimeout=${producerPubAckRecTimeout.toCoarsest}," +
s"receivePubCompTimeout=${producerPubCompTimeout.toCoarsest}," +
s"receivePubAckRecTimeout=${consumerPubAckRecTimeout.toCoarsest}," +
s"receivePubCompTimeout=${consumerPubCompTimeout.toCoarsest}," +
s"receivePubRelTimeout=${consumerPubRelTimeout.toCoarsest}," +
s"receiveSubAckTimeout=${receiveSubAckTimeout.toCoarsest}," +
s"receiveUnsubAckTimeout=${receiveUnsubAckTimeout.toCoarsest}," +
s"serverSendBufferSize=$serverSendBufferSize" +
")"
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@ package akka.stream.alpakka.mqtt.streaming
package impl

import akka.NotUsed
import akka.actor.typed.{ActorRef, Behavior, ChildFailed, PostStop, Terminated}
import akka.actor.typed._
import akka.actor.typed.scaladsl.{ActorContext, Behaviors}
import akka.annotation.InternalApi
import akka.stream.{Materializer, OverflowStrategy}
import akka.stream.{Materializer, OverflowStrategy, QueueOfferResult}
import akka.stream.scaladsl.{BroadcastHub, Keep, Source, SourceQueueWithComplete}
import akka.util.ByteString

import scala.concurrent.Promise
import scala.concurrent.duration.FiniteDuration
import scala.util.control.NoStackTrace
import scala.util.{Failure, Success}
import scala.util.{Either, Failure, Success}

/*
* A client connector is a Finite State Machine that manages MQTT client
Expand Down Expand Up @@ -154,6 +154,8 @@ import scala.util.{Failure, Success}
settings
)

final case class WaitingForQueueOfferResult(nextBehavior: Behavior[Event], stash: Seq[Event])

sealed abstract class Event(val connectionId: ByteString)

final case class ConnectReceivedLocally(override val connectionId: ByteString,
Expand Down Expand Up @@ -207,6 +209,11 @@ import scala.util.{Failure, Success}
remote: Promise[Unsubscriber.ForwardUnsubscribe])
extends Event(connectionId)

final case class QueueOfferCompleted(override val connectionId: ByteString,
result: Either[Throwable, QueueOfferResult])
extends Event(connectionId)
with QueueOfferState.QueueOfferCompleted

sealed abstract class Command
sealed abstract class ForwardConnectCommand
case object ForwardConnect extends ForwardConnectCommand
Expand All @@ -226,18 +233,22 @@ import scala.util.{Failure, Success}
Behaviors
.receivePartial[Event] {
case (context, ConnectReceivedLocally(connectionId, connect, connectData, remote)) =>
import context.executionContext

val (queue, source) = Source
.queue[ForwardConnectCommand](1, OverflowStrategy.dropHead)
.queue[ForwardConnectCommand](data.settings.clientSendBufferSize, OverflowStrategy.backpressure)
.toMat(BroadcastHub.sink)(Keep.both)
.run()

remote.success(source)

queue.offer(ForwardConnect)
data.stash.foreach(context.self.tell)
queue
.offer(ForwardConnect)
.onComplete(result => context.self.tell(QueueOfferCompleted(connectionId, result.toEither)))

if (connect.connectFlags.contains(ConnectFlags.CleanSession)) {
val nextState = if (connect.connectFlags.contains(ConnectFlags.CleanSession)) {
context.children.foreach(context.stop)

serverConnect(
ConnectReceived(
connectionId,
Expand Down Expand Up @@ -281,6 +292,9 @@ import scala.util.{Failure, Success}
)

}

QueueOfferState.waitForQueueOfferCompleted(nextState, stash = data.stash)

case (_, ConnectionLost(_)) =>
Behavior.same
case (_, e) =>
Expand Down Expand Up @@ -391,30 +405,37 @@ import scala.util.{Failure, Success}
if connectionId != data.connectionId =>
context.self ! connect
disconnect(context, data.remote, data)

case (_, event) if event.connectionId.nonEmpty && event.connectionId != data.connectionId =>
Behaviors.same

case (context, ConnectionLost(_)) =>
timer.cancel(SendPingreq)
disconnect(context, data.remote, data)

case (context, DisconnectReceivedLocally(_, remote)) =>
remote.success(ForwardDisconnect)
timer.cancel(SendPingreq)
disconnect(context, data.remote, data)

case (context, SubscribeReceivedLocally(_, _, subscribeData, remote)) =>
context.watch(
context.spawnAnonymous(Subscriber(subscribeData, remote, data.subscriberPacketRouter, data.settings))
)
serverConnected(data)

case (context, UnsubscribeReceivedLocally(_, _, unsubscribeData, remote)) =>
context.watch(
context
.spawnAnonymous(Unsubscriber(unsubscribeData, remote, data.unsubscriberPacketRouter, data.settings))
)
serverConnected(data)

case (_, PublishReceivedFromRemote(_, publish, local))
if (publish.flags & ControlPacketFlags.QoSReserved).underlying == 0 =>
local.success(Consumer.ForwardPublish)
serverConnected(data, resetPingReqTimer = false)

case (context,
prfr @ PublishReceivedFromRemote(_, publish @ Publish(_, topicName, Some(packetId), _), local)) =>
data.activeConsumers.get(topicName) match {
Expand All @@ -424,17 +445,21 @@ import scala.util.{Failure, Success}
context.spawn(Consumer(publish, None, packetId, local, data.consumerPacketRouter, data.settings),
consumerName)
context.watchWith(consumer, ConsumerFree(publish.topicName))

serverConnected(data.copy(activeConsumers = data.activeConsumers + (publish.topicName -> consumer)),
resetPingReqTimer = false)

case Some(consumer) if publish.flags.contains(ControlPacketFlags.DUP) =>
consumer ! Consumer.DupPublishReceivedFromRemote(local)
serverConnected(data, resetPingReqTimer = false)

case Some(_) =>
serverConnected(
data.copy(pendingRemotePublications = data.pendingRemotePublications :+ (publish.topicName -> prfr)),
resetPingReqTimer = false
)
}

case (context, ConsumerFree(topicName)) =>
val i = data.pendingRemotePublications.indexWhere(_._1 == topicName)
if (i >= 0) {
Expand Down Expand Up @@ -463,10 +488,20 @@ import scala.util.{Failure, Success}
} else {
serverConnected(data.copy(activeConsumers = data.activeConsumers - topicName))
}
case (_, PublishReceivedLocally(publish, _))

case (context, PublishReceivedLocally(publish, _))
if (publish.flags & ControlPacketFlags.QoSReserved).underlying == 0 =>
data.remote.offer(ForwardPublish(publish, None))
serverConnected(data)
import context.executionContext

data.remote
.offer(ForwardPublish(publish, None))
.onComplete(result => context.self.tell(QueueOfferCompleted(ByteString.empty, result.toEither)))

QueueOfferState.waitForQueueOfferCompleted(
serverConnected(data),
stash = Seq.empty
)

case (context, prl @ PublishReceivedLocally(publish, publishData)) =>
val producerName = ActorName.mkName(ProducerNamePrefix + publish.topicName + "-" + context.children.size)
if (!data.activeProducers.contains(publish.topicName)) {
Expand All @@ -489,6 +524,7 @@ import scala.util.{Failure, Success}
data.copy(pendingLocalPublications = data.pendingLocalPublications :+ (publish.topicName -> prl))
)
}

case (context, ProducerFree(topicName)) =>
val i = data.pendingLocalPublications.indexWhere(_._1 == topicName)
if (i >= 0) {
Expand Down Expand Up @@ -518,19 +554,48 @@ import scala.util.{Failure, Success}
} else {
serverConnected(data.copy(activeProducers = data.activeProducers - topicName))
}
case (_, ReceivedProducerPublishingCommand(Producer.ForwardPublish(publish, packetId))) =>
data.remote.offer(ForwardPublish(publish, packetId))
Behaviors.same
case (_, ReceivedProducerPublishingCommand(Producer.ForwardPubRel(_, packetId))) =>
data.remote.offer(ForwardPubRel(packetId))
Behaviors.same

case (context, ReceivedProducerPublishingCommand(Producer.ForwardPublish(publish, packetId))) =>
import context.executionContext

data.remote
.offer(ForwardPublish(publish, packetId))
.onComplete(result => context.self.tell(QueueOfferCompleted(ByteString.empty, result.toEither)))

QueueOfferState.waitForQueueOfferCompleted(
serverConnected(data, resetPingReqTimer = false),
stash = Seq.empty
)

case (context, ReceivedProducerPublishingCommand(Producer.ForwardPubRel(_, packetId))) =>
import context.executionContext

data.remote
.offer(ForwardPubRel(packetId))
.onComplete(result => context.self.tell(QueueOfferCompleted(ByteString.empty, result.toEither)))

QueueOfferState.waitForQueueOfferCompleted(
serverConnected(data, resetPingReqTimer = false),
stash = Seq.empty
)

case (context, SendPingReqTimeout(_)) if data.pendingPingResp =>
data.remote.fail(PingFailed)
timer.cancel(SendPingreq)
disconnect(context, data.remote, data)
case (_, SendPingReqTimeout(_)) =>
data.remote.offer(ForwardPingReq)
serverConnected(data.copy(pendingPingResp = true))

case (context, SendPingReqTimeout(_)) =>
import context.executionContext

data.remote
.offer(ForwardPingReq)
.onComplete(result => context.self.tell(QueueOfferCompleted(ByteString.empty, result.toEither)))

QueueOfferState.waitForQueueOfferCompleted(
serverConnected(data.copy(pendingPingResp = true)),
stash = Seq.empty
)

case (_, PingRespReceivedFromRemote(_, local)) =>
local.success(ForwardPingResp)
serverConnected(data.copy(pendingPingResp = false))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright (C) 2016-2019 Lightbend Inc. <http://www.lightbend.com>
*/

package akka.stream.alpakka.mqtt.streaming.impl

import akka.actor.typed.Behavior
import akka.actor.typed.scaladsl.Behaviors
import akka.stream.QueueOfferResult

private[mqtt] object QueueOfferState {

/**
* A marker trait that holds a result for SourceQueue#offer
*/
trait QueueOfferCompleted {
def result: Either[Throwable, QueueOfferResult]
}

/**
* A behavior that stashes messages until a response to the SourceQueue#offer
* method is received.
*
* This is to be used only with SourceQueues that use backpressure.
*/
def waitForQueueOfferCompleted[T](behavior: Behavior[T], stash: Seq[T]): Behavior[T] =
Behaviors
.receive[T] {
case (context, completed: QueueOfferCompleted) =>
completed.result match {
case Right(QueueOfferResult.Enqueued) =>
stash.foreach(context.self.tell)

behavior

case Right(other) =>
throw new IllegalStateException(s"Failed to offer to queue: $other")

case Left(failure) =>
throw failure
}

case (_, other) =>
waitForQueueOfferCompleted(behavior, stash = stash :+ other)
}
.orElse(behavior) // handle signals immediately
}
Loading

0 comments on commit ed5770c

Please sign in to comment.