Skip to content

Commit d8d1514

Browse files
authored
[DATA] Refines data transfer prefetching and synchronization
1 parent d70c104 commit d8d1514

File tree

15 files changed

+1223
-487
lines changed

15 files changed

+1223
-487
lines changed

build/etc/hbash

Lines changed: 0 additions & 19 deletions
This file was deleted.

hybridbackend/tensorflow/data/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@
2626
from hybridbackend.tensorflow.data.dataframe import unbatch_and_to_sparse
2727
from hybridbackend.tensorflow.data.parquet.dataset import ParquetDataset
2828
from hybridbackend.tensorflow.data.parquet.dataset import read_parquet
29-
from hybridbackend.tensorflow.data.prefetch.ops import Iterator
29+
from hybridbackend.tensorflow.data.prefetch.iterator import Iterator
3030
from hybridbackend.tensorflow.data.rebatch.dataset import RebatchDataset
3131
from hybridbackend.tensorflow.data.rebatch.dataset import rebatch
32+
from hybridbackend.tensorflow.data.sync.dataset import SyncReplicasDataset
3233
from hybridbackend.tensorflow.data.tabular.dataset import Dataset
3334

3435
# HybridBackend operators must be loaded before TensorFlow operators to

hybridbackend/tensorflow/data/iterators.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
from tensorflow.python.ops import math_ops
3030
from tensorflow.python.training import session_run_hook
3131

32-
from hybridbackend.tensorflow.data.detect_end.dataset import DetectEndDataset
33-
from hybridbackend.tensorflow.data.prefetch.ops import Iterator
32+
from hybridbackend.tensorflow.data.prefetch.iterator import Iterator
33+
from hybridbackend.tensorflow.data.sync.dataset import _SyncReplicasDataset
3434
from hybridbackend.tensorflow.distribute.collective import Collective
3535
from hybridbackend.tensorflow.distribute.ops import CollectiveOps
3636
from hybridbackend.tensorflow.framework.context import Context
@@ -82,14 +82,19 @@ def __init__(self):
8282
def wraps_make_iterator(self, fn):
8383
r'''Wraps make_*_iterator.
8484
'''
85-
def wrapped_make_iterator(ds):
85+
def wrapped_make_iterator(ds, *args, **kwargs):
86+
if isinstance(ds, _SyncReplicasDataset):
87+
return fn(ds, *args, **kwargs)
88+
if isinstance(ds, dataset_ops.DatasetV1Adapter):
89+
if isinstance(ds._dataset, _SyncReplicasDataset): # pylint: disable=protected-access
90+
return fn(ds, *args, **kwargs)
8691
with ops.device('/cpu:0'):
8792
options = Context.get().options
8893
if options.mode == ModeKeys.TRAIN:
89-
return fn(DetectEndDataset(ds))
94+
return fn(_SyncReplicasDataset(ds), *args, **kwargs)
9095
if options.mode == ModeKeys.EVAL:
91-
return fn(ds.repeat())
92-
return fn(ds)
96+
return fn(ds.repeat(), *args, **kwargs)
97+
return fn(ds, *args, **kwargs)
9398
return wrapped_make_iterator
9499

95100
def wraps_iterator(self, cls):
@@ -128,13 +133,22 @@ def begin(self):
128133
dataset_ops.make_one_shot_iterator = self.wraps_make_iterator(
129134
self._prev_make_one_shot_iterator)
130135
tf.data.make_one_shot_iterator = dataset_ops.make_one_shot_iterator
131-
self._prev_make_initializable_iterator = \
132-
dataset_ops.make_initializable_iterator
136+
self._prev_make_one_shot_iterator_method = (
137+
dataset_ops.DatasetV1._make_one_shot_iterator) # pylint: disable=protected-access
138+
dataset_ops.DatasetV1._make_one_shot_iterator = self.wraps_make_iterator( # pylint: disable=protected-access
139+
self._prev_make_one_shot_iterator_method)
140+
self._prev_make_initializable_iterator = (
141+
dataset_ops.make_initializable_iterator)
133142
dataset_ops.make_initializable_iterator = self.wraps_make_iterator(
134143
self._prev_make_initializable_iterator)
135144
self._prev_keras_get_iterator = training_utils.get_iterator
136-
tf.data.make_initializable_iterator = \
137-
dataset_ops.make_initializable_iterator
145+
tf.data.make_initializable_iterator = (
146+
dataset_ops.make_initializable_iterator)
147+
self._prev_make_initializable_iterator_method = (
148+
dataset_ops.DatasetV1._make_initializable_iterator) # pylint: disable=protected-access
149+
dataset_ops.DatasetV1._make_initializable_iterator = ( # pylint: disable=protected-access
150+
self.wraps_make_iterator(
151+
self._prev_make_initializable_iterator_method))
138152
training_utils.get_iterator = self.wraps_make_iterator(
139153
self._prev_keras_get_iterator)
140154
self._prev_iterator = iterator_ops.Iterator
@@ -147,10 +161,14 @@ def end(self):
147161
iterator_ops.Iterator = self._prev_iterator
148162
dataset_ops.make_one_shot_iterator = self._prev_make_one_shot_iterator
149163
tf.data.make_one_shot_iterator = dataset_ops.make_one_shot_iterator
150-
dataset_ops.make_initializable_iterator = \
151-
self._prev_make_initializable_iterator
152-
tf.data.make_initializable_iterator = \
153-
dataset_ops.make_initializable_iterator
164+
dataset_ops.DatasetV1._make_one_shot_iterator = ( # pylint: disable=protected-access
165+
self._prev_make_one_shot_iterator_method)
166+
dataset_ops.make_initializable_iterator = (
167+
self._prev_make_initializable_iterator)
168+
tf.data.make_initializable_iterator = (
169+
dataset_ops.make_initializable_iterator)
170+
dataset_ops.DatasetV1._make_initializable_iterator = ( # pylint: disable=protected-access
171+
self._prev_make_initializable_iterator_method)
154172
training_utils.get_iterator = self._prev_keras_get_iterator
155173

156174

@@ -187,7 +205,7 @@ def begin(self):
187205
else:
188206
with ops.device(should_stop_all_ops[0].device):
189207
self._should_stop_all = math_ops.reduce_max(
190-
array_ops.concat(should_stop_all_ops, 0), 0)
208+
array_ops.stack(should_stop_all_ops), 0)
191209

192210
def before_run(self, run_context): # pylint: disable=unused-argument
193211
r'''Call this before sess run.

0 commit comments

Comments
 (0)