2929from tensorflow .python .ops import math_ops
3030from tensorflow .python .training import session_run_hook
3131
32- from hybridbackend .tensorflow .data .detect_end . dataset import DetectEndDataset
33- from hybridbackend .tensorflow .data .prefetch . ops import Iterator
32+ from hybridbackend .tensorflow .data .prefetch . iterator import Iterator
33+ from hybridbackend .tensorflow .data .sync . dataset import _SyncReplicasDataset
3434from hybridbackend .tensorflow .distribute .collective import Collective
3535from hybridbackend .tensorflow .distribute .ops import CollectiveOps
3636from hybridbackend .tensorflow .framework .context import Context
@@ -82,14 +82,19 @@ def __init__(self):
8282 def wraps_make_iterator (self , fn ):
8383 r'''Wraps make_*_iterator.
8484 '''
85- def wrapped_make_iterator (ds ):
85+ def wrapped_make_iterator (ds , * args , ** kwargs ):
86+ if isinstance (ds , _SyncReplicasDataset ):
87+ return fn (ds , * args , ** kwargs )
88+ if isinstance (ds , dataset_ops .DatasetV1Adapter ):
89+ if isinstance (ds ._dataset , _SyncReplicasDataset ): # pylint: disable=protected-access
90+ return fn (ds , * args , ** kwargs )
8691 with ops .device ('/cpu:0' ):
8792 options = Context .get ().options
8893 if options .mode == ModeKeys .TRAIN :
89- return fn (DetectEndDataset (ds ))
94+ return fn (_SyncReplicasDataset (ds ), * args , ** kwargs )
9095 if options .mode == ModeKeys .EVAL :
91- return fn (ds .repeat ())
92- return fn (ds )
96+ return fn (ds .repeat (), * args , ** kwargs )
97+ return fn (ds , * args , ** kwargs )
9398 return wrapped_make_iterator
9499
95100 def wraps_iterator (self , cls ):
@@ -128,13 +133,22 @@ def begin(self):
128133 dataset_ops .make_one_shot_iterator = self .wraps_make_iterator (
129134 self ._prev_make_one_shot_iterator )
130135 tf .data .make_one_shot_iterator = dataset_ops .make_one_shot_iterator
131- self ._prev_make_initializable_iterator = \
132- dataset_ops .make_initializable_iterator
136+ self ._prev_make_one_shot_iterator_method = (
137+ dataset_ops .DatasetV1 ._make_one_shot_iterator ) # pylint: disable=protected-access
138+ dataset_ops .DatasetV1 ._make_one_shot_iterator = self .wraps_make_iterator ( # pylint: disable=protected-access
139+ self ._prev_make_one_shot_iterator_method )
140+ self ._prev_make_initializable_iterator = (
141+ dataset_ops .make_initializable_iterator )
133142 dataset_ops .make_initializable_iterator = self .wraps_make_iterator (
134143 self ._prev_make_initializable_iterator )
135144 self ._prev_keras_get_iterator = training_utils .get_iterator
136- tf .data .make_initializable_iterator = \
137- dataset_ops .make_initializable_iterator
145+ tf .data .make_initializable_iterator = (
146+ dataset_ops .make_initializable_iterator )
147+ self ._prev_make_initializable_iterator_method = (
148+ dataset_ops .DatasetV1 ._make_initializable_iterator ) # pylint: disable=protected-access
149+ dataset_ops .DatasetV1 ._make_initializable_iterator = ( # pylint: disable=protected-access
150+ self .wraps_make_iterator (
151+ self ._prev_make_initializable_iterator_method ))
138152 training_utils .get_iterator = self .wraps_make_iterator (
139153 self ._prev_keras_get_iterator )
140154 self ._prev_iterator = iterator_ops .Iterator
@@ -147,10 +161,14 @@ def end(self):
147161 iterator_ops .Iterator = self ._prev_iterator
148162 dataset_ops .make_one_shot_iterator = self ._prev_make_one_shot_iterator
149163 tf .data .make_one_shot_iterator = dataset_ops .make_one_shot_iterator
150- dataset_ops .make_initializable_iterator = \
151- self ._prev_make_initializable_iterator
152- tf .data .make_initializable_iterator = \
153- dataset_ops .make_initializable_iterator
164+ dataset_ops .DatasetV1 ._make_one_shot_iterator = ( # pylint: disable=protected-access
165+ self ._prev_make_one_shot_iterator_method )
166+ dataset_ops .make_initializable_iterator = (
167+ self ._prev_make_initializable_iterator )
168+ tf .data .make_initializable_iterator = (
169+ dataset_ops .make_initializable_iterator )
170+ dataset_ops .DatasetV1 ._make_initializable_iterator = ( # pylint: disable=protected-access
171+ self ._prev_make_initializable_iterator_method )
154172 training_utils .get_iterator = self ._prev_keras_get_iterator
155173
156174
@@ -187,7 +205,7 @@ def begin(self):
187205 else :
188206 with ops .device (should_stop_all_ops [0 ].device ):
189207 self ._should_stop_all = math_ops .reduce_max (
190- array_ops .concat (should_stop_all_ops , 0 ), 0 )
208+ array_ops .stack (should_stop_all_ops ), 0 )
191209
192210 def before_run (self , run_context ): # pylint: disable=unused-argument
193211 r'''Call this before sess run.
0 commit comments