diff --git a/kafka/producer/record_accumulator.py b/kafka/producer/record_accumulator.py index 1c250ee40..77d48d84f 100644 --- a/kafka/producer/record_accumulator.py +++ b/kafka/producer/record_accumulator.py @@ -328,6 +328,9 @@ def append(self, tp, timestamp_ms, key, value, headers, now=None): finally: self._appends_in_progress.decrement() + def reset_next_batch_expiry_time(self): + self._next_batch_expiry_time_ms = float('inf') + def maybe_update_next_batch_expiry_time(self, batch): self._next_batch_expiry_time_ms = min(self._next_batch_expiry_time_ms, batch.created * 1000 + self.delivery_timeout_ms) diff --git a/kafka/producer/sender.py b/kafka/producer/sender.py index 4a88b2f7a..7a4c557c8 100644 --- a/kafka/producer/sender.py +++ b/kafka/producer/sender.py @@ -77,7 +77,7 @@ def _maybe_remove_from_inflight_batches(self, batch): queue.pop() heapq.heapify(queue) - def _get_expired_inflight_batches(self): + def _get_expired_inflight_batches(self, now=None): """Get the in-flight batches that has reached delivery timeout.""" expired_batches = [] to_remove = [] @@ -174,7 +174,7 @@ def run_once(self): def _send_producer_data(self, now=None): now = time.time() if now is None else now # get the list of partitions with data ready to send - result = self._accumulator.ready(self._metadata) + result = self._accumulator.ready(self._metadata, now=now) ready_nodes, next_ready_check_delay, unknown_leaders_exist = result # if there are any partitions whose leaders are not known yet, force @@ -195,7 +195,7 @@ def _send_producer_data(self, now=None): # create produce requests batches_by_node = self._accumulator.drain( - self._metadata, ready_nodes, self.config['max_request_size']) + self._metadata, ready_nodes, self.config['max_request_size'], now=now) for batch_list in six.itervalues(batches_by_node): for batch in batch_list: @@ -209,8 +209,9 @@ def _send_producer_data(self, now=None): for batch in batch_list: self._accumulator.muted.add(batch.topic_partition) - expired_batches = self._accumulator.expired_batches() - expired_batches.extend(self._get_expired_inflight_batches()) + self._accumulator.reset_next_batch_expiry_time() + expired_batches = self._accumulator.expired_batches(now=now) + expired_batches.extend(self._get_expired_inflight_batches(now=now)) if expired_batches: log.debug("%s: Expired %s batches in accumulator", str(self), len(expired_batches)) diff --git a/test/test_sender.py b/test/test_sender.py index 0731454df..6d29c1e44 100644 --- a/test/test_sender.py +++ b/test/test_sender.py @@ -240,3 +240,16 @@ def test_maybe_wait_for_producer_id(): def test_run_once(): pass + + +def test__send_producer_data_expiry_time_reset(sender, accumulator, mocker): + now = time.time() + tp = TopicPartition('foo', 0) + mocker.patch.object(sender, '_failed_produce') + result = accumulator.append(tp, 0, b'key', b'value', [], now=now) + poll_timeout_ms = sender._send_producer_data(now=now) + assert poll_timeout_ms == accumulator.config['delivery_timeout_ms'] + sender._failed_produce.assert_not_called() + now += accumulator.config['delivery_timeout_ms'] + poll_timeout_ms = sender._send_producer_data(now=now) + assert poll_timeout_ms > 0