Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disparate Performance between Python and Java #602

Open
ryanhausen opened this issue Feb 11, 2025 · 6 comments
Open

Disparate Performance between Python and Java #602

ryanhausen opened this issue Feb 11, 2025 · 6 comments

Comments

@ryanhausen
Copy link

Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04 x86_64): Ubuntu 24.04 x86_64
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): 1.0.0
  • Java version (i.e., the output of java -version): openjdk version "21.0.6" 2025-01-21
  • Java command line flags (e.g., GC parameters):
  • Python version (if transferring a model trained in Python): 3.12.8
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version: 12.8.61/8905
  • GPU model and memory: V100 (32GB)

Describe the current behavior

Executing the exported model using Tensorflow in Python takes significantly less time than when calling the same function from using Tensorflow Java. I suspect that I am just not using the Java API correctly, because a small change to the python can lead to comparably poor performance in the python.

Describe the expected behavior

The function calls should take a comparable amount of time.

Code to reproduce the issue

I have the following python function:

@tf.function(
    input_signature=[
            tf.TensorSpec(shape=[41, 2048, 2048], dtype=tf.float32, name="data"),  # [k, n, n]
            tf.TensorSpec(shape=[1, 2048, 2048], dtype=tf.float32, name="image"),  # [1, n, n]
            tf.TensorSpec(shape=[41, 2048, 2048], dtype=tf.float32, name="psf"),  # [k, n, n]
    ],
    jit_compile=True
)
def rl_step(
    data: tf.Tensor,  # [k, n, n]
    image: tf.Tensor, # [1, n, n]
    psf: tf.Tensor,   # [k, n, n]
) -> tf.Tensor: # [k, n, n]
    psf_fft = tf.signal.rfft2d(psf)
    psft_fft = tf.signal.rfft2d(tf.reverse(psf, axis=(-2, -1)))
    denom = tf.reduce_sum(
        tf.signal.irfft2d(psf_fft * tf.signal.rfft2d(data)),
        axis=0,
        keepdims=True
    )
    img_err = image / denom
    return data * tf.signal.irfft2d(tf.signal.rfft2d(img_err) * psft_fft)

In python, this function is applied iteratively over the same tensor as below:

    image_tensor = tf.constant(image) # [k, n, n]
    measured_psf_tensor = tf.constant(measured_psf) # [1, n, n]
    data_tensor = tf.constant(data) # [k, n, n]

    for i in range(10):
        start = time()
        data = rl_step(data_tensor, image_tensor, measured_psf_tensor)
        print(f"Iter {i}:", time() - start, "seconds.")

Here image, measured_psf, and data are all 3D arrays with dtype=float32 and n=2048 and k=41

This prints timings around the following:

Iter 0: 0.2061774730682373 seconds.
Iter 1: 0.004193544387817383 seconds.
Iter 2: 0.0007469654083251953 seconds.
Iter 3: 0.000415802001953125 seconds.
Iter 4: 0.0004220008850097656 seconds.
Iter 5: 0.0004246234893798828 seconds.
Iter 6: 0.0004112720489501953 seconds.
Iter 7: 0.00042128562927246094 seconds.
Iter 8: 0.0004055500030517578 seconds.
Iter 9: 0.00040721893310546875 seconds.

I tried exporting the model by adding the following after the timing code:

    mod = tf.Module()
    mod.f = rl_step
    tf.saved_model.save(mod, "pure_tf_export")

Now I tried to use this exported mode from the Java API,

        String modelLocation = "./pure_tf_export";
        try(Graph g = new Graph(); Session s = new Session(g)){
            SavedModelBundle model = SavedModelBundle.loader(modelLocation).load();

            try (Tensor imageTensor = TFloat32.tensorOf(image);
                Tensor psfTensor = TFloat32.tensorOf(psf);
                Tensor dataTensor = TFloat32.tensorOf(data)
            ){
                Map<String, Tensor> inputs = new HashMap<String, Tensor>();
                inputs.put("data", dataTensor);
                inputs.put("image", imageTensor);
                inputs.put("psf", psfTensor);

                for (int i = 0; i < 10; i++){

                    Instant start = Instant.now();

                    Result result = model.function("serving_default").call(inputs);
                    inputs.replace("data", result.get("output_0").get());

                    System.out.println("Iter " + i + " " + (Duration.between(start, Instant.now()).toMillis()/1000f) + " seconds");
                }
            }
        }

And I get timings as follows:

Iter 0 0.701 seconds
Iter 1 0.528 seconds
Iter 2 0.874 seconds
Iter 3 0.224 seconds
Iter 4 0.254 seconds
Iter 5 1.622 seconds
Iter 6 0.241 seconds
Iter 7 0.224 seconds
Iter 8 0.231 seconds
Iter 9 0.228 seconds

I am pretty sure I am making a simple mistake somewhere. I suspect it is in how I am instantiating the Tensors. I know in python if you don't use tf.constant the timings go up a lot.

Any help would be very much appreciated. I tried looking through the documentation and the tensorflow java-examples repository, but couldn't spot what I am doing wrong.

Thanks again!

@Craigacp
Copy link
Collaborator

The Python snippet isn't applied iteratively, data is not fed back in to data_tensor. Is this a typo? The SavedModelBundle has it's own session & graph so you shouldn't need to make those. The Java code is also leaking old data tensors as they aren't closed before they are being replaced, but that shouldn't be a speed issue.

Other than the above points, there's also that Java won't warm up in 10 iterations, it'll be more like a few hundred, but you're likely measuring something problematic in the way we're using TF (assuming it's not the Python typo). TF function's JIT compilation might not produce the same code as a saved model, it's hard to say what goes on inside that JIT.

@ryanhausen
Copy link
Author

The Python snippet isn't applied iteratively, data is not fed back in to data_tensor. Is this a typo?

Yep, sorry typo on my part I was cleaning up my code before putting it up here and messed that up.

The Java code is also leaking old data tensors as they aren't closed before they are being replaced, but that shouldn't be a speed issue.

That's good to know thanks. I am still trying to figure out how to properly set the code up.

Other than the above points, there's also that Java won't warm up in 10 iterations, it'll be more like a few hundred, but you're likely measuring something problematic in the way we're using TF (assuming it's not the Python typo).

The "warm-up" time of a few hundred iterations is interesting. I'll try running the code for longer just to see what it does wrt to the timing. Do you know why there is a longer warm-up time for the java api than in the python?

TF function's JIT compilation might not produce the same code as a saved model, it's hard to say what goes on inside that JIT.

This is true, I am pretty sure I got the same timings for python, when restoring the saved model in python as well, but I'll confirm.


For use case, latency is important. Would you suspect that I could get better times (comparable to the python times) with a direct implementation in the Java API?

I really appreciate you taking the time to look at my code and your feedback is already helpful!

@Craigacp
Copy link
Collaborator

Other than the above points, there's also that Java won't warm up in 10 iterations, it'll be more like a few hundred, but you're likely measuring something problematic in the way we're using TF (assuming it's not the Python typo).

The "warm-up" time of a few hundred iterations is interesting. I'll try running the code for longer just to see what it does wrt to the timing. Do you know why there is a longer warm-up time for the java api than in the python?

That's not a TF thing, it's that Java's interpreter is not as fast as the C2 generated code which will kick in after a few hundred iterations. The Java interpreter is probably comparable if not faster than Python's interpreter, but it's not performance optimized in the same way and I've not seen any comparisons.

TF function's JIT compilation might not produce the same code as a saved model, it's hard to say what goes on inside that JIT.

This is true, I am pretty sure I got the same timings for python, when restoring the saved model in python as well, but I'll confirm.

For use case, latency is important. Would you suspect that I could get better times (comparable to the python times) with a direct implementation in the Java API?

I really appreciate you taking the time to look at my code and your feedback is already helpful!

We expose some of the XLA ops which are generated by the JIT compiled TF code, but I'm not sure how to get the best performance out of the system with respect to that. You could try implementing the model itself in TF-Java, I think we should expose those operations, but I'm not sure if that would be faster.

Other things to consider are placement of operations on GPU vs CPU. Are you using our GPU binaries or the CPU ones?

@ryanhausen
Copy link
Author

Other than the above points, there's also that Java won't warm up in 10 iterations, it'll be more like a few hundred, but you're likely measuring something problematic in the way we're using TF (assuming it's not the Python typo).

The "warm-up" time of a few hundred iterations is interesting. I'll try running the code for longer just to see what it does wrt to the timing. Do you know why there is a longer warm-up time for the java api than in the python?

That's not a TF thing, it's that Java's interpreter is not as fast as the C2 generated code which will kick in after a few hundred iterations. The Java interpreter is probably comparable if not faster than Python's interpreter, but it's not performance optimized in the same way and I've not seen any comparisons.

I see, I'll play around with the compilation settings and rerun the timings. Thanks!

TF function's JIT compilation might not produce the same code as a saved model, it's hard to say what goes on inside that JIT.

This is true, I am pretty sure I got the same timings for python, when restoring the saved model in python as well, but I'll confirm.
For use case, latency is important. Would you suspect that I could get better times (comparable to the python times) with a direct implementation in the Java API?
I really appreciate you taking the time to look at my code and your feedback is already helpful!

We expose some of the XLA ops which are generated by the JIT compiled TF code, but I'm not sure how to get the best performance out of the system with respect to that. You could try implementing the model itself in TF-Java, I think we should expose those operations, but I'm not sure if that would be faster.

Ok, thanks I'll look into prototyping a pure Java API based implementation.

Other things to consider are placement of operations on GPU vs CPU. Are you using our GPU binaries or the CPU ones?

The GPU makes a really big difference for this calculation. I am pretty sure I am using the GPU binaries, I grabbed the following from maven and logged device placement and the output seemed to indicate that things were being put on the GPU. I did a run on the CPU and it was slower, but I don't remember by how much.

<!-- https://mvnrepository.com/artifact/org.tensorflow/tensorflow-core-api -->
<dependency>
    <groupId>org.tensorflow</groupId>
    <artifactId>tensorflow-core-api</artifactId>
    <version>1.0.0</version>
</dependency>
 <dependency>
    <groupId>org.tensorflow</groupId>
    <artifactId>tensorflow-core-native</artifactId>
    <version>1.0.0</version>
    <classifier>linux-x86_64-gpu</classifier>
</dependency>

Thanks again for your help!

@Craigacp
Copy link
Collaborator

Ok. Did you check the saved model speed in Python, if that's slower than the tf.function and similar to the TF-Java speed then you'd have to look into turning on XLA in TF-Java and I'm not sure how to do that.

@ryanhausen
Copy link
Author

ryanhausen commented Feb 13, 2025

I did check the speed of running the saved model in python and it is as quick as running it before saving it, ~0.004 seconds after the first few iterations.

So, perhaps I can fix it by forcing the C2 compiler to be used, or write the function in such a way that it requires fewer calls from the java like

@tf.function
def wrapper(data, img, psf):
    data = rl_step(data, img, psf)
    data = rl_step(data, img, psf)
# etc.

and see if that helps. Or use a pure tensorflow java implementation and see if that helps.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants