Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal change. #283

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 44 additions & 15 deletions tensorflow_gnn/models/multi_head_attention/layers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pyformat: mode=yapf
# Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -13,7 +14,7 @@
# limitations under the License.
# ==============================================================================
"""Contains a Multi-Head Attention and associated layers."""
from typing import Any, Callable, Collection, Mapping, Optional, Union
from typing import Any, Callable, Collection, Literal, Mapping, Optional, Union
import warnings

import tensorflow as tf
Expand Down Expand Up @@ -72,13 +73,13 @@ class MultiHeadAttentionConv(tfgnn.keras.layers.AnyToAnyConvolutionBase):
attended to each other, which means we do NOT compute $N^2$ pairs of scores
as the original Transformer-style Attention.

Users are able to remove the scaling of attention scores (score_scaling=False)
or add an activation on the transformed query (controled by
`attention_activation`). However, we recommend to remove the scaling when
using an `attention_activation` since activating both of them may lead to
degrated accuracy. One can also customize the transformation kernels with
different intializers, regularizers as well as the use of bias terms, using
the other arguments.
Users are able to remove the scaling of attention scores
(`score_scaling="none"`) or add an activation on the transformed query
(controlled by `attention_activation`). However, we recommend to remove the
scaling when using an `attention_activation` since activating both of them may
lead to degraded accuracy. One can also customize the transformation kernels
with different initializers, regularizers as well as the use of bias terms,
using the other arguments.

Example: Transformer-style attention on neighbors along incoming edges
whose result is concatenated with the old node state and passed through
Expand Down Expand Up @@ -157,9 +158,15 @@ class MultiHeadAttentionConv(tfgnn.keras.layers.AnyToAnyConvolutionBase):
only queries are transformed since the two transformations on queries and
keys are equivalent to one. (The presence of transformations on values is
independent of this arg.)
score_scaling: If true, the attention scores are divided by the square root
of the dimension of keys (i.e., per_head_channels if transform_keys=True,
else whatever the dimension of combined sender inputs is).
score_scaling: One of either `"none"`, `"rsqrt_dim"`, or
`"trainable_sigmoid"`. If set to `"rsqrt_dim"`, the attention scores are
divided by the square root of the dimension of keys (i.e.,
`per_head_channels` if `transform_keys=True`, otherwise whatever the
dimension of combined sender inputs is). If set to `"trainable_sigmoid"`,
the scores are scaled with `sigmoid(x)`, where `x` is a trainable weight
of the model that is initialized to `-5.0`, which initially makes all the
attention weights equal and slowly ramps up as the other weights in the
layer converge. Defaults to `"rsqrt_dim"`.
transform_values_after_pooling: By default, each attention head applies
the value transformation, then pools with attention coefficients.
Setting this option pools inputs with attention coefficients, then applies
Expand All @@ -186,7 +193,8 @@ def __init__(
kernel_regularizer: Union[None, str,
tf.keras.regularizers.Regularizer] = None,
transform_keys: bool = True,
score_scaling: bool = True,
score_scaling: Literal["none", "rsqrt_dim",
"trainable_sigmoid"] = "rsqrt_dim",
transform_values_after_pooling: bool = False,
**kwargs):
kwargs.setdefault("name", "multi_head_attention_conv")
Expand Down Expand Up @@ -222,7 +230,7 @@ def __init__(
self._edge_dropout_layer = None

# Check for conflicting options.
if attention_activation is not None and score_scaling:
if attention_activation is not None and score_scaling != "none":
warnings.warn(
"using both an activation on transformed inputs and score scaling "
"may lead to degraded accuracy if the activation function restricts "
Expand Down Expand Up @@ -300,6 +308,9 @@ def __init__(
kernel_regularizer=kernel_regularizer,
name="value_pooled")

if self._score_scaling == "trainable_sigmoid":
self._score_scaling_weight = None

def get_config(self):
return dict(
num_heads=self._num_heads,
Expand Down Expand Up @@ -419,9 +430,27 @@ def convolve(self,
# [num_items, *extra_dims, num_heads, 1]
attention_coefficients = tf.expand_dims(
tf.einsum("...j,...j->...", queries, keys), axis=-1)
if self._score_scaling:

# Optionally scale the attention scores.
if self._score_scaling == "none":
pass
elif self._score_scaling == "rsqrt_dim":
attention_coefficients *= tf.math.rsqrt(
tf.cast(keys.shape[-1], tf.float32))
tf.cast(tf.shape(keys)[-1], tf.float32))
elif self._score_scaling == "trainable_sigmoid":
if self._score_scaling_weight is None:
self._score_scaling_weight = self.add_weight(
name="score_scaling",
shape=[],
dtype=tf.float32,
initializer=tf.keras.initializers.Constant(-5.0),
trainable=True,
)
attention_coefficients *= tf.keras.activations.sigmoid(
self._score_scaling_weight)
else:
raise ValueError("Unknown value MultiHeadAttentionConv("
f"score_scaling='{self._score_scaling}')")

attention_coefficients = extra_receiver_ops["softmax"](
attention_coefficients)
Expand Down
150 changes: 131 additions & 19 deletions tensorflow_gnn/models/multi_head_attention/layers_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pyformat: mode=yapf
# Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -15,6 +16,7 @@
"""Tests for Multi-Head Attention."""

import enum
import math
import os

from absl.testing import parameterized
Expand All @@ -32,8 +34,7 @@ class ReloadModel(int, enum.Enum):

class MultiHeadAttentionTest(tf.test.TestCase, parameterized.TestCase):

@parameterized.named_parameters(("", False),
("TransformAfter", True))
@parameterized.named_parameters(("", False), ("TransformAfter", True))
def testBasic(self, transform_values_after_pooling):
"""Tests that a single-headed MHA is correct given predefined weights."""
# NOTE: Many following tests use minor variations of the explicit
Expand Down Expand Up @@ -121,13 +122,12 @@ def testBasic(self, transform_values_after_pooling):
[0., 0., 0.])
else:
# Same weights, but as Einsum kernel "hvc".
weights["multi_head_attention_conv/value_pooled/kernel:0"].assign(
[[
[0., -1., 0.],
[-1., 0., 0.],
[-1., -1., 0.],
[0., 0., 1.1],
]])
weights["multi_head_attention_conv/value_pooled/kernel:0"].assign([[
[0., -1., 0.],
[-1., 0., 0.],
[-1., -1., 0.],
[0., 0., 1.1],
]])
weights["multi_head_attention_conv/value_pooled/bias:0"].assign(
[[0., 0., 0.]])

Expand Down Expand Up @@ -165,8 +165,7 @@ def testBasic(self, transform_values_after_pooling):
self.assertAllClose(got_2, want_2, atol=.0001)

def testAttentionActivation(self):
"""Tests that a single-headed MHA correctly applies attention activations.
"""
"""Tests that a single-headed MHA correctly applies attention activations."""

# The same test graph as in the testBasic above.
gt_input = _get_test_bidi_cycle_graph(
Expand All @@ -177,16 +176,15 @@ def testAttentionActivation(self):
]))

def get_conv(attention_activation=None):
"""Constructs a MultiHeadAttentionConv with the given attention_activation.
"""
"""Constructs a MultiHeadAttentionConv with the given attention_activation."""

conv = multi_head_attention.MultiHeadAttentionConv(
num_heads=1,
per_head_channels=3,
receiver_tag=tfgnn.TARGET,
attention_activation=attention_activation,
activation=None,
score_scaling=False,
score_scaling="none",
)

_ = conv(gt_input, edge_set_name="edges") # Build weights.
Expand Down Expand Up @@ -290,6 +288,122 @@ def get_conv(attention_activation=None):
self.assertAllEqual(got.shape, (3, 3))
self.assertAllClose(got, want, atol=.0001)

def testScoreScalingTypes(self):
"""Tests that the different types of score scaling are applied correctly."""

# The same test graph as in the testBasic above.
gt_input = _get_test_bidi_cycle_graph(
tf.constant([
[1.0, 0.0, 0.0, 1.0],
[0.0, 1.0, 0.0, 2.0],
[0.0, 0.0, 1.0, 3.0],
]))

def get_conv(score_scaling=None):
"""Constructs a MultiHeadAttentionConv with the given score_scaling."""

conv = multi_head_attention.MultiHeadAttentionConv(
num_heads=1,
per_head_channels=3,
receiver_tag=tfgnn.TARGET,
activation=None,
score_scaling=score_scaling,
)

_ = conv(gt_input, edge_set_name="edges") # Build weights.
weights = {v.name: v for v in conv.trainable_weights}
if score_scaling == "trainable_sigmoid":
# Additional trainable weight for the score scaling.
self.assertLen(weights, 7)
else:
self.assertLen(weights, 6)

weights["multi_head_attention_conv/query/kernel:0"].assign(
# The node states times the query kernel should be:
#
# [[0., 1., 0.],
# [0., 0., -1.],
# [1., 0., 0.]]
#
# i.e. the second query vector has negative values, which, after
# activation with the `relu` function, should be all zeros.
[
[0.0, 1.0, 0.0],
[0.0, 0.0, -1.0],
[1.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
])
weights["multi_head_attention_conv/query/bias:0"].assign([0.0, 0.0, 0.0])

weights["multi_head_attention_conv/key_node/kernel:0"].assign(
# The key_node kernel is chosen such that the the product with the
# node states is:
#
# [[-1., 0., 0.],
# [0., 1., 0.],
# [0., 0., 1.]]
#
# i.e. the third key vector has negative values, which, after
# activation with the `relu` function, should be all zeros.
[
[-1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[0.0, 0.0, 0.0],
])
weights["multi_head_attention_conv/key_node/bias:0"].assign(
[0.0, 0.0, 0.0])

# The attention scores are computed as the product of the transformed
# queries and keys (with a zero diagonal since there are no self edges and
# hence no self-attention), scaled by a factor s.
#
# s * [[0., 1., 0.], [[ 0, s, 0],
# [0., 0., -1], == [ 0, 0, -s],
# [-1, 0., 0.]] [-s, 0, 0]]
#
# Attention weights are computed by applying softmax to each row except
# the diagonal element. Recall that
# softmax([s, 0]) = [exp(s), 1] / (exp(s) + 1), and
# softmax([0, -s]) = softmax([s, 0]) = [exp(s), 1] / (exp(s) + 1),
# which explains the expected values below, with w = exp(s).

weights["multi_head_attention_conv/value_node/kernel:0"].assign(
# Identity matrix such that the transformed node states are `eye(3)`.
[
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[0.0, 0.0, 0.0],
])
weights["multi_head_attention_conv/value_node/bias:0"].assign(
[0.0, 0.0, 0.0])

return conv

named_scalings = {
"none": 1.0,
"rsqrt_dim": 1.0 / math.sqrt(3.0),
"trainable_sigmoid": tf.keras.activations.sigmoid(-5.0),
}

for scaling_name, scaling_factor in named_scalings.items():
with self.subTest(f"with_{scaling_name}"):
conv = get_conv(score_scaling=scaling_name)
got = conv(gt_input, edge_set_name="edges")

# Since the transformed values are just the identity matrix, we recover
# the attention weights for each query.
w = tf.math.exp(scaling_factor).numpy()
want = tf.constant([
[0.0, w, 1.0],
[w, 0.0, 1.0],
[1.0, w, 0.0],
]) / tf.constant(
w + 1.0, dtype=tf.float32)
self.assertAllEqual(got.shape, (3, 3))
self.assertAllClose(got, want, atol=0.0001)

def testNoTransformKeys(self):
"""Tests that the no key transformation variant of MHA is correct."""

Expand Down Expand Up @@ -386,8 +500,7 @@ def testNoTransformKeys(self):
self.assertAllEqual(got_2.shape, (3, 2, 3))
self.assertAllClose(got_2, want_2, atol=.0001)

@parameterized.named_parameters(("", False),
("TransformAfter", True))
@parameterized.named_parameters(("", False), ("TransformAfter", True))
def testMultihead(self, transform_values_after_pooling):
"""Extends testBasic with multiple attention heads."""
# The same test graph as in the testBasic above.
Expand All @@ -404,7 +517,7 @@ def testMultihead(self, transform_values_after_pooling):
receiver_tag=tfgnn.TARGET,
activation="relu",
use_bias=False, # Don't create /bias variables.
score_scaling=False, # Disable score scaling.
score_scaling="none", # Disable score scaling.
transform_values_after_pooling=transform_values_after_pooling,
)

Expand Down Expand Up @@ -475,8 +588,7 @@ def testMultihead(self, transform_values_after_pooling):
self.assertAllClose(got, want, atol=.0001)

@parameterized.named_parameters(
("", ReloadModel.SKIP, False),
("TransformAfter", ReloadModel.SKIP, True),
("", ReloadModel.SKIP, False), ("TransformAfter", ReloadModel.SKIP, True),
("Restored", ReloadModel.SAVED_MODEL, False),
("RestoredTransformAfter", ReloadModel.SAVED_MODEL, True),
("RestoredKeras", ReloadModel.KERAS, False),
Expand Down