Skip to content

Commit

Permalink
Project import generated by Copybara.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 254254890
  • Loading branch information
xuanhuiwang authored and ramakumar1729 committed Jun 20, 2019
1 parent 8536405 commit d36d7fb
Show file tree
Hide file tree
Showing 13 changed files with 2,096 additions and 631 deletions.
82 changes: 69 additions & 13 deletions tensorflow_ranking/examples/tf_ranking_libsvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ def example_feature_columns():
"""Returns the example feature columns."""
feature_names = ["{}".format(i + 1) for i in range(FLAGS.num_features)]
return {
name: tf.feature_column.numeric_column(
name, shape=(1,), default_value=0.0) for name in feature_names
name:
tf.feature_column.numeric_column(name, shape=(1,), default_value=0.0)
for name in feature_names
}


Expand Down Expand Up @@ -185,8 +186,8 @@ def _train_input_fn():
for k, v in six.iteritems(features)
}
labels_placeholder = tf.compat.v1.placeholder(labels.dtype, labels.shape)
dataset = tf.data.Dataset.from_tensor_slices((features_placeholder,
labels_placeholder))
dataset = tf.data.Dataset.from_tensor_slices(
(features_placeholder, labels_placeholder))
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
feed_dict = {labels_placeholder: labels}
Expand All @@ -210,8 +211,8 @@ def _eval_input_fn():
for k, v in six.iteritems(features)
}
labels_placeholder = tf.compat.v1.placeholder(labels.dtype, labels.shape)
dataset = tf.data.Dataset.from_tensors((features_placeholder,
labels_placeholder))
dataset = tf.data.Dataset.from_tensors(
(features_placeholder, labels_placeholder))
iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
feed_dict = {labels_placeholder: labels}
feed_dict.update(
Expand All @@ -223,6 +224,49 @@ def _eval_input_fn():
return _eval_input_fn, iterator_initializer_hook


def make_serving_input_fn():
"""Returns serving input fn to receive tf.Example."""
feature_spec = tf.feature_column.make_parse_example_spec(
example_feature_columns().values())
return tf.estimator.export.build_parsing_serving_input_receiver_fn(
feature_spec)


def make_transform_fn():
"""Returns a transform_fn that converts features to dense Tensors."""

def _transform_fn(features, mode):
"""Defines transform_fn."""
if mode == tf.estimator.ModeKeys.PREDICT:
# We expect tf.Example as input during serving. In this case, group_size
# must be set to 1.
if FLAGS.group_size != 1:
raise ValueError(
"group_size should be 1 to be able to export model, but get %s" %
FLAGS.group_size)
context_features, example_features = (
tfr.feature.encode_pointwise_features(
features=features,
context_feature_columns=None,
example_feature_columns=example_feature_columns(),
mode=mode,
scope="transform_layer"))
else:
example_name = next(six.iterkeys(example_feature_columns()))
input_size = tf.shape(input=features[example_name])[1]
context_features, example_features = tfr.feature.encode_listwise_features(
features=features,
input_size=input_size,
context_feature_columns=None,
example_feature_columns=example_feature_columns(),
mode=mode,
scope="transform_layer")

return context_features, example_features

return _transform_fn


def make_score_fn():
"""Returns a groupwise score fn to build `EstimatorSpec`."""

Expand Down Expand Up @@ -309,7 +353,7 @@ def _train_op_fn(loss):
model_fn=tfr.model.make_groupwise_ranking_fn(
group_score_fn=make_score_fn(),
group_size=FLAGS.group_size,
transform_fn=None,
transform_fn=make_transform_fn(),
ranking_head=ranking_head),
config=tf.estimator.RunConfig(
FLAGS.output_dir, save_checkpoints_steps=1000))
Expand All @@ -318,12 +362,24 @@ def _train_op_fn(loss):
input_fn=train_input_fn,
hooks=[train_hook],
max_steps=FLAGS.num_train_steps)
vali_spec = tf.estimator.EvalSpec(
input_fn=vali_input_fn,
hooks=[vali_hook],
steps=1,
start_delay_secs=0,
throttle_secs=30)
# Export model to accept tf.Example when group_size = 1.
if FLAGS.group_size == 1:
vali_spec = tf.estimator.EvalSpec(
input_fn=vali_input_fn,
hooks=[vali_hook],
steps=1,
exporters=tf.estimator.LatestExporter(
"latest_exporter",
serving_input_receiver_fn=make_serving_input_fn()),
start_delay_secs=0,
throttle_secs=30)
else:
vali_spec = tf.estimator.EvalSpec(
input_fn=vali_input_fn,
hooks=[vali_hook],
steps=1,
start_delay_secs=0,
throttle_secs=30)

# Train and validate
tf.estimator.train_and_evaluate(estimator, train_spec, vali_spec)
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_ranking/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ py_test(
name = "losses_test",
size = "medium",
srcs = ["losses_test.py"],
shard_count = 4,
srcs_version = "PY2AND3",
tags = [
"no_pip",
Expand Down Expand Up @@ -140,6 +139,7 @@ py_library(
srcs = ["data.py"],
srcs_version = "PY2AND3",
deps = [
":utils",
# py/tensorflow dep,
],
)
Expand Down
Loading

0 comments on commit d36d7fb

Please sign in to comment.