Skip to content
This repository was archived by the owner on Nov 7, 2024. It is now read-only.
This repository was archived by the owner on Nov 7, 2024. It is now read-only.

SVD on jax backend and thus split_node cannot be jitted when max_truncation_err is set #953

@refraction-ray

Description

@refraction-ray

SVD and split_node are ok on tensorflow backend with tensorflow jit:

import tensorflow as tf
tn.set_default_backend("tensorflow")
@tf.function
def f(b):
    a = tn.Node(b)
    n1, n2, _ = tn.split_node(a, left_edges=a[:2], right_edges=a[2:], max_truncation_err=0.5)
    return n1.tensor
f(tf.ones([2,2,2,2]))

But it fails on jax backend as:

import jax
from jax import numpy as jnp
tn.set_default_backend("jax")
@jax.jit
def f(b):
    a = tn.Node(b)
    n1, n2, _ = tn.split_node(a, left_edges=a[:2], right_edges=a[2:], max_truncation_err=0.5)
    return n1.tensor
f(jnp.ones([2,2,2,2]))

The error is raised from svd operation in backends/numpy/decompositions.py: num_sing_vals_keep = min(max_singular_values, num_sing_vals_err) as ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected:.

This error is actually as expected even before I tried this, since jax jitted function only accepts and returns tensors with fixed shape, which supports only a subset of functionalities of tf.function. Since split_node with max_truncation_err returns nodes of varying shape (final shape depends on the singular value), it seems to be incompatible with jax jit mechanism.

Any thoughts or workaround on this? As I believe it is very common to apply split_node with max_singular_values in tensornetwork related algorithms and it would be great such algorithms can be jitted.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions