Skip to content

Commit d70c104

Browse files
authored
[DIST] Support Non-intrusive embedding APIs
1 parent 9b3c00d commit d70c104

File tree

106 files changed

+6800
-10318
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

106 files changed

+6800
-10318
lines changed

Makefile

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ HYBRIDBACKEND_DEBUG ?= OFF
2323
HYBRIDBACKEND_WHEEL_ALIAS ?=
2424
HYBRIDBACKEND_WHEEL_BUILD ?=
2525
HYBRIDBACKEND_WHEEL_REQUIRES ?=
26-
HYBRIDBACKEND_WHEEL_DEBUG ?= OFF
2726
HYBRIDBACKEND_WHEEL_REPAIR ?= ON
2827
HYBRIDBACKEND_WHEEL_POSTCHECK ?= ON
2928
HYBRIDBACKEND_CHECK_INSTANCE ?= OFF
@@ -239,7 +238,6 @@ build: $(CORE_DEPS)
239238
WHEEL_ALIAS="$(HYBRIDBACKEND_WHEEL_ALIAS)" \
240239
WHEEL_BUILD="$(HYBRIDBACKEND_WHEEL_BUILD)" \
241240
WHEEL_REQUIRES="$(HYBRIDBACKEND_WHEEL_REQUIRES)" \
242-
WHEEL_DEBUG="$(HYBRIDBACKEND_WHEEL_DEBUG)" \
243241
$(PYTHON) setup.py bdist_wheel -d build/wheel
244242
@ls build/wheel/*.whl
245243

docs/tutorial/ranking/criteo/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,8 @@ def main(args):
162162
if args.eval_every_n_iter is not None:
163163
hooks.append(hb.train.EvaluationHook(
164164
lambda: model.evaluate(eval_filenames),
165-
every_n_iter=args.eval_every_n_iter))
165+
every_n_iter=args.eval_every_n_iter,
166+
summary_dir=args.output_dir))
166167
if args.log_every_n_iter is not None:
167168
hooks.append(
168169
tf.train.LoggingTensorHook(

docs/tutorial/ranking/taobao/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ def main(args):
144144
if args.eval_every_n_iter is not None:
145145
hooks.append(hb.train.EvaluationHook(
146146
lambda: model.evaluate(eval_filenames),
147-
every_n_iter=args.eval_every_n_iter))
147+
every_n_iter=args.eval_every_n_iter,
148+
summary_dir=args.output_dir))
148149
if args.log_every_n_iter is not None:
149150
hooks.append(
150151
tf.train.LoggingTensorHook(

hybridbackend/tensorflow/__init__.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020
from __future__ import division
2121
from __future__ import print_function
2222

23+
import contextlib as _ctxlib
24+
2325
from hybridbackend.libhybridbackend import buildinfo
24-
from hybridbackend.tensorflow.feature_column.dense_features import \
25-
dense_features
26+
from hybridbackend.tensorflow.framework.config import get_session_config
27+
from hybridbackend.tensorflow.framework.config import wraps_session_config
2628
from hybridbackend.tensorflow.framework.context import Context
2729
from hybridbackend.tensorflow.framework.context import context
2830
from hybridbackend.tensorflow.framework.rewriting import function
@@ -33,10 +35,18 @@
3335
from . import data
3436
from . import distribute
3537
from . import estimator
36-
from . import feature_column
3738
from . import keras
3839
from . import metrics
3940
from . import plugins
4041
from . import training as train
4142

4243
__version__ = buildinfo()
44+
45+
46+
@_ctxlib.contextmanager
47+
def embedding_scope(**kwargs):
48+
r'''Scope for defining embedding weights.
49+
'''
50+
kwargs.setdefault('sharding', True)
51+
with scope(**kwargs) as ctx:
52+
yield ctx
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# =============================================================================
15+
16+
r'''Collective benchmark.
17+
'''
18+
19+
from __future__ import absolute_import
20+
from __future__ import division
21+
from __future__ import print_function
22+
23+
import argparse
24+
import time
25+
26+
import numpy as np
27+
import tensorflow as tf
28+
29+
import hybridbackend.tensorflow as hb
30+
31+
32+
# pylint: disable=missing-docstring
33+
def allreduce(message_floats, message_partitions, message_device, topology):
34+
del topology
35+
step = tf.train.get_or_create_global_step()
36+
results = [step.assign_add(1)]
37+
for p in range(message_partitions):
38+
with tf.device(f'/{message_device}:0'):
39+
coll_input = tf.get_variable(
40+
f'input{p}',
41+
initializer=tf.random.normal(
42+
[int(message_floats / message_partitions)],
43+
mean=100,
44+
stddev=80))
45+
coll_output = hb.distribute.allreduce(coll_input)
46+
with tf.device(f'/{message_device}:0'):
47+
results.append(tf.identity(coll_output))
48+
return tf.group(results)
49+
50+
51+
def alltoall(message_floats, message_partitions, message_device):
52+
del topology
53+
step = tf.train.get_or_create_global_step()
54+
results = [step.assign_add(1)]
55+
for p in range(message_partitions):
56+
with tf.device(f'/{message_device}:0'):
57+
coll_input = tf.get_variable(
58+
f'input{p}',
59+
initializer=tf.random.normal(
60+
[int(message_floats / message_partitions)],
61+
mean=100,
62+
stddev=80))
63+
coll_output = hb.distribute.alltoall(coll_input)
64+
with tf.device(f'/{message_device}:0'):
65+
results.append(tf.identity(coll_output))
66+
return tf.group(results)
67+
68+
69+
def _uniform_sizes(total_size, active_size):
70+
uniform_distro = [
71+
1. / active_size for _ in range(active_size)]
72+
return np.random.multinomial(total_size, uniform_distro, size=1)[0].tolist()
73+
74+
75+
def alltoallv(
76+
message_floats, message_partitions, message_device, topology,
77+
random_sizes=True):
78+
active_size = hb.distribute.active_size(topology)
79+
step = tf.train.get_or_create_global_step()
80+
results = [step.assign_add(1)]
81+
for p in range(message_partitions):
82+
with tf.device(f'/{message_device}:0'):
83+
message_count = int(message_floats / message_partitions)
84+
coll_input = tf.get_variable(
85+
f'input{p}',
86+
initializer=tf.random.normal(
87+
[message_count],
88+
mean=100,
89+
stddev=80))
90+
if random_sizes:
91+
coll_input_sizes = tf.constant(
92+
_uniform_sizes(message_count, active_size),
93+
dtype=tf.int32)
94+
else:
95+
message_divided = message_count // active_size
96+
coll_input_sizes = tf.constant(
97+
[message_divided for _ in range(active_size)],
98+
dtype=tf.int32)
99+
coll_output, _ = hb.distribute.alltoall(
100+
coll_input, sizes=coll_input_sizes, topology=topology)
101+
with tf.device(f'/{message_device}:0'):
102+
results.append(tf.identity(coll_output))
103+
return tf.group(results)
104+
105+
106+
def benchmark(args):
107+
collective_ops = {
108+
'allreduce': allreduce,
109+
'alltoall': alltoall,
110+
'alltoallv_': lambda mf, mp, md, topo: alltoallv(
111+
mf, mp, md, topo, random_sizes=False),
112+
'alltoallv': alltoallv}
113+
for cop in args.collective_ops:
114+
if cop not in collective_ops:
115+
raise ValueError(
116+
f'Specified collective op type `{cop}` '
117+
f'not in {collective_ops.keys()}')
118+
for nf in args.message_floats:
119+
if nf % hb.context.world_size != 0:
120+
raise ValueError(
121+
f'#floats {nf} cannot be divided onto {hb.context.world_size} devices')
122+
for part in args.message_partitions:
123+
if nf % part != 0:
124+
raise ValueError(
125+
f'#floats {nf} cannot be divided into {part} partitions')
126+
127+
with tf.Graph().as_default(), hb.scope():
128+
bench_ops = {
129+
cop: {
130+
nf: {
131+
part: {
132+
dev: {} for dev in args.message_devices}
133+
for part in args.message_partitions}
134+
for nf in args.message_floats}
135+
for cop in args.collective_ops}
136+
for cop in args.collective_ops:
137+
with tf.name_scope(cop), tf.variable_scope(cop):
138+
for nf in args.message_floats:
139+
with tf.name_scope(f'{nf}floats'), tf.variable_scope(f'{nf}floats'):
140+
for p in args.message_partitions:
141+
with tf.name_scope(f'{p}parts'), tf.variable_scope(f'{p}parts'):
142+
for dev in args.message_devices:
143+
with tf.name_scope(f'{dev}'), tf.variable_scope(f'{dev}'):
144+
for topo in args.collective_topology:
145+
with tf.name_scope(f'{topo}topology'), tf.variable_scope(
146+
f'{topo}topology'):
147+
bench_ops[cop][nf][p][dev][topo] = (
148+
collective_ops[cop](
149+
nf, p, dev, topo))
150+
with tf.train.MonitoredTrainingSession('') as sess:
151+
print('Rank\tCollective\tTopology\tDevice\tSize\t#Splits\tThroughput')
152+
# pylint: disable=too-many-nested-blocks
153+
for cop in args.collective_ops:
154+
for dev in args.message_devices:
155+
for nf in args.message_floats:
156+
for part in args.message_partitions:
157+
for topo in args.collective_topology:
158+
for _ in range(args.warmup_steps):
159+
sess.run(bench_ops[cop][nf][part][dev][topo])
160+
prev_ts = time.time()
161+
for _ in range(args.num_steps):
162+
sess.run(bench_ops[cop][nf][part][dev][topo])
163+
duration = time.time() - prev_ts
164+
message_mbs = nf * 4. / 1024. / 1024.
165+
print(
166+
f'{hb.context.rank}/{hb.context.world_size}\t'
167+
f'{cop}\tTopology-{topo}\t{dev}\t'
168+
f'{message_mbs:.2f}MB\t{part}\t'
169+
f'{args.num_steps * message_mbs * 8.0 / duration:.2f}Gb/s')
170+
171+
172+
if __name__ == '__main__':
173+
tf.logging.set_verbosity(tf.logging.INFO)
174+
parser = argparse.ArgumentParser()
175+
parser.add_argument(
176+
'--collective-ops',
177+
nargs='+',
178+
help='Collective ops in (allreduce, alltoall, alltoallv)',
179+
default=['allreduce', 'alltoall', 'alltoallv'])
180+
parser.add_argument(
181+
'--collective-topology',
182+
type=int,
183+
nargs='+',
184+
help='All/Intra/Inter nodes participate collective_ops',
185+
default=[0, 1, 2])
186+
parser.add_argument(
187+
'--message-floats',
188+
type=int,
189+
nargs='+',
190+
help='Count of floats in each message',
191+
default=[65536, 262144, 1048576, 4194304, 16777216])
192+
parser.add_argument(
193+
'--message-partitions',
194+
type=int,
195+
nargs='+',
196+
help='Number of partitions of each message',
197+
default=[1, 8, 64])
198+
parser.add_argument(
199+
'--message-devices',
200+
nargs='+',
201+
help='Number of devices of each message',
202+
default=['gpu', 'cpu'])
203+
parser.add_argument('--warmup-steps', type=int, default=100)
204+
parser.add_argument('--num-steps', type=int, default=500)
205+
benchmark(parser.parse_args())
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# =============================================================================
15+
16+
r'''Embedding training using unobtrusive API on single GPU benchmark.
17+
'''
18+
19+
from __future__ import absolute_import
20+
from __future__ import division
21+
from __future__ import print_function
22+
23+
import argparse
24+
import os
25+
import time
26+
27+
import tensorflow as tf
28+
29+
import hybridbackend.tensorflow as hb
30+
31+
32+
# pylint: disable=missing-docstring
33+
def benchmark(params):
34+
with tf.Graph().as_default():
35+
fields = hb.data.Dataset.schema_from_parquet(params.filenames[0])
36+
fields = [
37+
f for f in fields
38+
if f.name not in ('label', 'ts') and f.dtype in (tf.int32, tf.int64)]
39+
ds = hb.data.Dataset.from_parquet(params.filenames, fields=fields)
40+
ds = ds.batch(params.batch_size, drop_remainder=True)
41+
ds = ds.prefetch(1)
42+
iterator = tf.data.make_one_shot_iterator(ds)
43+
iterator = hb.data.Iterator(iterator)
44+
iterator_hook = hb.data.Iterator.Hook()
45+
inputs = iterator.get_next()
46+
outputs = []
47+
with tf.name_scope('features'):
48+
for field in inputs:
49+
with tf.device('/cpu:0'):
50+
embedding_weights = tf.get_variable(
51+
f'{field}_weight',
52+
shape=(128, 32),
53+
initializer=tf.random_uniform_initializer(-1e-3, 1e-3))
54+
ids = inputs[field]
55+
if isinstance(ids, tf.Tensor):
56+
ids = ids % params.dimension_size
57+
ids, idx = tf.unique(tf.reshape(ids, shape=[-1]))
58+
embeddings = tf.nn.embedding_lookup(embedding_weights, ids)
59+
embeddings = tf.gather(embeddings, idx)
60+
else:
61+
ids = tf.SparseTensor(
62+
ids.indices,
63+
ids.values % params.dimension_size,
64+
ids.dense_shape)
65+
embeddings = tf.nn.embedding_lookup_sparse(
66+
embedding_weights, ids, None)
67+
outputs.append(embeddings)
68+
loss = tf.math.add_n([tf.reduce_sum(t) for t in outputs])
69+
opt = hb.train.AdagradOptimizer(learning_rate=0.01)
70+
step = tf.train.get_or_create_global_step()
71+
train_op = opt.minimize(loss, global_step=step)
72+
73+
with tf.train.MonitoredTrainingSession('', hooks=[iterator_hook]) as sess:
74+
count = 0
75+
prev_ts = time.time()
76+
try:
77+
while not sess.should_stop():
78+
sess.run(train_op)
79+
count += 1
80+
except tf.errors.OutOfRangeError:
81+
pass
82+
duration = time.time() - prev_ts
83+
if count <= 0:
84+
print('Training embedding layers stopped unexpectedly')
85+
return
86+
print(
87+
'Training embedding layers elapsed in '
88+
f'{params.batch_size * count / duration:.2f} samples/sec ('
89+
f'{1000. * duration / count:.2f} msec/step)')
90+
91+
92+
if __name__ == '__main__':
93+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
94+
os.environ['HB_OP_RELOCATION_ENABLED'] = '1'
95+
tf.logging.set_verbosity(tf.logging.INFO)
96+
parser = argparse.ArgumentParser()
97+
parser.add_argument('--batch-size', type=int, default=16384)
98+
parser.add_argument('--dimension-size', type=int, default=32)
99+
parser.add_argument('filenames', nargs='+')
100+
benchmark(parser.parse_args())

0 commit comments

Comments
 (0)