-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Fix and update create_tfrecords #685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: keras3-dev
Are you sure you want to change the base?
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @takumiohym, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a new, robust solution for preparing image datasets for machine learning by converting them into TensorFlow TFRecord format. The core of this solution is an Apache Beam pipeline that reads image URIs and labels from a CSV file, processes the images, and outputs sharded TFRecord files for both training and validation. This enhancement streamlines the data preparation phase for image-based models, ensuring efficient data handling and compatibility with Google Cloud's data processing services.
Highlights
- New TFRecord Creation Script: A new Python script,
create_tfrecords.py
, has been added. This script utilizes Apache Beam and TensorFlow to convert image datasets from a CSV input file into sharded TFRecord files, suitable for machine learning training. It includes functionality to split the data into training and validation sets. - TensorFlow Dependency for New Script: A dedicated
requirements.txt
file has been introduced within the new solution's directory, specifyingtensorflow==2.18.1
as a direct dependency for the TFRecord creation process. - Apache Beam GCP Support: The main
requirements.txt
file has been updated to include thegcp
extra for theapache-beam
dependency. This change ensures that Apache Beam can leverage Google Cloud Platform-specific features, particularly for running pipelines on Dataflow.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point by creating a comment using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands on the current page.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in pull request comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new Apache Beam pipeline script for creating TFRecords from image data, along with dependency updates. The script is a great starting point. My review provides feedback to enhance its robustness, reusability, and ensure the data processing is deterministic for better reproducibility. The main suggestions focus on avoiding hardcoded values, implementing deterministic data splitting, and adding error handling for data parsing.
def partition_fn(example, num_partitions, train_percent): | ||
if random.random() < train_percent: | ||
return 0 | ||
return 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of random.random()
for partitioning data makes the train/validation split non-deterministic. This is detrimental to experiment reproducibility, as each run will produce a different dataset split.
A better approach is to use a deterministic partitioning function based on a stable property of the data, like image_uri
. This requires changing the pipeline to partition the data before creating tf.train.Example
objects, as the image_uri
is not available after that step.
Here's a suggested implementation:
- First, add
import hashlib
to the top of your file. - Then, replace the current
partition_fn
(lines 58-61) with this deterministic version:def partition_fn(element, num_partitions, train_percent): """Deterministically partitions elements based on their image_uri.""" image_uri_hash = abs(int(hashlib.md5(element.image_uri.encode('utf-8')).hexdigest(), 16)) if (image_uri_hash % 100) < (train_percent * 100): return 0 # Training set return 1 # Validation set
- Finally, restructure the pipeline in the
run
function by replacing lines 136-141 with the following to apply the partition before creating the TFExamples:train_rows, val_rows = ( rows | "Split Data" >> beam.Partition(partition_fn, 2, train_percent=TRAIN_PERCENT) ) train = train_rows | "Create Train TF Examples" >> beam.ParDo(CreateTFExample()) val = val_rows | "Create Val TF Examples" >> beam.ParDo(CreateTFExample())
def process(self, element): | ||
image_uri, label = element.split(",") | ||
yield CSVRow(image_uri=image_uri, label=label) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ParseCsv.process
method assumes that every line in the input file will contain a comma and can be split into exactly two parts. If a line is malformed (e.g., empty, or without a comma), element.split(",")
will raise a ValueError
and cause the worker to fail. It's safer to handle potential errors, for example with a try...except ValueError
block to skip malformed lines.
Example:
def process(self, element):
try:
image_uri, label = element.split(",")
yield CSVRow(image_uri=image_uri, label=label)
except ValueError:
# Log or count malformed rows if needed
pass
class CreateTFExample(beam.DoFn): | ||
def process(self, element): | ||
CLASSES = ["daisy", "dandelion", "roses", "sunflowers", "tulips"] | ||
img = tf.io.decode_jpeg(tf.io.read_file(element.image_uri)) | ||
|
||
feature = { | ||
"image": _image_feature(img), | ||
"label": _int64_feature(CLASSES.index(element.label)), | ||
} | ||
|
||
yield tf.train.Example(features=tf.train.Features(feature=feature)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The CLASSES
list is hardcoded inside the process
method. This is inefficient as it's recreated for every element, and it makes the CreateTFExample
DoFn less reusable for datasets with different classes.
It would be better to pass the classes list into the DoFn
's constructor. You would then instantiate it in your pipeline like this: beam.ParDo(CreateTFExample(CLASSES))
, where CLASSES
is defined at a higher scope. This also makes it easier to parameterize the classes via command-line arguments in the future.
class CreateTFExample(beam.DoFn): | |
def process(self, element): | |
CLASSES = ["daisy", "dandelion", "roses", "sunflowers", "tulips"] | |
img = tf.io.decode_jpeg(tf.io.read_file(element.image_uri)) | |
feature = { | |
"image": _image_feature(img), | |
"label": _int64_feature(CLASSES.index(element.label)), | |
} | |
yield tf.train.Example(features=tf.train.Features(feature=feature)) | |
class CreateTFExample(beam.DoFn): | |
def __init__(self, classes): | |
self._classes = classes | |
def process(self, element): | |
img = tf.io.decode_jpeg(tf.io.read_file(element.image_uri)) | |
feature = { | |
"image": _image_feature(img), | |
"label": _int64_feature(self._classes.index(element.label)), | |
} | |
yield tf.train.Example(features=tf.train.Features(feature=feature)) |
|
||
feature = { | ||
"image": _image_feature(img), | ||
"label": _int64_feature(CLASSES.index(element.label)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CLASSES.index(element.label)
will raise a ValueError
if element.label
is not found in the CLASSES
list, which will cause the worker to fail on that element. It would be more robust to wrap this in a try...except ValueError
block and skip elements with unknown labels, possibly with logging.
Example:
try:
label_index = self._classes.index(element.label)
except ValueError:
# Optionally log the unknown label and skip the element
return
feature = {
"image": _image_feature(img),
"label": _int64_feature(label_index),
}
| "Write Train" | ||
>> beam.io.tfrecordio.WriteToTFRecord( | ||
f"{OUTPUT_DIR}/train.tfrecord", num_shards=10 | ||
) | ||
) | ||
write_val = ( | ||
val | ||
| "Serialize Validation Examples" | ||
>> beam.Map(lambda x: x.SerializeToString()) | ||
| "Write Validation" | ||
>> beam.io.tfrecordio.WriteToTFRecord( | ||
f"{OUTPUT_DIR}/eval.tfrecord", num_shards=3 | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed and updated create_tfrecords notebook.