Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How Do I Track a List of Objects for Export? #21078

Closed
rlcauvin opened this issue Mar 21, 2025 · 3 comments
Closed

How Do I Track a List of Objects for Export? #21078

rlcauvin opened this issue Mar 21, 2025 · 3 comments
Assignees
Labels
type:support User is asking for help / asking an implementation question. Stackoverflow would be better suited.

Comments

@rlcauvin
Copy link

rlcauvin commented Mar 21, 2025

I define an EnsembleModel class that is constructed from a list of other Keras models.

class EnsembleModel(keras.Model):

  def __init__(
    self,
    models: Iterable[keras.Model],
    reduce_fn: Callable = keras.ops.mean,
    **kwargs):
    
    super(EnsembleModel, self).__init__(**kwargs)
    
    self.models = models
    # self.model0 = models[0]
    # self.model1 = models[1]
    self.reduce_fn = reduce_fn

  @tf.function(input_signature=[input_signature])
  def call(
    self,
    input: Dict[Text, Any]) -> Any:
        
    all_outputs = [keras.ops.reshape(model(input), newshape=(-1,)) for model in self.models]
    output = self.reduce_fn(all_outputs, axis=0)
    
    return output

averaging_model = EnsembleModel(models=[model0, model1])

I then wish to export the ensemble model:

averaging_model.export("export/1/", input_signature=[input_signature])

But I get an error on the export:

AssertionError: Tried to export a function which references an 'untracked' resource. TensorFlow objects (e.g. 
tf.Variable) captured by functions must be 'tracked' by assigning them to an attribute of a tracked object or 
assigned to an attribute of the main object directly. See the information below:
        Function name = b'__inference_signature_wrapper___call___10899653'
        Captured Tensor = <ResourceHandle(name="10671455", device="/job:localhost/replica:0/task:0/device:CPU:0", 
container="localhost", type="tensorflow::lookup::LookupInterface", dtype and shapes : "[  ]")>
        Trackable referencing this tensor = <tensorflow.python.ops.lookup_ops.StaticHashTable object at 
0x7fd62d126990>
        Internal Tensor = Tensor("10899255:0", shape=(), dtype=resource)

If I explicitly assign the models to variables in the constructor:

    self.model0 = models[0]
    self.model1 = models[1]

It works fine (even if I don't reference those variables anywhere else). But I want an instance of the EnsembleModel class to support an arbitrary list of models. How can I ensure the models are "tracked" so that I don't get an error on export?

@rlcauvin rlcauvin changed the title How Do I Track Objects for Export? How Do I Track a List of Objects for Export? Mar 21, 2025
@rlcauvin
Copy link
Author

Adding this kludge to the EnsembleModel constructor seems to have worked to "track" the models in the list and avoid export errors:

    # Register each model so that it is "tracked" for export.
    for i, model in enumerate(self.models):
      self.__setattr__(f"model_{i}", model)

Using __setattr__ dynamically assigns each model in the collection to an attribute of the EnsembleModel instance so that it is tracked.

Is there a better way?

@sonali-kumari1 sonali-kumari1 added the type:support User is asking for help / asking an implementation question. Stackoverflow would be better suited. label Mar 24, 2025
@rlcauvin
Copy link
Author

Closing this issue, as I am satisfied with the kludge mentioned in my last comment. But I welcome any comments on the kludge or any alternate suggestions.

Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type:support User is asking for help / asking an implementation question. Stackoverflow would be better suited.
Projects
None yet
Development

No branches or pull requests

3 participants