@@ -251,6 +251,13 @@ def delivery_timeout_ms(self):
251
251
def next_expiry_time_ms (self ):
252
252
return self ._next_batch_expiry_time_ms
253
253
254
+ def _tp_lock (self , tp ):
255
+ if tp not in self ._tp_locks :
256
+ with self ._tp_locks [None ]:
257
+ if tp not in self ._tp_locks :
258
+ self ._tp_locks [tp ] = threading .Lock ()
259
+ return self ._tp_locks [tp ]
260
+
254
261
def append (self , tp , timestamp_ms , key , value , headers , now = None ):
255
262
"""Add a record to the accumulator, return the append result.
256
263
@@ -275,12 +282,7 @@ def append(self, tp, timestamp_ms, key, value, headers, now=None):
275
282
# not miss batches in abortIncompleteBatches().
276
283
self ._appends_in_progress .increment ()
277
284
try :
278
- if tp not in self ._tp_locks :
279
- with self ._tp_locks [None ]:
280
- if tp not in self ._tp_locks :
281
- self ._tp_locks [tp ] = threading .Lock ()
282
-
283
- with self ._tp_locks [tp ]:
285
+ with self ._tp_lock (tp ):
284
286
# check if we have an in-progress batch
285
287
dq = self ._batches [tp ]
286
288
if dq :
@@ -290,7 +292,7 @@ def append(self, tp, timestamp_ms, key, value, headers, now=None):
290
292
batch_is_full = len (dq ) > 1 or last .records .is_full ()
291
293
return future , batch_is_full , False
292
294
293
- with self ._tp_locks [ tp ] :
295
+ with self ._tp_lock ( tp ) :
294
296
# Need to check if producer is closed again after grabbing the
295
297
# dequeue lock.
296
298
assert not self ._closed , 'RecordAccumulator is closed'
@@ -333,8 +335,7 @@ def expired_batches(self, now=None):
333
335
"""Get a list of batches which have been sitting in the accumulator too long and need to be expired."""
334
336
expired_batches = []
335
337
for tp in list (self ._batches .keys ()):
336
- assert tp in self ._tp_locks , 'TopicPartition not in locks dict'
337
- with self ._tp_locks [tp ]:
338
+ with self ._tp_lock (tp ):
338
339
# iterate over the batches and expire them if they have stayed
339
340
# in accumulator for more than request_timeout_ms
340
341
dq = self ._batches [tp ]
@@ -352,14 +353,12 @@ def expired_batches(self, now=None):
352
353
353
354
def reenqueue (self , batch , now = None ):
354
355
"""
355
- Re-enqueue the given record batch in the accumulator. In Sender.completeBatch method, we check
356
- whether the batch has reached deliveryTimeoutMs or not. Hence we do not do the delivery timeout check here.
356
+ Re-enqueue the given record batch in the accumulator. In Sender._complete_batch method, we check
357
+ whether the batch has reached delivery_timeout_ms or not. Hence we do not do the delivery timeout check here.
357
358
"""
358
359
batch .retry (now = now )
359
- assert batch .topic_partition in self ._tp_locks , 'TopicPartition not in locks dict'
360
- assert batch .topic_partition in self ._batches , 'TopicPartition not in batches'
361
- dq = self ._batches [batch .topic_partition ]
362
- with self ._tp_locks [batch .topic_partition ]:
360
+ with self ._tp_lock (batch .topic_partition ):
361
+ dq = self ._batches [batch .topic_partition ]
363
362
dq .appendleft (batch )
364
363
365
364
def ready (self , cluster , now = None ):
@@ -412,7 +411,7 @@ def ready(self, cluster, now=None):
412
411
elif tp in self .muted :
413
412
continue
414
413
415
- with self ._tp_locks [ tp ] :
414
+ with self ._tp_lock ( tp ) :
416
415
dq = self ._batches [tp ]
417
416
if not dq :
418
417
continue
@@ -445,7 +444,7 @@ def ready(self, cluster, now=None):
445
444
def has_undrained (self ):
446
445
"""Check whether there are any batches which haven't been drained"""
447
446
for tp in list (self ._batches .keys ()):
448
- with self ._tp_locks [ tp ] :
447
+ with self ._tp_lock ( tp ) :
449
448
dq = self ._batches [tp ]
450
449
if len (dq ):
451
450
return True
@@ -485,7 +484,7 @@ def drain_batches_for_one_node(self, cluster, node_id, max_size, now=None):
485
484
if tp not in self ._batches :
486
485
continue
487
486
488
- with self ._tp_locks [ tp ] :
487
+ with self ._tp_lock ( tp ) :
489
488
dq = self ._batches [tp ]
490
489
if len (dq ) == 0 :
491
490
continue
@@ -619,7 +618,7 @@ def _abort_batches(self, error):
619
618
for batch in self ._incomplete .all ():
620
619
tp = batch .topic_partition
621
620
# Close the batch before aborting
622
- with self ._tp_locks [ tp ] :
621
+ with self ._tp_lock ( tp ) :
623
622
batch .records .close ()
624
623
self ._batches [tp ].remove (batch )
625
624
batch .abort (error )
@@ -628,7 +627,7 @@ def _abort_batches(self, error):
628
627
def abort_undrained_batches (self , error ):
629
628
for batch in self ._incomplete .all ():
630
629
tp = batch .topic_partition
631
- with self ._tp_locks [ tp ] :
630
+ with self ._tp_lock ( tp ) :
632
631
aborted = False
633
632
if not batch .is_done :
634
633
aborted = True
0 commit comments