Skip to content

Commit b1dae2e

Browse files
committed
RecordAccumulator: Use helper method to get/set _tp_locks; get dq with lock in reenqueue()
1 parent 707913f commit b1dae2e

File tree

1 file changed

+19
-20
lines changed

1 file changed

+19
-20
lines changed

kafka/producer/record_accumulator.py

+19-20
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,13 @@ def delivery_timeout_ms(self):
251251
def next_expiry_time_ms(self):
252252
return self._next_batch_expiry_time_ms
253253

254+
def _tp_lock(self, tp):
255+
if tp not in self._tp_locks:
256+
with self._tp_locks[None]:
257+
if tp not in self._tp_locks:
258+
self._tp_locks[tp] = threading.Lock()
259+
return self._tp_locks[tp]
260+
254261
def append(self, tp, timestamp_ms, key, value, headers, now=None):
255262
"""Add a record to the accumulator, return the append result.
256263
@@ -275,12 +282,7 @@ def append(self, tp, timestamp_ms, key, value, headers, now=None):
275282
# not miss batches in abortIncompleteBatches().
276283
self._appends_in_progress.increment()
277284
try:
278-
if tp not in self._tp_locks:
279-
with self._tp_locks[None]:
280-
if tp not in self._tp_locks:
281-
self._tp_locks[tp] = threading.Lock()
282-
283-
with self._tp_locks[tp]:
285+
with self._tp_lock(tp):
284286
# check if we have an in-progress batch
285287
dq = self._batches[tp]
286288
if dq:
@@ -290,7 +292,7 @@ def append(self, tp, timestamp_ms, key, value, headers, now=None):
290292
batch_is_full = len(dq) > 1 or last.records.is_full()
291293
return future, batch_is_full, False
292294

293-
with self._tp_locks[tp]:
295+
with self._tp_lock(tp):
294296
# Need to check if producer is closed again after grabbing the
295297
# dequeue lock.
296298
assert not self._closed, 'RecordAccumulator is closed'
@@ -333,8 +335,7 @@ def expired_batches(self, now=None):
333335
"""Get a list of batches which have been sitting in the accumulator too long and need to be expired."""
334336
expired_batches = []
335337
for tp in list(self._batches.keys()):
336-
assert tp in self._tp_locks, 'TopicPartition not in locks dict'
337-
with self._tp_locks[tp]:
338+
with self._tp_lock(tp):
338339
# iterate over the batches and expire them if they have stayed
339340
# in accumulator for more than request_timeout_ms
340341
dq = self._batches[tp]
@@ -352,14 +353,12 @@ def expired_batches(self, now=None):
352353

353354
def reenqueue(self, batch, now=None):
354355
"""
355-
Re-enqueue the given record batch in the accumulator. In Sender.completeBatch method, we check
356-
whether the batch has reached deliveryTimeoutMs or not. Hence we do not do the delivery timeout check here.
356+
Re-enqueue the given record batch in the accumulator. In Sender._complete_batch method, we check
357+
whether the batch has reached delivery_timeout_ms or not. Hence we do not do the delivery timeout check here.
357358
"""
358359
batch.retry(now=now)
359-
assert batch.topic_partition in self._tp_locks, 'TopicPartition not in locks dict'
360-
assert batch.topic_partition in self._batches, 'TopicPartition not in batches'
361-
dq = self._batches[batch.topic_partition]
362-
with self._tp_locks[batch.topic_partition]:
360+
with self._tp_lock(batch.topic_partition):
361+
dq = self._batches[batch.topic_partition]
363362
dq.appendleft(batch)
364363

365364
def ready(self, cluster, now=None):
@@ -412,7 +411,7 @@ def ready(self, cluster, now=None):
412411
elif tp in self.muted:
413412
continue
414413

415-
with self._tp_locks[tp]:
414+
with self._tp_lock(tp):
416415
dq = self._batches[tp]
417416
if not dq:
418417
continue
@@ -445,7 +444,7 @@ def ready(self, cluster, now=None):
445444
def has_undrained(self):
446445
"""Check whether there are any batches which haven't been drained"""
447446
for tp in list(self._batches.keys()):
448-
with self._tp_locks[tp]:
447+
with self._tp_lock(tp):
449448
dq = self._batches[tp]
450449
if len(dq):
451450
return True
@@ -485,7 +484,7 @@ def drain_batches_for_one_node(self, cluster, node_id, max_size, now=None):
485484
if tp not in self._batches:
486485
continue
487486

488-
with self._tp_locks[tp]:
487+
with self._tp_lock(tp):
489488
dq = self._batches[tp]
490489
if len(dq) == 0:
491490
continue
@@ -619,7 +618,7 @@ def _abort_batches(self, error):
619618
for batch in self._incomplete.all():
620619
tp = batch.topic_partition
621620
# Close the batch before aborting
622-
with self._tp_locks[tp]:
621+
with self._tp_lock(tp):
623622
batch.records.close()
624623
self._batches[tp].remove(batch)
625624
batch.abort(error)
@@ -628,7 +627,7 @@ def _abort_batches(self, error):
628627
def abort_undrained_batches(self, error):
629628
for batch in self._incomplete.all():
630629
tp = batch.topic_partition
631-
with self._tp_locks[tp]:
630+
with self._tp_lock(tp):
632631
aborted = False
633632
if not batch.is_done:
634633
aborted = True

0 commit comments

Comments
 (0)