Skip to content

Commit 7b17dd5

Browse files
tomhennigandiegolascasas
authored andcommitted
Add split_rng=False (current default) to HTM.
Haiku plans to make split_rng a required argument to hk.vmap in an upcoming release. This change updates HTM to preserve the current behaviour. We also handle the case where users are using a release of Haiku without the split_rng option, for these users split_rng=False is implied. PiperOrigin-RevId: 428454975
1 parent 840bfe8 commit 7b17dd5

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

hierarchical_transformer_memory/hierarchical_attention/htm_attention.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
"""Haiku module implementing hierarchical attention over memory."""
1616

17+
import functools
18+
import inspect
1719
from typing import Optional, NamedTuple
1820

1921
import chex
@@ -198,7 +200,7 @@ def do_attention(sub_sub_inputs, sub_sub_top_k_contents):
198200
key=sub_sub_top_k_contents,
199201
value=sub_sub_top_k_contents)
200202
return sub_attention_results
201-
do_attention = hk.vmap(do_attention, in_axes=0)
203+
do_attention = hk_vmap(do_attention, in_axes=0, split_rng=False)
202204
attention_results = do_attention(sub_inputs, top_k_contents)
203205
attention_results = jnp.squeeze(attention_results, axis=2)
204206
# Now collapse results across k memories
@@ -207,12 +209,27 @@ def do_attention(sub_sub_inputs, sub_sub_top_k_contents):
207209
return attention_results
208210

209211
# vmap across batch
210-
batch_within_memory_attention = hk.vmap(_within_memory_attention,
211-
in_axes=0)
212+
batch_within_memory_attention = hk_vmap(_within_memory_attention,
213+
in_axes=0, split_rng=False)
212214
outputs = batch_within_memory_attention(
213215
queries,
214216
jax.lax.stop_gradient(augmented_contents),
215217
weights,
216218
top_k_indices)
217219

218220
return outputs
221+
222+
223+
@functools.wraps(hk.vmap)
224+
def hk_vmap(*args, **kwargs):
225+
"""Helper function to support older versions of Haiku."""
226+
# Older versions of Haiku did not have split_rng, but the behavior has always
227+
# been equivalent to split_rng=False.
228+
if "split_rng" not in inspect.signature(hk.vmap).parameters:
229+
kwargs.setdefault("split_rng", False)
230+
if kwargs.get["split_rng"]:
231+
raise ValueError("The installed version of Haiku only supports "
232+
"`split_rng=False`, please upgrade Haiku.")
233+
del kwargs["split_rng"]
234+
235+
return hk.vmap(*args, **kwargs)

0 commit comments

Comments
 (0)