diff --git a/retry_decorator/retry_decorator.py b/retry_decorator/retry_decorator.py index b5cc4b7..8d3bafd 100644 --- a/retry_decorator/retry_decorator.py +++ b/retry_decorator/retry_decorator.py @@ -8,6 +8,14 @@ import random +def _is_valid_iter(i): + if not isinstance(i, (list, tuple)): + return False + elif len(i) != 2: + raise ValueError("provided list|tuple needs to have size of 2") + return True + + def _deco_retry(f, exc=Exception, tries=10, timeout_secs=1.0, logger=None, callback_by_exception=None): """ Common function logic for the internal retry flows. @@ -19,22 +27,27 @@ def _deco_retry(f, exc=Exception, tries=10, timeout_secs=1.0, logger=None, callb :param callback_by_exception: :return: """ + def f_retry(*args, **kwargs): mtries, mdelay = tries, timeout_secs run_one_last_time = True + while mtries > 1: try: return f(*args, **kwargs) except exc as e: # check if this exception is something the caller wants special handling for callback_errors = callback_by_exception or {} + for error_type in callback_errors: if isinstance(e, error_type): callback_logic = callback_by_exception[error_type] - should_break_out = run_one_last_time = False - if isinstance(callback_logic, (list, tuple)): + + should_break_out = False + run_one_last_time = True + if _is_valid_iter(callback_logic): callback_logic, should_break_out = callback_logic - if isinstance(should_break_out, (list, tuple)): + if _is_valid_iter(should_break_out): should_break_out, run_one_last_time = should_break_out callback_logic() if should_break_out: # caller requests we stop handling this exception @@ -62,7 +75,7 @@ def retry(exc=Exception, tries=10, timeout_secs=1.0, logger=None, callback_by_ex :param timeout_secs: general delay between retries (we do employ a jitter) :param logger: an optional logger object :param callback_by_exception: callback/method invocation on certain exceptions - :type callback_by_exception: None or dict + :type callback_by_exception: None, list, tuple, function or dict """ # We re-use `RetryHandler` so that we can reduce duplication; decorator is still useful! retry_handler = RetryHandler(exc, tries, timeout_secs, logger, callback_by_exception) @@ -78,6 +91,14 @@ class RetryHandler(object): def __init__( self, exc=Exception, tries=10, timeout_secs=1.0, logger=None, callback_by_exception=None, ): + if not isinstance(tries, int): + raise TypeError("[tries] arg needs to be of int type") + elif tries < 1: + raise ValueError("[tries] arg needs to be an int >= 1") + + if callable(callback_by_exception) or isinstance(callback_by_exception, (list, tuple)): + callback_by_exception = {Exception: callback_by_exception} + self.exc = exc self.tries = tries self.timeout_secs = timeout_secs diff --git a/tests/test_callback.py b/tests/test_callback.py index 6cf919f..7482718 100644 --- a/tests/test_callback.py +++ b/tests/test_callback.py @@ -1,52 +1,204 @@ import unittest -import functools +from functools import partial import retry_decorator class ClassForTesting(object): hello = None + cb_counter = 0 # counts how many times callback was invoked + exe_counter = 0 # counts how many times our retriable logic was invoked + + +class ExampleTestError(Exception): + pass class_for_testing = ClassForTesting() class MyTestCase(unittest.TestCase): - def test_something(self): + + def setUp(self): + class_for_testing.hello = None + class_for_testing.cb_counter = 0 + class_for_testing.exe_counter = 0 + + def test_callback_invoked_on_configured_exception_type(self): try: my_test_func() except Exception: # for the dangling exception (the "final" function execution) pass - self.assertIn(class_for_testing.hello, ('world', 'fish', )) + self.assertEqual(class_for_testing.hello, 'world') def test_two_exceptions_to_check_use_one(self): try: my_test_func_2() except Exception: pass - self.assertIn(class_for_testing.hello, ('world', 'fish', )) + self.assertEqual(class_for_testing.hello, 'fish') + self.assertEqual(class_for_testing.cb_counter, 1) + self.assertEqual(class_for_testing.exe_counter, 2) + def test_callback_by_exception_may_be_func(self): + try: + my_test_func_3() + except Exception: + pass + self.assertEqual(class_for_testing.hello, 'foo') -def callback_logic(instance, attr_to_set, value_to_set): - print('Callback called for {}, {}, {}'.format(instance, attr_to_set, value_to_set)) - setattr(instance, attr_to_set, value_to_set) + def test_callback_by_exception_may_be_tuple(self): + try: + my_test_func_4() + except Exception: + pass + self.assertEqual(class_for_testing.hello, 'bar') + def test_verify_correct_amount_of_retries_and_callback_invokations(self): + try: + my_test_func_5() + except Exception: + pass + self.assertEqual(class_for_testing.hello, 'bar') + self.assertEqual(class_for_testing.cb_counter, 10) + self.assertEqual(class_for_testing.exe_counter, 6) -class ExampleTestError(Exception): - pass + def test_verify_correct_amount_of_retries_and_callback_invokations2(self): + try: + my_test_func_6() + except Exception: + pass + self.assertEqual(class_for_testing.hello, 'foo') + self.assertEqual(class_for_testing.cb_counter, 5) + self.assertEqual(class_for_testing.exe_counter, 6) + + def test_verify_breakout_true_works(self): + try: + my_test_func_7() + except Exception: + pass + self.assertEqual(class_for_testing.hello, 'baz') + self.assertEqual(class_for_testing.cb_counter, 6) # we had 2 handlers, but because of breakout=True only first of them was ever ran + self.assertEqual(class_for_testing.exe_counter, 7) + + def test_verify_run_last_time_false_works(self): + try: + my_test_func_8() + except Exception: + pass + self.assertEqual(class_for_testing.hello, 'bar') + self.assertEqual(class_for_testing.cb_counter, 14) + self.assertEqual(class_for_testing.exe_counter, 7) # note value is tries-1 because of run_one_last_time=False + + def test_verify_tries_1_is_ok(self): + try: + my_test_func_9() + except Exception: + pass + self.assertEqual(class_for_testing.hello, None) + self.assertEqual(class_for_testing.cb_counter, 0) + self.assertEqual(class_for_testing.exe_counter, 1) + + def test_verify_run_last_time_false_with_2_tries(self): + try: + my_test_func_10() + except Exception: + pass + self.assertEqual(class_for_testing.hello, 'foo') + self.assertEqual(class_for_testing.cb_counter, 1) + self.assertEqual(class_for_testing.exe_counter, 1) + + def test_verify_tries_0_errors_out(self): + try: + retry_decorator.retry(tries=0, callback_by_exception=partial(callback_logic, class_for_testing, 'hello', 'foo')) + raise AssertionError('Expected ValueError to be thrown') + except ValueError: + pass + + def test_verify_tries_not_int_is_error(self): + try: + retry_decorator.retry(tries='not int', callback_by_exception=partial(callback_logic, class_for_testing, 'hello', 'foo')) + raise AssertionError('Expected TypeError to be thrown') + except TypeError: + pass + + +def callback_logic(instance, attr_to_set, value_to_set): + print('Callback called for {}; setting attr [{}] to value [{}]'.format(instance, attr_to_set, value_to_set)) + setattr(instance, attr_to_set, value_to_set) + instance.cb_counter += 1 @retry_decorator.retry(exc=ExampleTestError, tries=2, callback_by_exception={ - ExampleTestError: functools.partial(callback_logic, class_for_testing, 'hello', 'world')}) + ExampleTestError: partial(callback_logic, class_for_testing, 'hello', 'world')}) def my_test_func(): raise ExampleTestError('oh noes.') @retry_decorator.retry(exc=(ExampleTestError, AttributeError), tries=2, callback_by_exception={ - AttributeError: functools.partial(callback_logic, class_for_testing, 'hello', 'fish')}) + AttributeError: partial(callback_logic, class_for_testing, 'hello', 'fish')}) def my_test_func_2(): + class_for_testing.exe_counter += 1 raise AttributeError('attribute oh noes.') +@retry_decorator.retry(tries=2, callback_by_exception=partial(callback_logic, class_for_testing, 'hello', 'foo')) +def my_test_func_3(): + raise TypeError('type oh noes.') + + +@retry_decorator.retry(tries=2, callback_by_exception=(partial(callback_logic, class_for_testing, 'hello', 'bar'), False)) +def my_test_func_4(): + raise TypeError('type oh noes.') + + +@retry_decorator.retry(tries=6, callback_by_exception={ + TypeError: partial(callback_logic, class_for_testing, 'hello', 'foo'), + Exception: partial(callback_logic, class_for_testing, 'hello', 'bar') + }) +def my_test_func_5(): + class_for_testing.exe_counter += 1 + raise TypeError('type oh noes.') + + +@retry_decorator.retry(exc=ExampleTestError, tries=6, callback_by_exception={ + TypeError: partial(callback_logic, class_for_testing, 'hello', 'bar'), + ExampleTestError: partial(callback_logic, class_for_testing, 'hello', 'foo') + }) +def my_test_func_6(): + class_for_testing.exe_counter += 1 + raise ExampleTestError('oh noes.') + + +@retry_decorator.retry(tries=7, callback_by_exception={ + TypeError: (partial(callback_logic, class_for_testing, 'hello', 'baz'), True), + Exception: partial(callback_logic, class_for_testing, 'hello', 'foo') + }) +def my_test_func_7(): + class_for_testing.exe_counter += 1 + raise TypeError('type oh noes.') + + +@retry_decorator.retry(tries=8, callback_by_exception={ + TypeError: partial(callback_logic, class_for_testing, 'hello', 'foo'), + Exception: (partial(callback_logic, class_for_testing, 'hello', 'bar'), (False, False)) + }) +def my_test_func_8(): + class_for_testing.exe_counter += 1 + raise TypeError('type oh noes.') + + +@retry_decorator.retry(tries=1, callback_by_exception=partial(callback_logic, class_for_testing, 'hello', 'foo')) +def my_test_func_9(): + class_for_testing.exe_counter += 1 + raise TypeError('type oh noes.') + + +@retry_decorator.retry(tries=2, callback_by_exception=(partial(callback_logic, class_for_testing, 'hello', 'foo'), (False, False))) +def my_test_func_10(): + class_for_testing.exe_counter += 1 + raise TypeError('type oh noes.') + + if __name__ == '__main__': unittest.main()