Skip to content

Commit 9b3c00d

Browse files
authored
[DATA] Introduce new RebatchDataset to replace rebatch and rectify
1 parent 7fcffe6 commit 9b3c00d

29 files changed

+947
-2086
lines changed

README.md

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,8 @@ recommender systems on heterogeneous cluster.
1111
## Features
1212

1313
- Memory-efficient loading of categorical data
14-
1514
- GPU-efficient orchestration of embedding layers
16-
1715
- Communication-efficient training and evaluation at scale
18-
1916
- Easy to use with existing AI workflows
2017

2118
## Usage
@@ -26,8 +23,8 @@ A minimal example:
2623
import tensorflow as tf
2724
import hybridbackend.tensorflow as hb
2825

29-
ds = hb.data.ParquetDataset(filenames, batch_size=batch_size)
30-
ds = ds.apply(hb.data.parse())
26+
ds = hb.data.Dataset.from_parquet(filenames)
27+
ds = ds.batch(batch_size)
3128
# ...
3229

3330
with tf.device('/gpu:0'):
@@ -44,16 +41,16 @@ more information.
4441

4542
`pip install {PACKAGE}`
4643

47-
`{PACKAGE}` | Dependency | Python | CUDA | GLIBC | Data Opt. | Embedding Opt. | Parallelism Opt.
48-
----------- | ---------- | ------- | ---- | ----- | --------- | -------------- | -----------------
49-
[hybridbackend-deeprec2208-cu114](https://pypi.org/project/hybridbackend-deeprec2208-cu114/) | [DeepRec 22.08](https://github.com/alibaba/DeepRec/tree/deeprec2208) `1` | 3.6 | 11.4 | >=2.27 | ✓ | ✓ | ✓
50-
[hybridbackend-tf115-cu118](https://pypi.org/project/hybridbackend-tf115-cu118/) | [TensorFlow 1.15](https://github.com/NVIDIA/tensorflow) `2` | 3.8 | 11.8 | >=2.31 | ✓ | ✓ | ✓
51-
[hybridbackend-tf115-cu100](https://pypi.org/project/hybridbackend-tf115-cu100/) | [TensorFlow 1.15](https://github.com/tensorflow/tensorflow/tree/r1.15) | 3.6 | 10.0 | >=2.27 | ✓ | ✓ | ✗
52-
[hybridbackend-tf115-cpu](https://pypi.org/project/hybridbackend-tf115-cpu/) | [TensorFlow 1.15](https://github.com/tensorflow/tensorflow/tree/r1.15) | 3.6 | - | >=2.24 | ✓ | ✗ | ✗
44+
| `{PACKAGE}` | Dependency | Python | CUDA | GLIBC | Data Opt. | Embedding Opt. | Parallelism Opt. |
45+
| ----------------------------------------------------------------------------------------- | ----------------------------------------------------------------------- | ------ | ---- | ------ | --------- | -------------- | ---------------- |
46+
| [hybridbackend-tf115-cu118](https://pypi.org/project/hybridbackend-tf115-cu118/) | [TensorFlow 1.15](https://github.com/NVIDIA/tensorflow) `1` | 3.8 | 11.8 | >=2.31 | ✓ | ✓ | ✓ |
47+
| [hybridbackend-tf115-cu100](https://pypi.org/project/hybridbackend-tf115-cu100/) | [TensorFlow 1.15](https://github.com/tensorflow/tensorflow/tree/r1.15) | 3.6 | 10.0 | >=2.27 | ✓ | ✓ | ✗ |
48+
| [hybridbackend-tf115-cpu](https://pypi.org/project/hybridbackend-tf115-cpu/) | [TensorFlow 1.15](https://github.com/tensorflow/tensorflow/tree/r1.15) | 3.6 | - | >=2.24 | ✓ | ✗ | ✗ |
49+
| [hybridbackend-deeprec2208-cu114](https://pypi.org/project/hybridbackend-deeprec2208-cu114/) | [DeepRec 22.08](https://github.com/alibaba/DeepRec/tree/deeprec2208) `2` | 3.6 | 11.4 | >=2.27 | ✓ | ✓ | ✓ |
5350

54-
> `1`: Suggested docker image: `dsw-registry.cn-shanghai.cr.aliyuncs.com/pai/tensorflow-training:1.15PAI-gpu-py36-cu114-ubuntu18.04`
51+
> `1`: Suggested docker image: `nvcr.io/nvidia/tensorflow:22.12-tf1-py3`
5552
56-
> `2`: Suggested docker image: `nvcr.io/nvidia/tensorflow:22.12-tf1-py3`
53+
> `2`: Suggested docker image: `dsw-registry.cn-shanghai.cr.aliyuncs.com/pai/tensorflow-training:1.15PAI-gpu-py36-cu114-ubuntu18.04`
5754
5855
### Method 2: Build from source
5956

@@ -66,13 +63,11 @@ HybridBackend is licensed under the [Apache 2.0 License](LICENSE).
6663
## Community
6764

6865
- Please see [Contributing Guide](https://github.com/alibaba/HybridBackend/blob/main/CONTRIBUTING.md)
69-
before your first contribution.
70-
66+
before your first contribution.
7167
- Please [register as an adopter](https://github.com/alibaba/HybridBackend/blob/main/ADOPTERS.md)
72-
if your organization is interested in adoption. We will discuss
73-
[RoadMap](https://github.com/alibaba/HybridBackend/blob/main/ROADMAP.md) with
74-
registered adopters in advance.
75-
68+
if your organization is interested in adoption. We will discuss
69+
[RoadMap](https://github.com/alibaba/HybridBackend/blob/main/ROADMAP.md) with
70+
registered adopters in advance.
7671
- Please cite [HybridBackend](https://ieeexplore.ieee.org/document/9835450) in your publications if it helps:
7772

7873
```text
@@ -90,4 +85,4 @@ registered adopters in advance.
9085
If you would like to share your experiences with others, you are welcome to
9186
contact us in DingTalk:
9287

93-
[![dingtalk](https://github.com/alibaba/HybridBackend/raw/main/docs/images/dingtalk.png)](https://qr.dingtalk.com/action/joingroup?code=v1,k1,VouhbeuTwXYEgaLzSOE8o6VF2kTHVJ8lw5h93WbZW8o=&_dt_no_comment=1&origin=11)
88+
[![dingtalk](https://github.com/alibaba/HybridBackend/raw/main/docs/images/dingtalk.png)](https://qr.dingtalk.com/action/joingroup?code=v1,k1,VouhbeuTwXYEgaLzSOE8o6VF2kTHVJ8lw5h93WbZW8o=&_dt_no_comment=1&origin=11)

hybridbackend/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,6 @@
2020
from __future__ import division
2121
from __future__ import print_function
2222

23-
__version__ = '0.7.0a2'
23+
__version__ = '0.8.0'
2424
__author__ = 'Alibaba Group Holding Limited'
2525
__copyright__ = '2021 Alibaba Group Holding Limited'

hybridbackend/run.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ def run(command):
6969
command: Function or command to run
7070
'''
7171
visible_devices = _query_visible_devices()
72+
local_world_size_str = str(len(visible_devices))
73+
7274
port = int(os.getenv('HB_RUN_BASE_PORT', '20001'))
7375
device_to_ports = []
7476
for d in visible_devices:
@@ -126,6 +128,8 @@ def run(command):
126128
new_tf_config['task']['type'] = task_type
127129
new_tf_config['task']['index'] = task_id
128130
os.environ['TF_CONFIG'] = json.dumps(new_tf_config)
131+
os.environ['TF_TASK_TYPE'] = str(task_type)
132+
os.environ['TF_TASK_INDEX'] = str(task_id)
129133
os.environ['CUDA_VISIBLE_DEVICES'] = ''
130134
os.environ['HB_OP_OPTIMIZATION_DISABLED'] = '1'
131135
if callable(command):
@@ -165,7 +169,10 @@ def run(command):
165169
gpu_tf_config['task']['index'] = gpu_index
166170
gpu_env = os.environ.copy()
167171
gpu_env['TF_CONFIG'] = json.dumps(gpu_tf_config)
172+
gpu_env['TF_TASK_TYPE'] = gpu_tf_config['task']['type']
173+
gpu_env['TF_TASK_INDEX'] = str(gpu_tf_config['task']['index'])
168174
gpu_env['CUDA_VISIBLE_DEVICES'] = device
175+
gpu_env['LOCAL_WORLD_SIZE'] = local_world_size_str
169176
if interop_threads_gpu:
170177
gpu_env['TF_NUM_INTEROP_THREADS'] = str(interop_threads_gpu)
171178
if intraop_threads_gpu:

hybridbackend/tensorflow/benchmarks/data_benchmark_parquet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,13 @@ def benchmark(params):
5151
with tf.Graph().as_default():
5252
step = tf.train.get_or_create_global_step()
5353
if params.baseline:
54-
ds = hb.data.TabularDataset.from_parquet(params.filenames)
54+
ds = hb.data.Dataset.from_parquet(params.filenames)
5555
ds = ds.map(lambda data: data) # Prevent fusion
5656
if params.shuffle:
5757
ds = ds.shuffle(params.batch_size * 10)
5858
ds = ds.batch(params.batch_size, drop_remainder=True)
5959
else:
60-
ds = hb.data.TabularDataset.from_parquet(params.filenames)
60+
ds = hb.data.Dataset.from_parquet(params.filenames)
6161
if params.shuffle:
6262
ds = ds.shuffle_batch(
6363
params.batch_size, drop_remainder=True,
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/* Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#if HYBRIDBACKEND_TENSORFLOW
17+
18+
#if GOOGLE_CUDA
19+
#define EIGEN_USE_GPU
20+
21+
#include <cuda.h>
22+
#include <cuda_runtime.h>
23+
24+
#include <limits>
25+
26+
#include <tensorflow/core/framework/register_types.h>
27+
#include <tensorflow/core/framework/tensor.h>
28+
#include <tensorflow/core/public/version.h>
29+
30+
#include "hybridbackend/common/atomic.cu.h"
31+
#include "hybridbackend/tensorflow/common/device_functions.h"
32+
#include "hybridbackend/tensorflow/common/slice_sum.h"
33+
34+
namespace tensorflow {
35+
36+
using GPUDevice = Eigen::GpuDevice;
37+
38+
namespace hybridbackend {
39+
40+
namespace functor {
41+
42+
template <typename T, int32 N = 256>
43+
__global__ void SliceSumKernel(const int32 num_rows, const int32 num_cols,
44+
const int32 col, const T* input, T* output_total,
45+
T* output) {
46+
for (int32 idx : CudaGridRangeX(num_rows)) {
47+
const T v = input[idx * num_cols + col];
48+
output[idx] = v;
49+
atomicAdd(output_total, v);
50+
}
51+
}
52+
53+
template <typename T>
54+
struct SliceSum<GPUDevice, T> {
55+
void operator()(const int32 num_rows, const int32 num_cols, const int32 col,
56+
const T* input, T* output_total, T* output,
57+
const Eigen::GpuDevice& d) {
58+
CudaLaunch(SliceSumKernel<T>, num_rows, 0, d, nullptr, num_rows, num_cols,
59+
col, input, output_total, output);
60+
}
61+
};
62+
63+
template struct SliceSum<GPUDevice, int32>;
64+
template struct SliceSum<GPUDevice, int64>;
65+
template struct SliceSum<GPUDevice, uint32>;
66+
template struct SliceSum<GPUDevice, uint64>;
67+
68+
template <typename T, int32 N = 256>
69+
__global__ void GroupSliceSumKernel(const int32 num_rows, const int32 num_cols,
70+
const int32 col, const int32 num_inputs,
71+
const T* inputs, T* output_totals,
72+
T** outputs) {
73+
for (int32 idx : CudaGridRangeX(num_inputs * num_rows)) {
74+
const int32 s = idx / num_rows;
75+
const int32 sidx = idx % num_rows;
76+
const T v = inputs[idx * num_cols + col];
77+
outputs[s][sidx] = v;
78+
atomicAdd(output_totals + s, v);
79+
}
80+
}
81+
82+
template <typename T>
83+
struct SliceSumN<GPUDevice, T> {
84+
void operator()(const int32 num_rows, const int32 num_cols, const int32 col,
85+
const int32 num_inputs, const T* inputs, T* output_totals,
86+
T** outputs, const Eigen::GpuDevice& d) {
87+
CudaLaunch(GroupSliceSumKernel<T>, num_inputs * num_rows, 0, d, nullptr,
88+
num_rows, num_cols, col, num_inputs, inputs, output_totals,
89+
outputs);
90+
}
91+
};
92+
93+
template struct SliceSumN<GPUDevice, int32>;
94+
template struct SliceSumN<GPUDevice, int64>;
95+
template struct SliceSumN<GPUDevice, uint32>;
96+
template struct SliceSumN<GPUDevice, uint64>;
97+
98+
} // namespace functor
99+
} // namespace hybridbackend
100+
} // namespace tensorflow
101+
102+
#endif // GOOGLE_CUDA
103+
#endif // HYBRIDBACKEND_TENSORFLOW

hybridbackend/tensorflow/data/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@
2929
from hybridbackend.tensorflow.data.prefetch.ops import Iterator
3030
from hybridbackend.tensorflow.data.rebatch.dataset import RebatchDataset
3131
from hybridbackend.tensorflow.data.rebatch.dataset import rebatch
32-
from hybridbackend.tensorflow.data.rectify.dataset import rectify
33-
from hybridbackend.tensorflow.data.tabular.dataset import TabularDataset
32+
from hybridbackend.tensorflow.data.tabular.dataset import Dataset
3433

3534
# HybridBackend operators must be loaded before TensorFlow operators to
3635
# make AWS SDK implementation correct.
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/* Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef HYBRIDBACKEND_TENSORFLOW_DATA_REBATCH_BUFFER_H_
17+
#define HYBRIDBACKEND_TENSORFLOW_DATA_REBATCH_BUFFER_H_
18+
19+
#include <deque>
20+
#include <vector>
21+
22+
#include <tensorflow/core/framework/tensor.h>
23+
#include <tensorflow/core/lib/random/philox_random.h>
24+
#include <tensorflow/core/lib/random/random.h>
25+
#include <tensorflow/core/lib/random/random_distributions.h>
26+
27+
namespace tensorflow {
28+
namespace hybridbackend {
29+
30+
struct RebatchBufferItem {
31+
public:
32+
RebatchBufferItem(int64 batch_size, const std::vector<Tensor>& components)
33+
: batch_size(batch_size), components(components) {}
34+
int64 batch_size;
35+
std::vector<Tensor> components;
36+
};
37+
38+
class RebatchBuffer {
39+
public:
40+
RebatchBuffer(const DataTypeVector& output_dtypes,
41+
const std::vector<PartialTensorShape>& output_shapes,
42+
const std::vector<int32>& field_ranks);
43+
44+
int64 size() const { return size_; }
45+
46+
Status Put(const std::vector<Tensor>& input_tensors, const int64 num_rows);
47+
48+
Status PutSlice(const std::vector<Tensor>& input_tensors,
49+
const int64 row_start, const int64 row_limit);
50+
51+
Status Shuffle(random::SingleSampleAdapter<random::PhiloxRandom>& generator,
52+
const int64 num_rows);
53+
54+
Status Take(Allocator* alloc, std::vector<Tensor>* output_tensors,
55+
const int64 num_rows);
56+
57+
private:
58+
Status TakeDense(Allocator* alloc, std::vector<Tensor>* output_tensors,
59+
std::vector<Tensor>* residual_tensors, const int64 num_rows,
60+
const int64 remained_rows, const int64 rank,
61+
const int64 col);
62+
63+
Status TakeSparse(Allocator* alloc, std::vector<Tensor>* output_tensors,
64+
std::vector<Tensor>* residual_tensors, const int64 num_rows,
65+
const int64 remained_rows, const int64 rank,
66+
const int64 col);
67+
68+
const DataTypeVector& output_dtypes_;
69+
const std::vector<PartialTensorShape>& output_shapes_;
70+
const std::vector<int32> field_ranks_;
71+
72+
int64 size_;
73+
std::deque<RebatchBufferItem> items_;
74+
};
75+
76+
} // namespace hybridbackend
77+
} // namespace tensorflow
78+
79+
#endif // HYBRIDBACKEND_TENSORFLOW_DATA_REBATCH_BUFFER_H_

hybridbackend/tensorflow/data/rebatch/dataset.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import inspect
2424

2525
# pylint: disable=ungrouped-imports
26+
from hybridbackend.tensorflow.data.dataframe import input_fields
27+
2628
try:
2729
from tensorflow.python.data.ops.dataset_ops import DatasetV2 as _dataset # pylint: disable=unused-import
2830

@@ -43,28 +45,19 @@
4345

4446
def rebatch(
4547
batch_size,
46-
min_batch_size=None,
47-
fields=None,
4848
drop_remainder=False,
49-
num_parallel_scans=1):
49+
fields=None):
5050
r'''Create a `RebatchDataset`.
5151
5252
Args:
5353
batch_size: Maxium number of samples in an output batch.
54-
min_batch_size: (Optional.) Minimum number of samples in a non-final
55-
batch. Same to `batch_size` by default.
56-
fields: (Optional.) List of DataFrame fields. Fetched from `input_dataset`
57-
by default.
5854
drop_remainder: (Optional.) If True, smaller final batch is dropped.
5955
`False` by default.
60-
num_parallel_scans: (Optional.) Number of concurrent scans against fields
61-
of input dataset.
56+
fields: (Optional.) List of DataFrame fields. Fetched from `input_dataset`
57+
by default.
6258
'''
6359
def _apply_fn(dataset):
6460
return RebatchDataset(
65-
dataset, batch_size,
66-
min_batch_size=min_batch_size,
67-
fields=fields,
68-
drop_remainder=drop_remainder,
69-
num_parallel_scans=num_parallel_scans)
61+
dataset, input_fields(dataset, fields), batch_size,
62+
drop_remainder=drop_remainder)
7063
return _apply_fn

0 commit comments

Comments
 (0)