-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Disparate Performance between Python and Java #602
Comments
The Python snippet isn't applied iteratively, Other than the above points, there's also that Java won't warm up in 10 iterations, it'll be more like a few hundred, but you're likely measuring something problematic in the way we're using TF (assuming it's not the Python typo). TF function's JIT compilation might not produce the same code as a saved model, it's hard to say what goes on inside that JIT. |
Yep, sorry typo on my part I was cleaning up my code before putting it up here and messed that up.
That's good to know thanks. I am still trying to figure out how to properly set the code up.
The "warm-up" time of a few hundred iterations is interesting. I'll try running the code for longer just to see what it does wrt to the timing. Do you know why there is a longer warm-up time for the java api than in the python?
This is true, I am pretty sure I got the same timings for python, when restoring the saved model in python as well, but I'll confirm. For use case, latency is important. Would you suspect that I could get better times (comparable to the python times) with a direct implementation in the Java API? I really appreciate you taking the time to look at my code and your feedback is already helpful! |
That's not a TF thing, it's that Java's interpreter is not as fast as the C2 generated code which will kick in after a few hundred iterations. The Java interpreter is probably comparable if not faster than Python's interpreter, but it's not performance optimized in the same way and I've not seen any comparisons.
We expose some of the XLA ops which are generated by the JIT compiled TF code, but I'm not sure how to get the best performance out of the system with respect to that. You could try implementing the model itself in TF-Java, I think we should expose those operations, but I'm not sure if that would be faster. Other things to consider are placement of operations on GPU vs CPU. Are you using our GPU binaries or the CPU ones? |
I see, I'll play around with the compilation settings and rerun the timings. Thanks!
Ok, thanks I'll look into prototyping a pure Java API based implementation.
The GPU makes a really big difference for this calculation. I am pretty sure I am using the GPU binaries, I grabbed the following from maven and logged device placement and the output seemed to indicate that things were being put on the GPU. I did a run on the CPU and it was slower, but I don't remember by how much. <!-- https://mvnrepository.com/artifact/org.tensorflow/tensorflow-core-api -->
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-api</artifactId>
<version>1.0.0</version>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-native</artifactId>
<version>1.0.0</version>
<classifier>linux-x86_64-gpu</classifier>
</dependency> Thanks again for your help! |
Ok. Did you check the saved model speed in Python, if that's slower than the tf.function and similar to the TF-Java speed then you'd have to look into turning on XLA in TF-Java and I'm not sure how to do that. |
I did check the speed of running the saved model in python and it is as quick as running it before saving it, ~0.004 seconds after the first few iterations. So, perhaps I can fix it by forcing the C2 compiler to be used, or write the function in such a way that it requires fewer calls from the java like @tf.function
def wrapper(data, img, psf):
data = rl_step(data, img, psf)
data = rl_step(data, img, psf)
# etc. and see if that helps. Or use a pure tensorflow java implementation and see if that helps. |
Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template
System information
java -version
): openjdk version "21.0.6" 2025-01-21Describe the current behavior
Executing the exported model using Tensorflow in Python takes significantly less time than when calling the same function from using Tensorflow Java. I suspect that I am just not using the Java API correctly, because a small change to the python can lead to comparably poor performance in the python.
Describe the expected behavior
The function calls should take a comparable amount of time.
Code to reproduce the issue
I have the following python function:
In python, this function is applied iteratively over the same tensor as below:
Here
image
,measured_psf
, anddata
are all 3D arrays with dtype=float32 andn=2048
andk=41
This prints timings around the following:
I tried exporting the model by adding the following after the timing code:
Now I tried to use this exported mode from the Java API,
And I get timings as follows:
I am pretty sure I am making a simple mistake somewhere. I suspect it is in how I am instantiating the Tensors. I know in python if you don't use
tf.constant
the timings go up a lot.Any help would be very much appreciated. I tried looking through the documentation and the tensorflow java-examples repository, but couldn't spot what I am doing wrong.
Thanks again!
The text was updated successfully, but these errors were encountered: