Skip to content

Commit fd385cc

Browse files
committed
[DOC] Add simple ranking example
1 parent 46e2bfd commit fd385cc

File tree

5 files changed

+810
-0
lines changed

5 files changed

+810
-0
lines changed

examples/ranking/layers.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# =============================================================================
15+
16+
r'''Layers for ranking model.
17+
'''
18+
19+
from __future__ import absolute_import
20+
from __future__ import division
21+
from __future__ import print_function
22+
23+
import math
24+
import tensorflow as tf
25+
26+
27+
class DotInteract(tf.layers.Layer):
28+
r'''DLRM: Deep Learning Recommendation Model for Personalization and
29+
Recommendation Systems.
30+
31+
See https://github.com/facebookresearch/dlrm for more information.
32+
'''
33+
def call(self, x):
34+
r'''Call the DLRM dot interact layer.
35+
'''
36+
x2 = tf.matmul(x, x, transpose_b=True)
37+
x2_dim = x2.shape[-1]
38+
x2_ones = tf.ones_like(x2)
39+
x2_mask = tf.linalg.band_part(x2_ones, 0, -1)
40+
y = tf.boolean_mask(x2, x2_ones - x2_mask)
41+
y = tf.reshape(y, [-1, x2_dim * (x2_dim - 1) // 2])
42+
return y
43+
44+
45+
class Cross(tf.layers.Layer):
46+
r'''DCN V2: Improved Deep & Cross Network and Practical Lessons for Web-scale
47+
Learning to Rank Systems.
48+
49+
See https://arxiv.org/abs/2008.13535 for more information.
50+
'''
51+
def call(self, x):
52+
r'''Call the DCN cross layer.
53+
'''
54+
x2 = tf.layers.dense(
55+
x, x.shape[-1],
56+
activation=tf.nn.relu,
57+
kernel_initializer=tf.truncated_normal_initializer(),
58+
bias_initializer=tf.zeros_initializer())
59+
y = x * x2 + x
60+
y = tf.reshape(y, [-1, x.shape[1] * x.shape[2]])
61+
return y
62+
63+
64+
class Ranking(tf.layers.Layer):
65+
r'''A simple ranking model.
66+
'''
67+
def __init__(
68+
self,
69+
embedding_columns,
70+
bottom_mlp=None,
71+
top_mlp=None,
72+
feature_interaction=None,
73+
**kwargs):
74+
r'''Constructor.
75+
76+
Args:
77+
embedding_columns: List of embedding columns.
78+
bottom_mlp: List of bottom MLP dimensions.
79+
top_mlp: List of top MLP dimensions.
80+
feature_interaction: Feature interaction layer class.
81+
**kwargs: keyword named properties.
82+
'''
83+
super().__init__(**kwargs)
84+
85+
if bottom_mlp is None:
86+
bottom_mlp = [512, 256, 64]
87+
self.bottom_mlp = bottom_mlp
88+
if top_mlp is None:
89+
top_mlp = [1024, 1024, 512, 256, 1]
90+
self.top_mlp = top_mlp
91+
if feature_interaction is None:
92+
feature_interaction = DotInteract
93+
self.feature_interaction = feature_interaction
94+
self.embedding_columns = embedding_columns
95+
dimensions = {c.dimension for c in embedding_columns}
96+
if len(dimensions) > 1:
97+
raise ValueError('Only one dimension supported')
98+
self.dimension = list(dimensions)[0]
99+
100+
def call(self, values, embeddings):
101+
r'''Call the dlrm model
102+
'''
103+
with tf.name_scope('bottom_mlp'):
104+
bot_mlp_input = tf.math.log(values + 1.)
105+
for i, d in enumerate(self.bottom_mlp):
106+
bot_mlp_input = tf.layers.dense(
107+
bot_mlp_input, d,
108+
activation=tf.nn.relu,
109+
kernel_initializer=tf.glorot_normal_initializer(),
110+
bias_initializer=tf.random_normal_initializer(
111+
mean=0.0,
112+
stddev=math.sqrt(1.0 / d)),
113+
name=f'bottom_mlp_{i}')
114+
bot_mlp_output = tf.layers.dense(
115+
bot_mlp_input, self.dimension,
116+
activation=tf.nn.relu,
117+
kernel_initializer=tf.glorot_normal_initializer(),
118+
bias_initializer=tf.random_normal_initializer(
119+
mean=0.0,
120+
stddev=math.sqrt(1.0 / self.dimension)),
121+
name='bottom_mlp_output')
122+
123+
with tf.name_scope('feature_interaction'):
124+
feat_interact_input = tf.concat([bot_mlp_output] + embeddings, axis=-1)
125+
feat_interact_input = tf.reshape(
126+
feat_interact_input,
127+
[-1, 1 + len(embeddings), self.dimension])
128+
feat_interact_output = self.feature_interaction()(feat_interact_input)
129+
130+
with tf.name_scope('top_mlp'):
131+
top_mlp_input = tf.concat([bot_mlp_output, feat_interact_output], axis=1)
132+
num_fields = len(self.embedding_columns)
133+
prev_d = (num_fields * (num_fields + 1)) / 2 + self.dimension
134+
for i, d in enumerate(self.top_mlp[:-1]):
135+
top_mlp_input = tf.layers.dense(
136+
top_mlp_input, d,
137+
activation=tf.nn.relu,
138+
kernel_initializer=tf.random_normal_initializer(
139+
mean=0.0,
140+
stddev=math.sqrt(2.0 / (prev_d + d))),
141+
bias_initializer=tf.random_normal_initializer(
142+
mean=0.0,
143+
stddev=math.sqrt(1.0 / d)),
144+
name=f'top_mlp_{i}')
145+
prev_d = d
146+
top_mlp_output = tf.layers.dense(
147+
top_mlp_input, self.top_mlp[-1],
148+
activation=tf.nn.sigmoid,
149+
kernel_initializer=tf.random_normal_initializer(
150+
mean=0.0,
151+
stddev=math.sqrt(2.0 / (prev_d + self.top_mlp[-1]))),
152+
bias_initializer=tf.random_normal_initializer(
153+
mean=0.0,
154+
stddev=math.sqrt(1.0 / self.top_mlp[-1])),
155+
name=f'top_mlp_{len(self.top_mlp) - 1}')
156+
return top_mlp_output

examples/ranking/optimization.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# =============================================================================
15+
16+
r'''Functions for optimization
17+
'''
18+
19+
from __future__ import absolute_import
20+
from __future__ import division
21+
from __future__ import print_function
22+
23+
import tensorflow as tf
24+
25+
26+
def lr_with_linear_warmup_and_polynomial_decay(
27+
global_step,
28+
initial_value=24.,
29+
scaling_factor=1.,
30+
warmup_steps=None,
31+
decay_steps=None,
32+
decay_start_step=None,
33+
decay_exp=2,
34+
epsilon=1.e-7):
35+
r'''Calculates learning rate with linear warmup and polynomial decay.
36+
37+
Args:
38+
global_step: Variable representing the current step.
39+
initial_value: Initial value of learning rates.
40+
warmup_steps: Steps of warmup.
41+
decay_steps: Steps of decay.
42+
decay_start_step: Start step of decay.
43+
decay_exp: Exponent part of decay.
44+
scaling_factor: Factor for scaling.
45+
46+
Returns:
47+
New learning rate tensor.
48+
'''
49+
initial_lr = tf.constant(initial_value * scaling_factor, tf.float32)
50+
51+
if warmup_steps is None:
52+
return initial_lr
53+
54+
global_step = tf.cast(global_step, tf.float32)
55+
warmup_steps = tf.constant(warmup_steps, tf.float32)
56+
warmup_rate = initial_lr / warmup_steps
57+
warmup_lr = initial_lr - (warmup_steps - global_step) * warmup_rate
58+
59+
if decay_steps is None or decay_start_step is None:
60+
return warmup_lr
61+
62+
decay_start_step = tf.constant(decay_start_step, tf.float32)
63+
steps_since_decay_start = global_step - decay_start_step
64+
decay_steps = tf.constant(decay_steps, tf.float32)
65+
decayed_steps = tf.minimum(steps_since_decay_start, decay_steps)
66+
to_decay_rate = (decay_steps - decayed_steps) / decay_steps
67+
decay_lr = initial_lr * to_decay_rate**decay_exp
68+
decay_lr = tf.maximum(decay_lr, tf.constant(epsilon))
69+
70+
warmup_lambda = tf.cast(global_step < warmup_steps, tf.float32)
71+
decay_lambda = tf.cast(global_step > decay_start_step, tf.float32)
72+
initial_lambda = tf.cast(
73+
tf.math.abs(warmup_lambda + decay_lambda) < epsilon, tf.float32)
74+
75+
lr = warmup_lambda * warmup_lr
76+
lr += decay_lambda * decay_lr
77+
lr += initial_lambda * initial_lr
78+
return lr
79+
80+
81+
def sgd_decay_optimize(
82+
loss,
83+
lr_initial_value,
84+
lr_warmup_steps,
85+
lr_decay_start_step,
86+
lr_decay_steps):
87+
r'''Optimize using SGD and learning rate decay.
88+
'''
89+
step = tf.train.get_or_create_global_step()
90+
lr = lr_with_linear_warmup_and_polynomial_decay(
91+
step,
92+
initial_value=lr_initial_value,
93+
warmup_steps=lr_warmup_steps,
94+
decay_start_step=lr_decay_start_step,
95+
decay_steps=lr_decay_steps)
96+
opt = tf.train.GradientDescentOptimizer(learning_rate=lr)
97+
return opt.minimize(loss, global_step=step)

0 commit comments

Comments
 (0)