diff --git a/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/PubSub.scala b/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/PubSub.scala index 241ebddb..5db1b674 100644 --- a/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/PubSub.scala +++ b/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/PubSub.scala @@ -56,7 +56,7 @@ object PubSub { val (acquire, release) = acquireAndRelease[F, K, V](client, codec) // One exclusive connection for subscriptions and another connection for publishing / stats for { - state <- Resource.eval(PubSubState.make[F, K, V]) + state <- Resource.eval(PubSubState.make[F, K, V](shards = None)) sConn <- Resource.make(acquire)(release) pConn <- Resource.make(acquire)(release) } yield new LivePubSubCommands[F, K, V](state, sConn, pConn) @@ -84,7 +84,7 @@ object PubSub { ): Resource[F, SubscribeCommands[F, Stream[F, *], K, V]] = { val (acquire, release) = acquireAndRelease[F, K, V](client, codec) for { - state <- Resource.eval(PubSubState.make[F, K, V]) + state <- Resource.eval(PubSubState.make[F, K, V](shards = None)) conn <- Resource.make(acquire)(release) } yield new Subscriber(state, conn) } diff --git a/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/PubSubState.scala b/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/PubSubState.scala index 06f884ca..0044a782 100644 --- a/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/PubSubState.scala +++ b/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/PubSubState.scala @@ -16,21 +16,214 @@ package dev.profunktor.redis4cats.pubsub.internals +import cats.{ Applicative, Monad } import cats.syntax.all._ -import cats.effect.kernel.Concurrent -import cats.effect.std.AtomicCell +import cats.effect.kernel.{ Concurrent, Deferred, MonadCancelThrow, Ref, Resource } +import cats.effect.std.{ AtomicCell, MapRef } import dev.profunktor.redis4cats.data.{ RedisChannel, RedisPattern, RedisPatternEvent } +import fs2.Stream +import fs2.concurrent.Topic /** We use `AtomicCell` instead of `Ref` because we need locking while side-effecting. */ -case class PubSubState[F[_], K, V]( - channelSubs: AtomicCell[F, Map[RedisChannel[K], Redis4CatsSubscription[F, V]]], - patternSubs: AtomicCell[F, Map[RedisPattern[K], Redis4CatsSubscription[F, RedisPatternEvent[K, V]]]] -) -object PubSubState { - def make[F[_]: Concurrent, K, V]: F[PubSubState[F, K, V]] = +private[pubsub] trait PubSubState[F[_], K, V] { + def channelSubs: PubSubState.SubscriptionMap[F, RedisChannel[K], V] + def patternSubs: PubSubState.SubscriptionMap[F, RedisPattern[K], RedisPatternEvent[K, V]] +} + +private[pubsub] object PubSubState { + + trait SubscriptionMap[F[_], K, V] { + def counts: F[Map[K, Long]] + + def subscribe(key: K)(create: Resource[F, Topic[F, Option[V]]]): Stream[F, V] + + def unsubscribe(key: K): F[Unit] + } + + private object SubscriptionMap { + def makeCell[F[_]: Concurrent, K, V]: F[SubscriptionMap[F, K, V]] = + AtomicCell[F].of(Map.empty[K, Redis4CatsSubscription[F, V]]).map(fromCell[F, K, V]) + + def fromCell[F[_]: MonadCancelThrow, K, V]( + cell: AtomicCell[F, Map[K, Redis4CatsSubscription[F, V]]] + ): SubscriptionMap[F, K, V] = + new SubscriptionMap[F, K, V] { + override def counts: F[Map[K, Long]] = + cell.get.map(_.iterator.map { case (k, v) => k -> v.subscribers }.toMap) + + override def subscribe(key: K)(create: Resource[F, Topic[F, Option[V]]]): Stream[F, V] = + Stream.eval(addSubscription(key)(create)).flatMap(_.stream(remove(key))) + + private def addSubscription(key: K)(create: Resource[F, Topic[F, Option[V]]]): F[Redis4CatsSubscription[F, V]] = + cell.evalModify { subscribers => + val getSubscription = subscribers.get(key) match { + case Some(subscription) => + // We have an existing subscription, mark that it has one more subscriber. + subscription.addSubscriber.pure[F] + case None => + // No existing subscription, create a new one. + create.allocated.map { case (topic, cleanup) => + Redis4CatsSubscription(topic, subscribers = 1, cleanup) + } + } + getSubscription.map(s => (subscribers.updated(key, s), s)) + } + + private def remove(key: K): F[Unit] = + cell.evalUpdate { subscribers => + subscribers.get(key) match { + case Some(sub) => + if (sub.isLastSubscriber) sub.cleanup.as(subscribers - key) + else subscribers.updated(key, sub.removeSubscriber).pure + case None => + // We were notified about stream termination but we don't have a subscription, this would be a bug + subscribers.pure + } + } + + override def unsubscribe(key: K): F[Unit] = + cell.get.map(_.get(key)).flatMap { + // No subscription = nothing to do + case None => Applicative[F].unit + // Publish `None` which will terminate all streams, which will perform cleanup once the last stream + // terminates. + case Some(sub) => sub.topic.publish1(None).void + } + } + + def fromShards[F[_]: Monad, K, V](shards: Vector[SubscriptionMap[F, K, V]]): SubscriptionMap[F, K, V] = + new SubscriptionMap[F, K, V] { + override def counts: F[Map[K, Long]] = shards.foldMapM(_.counts) + + override def subscribe(key: K)(create: Resource[F, Topic[F, Option[V]]]): Stream[F, V] = + getKeyShard(key).subscribe(key)(create) + + override def unsubscribe(key: K): F[Unit] = + getKeyShard(key).unsubscribe(key) + + private def getKeyShard(key: K): SubscriptionMap[F, K, V] = { + val location = Math.abs(key.## % shards.size) + shards(location) + } + } + + sealed trait SubscriptionState[F[_], V] + object SubscriptionState { + final case class Active[F[_], V](subscription: Redis4CatsSubscription[F, V]) extends SubscriptionState[F, V] + final case class Starting[F[_], V](done: F[Unit]) extends SubscriptionState[F, V] + final case class ShuttingDown[F[_], V](done: F[Unit]) extends SubscriptionState[F, V] + } + + def makeRef[F[_]: Concurrent, K, V]: F[SubscriptionMap[F, K, V]] = + Ref[F].of(Map.empty[K, SubscriptionState[F, V]]).map(fromRef[F, K, V]) + + def fromRef[F[_]: Concurrent, K, V]( + ref: Ref[F, Map[K, SubscriptionState[F, V]]] + ): SubscriptionMap[F, K, V] = + new SubscriptionMap[F, K, V] { + import SubscriptionState._ + + private val mapRef = MapRef.fromSingleImmutableMapRef(ref) + + override def counts: F[Map[K, Long]] = + ref.get.map(_.iterator.collect { case (k, Active(v)) => k -> v.subscribers }.toMap) + + override def subscribe(key: K)(create: Resource[F, Topic[F, Option[V]]]): Stream[F, V] = + Stream.eval(addSubscription(key)(create)).flatMap(_.stream(remove(key))) + + private def addSubscription(key: K)(create: Resource[F, Topic[F, Option[V]]]): F[Redis4CatsSubscription[F, V]] = + Deferred[F, Unit].flatMap { d => + ref.flatModify[Redis4CatsSubscription[F, V]] { subscribers => + subscribers.get(key) match { + case Some(Active(subscription)) => + // We have an existing subscription, mark that it has one more subscriber. + val newSubscription = subscription.addSubscriber + (subscribers.updated(key, Active(newSubscription)), newSubscription.pure[F]) + case Some(ShuttingDown(wait)) => + // an existing subscription is getting shut down, wait and try again + (subscribers, wait >> addSubscription(key)(create)) + case Some(Starting(wait)) => + // an existing subscription is getting created, wait and try again + (subscribers, wait >> addSubscription(key)(create)) + case None => + // No existing subscription, create a new one. + val start = create.allocated.flatMap { case (topic, cleanup) => + val subscription = Redis4CatsSubscription(topic, subscribers = 1, cleanup) + mapRef(key).flatModify { + case Some(Starting(_)) => (Some(Active(subscription)), d.complete(()).as(subscription)) + case _ => + // this would be a bug, we only expect a starting subscription + // TODO should we error? + (None, cleanup >> d.complete(()) >> addSubscription(key)(create)) + } + } + (subscribers.updated(key, Starting(d.get)), start) + } + } + } + + private def remove(key: K): F[Unit] = + Deferred[F, Unit].flatMap { d => + mapRef(key).flatModify { + case Some(Active(sub)) => + if (sub.isLastSubscriber) { + val cleanup = sub.cleanup >> mapRef(key).flatModify { + case Some(ShuttingDown(_)) => (None, d.complete(()).void) + case _ => (None, Applicative[F].unit) // TODO bug + } + (Some(ShuttingDown(d.get)), cleanup) + } else (Some(Active(sub.removeSubscriber)), Applicative[F].unit) + case other => // bug + (other, Applicative[F].unit) + } + } + + override def unsubscribe(key: K): F[Unit] = + ref.get.map(_.get(key)).flatMap { + // No subscription = nothing to do + case None => Applicative[F].unit + // Subscription already shutting down = nothing to do + case Some(ShuttingDown(_)) => Applicative[F].unit + // Publish `None` which will terminate all streams, which will perform cleanup once the last stream + // terminates. + case Some(Active(sub)) => sub.topic.publish1(None).void + // wait until the subscription has started and unsubscribe + case Some(Starting(wait)) => wait >> unsubscribe(key) + } + } + } + + def make[F[_]: Concurrent, K, V](shards: Option[Int]): F[PubSubState[F, K, V]] = + shards.filter(_ > 1) match { + case None => singleRef[F, K, V] + case Some(n) => sharded[F, K, V](n) + } + + def single[F[_]: Concurrent, K, V]: F[PubSubState[F, K, V]] = + for { + channelSubs0 <- SubscriptionMap.makeCell[F, RedisChannel[K], V] + patternSubs0 <- SubscriptionMap.makeCell[F, RedisPattern[K], RedisPatternEvent[K, V]] + } yield new PubSubStateImpl[F, K, V](channelSubs0, patternSubs0) + + private def sharded[F[_]: Concurrent, K, V](number: Int): F[PubSubState[F, K, V]] = { + assert(number > 1) + for { + channelShards <- SubscriptionMap.makeCell[F, RedisChannel[K], V].replicateA(number) + patternShards <- SubscriptionMap.makeCell[F, RedisPattern[K], RedisPatternEvent[K, V]].replicateA(number) + } yield new PubSubStateImpl[F, K, V]( + SubscriptionMap.fromShards(channelShards.toVector), + SubscriptionMap.fromShards(patternShards.toVector) + ) + } + + private def singleRef[F[_]: Concurrent, K, V]: F[PubSubState[F, K, V]] = for { - channelSubs <- AtomicCell[F].of(Map.empty[RedisChannel[K], Redis4CatsSubscription[F, V]]) - patternSubs <- AtomicCell[F].of(Map.empty[RedisPattern[K], Redis4CatsSubscription[F, RedisPatternEvent[K, V]]]) - } yield apply(channelSubs, patternSubs) + channelSubs0 <- SubscriptionMap.makeRef[F, RedisChannel[K], V] + patternSubs0 <- SubscriptionMap.makeRef[F, RedisPattern[K], RedisPatternEvent[K, V]] + } yield new PubSubStateImpl[F, K, V](channelSubs0, patternSubs0) + private class PubSubStateImpl[F[_], K, V]( + override val channelSubs: PubSubState.SubscriptionMap[F, RedisChannel[K], V], + override val patternSubs: PubSubState.SubscriptionMap[F, RedisPattern[K], RedisPatternEvent[K, V]] + ) extends PubSubState[F, K, V] } diff --git a/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/Subscriber.scala b/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/Subscriber.scala index 9a126313..a66e9575 100644 --- a/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/Subscriber.scala +++ b/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/Subscriber.scala @@ -18,10 +18,8 @@ package dev.profunktor.redis4cats package pubsub package internals -import cats.Applicative import cats.effect.kernel._ -import cats.effect.kernel.implicits._ -import cats.effect.std.{ AtomicCell, Dispatcher } +import cats.effect.std.Dispatcher import cats.syntax.all._ import dev.profunktor.redis4cats.data.{ RedisChannel, RedisPattern, RedisPatternEvent } import dev.profunktor.redis4cats.effect.{ FutureLift, Log } @@ -66,103 +64,38 @@ private[pubsub] class Subscriber[F[_]: Async: FutureLift: Log, K, V]( Subscriber.unsubscribeFrom(pattern, state.patternSubs) override def internalChannelSubscriptions: F[Map[RedisChannel[K], Long]] = - state.channelSubs.get.map(_.iterator.map { case (k, v) => k -> v.subscribers }.toMap) + state.channelSubs.counts override def internalPatternSubscriptions: F[Map[RedisPattern[K], Long]] = - state.patternSubs.get.map(_.iterator.map { case (k, v) => k -> v.subscribers }.toMap) + state.patternSubs.counts } object Subscriber { - /** Check if we have a subscriber for this channel and remove it if we do. - * - * If it is the last subscriber, perform the subscription cleanup. - */ - private def onStreamTermination[F[_]: Applicative: Log, K, V]( - subs: AtomicCell[F, Map[K, Redis4CatsSubscription[F, V]]], - key: K - ): F[Unit] = subs.evalUpdate { subscribers => - subscribers.get(key) match { - case None => - Log[F] - .error( - s"We were notified about stream termination for $key but we don't have a subscription, " + - s"this is a bug in redis4cats!" - ) - .as(subscribers) - case Some(sub) => - if (!sub.isLastSubscriber) subscribers.updated(key, sub.removeSubscriber).pure - else sub.cleanup.as(subscribers - key) - } - } - - private def unsubscribeFrom[F[_]: MonadCancelThrow: Log, K, V]( + private def unsubscribeFrom[F[_], K, V]( key: K, - subs: AtomicCell[F, Map[K, Redis4CatsSubscription[F, V]]] + state: PubSubState.SubscriptionMap[F, K, V] ): F[Unit] = - subs.evalUpdate { subscribers => - subscribers.get(key) match { - case None => - // No subscription = nothing to do - Log[F] - .debug(s"Not unsubscribing from $key because we don't have a subscription") - .as(subscribers) - case Some(sub) => - // Publish `None` which will terminate all streams, which will perform cleanup once the last stream - // terminates. - (Log[F].info( - s"Unsubscribing from $key with ${sub.subscribers} subscribers" - ) *> sub.topic.publish1(None)).uncancelable.as(subscribers) - } - } + state.unsubscribe(key) private def subscribe[F[_]: Async: Log, TypedKey, SubValue, K, V]( key: TypedKey, - subs: AtomicCell[F, Map[TypedKey, Redis4CatsSubscription[F, SubValue]]], + state: PubSubState.SubscriptionMap[F, TypedKey, SubValue], subConnection: StatefulRedisPubSubConnection[K, V], subscribeToRedis: F[Unit], unsubscribeFromRedis: F[Unit] )(makeListener: (Dispatcher[F], Topic[F, Option[SubValue]]) => RedisPubSubListener[K, V]): Stream[F, SubValue] = - Stream - .eval(subs.evalModify { subscribers => - def stream(sub: Redis4CatsSubscription[F, SubValue]) = - sub.stream(onStreamTermination(subs, key)) - - subscribers.get(key) match { - case Some(subscription) => - // We have an existing subscription, mark that it has one more subscriber. - val newSubscription = subscription.addSubscriber - val newSubscribers = subscribers.updated(key, newSubscription) - Log[F] - .debug( - s"Returning existing subscription for $key, " + - s"subscribers: ${subscription.subscribers} -> ${newSubscription.subscribers}" - ) - .as((newSubscribers, stream(newSubscription))) - - case None => - // No existing subscription, create a new one. - val makeSubscription = for { - _ <- Log[F].info(s"Creating subscription for $key") - // We use parallel dispatcher because multiple subscribers can be interested in the same key - dispatcherTpl <- Dispatcher.parallel[F].allocated - (dispatcher, cleanupDispatcher) = dispatcherTpl - topic <- Topic[F, Option[SubValue]] - listener = makeListener(dispatcher, topic) - cleanupListener = Sync[F].delay(subConnection.removeListener(listener)) - cleanup = ( - Log[F].debug(s"Cleaning up resources for $key subscription") *> - unsubscribeFromRedis *> cleanupListener *> cleanupDispatcher *> - Log[F].debug(s"Cleaned up resources for $key subscription") - ).uncancelable - _ <- Sync[F].delay(subConnection.addListener(listener)) - _ <- subscribeToRedis - sub = Redis4CatsSubscription(topic, subscribers = 1, cleanup) - newSubscribers = subscribers.updated(key, sub) - _ <- Log[F].debug(s"Created subscription for $key") - } yield (newSubscribers, stream(sub)) - - makeSubscription.uncancelable - } - }) - .flatten + state.subscribe(key) { + for { + _ <- Resource.eval(Log[F].info(s"Creating subscription for $key")) + // We use parallel dispatcher because multiple subscribers can be interested in the same key + dispatcher <- Dispatcher.parallel[F] + topic <- Resource.eval(Topic[F, Option[SubValue]]) + _ <- Resource.make { + val listener = makeListener(dispatcher, topic) + Sync[F].delay(subConnection.addListener(listener)).as(listener) + }(listener => Sync[F].delay(subConnection.removeListener(listener))) + _ <- Resource.make(subscribeToRedis)(_ => unsubscribeFromRedis) + _ <- Resource.eval(Log[F].debug(s"Created subscription for $key")) + } yield topic + } }