diff --git a/src/private/channel_spsc.nim b/src/private/channel_spsc.nim index 784417f..6de284f 100644 --- a/src/private/channel_spsc.nim +++ b/src/private/channel_spsc.nim @@ -8,18 +8,32 @@ import std/[atomics, times] type ChannelMode* = enum SPSC + MPSC ## Multi-Producer Single-Consumer SPSCSlot[T] = object value: T sequence: Atomic[int] + # Padding to prevent false sharing between producer and consumer cache lines + CacheLinePad = object + pad: array[64, byte] + Channel*[T] = ref object mode: ChannelMode # Lock-free single producer single consumer buffer: seq[SPSCSlot[T]] mask: int + + # SPSC fields head: Atomic[int] # Producer writes here + pad1: CacheLinePad # Prevent false sharing tail: Atomic[int] # Consumer reads here + pad2: CacheLinePad + + # MPSC-specific fields + mpscHead: Atomic[int] # Atomic head for multi-producer CAS + mpscCount: Atomic[int] # Track count for wait-free full check + capacity: int proc newChannel*[T](size: int, mode: ChannelMode): Channel[T] = @@ -35,9 +49,13 @@ proc newChannel*[T](size: int, mode: ChannelMode): Channel[T] = result.head.store(0, moRelaxed) result.tail.store(0, moRelaxed) -proc trySend*[T](c: Channel[T], value: T): bool = - ## Try to send a value to the channel. Returns true if successful. - ## Non-blocking operation. + # Initialize MPSC-specific fields + if mode == MPSC: + result.mpscHead.store(0, moRelaxed) + result.mpscCount.store(0, moRelaxed) + +proc trySendSPSC[T](c: Channel[T], value: T): bool = + ## SPSC-optimized send (single producer, relaxed ordering). let currentHead = c.head.load(moRelaxed) let currentTail = c.tail.load(moAcquire) @@ -54,9 +72,40 @@ proc trySend*[T](c: Channel[T], value: T): bool = c.head.store(currentHead + 1, moRelease) return true -proc tryReceive*[T](c: Channel[T], value: var T): bool = - ## Try to receive a value from the channel. Returns true if successful. - ## Non-blocking operation. +proc trySendMPSC[T](c: Channel[T], value: T): bool = + ## MPSC wait-free send (multiple producers, CAS on head). + ## Based on dbittman's wait-free MPSC algorithm + JCTools patterns. + + # Step 1: Atomically increment count to reserve a slot (wait-free) + let count = c.mpscCount.fetchAdd(1, moAcquire) + if count >= c.capacity: + # Queue full - backoff + discard c.mpscCount.fetchSub(1, moRelease) + return false + + # Step 2: Atomically claim a slot by incrementing head (wait-free) + let myHead = c.mpscHead.fetchAdd(1, moAcquire) + + # Step 3: Write to the slot (no contention, we own it) + let slot = myHead and c.mask + c.buffer[slot].value = value + + # Step 4: Publish the write with release semantics for consumer visibility + c.buffer[slot].sequence.store(myHead + 1, moRelease) + + return true + +proc trySend*[T](c: Channel[T], value: T): bool = + ## Try to send a value to the channel. Returns true if successful. + ## Non-blocking operation. Dispatches to SPSC or MPSC implementation. + case c.mode + of SPSC: + return trySendSPSC(c, value) + of MPSC: + return trySendMPSC(c, value) + +proc tryReceiveSPSC[T](c: Channel[T], value: var T): bool = + ## SPSC-optimized receive (single consumer, relaxed ordering). let currentTail = c.tail.load(moRelaxed) let currentHead = c.head.load(moAcquire) @@ -77,14 +126,57 @@ proc tryReceive*[T](c: Channel[T], value: var T): bool = c.tail.store(currentTail + 1, moRelease) return true +proc tryReceiveMPSC[T](c: Channel[T], value: var T): bool = + ## MPSC receive (single consumer, must handle concurrent producers). + ## Uses mpscHead instead of head for accurate empty check. + let currentTail = c.tail.load(moRelaxed) + let currentHead = c.mpscHead.load(moAcquire) + + # Check if empty + if currentTail >= currentHead: + return false + + # Read from slot + let slot = currentTail and c.mask + let seq = c.buffer[slot].sequence.load(moAcquire) + + # Wait-free: if producer hasn't published yet, return false (not an error) + if seq != currentTail + 1: + return false # Producer in-flight, try again later + + value = c.buffer[slot].value + + # Update tail and decrement count + c.tail.store(currentTail + 1, moRelease) + discard c.mpscCount.fetchSub(1, moRelease) + + return true + +proc tryReceive*[T](c: Channel[T], value: var T): bool = + ## Try to receive a value from the channel. Returns true if successful. + ## Non-blocking operation. Dispatches to SPSC or MPSC implementation. + case c.mode + of SPSC: + return tryReceiveSPSC(c, value) + of MPSC: + return tryReceiveMPSC(c, value) + proc capacity*[T](c: Channel[T]): int = ## Get the capacity of the channel. c.capacity proc isEmpty*[T](c: Channel[T]): bool = ## Check if the channel is empty. - c.tail.load(moRelaxed) >= c.head.load(moRelaxed) + case c.mode + of SPSC: + c.tail.load(moRelaxed) >= c.head.load(moRelaxed) + of MPSC: + c.tail.load(moRelaxed) >= c.mpscHead.load(moRelaxed) proc isFull*[T](c: Channel[T]): bool = ## Check if the channel is full. - c.head.load(moRelaxed) - c.tail.load(moRelaxed) >= c.capacity \ No newline at end of file + case c.mode + of SPSC: + c.head.load(moRelaxed) - c.tail.load(moRelaxed) >= c.capacity + of MPSC: + c.mpscCount.load(moRelaxed) >= c.capacity \ No newline at end of file diff --git a/tests/performance/benchmark_mpsc.nim b/tests/performance/benchmark_mpsc.nim new file mode 100644 index 0000000..be8493a --- /dev/null +++ b/tests/performance/benchmark_mpsc.nim @@ -0,0 +1,316 @@ +## Benchmarks comparing MPSC vs SPSC channel performance +## +## Tests throughput, latency, and scalability under various workloads + +import std/[atomics, times, strformat, strutils] +import ../../src/private/channel_spsc as ch + +# Avoid ambiguity with system.Channel +type Channel[T] = ch.Channel[T] +const SPSC = ch.SPSC +const MPSC = ch.MPSC + +type + BenchResult = object + name: string + items: int + duration: float + throughputMops: float + avgLatencyNs: float + +proc formatBenchResult(r: BenchResult): string = + &"{r.name:<40} | {r.items:>10} items | {r.duration:>6.3f}s | {r.throughputMops:>7.2f} Mops/s | {r.avgLatencyNs:>8.1f} ns/op" + +# ============================================================================ +# SPSC Benchmarks +# ============================================================================ + +proc benchSPSCThroughput(capacity, numItems: int): BenchResult = + let chan = ch.newChannel[int](capacity, SPSC) + var producerThread: Thread[tuple[chan: Channel[int], n: int]] + + proc producer(args: tuple[chan: Channel[int], n: int]) {.thread.} = + for i in 0 ..< args.n: + while not args.chan.trySend(i): + discard + + let t0 = cpuTime() + createThread(producerThread, producer, (chan, numItems)) + + var val: int + for _ in 0 ..< numItems: + while not chan.tryReceive(val): + discard + + joinThread(producerThread) + let elapsed = cpuTime() - t0 + + BenchResult( + name: &"SPSC throughput (cap={capacity})", + items: numItems, + duration: elapsed, + throughputMops: float(numItems) / elapsed / 1_000_000.0, + avgLatencyNs: elapsed / float(numItems) * 1_000_000_000.0 + ) + +proc benchSPSCLatency(capacity, numSamples: int): BenchResult = + let chan = ch.newChannel[int](capacity, SPSC) + var producerThread: Thread[tuple[chan: Channel[int], n: int]] + + proc producer(args: tuple[chan: Channel[int], n: int]) {.thread.} = + for i in 0 ..< args.n: + while not args.chan.trySend(i): + discard + + createThread(producerThread, producer, (chan, numSamples)) + + var latencies = newSeq[float](numSamples) + var val: int + + for i in 0 ..< numSamples: + let t0 = cpuTime() + while not chan.tryReceive(val): + discard + latencies[i] = (cpuTime() - t0) * 1_000_000_000.0 + + joinThread(producerThread) + + var sum = 0.0 + for lat in latencies: + sum += lat + + BenchResult( + name: &"SPSC latency (cap={capacity})", + items: numSamples, + duration: sum / 1_000_000_000.0, + throughputMops: 0.0, + avgLatencyNs: sum / float(numSamples) + ) + +# ============================================================================ +# MPSC Benchmarks +# ============================================================================ + +proc benchMPSCThroughput(capacity, numProducers, itemsPerProducer: int): BenchResult = + let chan = ch.newChannel[int](capacity, MPSC) + let totalItems = numProducers * itemsPerProducer + + var producerThreads = newSeq[Thread[tuple[chan: Channel[int], n: int]]](numProducers) + + proc producer(args: tuple[chan: Channel[int], n: int]) {.thread.} = + for i in 0 ..< args.n: + while not args.chan.trySend(i): + discard + + let t0 = cpuTime() + + for i in 0 ..< numProducers: + createThread(producerThreads[i], producer, (chan, itemsPerProducer)) + + var val: int + for _ in 0 ..< totalItems: + while not chan.tryReceive(val): + discard + + for i in 0 ..< numProducers: + joinThread(producerThreads[i]) + + let elapsed = cpuTime() - t0 + + BenchResult( + name: &"MPSC throughput {numProducers}P (cap={capacity})", + items: totalItems, + duration: elapsed, + throughputMops: float(totalItems) / elapsed / 1_000_000.0, + avgLatencyNs: elapsed / float(totalItems) * 1_000_000_000.0 + ) + +proc benchMPSCLatency(capacity, numProducers, numSamples: int): BenchResult = + let chan = ch.newChannel[int](capacity, MPSC) + var producerThreads = newSeq[Thread[tuple[chan: Channel[int], n: int]]](numProducers) + + proc producer(args: tuple[chan: Channel[int], n: int]) {.thread.} = + for i in 0 ..< args.n: + while not args.chan.trySend(i): + discard + + for i in 0 ..< numProducers: + createThread(producerThreads[i], producer, (chan, numSamples)) + + let totalItems = numProducers * numSamples + var latencies = newSeq[float](totalItems) + var val: int + + for i in 0 ..< totalItems: + let t0 = cpuTime() + while not chan.tryReceive(val): + discard + latencies[i] = (cpuTime() - t0) * 1_000_000_000.0 + + for i in 0 ..< numProducers: + joinThread(producerThreads[i]) + + var sum = 0.0 + for lat in latencies: + sum += lat + + BenchResult( + name: &"MPSC latency {numProducers}P (cap={capacity})", + items: totalItems, + duration: sum / 1_000_000_000.0, + throughputMops: 0.0, + avgLatencyNs: sum / float(totalItems) + ) + +# ============================================================================ +# Scalability Benchmark +# ============================================================================ + +proc benchMPSCScalability(): seq[BenchResult] = + result = @[] + const ItemsPerProducer = 100_000 + const Capacity = 1024 + + echo "\n=== MPSC Scalability (fixed ", ItemsPerProducer, " items/producer) ===" + echo &"{\"Benchmark\":<40} | {\"Items\":>10} | {\"Time\":>8} | {\"Throughput\":>14} | {\"Latency\":>12}" + echo repeat("=", 100) + + for numProducers in [1, 2, 4, 8]: + let r = benchMPSCThroughput(Capacity, numProducers, ItemsPerProducer) + result.add(r) + echo formatBenchResult(r) + +# ============================================================================ +# Burst Workload Benchmark +# ============================================================================ + +proc benchBurstWorkload(mode: ch.ChannelMode, capacity, numBursts, burstSize: int): BenchResult = + let chan = ch.newChannel[int](capacity, mode) + let totalItems = numBursts * burstSize + + let modeStr = if mode == SPSC: "SPSC" else: "MPSC-1P" + + var producerThread: Thread[tuple[chan: Channel[int], bursts: int, size: int]] + + proc burstProducer(args: tuple[chan: Channel[int], bursts: int, size: int]) {.thread.} = + for burst in 0 ..< args.bursts: + for i in 0 ..< args.size: + while not args.chan.trySend(burst * args.size + i): + discard + # Small pause between bursts + for _ in 0 ..< 100: discard + + let t0 = cpuTime() + createThread(producerThread, burstProducer, (chan, numBursts, burstSize)) + + var val: int + for _ in 0 ..< totalItems: + while not chan.tryReceive(val): + discard + + joinThread(producerThread) + let elapsed = cpuTime() - t0 + + BenchResult( + name: &"{modeStr} burst workload (cap={capacity})", + items: totalItems, + duration: elapsed, + throughputMops: float(totalItems) / elapsed / 1_000_000.0, + avgLatencyNs: elapsed / float(totalItems) * 1_000_000_000.0 + ) + +# ============================================================================ +# Size Comparison Benchmark +# ============================================================================ + +proc benchSizeComparison(): seq[BenchResult] = + result = @[] + const NumItems = 1_000_000 + + echo "\n=== Channel Size Impact ===" + echo &"{\"Benchmark\":<40} | {\"Items\":>10} | {\"Time\":>8} | {\"Throughput\":>14} | {\"Latency\":>12}" + echo repeat("=", 100) + + for capacity in [64, 256, 1024, 4096]: + let r1 = benchSPSCThroughput(capacity, NumItems) + result.add(r1) + echo formatBenchResult(r1) + + let r2 = benchMPSCThroughput(capacity, 4, NumItems div 4) + result.add(r2) + echo formatBenchResult(r2) + +# ============================================================================ +# Main Benchmark Suite +# ============================================================================ + +proc runBenchmarks() = + echo "\n" & repeat("=", 100) + echo "nimsync Channel Benchmarks: MPSC vs SPSC" + echo repeat("=", 100) + + var results: seq[BenchResult] = @[] + + # Throughput benchmarks + echo "\n=== Throughput Comparison ===" + echo &"{\"Benchmark\":<40} | {\"Items\":>10} | {\"Time\":>8} | {\"Throughput\":>14} | {\"Latency\":>12}" + echo repeat("=", 100) + + block: + let r = benchSPSCThroughput(1024, 1_000_000) + results.add(r) + echo formatBenchResult(r) + + for numProducers in [1, 2, 4, 8]: + let r = benchMPSCThroughput(1024, numProducers, 1_000_000 div numProducers) + results.add(r) + echo formatBenchResult(r) + + # Latency benchmarks + echo "\n=== Latency Comparison ===" + echo &"{\"Benchmark\":<40} | {\"Items\":>10} | {\"Time\":>8} | {\"Throughput\":>14} | {\"Latency\":>12}" + echo repeat("=", 100) + + block: + let r = benchSPSCLatency(128, 10_000) + results.add(r) + echo formatBenchResult(r) + + for numProducers in [1, 2, 4]: + let r = benchMPSCLatency(128, numProducers, 10_000 div numProducers) + results.add(r) + echo formatBenchResult(r) + + # Scalability + results.add(benchMPSCScalability()) + + # Size comparison + results.add(benchSizeComparison()) + + # Burst workload + echo "\n=== Burst Workload ===" + echo &"{\"Benchmark\":<40} | {\"Items\":>10} | {\"Time\":>8} | {\"Throughput\":>14} | {\"Latency\":>12}" + echo repeat("=", 100) + + block: + let r = benchBurstWorkload(SPSC, 256, 1000, 100) + results.add(r) + echo formatBenchResult(r) + + block: + let r = benchBurstWorkload(MPSC, 256, 1000, 100) + results.add(r) + echo formatBenchResult(r) + + # Summary + echo "\n" & repeat("=", 100) + echo "Benchmark Summary" + echo repeat("=", 100) + echo "- SPSC: Optimized for single producer/consumer, uses relaxed atomics" + echo "- MPSC: Supports multiple concurrent producers via wait-free fetchAdd" + echo "- Both implementations use cache-line padding to prevent false sharing" + echo "- Expected performance: 100-200M ops/sec depending on contention" + echo repeat("=", 100) & "\n" + +when isMainModule: + runBenchmarks() diff --git a/tests/unit/channels/test_mpsc_channel.nim b/tests/unit/channels/test_mpsc_channel.nim new file mode 100644 index 0000000..25682cd --- /dev/null +++ b/tests/unit/channels/test_mpsc_channel.nim @@ -0,0 +1,259 @@ +## Unit tests for MPSC (Multi-Producer Single-Consumer) channel + +import std/[atomics, times, strformat] +import ../../../src/private/channel_spsc as ch +import ../../support/test_fixtures + +# Alias to avoid conflict with system.Channel +type Channel[T] = ch.Channel[T] + +type + Stats = object + sent: Atomic[int] + received: Atomic[int] + failed: Atomic[int] + +proc testBasicMPSC() = + echo "Testing basic MPSC send/receive..." + + let chan = ch.newChannel[int](16, ch.MPSC) + assert chan.capacity == 16 + assert chan.isEmpty + + # Single send/receive + assert chan.trySend(42) + var val: int + assert chan.tryReceive(val) + assert val == 42 + assert chan.isEmpty + + echo "✓ Basic MPSC operations work" + +proc testMPSCMultiProducersSingleConsumer() = + echo "Testing MPSC with 4 producers, 1 consumer..." + + let chan = ch.newChannel[int](1024, ch.MPSC) + const NumProducers = 4 + const ItemsPerProducer = 10000 + const TotalItems = NumProducers * ItemsPerProducer + + var stats: Stats + stats.sent.store(0) + stats.received.store(0) + stats.failed.store(0) + + # Producer threads + var producerThreads: array[NumProducers, Thread[tuple[chan: Channel[int], id: int, stats: ptr Stats]]] + + proc producerProc(args: tuple[chan: Channel[int], id: int, stats: ptr Stats]) {.thread.} = + let start = args.id * ItemsPerProducer + for i in 0 ..< ItemsPerProducer: + let val = start + i + while not args.chan.trySend(val): + # Spin-wait if full + discard + discard args.stats.sent.fetchAdd(1) + + # Start producers + for i in 0 ..< NumProducers: + createThread(producerThreads[i], producerProc, (chan, i, stats.addr)) + + # Consumer (main thread) + var received: seq[int] = @[] + var val: int + + while stats.received.load() < TotalItems: + if chan.tryReceive(val): + received.add(val) + discard stats.received.fetchAdd(1) + + # Wait for producers to finish + for i in 0 ..< NumProducers: + joinThread(producerThreads[i]) + + # Drain any remaining items + while chan.tryReceive(val): + received.add(val) + discard stats.received.fetchAdd(1) + + assert received.len == TotalItems, "Expected " & $TotalItems & " items, got " & $received.len + echo "✓ Received all ", TotalItems, " items from ", NumProducers, " producers" + + # Verify all values are present (no duplicates or missing) + var expected = newSeq[bool](TotalItems) + for v in received: + assert v >= 0 and v < TotalItems, "Invalid value: " & $v + assert not expected[v], "Duplicate value: " & $v + expected[v] = true + + for i, seen in expected: + assert seen, "Missing value: " & $i + + echo "✓ All values unique and accounted for" + +proc testMPSCStressTest() = + echo "Testing MPSC stress test (1M items)..." + + let chan = ch.newChannel[int](2048, ch.MPSC) + const NumProducers = 8 + const ItemsPerProducer = 125000 # 8 * 125k = 1M + const TotalItems = NumProducers * ItemsPerProducer + + var stats: Stats + stats.sent.store(0) + stats.received.store(0) + + let startTime = cpuTime() + + var producerThreads: array[NumProducers, Thread[tuple[chan: Channel[int], id: int, stats: ptr Stats]]] + + proc producerProc(args: tuple[chan: Channel[int], id: int, stats: ptr Stats]) {.thread.} = + let start = args.id * ItemsPerProducer + for i in 0 ..< ItemsPerProducer: + while not args.chan.trySend(start + i): + discard + discard args.stats.sent.fetchAdd(1) + + for i in 0 ..< NumProducers: + createThread(producerThreads[i], producerProc, (chan, i, stats.addr)) + + # Consumer + var count = 0 + var val: int + while count < TotalItems: + if chan.tryReceive(val): + count += 1 + + for i in 0 ..< NumProducers: + joinThread(producerThreads[i]) + + let elapsed = cpuTime() - startTime + let throughput = float(TotalItems) / elapsed / 1_000_000.0 + + echo &"✓ Processed {TotalItems} items in {elapsed:.3f}s" + echo &" Throughput: {throughput:.2f}M ops/sec" + +proc testMPSCFullAndEmpty() = + echo "Testing MPSC full/empty conditions..." + + let chan = ch.newChannel[int](8, ch.MPSC) + + # Test empty + assert chan.isEmpty + assert not chan.isFull + var val: int + assert not chan.tryReceive(val) + + # Fill up + for i in 0 ..< 8: + assert chan.trySend(i), "Send " & $i & " failed" + + assert chan.isFull + assert not chan.isEmpty + assert not chan.trySend(999), "Should reject when full" + + # Drain + for i in 0 ..< 8: + assert chan.tryReceive(val), "Receive " & $i & " failed" + assert val == i + + assert chan.isEmpty + assert not chan.isFull + + echo "✓ Full/empty detection works correctly" + +proc testMPSCBurstWorkload() = + echo "Testing MPSC with bursty workload..." + + let chan = ch.newChannel[int](256, ch.MPSC) + const NumProducers = 4 + const NumBursts = 100 + const BurstSize = 50 + + var stats: Stats + stats.sent.store(0) + stats.received.store(0) + + var producerThreads: array[NumProducers, Thread[tuple[chan: Channel[int], stats: ptr Stats]]] + + proc burstProducer(args: tuple[chan: Channel[int], stats: ptr Stats]) {.thread.} = + for burst in 0 ..< NumBursts: + for i in 0 ..< BurstSize: + while not args.chan.trySend(burst * BurstSize + i): + discard + discard args.stats.sent.fetchAdd(1) + # Small pause between bursts + for _ in 0 ..< 1000: discard + + for i in 0 ..< NumProducers: + createThread(producerThreads[i], burstProducer, (chan, stats.addr)) + + # Consumer + const TotalItems = NumProducers * NumBursts * BurstSize + var val: int + var count = 0 + + while count < TotalItems: + if chan.tryReceive(val): + count += 1 + + for i in 0 ..< NumProducers: + joinThread(producerThreads[i]) + + assert count == TotalItems + echo "✓ Handled ", TotalItems, " items in bursty workload" + +proc testMPSCLatency() = + echo "Testing MPSC latency..." + + let chan = ch.newChannel[int](128, ch.MPSC) + const NumSamples = 10000 + + var latencies = newSeq[float](NumSamples) + + # Single producer/consumer for latency measurement + var producerThread: Thread[Channel[int]] + + proc latencyProducer(c: Channel[int]) {.thread.} = + for i in 0 ..< NumSamples: + while not c.trySend(i): + discard + + createThread(producerThread, latencyProducer, chan) + + var val: int + for i in 0 ..< NumSamples: + let t0 = cpuTime() + while not chan.tryReceive(val): + discard + let t1 = cpuTime() + latencies[i] = (t1 - t0) * 1_000_000_000.0 # nanoseconds + + joinThread(producerThread) + + # Calculate stats + var sum = 0.0 + var minLat = latencies[0] + var maxLat = latencies[0] + + for lat in latencies: + sum += lat + if lat < minLat: minLat = lat + if lat > maxLat: maxLat = lat + + let avgLat = sum / float(NumSamples) + + echo &"✓ Latency ({NumSamples} samples):" + echo &" Avg: {avgLat:.1f} ns/op" + echo &" Min: {minLat:.1f} ns/op" + echo &" Max: {maxLat:.1f} ns/op" + +when isMainModule: + echo "\n=== MPSC Channel Unit Tests ===" + testBasicMPSC() + testMPSCFullAndEmpty() + testMPSCMultiProducersSingleConsumer() + testMPSCBurstWorkload() + testMPSCStressTest() + testMPSCLatency() + echo "\n=== All MPSC tests passed! ===\n"