-
Notifications
You must be signed in to change notification settings - Fork 2.1k
migrate tweet classification example to use keras 3 #2211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
migrate tweet classification example to use keras 3 #2211
Conversation
Summary of ChangesHello @Shi-pra-19, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request refactors the tweet classification example to ensure full compatibility with Keras 3. The primary goal is to modernize the codebase by updating Keras imports, encapsulating the Universal Sentence Encoder within a custom Keras layer, and transitioning to the new Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request successfully migrates the tweet classification example to Keras 3 by updating imports, using keras.ops, and wrapping the tensorflow_hub layer in a custom Keras layer. The changes are well-aligned with the goals of Keras 3. I've provided a few suggestions to further improve the implementation of the custom layer and to align the prediction logic with Keras best practices. These changes will enhance the code's maintainability and robustness.
|
|
||
| class SentenceEncoderLayer(keras.layers.Layer): | ||
| def __init__(self, **kwargs): | ||
| super(SentenceEncoderLayer, self).__init__(**kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def call(self, inputs): | ||
| return self.encoder(inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a good practice for custom layers to accept and forward the training argument in the call method. This ensures the layer behaves correctly in both training and inference modes, even if the underlying hub.KerasLayer doesn't change its behavior in this case. It makes the layer more robust and aligned with Keras API conventions.
| def call(self, inputs): | |
| return self.encoder(inputs) | |
| def call(self, inputs, training=False): | |
| return self.encoder(inputs, training=training) |
| preds = model_1.predict_step(text) | ||
| preds = tf.squeeze(tf.round(preds)) | ||
| preds = keras.ops.squeeze(keras.ops.round(preds)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While predict_step works, it's not part of the public Keras API and is intended for internal use. The recommended public API for getting predictions is model.predict(). It's more idiomatic and returns NumPy arrays, which simplifies the subsequent processing. Using predict() would make the code more robust to future Keras updates.
| preds = model_1.predict_step(text) | |
| preds = tf.squeeze(tf.round(preds)) | |
| preds = keras.ops.squeeze(keras.ops.round(preds)) | |
| preds = model_1.predict(text) | |
| preds = np.squeeze(np.round(preds)) |
Refactor the example to use keras 3 and create custom layer to wrap Sentence Encoder.