Skip to content

Softplus leaks memory (and is no longer needed) #2008

@nfergu

Description

@nfergu

The TensorFlow probability implementation of softplus leaks memory, and appears to no longer be needed. That is, I think the standard tf.nn.softplus implementation can be used now, as numerical stability issues appear to have been solved.

Currently the implementation of softplus is as follows (from here):

# TODO(b/155501444): Remove this when tf.nn.softplus is fixed.
if JAX_MODE:
  _stable_grad_softplus = tf.nn.softplus
else:

  @tf.custom_gradient
  def _stable_grad_softplus(x):
    """A (more) numerically stable softplus than `tf.nn.softplus`."""
    x = tf.convert_to_tensor(x)
    if x.dtype == tf.float64:
      cutoff = -20
    else:
      cutoff = -9

    y = tf.where(x < cutoff, tf.math.log1p(tf.exp(x)), tf.nn.softplus(x))

    def grad_fn(dy):
      return dy * tf.where(x < cutoff, tf.exp(x), tf.nn.sigmoid(x))

    return y, grad_fn

This leaks memory (in non-JAX mode) due to a couple of issues:

  • The grad_fn closure captures the tensor represented by x. This closure then ends up in the gradient registry, which is never cleared. So the tensor represented by x hangs around forever.
  • For a similar reason TensorFlow's custom_gradient implementation also leaks memory. See 97697 for more details.

Here is a Colab notebook to demonstrate the memory leak.

However, I believe that the numerical stability issues with tf.nn.softplus have been solved. Specifically:

  • The tf.nn.softplus implementation now uses log1p as of this commit on May 1 2020.
  • The gradient computation for tf.nn.softplus now uses math_ops.sigmoid as of this commit on April 4 2019.
  • The Eigen implementation of sigmoid (which I think is here) computes this as e^x / 1.0 + e^x, so using the approximation of e^x in _stable_grad_softplus seems unnecessary to me. If e^x is very small then 1.0 + e^x will be exactly 1.0, so this is equivalent to e^x. If e^x > 1.0 then the result of e^x / 1.0 + e^x will be (I think) more accurate than just approximating the gradient to e^x. But I am not a numerical stability expert, so I may be wrong.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions