Skip to content

Commit 962dfe8

Browse files
StarryZhang-whusaikatG
authored andcommitted
Add tensorflow and pytorch random seed setting.
Tests and README are included. Add support for setting tensorflow and pytorch random seed. It can help detect flaky tests due to the randomness, provide reproducibility. Co-Authored-By: Saikat Dutta <[email protected]>
1 parent e4bc22f commit 962dfe8

File tree

3 files changed

+65
-0
lines changed

3 files changed

+65
-0
lines changed

README.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ All of these features are on by default but can be disabled with flags.
6767
.. |numpy.random| replace:: ``numpy.random``
6868
__ https://numpy.org/doc/stable/reference/random/index.html
6969

70+
* If `TensorFlow <https://www.tensorflow.org/>`_ is installed, its random seed in ``tensorflow.random`` is reset at the start of every test.
71+
72+
* If `PyTorch <https://pytorch.org/>`_ is installed, its random seed is reset at the start of every test. The random seed of each test is recorded, and can play a role in detecting flaky tests.
73+
7074
* If additional random generators are used, they can be registered under the
7175
``pytest_randomly.random_seeder``
7276
`entry point <https://packaging.python.org/specifications/entry-points/>`_ and

src/pytest_randomly/__init__.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,22 @@
6767
except ImportError: # pragma: no cover
6868
have_numpy = False
6969

70+
# tensorflow
71+
try:
72+
import tensorflow as tf
73+
74+
have_tensorflow = True
75+
except ImportError: # pragma: no cover
76+
have_tensorflow = False
77+
78+
# pytorch
79+
try:
80+
import torch
81+
82+
have_pytorch = True
83+
except ImportError: # pragma: no cover
84+
have_pytorch = False
85+
7086

7187
default_seed = random.Random().getrandbits(32)
7288

@@ -180,6 +196,17 @@ def _reseed(config: Config, offset: int = 0) -> int:
180196
else:
181197
np_random.set_state(np_random_states[numpy_seed])
182198

199+
if have_tensorflow: # pragma: no branch
200+
tf.random.set_seed(seed)
201+
# TensorFlow 1.x compatibility
202+
if hasattr(tf, 'compat'):
203+
tf.compat.v1.set_random_seed(seed)
204+
205+
if have_pytorch: # pragma: no branch
206+
torch.manual_seed(seed)
207+
if torch.cuda.is_available(): # Also seed CUDA if available
208+
torch.cuda.manual_seed_all(seed)
209+
183210
if entrypoint_reseeds is None:
184211
eps = entry_points(group="pytest_randomly.random_seeder")
185212
entrypoint_reseeds = [e.load() for e in eps]

tests/test_pytest_randomly.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,40 @@ def test_one():
727727
out.assert_outcomes(passed=1)
728728

729729

730+
def test_tensorflow(ourtester):
731+
ourtester.makepyfile(
732+
test_one="""
733+
import tensorflow as tf
734+
735+
def test_one():
736+
assert tf.random.uniform([]) == tf.constant(0.16513085, dtype=tf.float32)
737+
738+
def test_two():
739+
assert tf.random.uniform([]) == tf.constant(0.16513085, dtype=tf.float32)
740+
"""
741+
)
742+
743+
out = ourtester.runpytest("--randomly-seed=1")
744+
out.assert_outcomes(passed=2)
745+
746+
747+
def test_pytorch(ourtester):
748+
ourtester.makepyfile(
749+
test_one="""
750+
import torch
751+
752+
def test_one():
753+
assert torch.rand(1) == torch.tensor([0.757631599903106689453125])
754+
755+
def test_two():
756+
assert torch.rand(1) == torch.tensor([0.757631599903106689453125])
757+
"""
758+
)
759+
760+
out = ourtester.runpytest("--randomly-seed=1")
761+
out.assert_outcomes(passed=2)
762+
763+
730764
def test_failing_import(testdir):
731765
"""Test with pytest raising CollectError or ImportError.
732766

0 commit comments

Comments
 (0)