Skip to content

Commit 1cbcb4a

Browse files
mbezoyanjenkins
authored and
jenkins
committed
[scrooge] Speeding up serialization of collections and in particular arrays of primitives
Differential Revision: https://phabricator.twitter.biz/D1173708
1 parent 2468700 commit 1cbcb4a

File tree

8 files changed

+207
-30
lines changed

8 files changed

+207
-30
lines changed

scrooge-benchmark/BUILD

-5
This file was deleted.

scrooge-benchmark/BUILD.bazel

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
target(
2+
dependencies = [
3+
"scrooge/scrooge-benchmark/src/main/scala:benchmark",
4+
],
5+
)
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
scala_library(
1+
scala_benchmark_jmh(
2+
name = "benchmark",
23
sources = ["**/*.scala"],
34
compiler_option_sets = ["fatal_warnings"],
45
platform = "java8",
@@ -11,20 +12,23 @@ scala_library(
1112
"scrooge/scrooge-core/src/main/scala",
1213
"scrooge/scrooge-serializer",
1314
],
14-
exports = [
15-
"3rdparty/jvm/org/openjdk/jmh:jmh-core",
16-
],
1715
)
1816

1917
jvm_binary(
2018
name = "jmh",
2119
main = "org.openjdk.jmh.Main",
2220
platform = "java8",
2321
dependencies = [
24-
":scala",
22+
":benchmark_compiled_benchmark_lib",
2523
scoped(
2624
"3rdparty/jvm/org/slf4j:slf4j-nop",
2725
scope = "runtime",
2826
),
2927
],
3028
)
29+
30+
jvm_app(
31+
name = "jmh-bundle",
32+
basename = "scrooge-benchmark-bundle",
33+
binary = ":jmh",
34+
)

scrooge-benchmark/src/main/scala/com/twitter/scrooge/benchmark/Collections.scala

+71-15
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
package com.twitter.scrooge.benchmark
22

3+
import com.twitter.scrooge.ThriftStruct
34
import com.twitter.scrooge.ThriftStructCodec
45
import java.io.ByteArrayOutputStream
56
import java.util.concurrent.TimeUnit
67
import java.util.Random
7-
import org.apache.thrift.protocol.{TProtocol, TBinaryProtocol}
8+
import org.apache.thrift.protocol.TBinaryProtocol
9+
import org.apache.thrift.protocol.TProtocol
810
import org.apache.thrift.transport.TTransport
911
import org.openjdk.jmh.annotations._
1012
import thrift.benchmark._
@@ -39,6 +41,7 @@ class TRewindable extends TTransport {
3941

4042
def rewind(): Unit = {
4143
pos = 0
44+
arr.reset()
4245
}
4346

4447
def inspect: String = {
@@ -63,33 +66,57 @@ class Collections(size: Int) {
6366
val list: TRewindable = new TRewindable
6467
val listProt: TBinaryProtocol = new TBinaryProtocol(list)
6568

69+
val listDouble: TRewindable = new TRewindable
70+
val listDoubleProt: TBinaryProtocol = new TBinaryProtocol(listDouble)
71+
6672
val rng: Random = new Random(31415926535897932L)
6773

6874
val mapVals: mutable.Builder[(Long, String), Map[Long, String]] = Map.newBuilder[Long, String]
6975
val setVals: mutable.Builder[Long, Set[Long]] = Set.newBuilder[Long]
7076
val listVals: mutable.Builder[Long, Seq[Long]] = Seq.newBuilder[Long]
77+
val arrayVals = new Array[Long](size)
78+
val arrayDoublesVals = new Array[Double](size)
7179

72-
val m: Unit = for (_ <- (0 until size)) {
80+
val m: Unit = for (i <- (0 until size)) {
7381
val num = rng.nextLong()
7482
mapVals += (num -> num.toString)
7583
setVals += num
7684
listVals += num
85+
arrayVals(i) = num
86+
arrayDoublesVals(i) = num
7787
}
7888

79-
MapCollections.encode(MapCollections(mapVals.result), mapProt)
80-
SetCollections.encode(SetCollections(setVals.result), setProt)
81-
ListCollections.encode(ListCollections(listVals.result), listProt)
89+
val mapCollections: MapCollections = MapCollections(mapVals.result)
90+
val setCollections: SetCollections = SetCollections(setVals.result)
91+
val listCollections: ListCollections = ListCollections(listVals.result)
92+
val arrayCollections: ListCollections = ListCollections(arrayVals)
93+
val arrayDoubleCollections: ListDoubleCollections = ListDoubleCollections(arrayDoublesVals)
94+
95+
MapCollections.encode(mapCollections, mapProt)
96+
SetCollections.encode(setCollections, setProt)
97+
ListCollections.encode(listCollections, listProt)
98+
ListDoubleCollections.encode(arrayDoubleCollections, listDoubleProt)
8299

83-
def run(codec: ThriftStructCodec[_], prot: TProtocol, buff: TRewindable): Unit = {
100+
def decode(codec: ThriftStructCodec[_], prot: TProtocol, buff: TRewindable): Unit = {
84101
codec.decode(prot)
85102
buff.rewind()
86103
}
104+
105+
def encode[T <: ThriftStruct](
106+
codec: ThriftStructCodec[T],
107+
prot: TProtocol,
108+
buff: TRewindable,
109+
obj: T
110+
): Unit = {
111+
codec.encode(obj, prot)
112+
buff.rewind()
113+
}
87114
}
88115

89116
object CollectionsBenchmark {
90117
@State(Scope.Thread)
91118
class CollectionsState {
92-
@Param(Array("1", "5", "10", "100", "500", "1000"))
119+
@Param(Array("1", "5", "10", "100", "500"))
93120
var size: Int = 1
94121

95122
var col: Collections = _
@@ -98,24 +125,53 @@ object CollectionsBenchmark {
98125
def setup(): Unit = {
99126
col = new Collections(size)
100127
}
101-
102128
}
103129
}
104130

105-
@OutputTimeUnit(TimeUnit.NANOSECONDS)
131+
@OutputTimeUnit(TimeUnit.SECONDS)
106132
@BenchmarkMode(Array(Mode.Throughput))
133+
@Fork(1)
134+
@Warmup(iterations = 3, time = 10, timeUnit = TimeUnit.SECONDS)
135+
@Measurement(iterations = 5, time = 10, timeUnit = TimeUnit.SECONDS)
107136
class CollectionsBenchmark {
108137
import CollectionsBenchmark._
109138

110139
@Benchmark
111-
def timeMap(state: CollectionsState): Unit =
112-
state.col.run(MapCollections, state.col.mapProt, state.col.map)
140+
def timeEncodeList(state: CollectionsState): Unit =
141+
state.col.encode(
142+
ListCollections,
143+
state.col.listProt,
144+
state.col.list,
145+
state.col.listCollections
146+
)
147+
148+
@Benchmark
149+
def timeEncodeArray(state: CollectionsState): Unit =
150+
state.col.encode(
151+
ListCollections,
152+
state.col.listProt,
153+
state.col.list,
154+
state.col.arrayCollections
155+
)
156+
157+
@Benchmark
158+
def timeEncodeDoubleArray(state: CollectionsState): Unit =
159+
state.col.encode(
160+
ListDoubleCollections,
161+
state.col.listDoubleProt,
162+
state.col.listDouble,
163+
state.col.arrayDoubleCollections
164+
)
165+
166+
@Benchmark
167+
def timeDecodeMap(state: CollectionsState): Unit =
168+
state.col.decode(MapCollections, state.col.mapProt, state.col.map)
113169

114170
@Benchmark
115-
def timeSet(state: CollectionsState): Unit =
116-
state.col.run(SetCollections, state.col.setProt, state.col.set)
171+
def timeDecodeSet(state: CollectionsState): Unit =
172+
state.col.decode(SetCollections, state.col.setProt, state.col.set)
117173

118174
@Benchmark
119-
def timeList(state: CollectionsState): Unit =
120-
state.col.run(ListCollections, state.col.listProt, state.col.list)
175+
def timeDecodeList(state: CollectionsState): Unit =
176+
state.col.decode(ListCollections, state.col.listProt, state.col.list)
121177
}

scrooge-benchmark/src/main/thrift/collections.thrift

+4
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,7 @@ struct SetCollections {
1212
struct ListCollections {
1313
1: list<i64> longs
1414
}
15+
16+
struct ListDoubleCollections {
17+
1: list<double> doubles
18+
}

scrooge-core/src/main/scala/com/twitter/scrooge/internal/TProtocols.scala

+98-2
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@ import com.twitter.scrooge.TFieldBlob
44
import com.twitter.scrooge.ThriftEnum
55
import com.twitter.scrooge.ThriftUnion
66
import java.nio.ByteBuffer
7+
import java.util.function.ObjDoubleConsumer
8+
import java.util.function.ObjLongConsumer
79
import org.apache.thrift.protocol._
810
import scala.collection.immutable
911
import scala.collection.mutable
12+
import scala.collection.mutable.ArrayBuffer
1013

1114
/**
1215
* Reads and writes fields for a `TProtocol`. Intended to be used
@@ -93,11 +96,26 @@ final class TProtocols private[TProtocols] {
9396
elementType: Byte,
9497
writeElement: (TProtocol, T) => Unit
9598
): Unit = {
96-
protocol.writeListBegin(new TList(typeForCollection(elementType), list.size))
99+
val size = list.size
100+
protocol.writeListBegin(new TList(typeForCollection(elementType), size))
97101
list match {
102+
case wrappedArray: mutable.WrappedArray[T] =>
103+
val arr = wrappedArray.array
104+
var i = 0
105+
while (i < size) {
106+
val el: T = arr(i).asInstanceOf[T]
107+
writeElement(protocol, el)
108+
i += 1
109+
}
110+
case arrayBuffer: ArrayBuffer[T] =>
111+
var i = 0
112+
while (i < size) {
113+
writeElement(protocol, arrayBuffer(i))
114+
i += 1
115+
}
98116
case _: IndexedSeq[_] =>
99117
var i = 0
100-
while (i < list.size) {
118+
while (i < size) {
101119
writeElement(protocol, list(i))
102120
i += 1
103121
}
@@ -109,6 +127,78 @@ final class TProtocols private[TProtocols] {
109127
protocol.writeListEnd()
110128
}
111129

130+
def writeListDouble(
131+
protocol: TProtocol,
132+
list: collection.Seq[Double],
133+
elementType: Byte,
134+
writeElement: ObjDoubleConsumer[TProtocol]
135+
): Unit = {
136+
val size = list.size
137+
protocol.writeListBegin(new TList(typeForCollection(elementType), size))
138+
list match {
139+
case wrappedArray: mutable.WrappedArray.ofDouble =>
140+
val arr = wrappedArray.array
141+
var i = 0
142+
while (i < size) {
143+
writeElement.accept(protocol, arr(i))
144+
i += 1
145+
}
146+
case arrayBuffer: ArrayBuffer[Double] =>
147+
var i = 0
148+
while (i < size) {
149+
writeElement.accept(protocol, arrayBuffer(i))
150+
i += 1
151+
}
152+
case _: IndexedSeq[_] =>
153+
var i = 0
154+
while (i < size) {
155+
writeElement.accept(protocol, list(i))
156+
i += 1
157+
}
158+
case _ =>
159+
list.foreach { element =>
160+
writeElement.accept(protocol, element)
161+
}
162+
}
163+
protocol.writeListEnd()
164+
}
165+
166+
def writeListI64(
167+
protocol: TProtocol,
168+
list: collection.Seq[Long],
169+
elementType: Byte,
170+
writeElement: ObjLongConsumer[TProtocol]
171+
): Unit = {
172+
val len = list.size
173+
protocol.writeListBegin(new TList(typeForCollection(elementType), len))
174+
list match {
175+
case wrappedArray: mutable.WrappedArray.ofLong =>
176+
val arr = wrappedArray.array
177+
var i = 0
178+
while (i < len) {
179+
writeElement.accept(protocol, arr(i))
180+
i += 1
181+
}
182+
case arrayBuffer: ArrayBuffer[Long] =>
183+
var i = 0
184+
while (i < len) {
185+
writeElement.accept(protocol, arrayBuffer(i))
186+
i += 1
187+
}
188+
case _: IndexedSeq[_] =>
189+
var i = 0
190+
while (i < len) {
191+
writeElement.accept(protocol, list(i))
192+
i += 1
193+
}
194+
case _ =>
195+
list.foreach { element =>
196+
writeElement.accept(protocol, element)
197+
}
198+
}
199+
protocol.writeListEnd()
200+
}
201+
112202
def writeSet[T](
113203
protocol: TProtocol,
114204
set: collection.Set[T],
@@ -193,9 +283,15 @@ object TProtocols {
193283
val writeI64Fn: (TProtocol, Long) => Unit =
194284
(protocol, value) => protocol.writeI64(value)
195285

286+
val writeI64Consumer: ObjLongConsumer[TProtocol] =
287+
(protocol: TProtocol, value: Long) => protocol.writeI64(value)
288+
196289
val writeDoubleFn: (TProtocol, Double) => Unit =
197290
(protocol, value) => protocol.writeDouble(value)
198291

292+
val writeDoubleConsumer: ObjDoubleConsumer[TProtocol] =
293+
(protocol: TProtocol, value: Double) => protocol.writeDouble(value)
294+
199295
val writeStringFn: (TProtocol, String) => Unit =
200296
(protocol, value) => protocol.writeString(value)
201297

scrooge-generator/src/main/scala/com/twitter/scrooge/backend/StructTemplate.scala

+20-3
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,25 @@ trait StructTemplate { self: TemplateGenerator =>
236236
}
237237
}
238238

239+
@scala.annotation.tailrec
240+
private[this] def genWriteListFn(
241+
elementFieldType: FieldType,
242+
fieldName: CodeFragment,
243+
protoName: String
244+
): CodeFragment = {
245+
elementFieldType match {
246+
case at: AnnotatedFieldType => genWriteListFn(at.unwrap, fieldName, protoName)
247+
case TDouble =>
248+
v(s"$rootProtos.writeListDouble($protoName, $fieldName, TType.DOUBLE, _root_.com.twitter.scrooge.internal.TProtocols.writeDoubleConsumer)")
249+
case TI64 =>
250+
v(s"$rootProtos.writeListI64($protoName, $fieldName, TType.I64, _root_.com.twitter.scrooge.internal.TProtocols.writeI64Consumer)")
251+
case _ =>
252+
val elemFieldType = s"TType.${genConstType(elementFieldType)}"
253+
val writeElementFn = genWriteValueFn2(elementFieldType)
254+
v(s"$rootProtos.writeList($protoName, $fieldName, $elemFieldType, $writeElementFn)")
255+
}
256+
}
257+
239258
@scala.annotation.tailrec
240259
private[this] def genWriteValueFn2(fieldType: FieldType): CodeFragment = {
241260
fieldType match {
@@ -306,9 +325,7 @@ trait StructTemplate { self: TemplateGenerator =>
306325
val writeElement = genWriteValueFn2(t.eltType)
307326
v(s"$rootProtos.writeSet($protoName, $fieldName, $elemFieldType, $writeElement)")
308327
case t: ListType =>
309-
val elemFieldType = s"TType.${genConstType(t.eltType)}"
310-
val writeElement = genWriteValueFn2(t.eltType)
311-
v(s"$rootProtos.writeList($protoName, $fieldName, $elemFieldType, $writeElement)")
328+
genWriteListFn(t.eltType, fieldName, protoName)
312329
case t: MapType =>
313330
val keyType = s"TType.${genConstType(t.keyType)}"
314331
val valType = s"TType.${genConstType(t.valueType)}"

0 commit comments

Comments
 (0)