Skip to content

Commit debbe73

Browse files
brillstfx-copybara
authored andcommitted
Lift stats generator: use Tuple instead of FeaturePath when encoding feature paths as keys in a PTable.
Also changed the output type of some internal functions from numpy arrays to python lists. elements of an numpy array are of numpy types which do not have deterministic beam coders. PiperOrigin-RevId: 364593214
1 parent 1a08318 commit debbe73

File tree

3 files changed

+43
-37
lines changed

3 files changed

+43
-37
lines changed

tensorflow_data_validation/statistics/generators/lift_stats_generator.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -63,25 +63,34 @@
6363
_SlicedYKey = typing.NamedTuple('_SlicedYKey', [('slice_key', types.SliceKey),
6464
('y', _YType)])
6565

66-
_SlicedXKey = typing.NamedTuple('_SlicedXKey', [('slice_key', types.SliceKey),
67-
('x_path', types.FeaturePath),
68-
('x', _XType)])
6966

70-
_SlicedXYKey = typing.NamedTuple('_SlicedXYKey', [('slice_key', types.SliceKey),
71-
('x_path', types.FeaturePath),
72-
('x', _XType), ('y', _YType)])
67+
# TODO(embr,zhuo): FeaturePathTuple is used instead of FeaturePath because:
68+
# - FeaturePath does not have a deterministic coder
69+
# - Even if it does, beam does not automatically derive a coder for a
70+
# NamedTuple.
71+
# Once the latter is supported we can change all FEaturePathTuples back to
72+
# FeaturePaths.
73+
_SlicedXKey = typing.NamedTuple('_SlicedXKey',
74+
[('slice_key', types.SliceKey),
75+
('x_path', types.FeaturePathTuple),
76+
('x', _XType)])
77+
78+
_SlicedXYKey = typing.NamedTuple('_SlicedXYKey',
79+
[('slice_key', types.SliceKey),
80+
('x_path', types.FeaturePathTuple),
81+
('x', _XType), ('y', _YType)])
7382

7483
_LiftSeriesKey = typing.NamedTuple('_LiftSeriesKey',
7584
[('slice_key', types.SliceKey),
76-
('x_path', types.FeaturePath),
85+
('x_path', types.FeaturePathTuple),
7786
('y', _YType), ('y_count', _CountType)])
7887

7988
_SlicedFeatureKey = typing.NamedTuple('_SlicedFeatureKey',
8089
[('slice_key', types.SliceKey),
81-
('x_path', types.FeaturePath)])
90+
('x_path', types.FeaturePathTuple)])
8291

8392
_ConditionalYRate = typing.NamedTuple('_ConditionalYRate',
84-
[('x_path', types.FeaturePath),
93+
[('x_path', types.FeaturePathTuple),
8594
('x', _XType), ('xy_count', _CountType),
8695
('x_count', _CountType)])
8796

@@ -171,15 +180,15 @@ def _get_example_value_presence(
171180
if is_binary_like:
172181
# return binary like values a pd.Categorical wrapped in a Series. This makes
173182
# subsqeuent operations like pd.Merge cheaper.
174-
values = arr_flat_dict[values]
183+
values = arr_flat_dict[values].tolist()
175184
else:
176185
values = values.tolist() # converts values to python native types.
177186
if weight_column_name:
178187
weights = arrow_util.get_weight_feature(record_batch, weight_column_name)
179-
weights = np.asarray(weights)[example_indices]
188+
weights = np.asarray(weights)[example_indices].tolist()
180189
else:
181190
weights = np.ones(len(example_indices), dtype=int).tolist()
182-
return _ValuePresence(example_indices, values, weights)
191+
return _ValuePresence(example_indices.tolist(), values, weights)
183192

184193

185194
def _to_partial_copresence_counts(
@@ -246,7 +255,8 @@ def _to_partial_copresence_counts(
246255
if num_xy_pairs_batch_copresent:
247256
num_xy_pairs_batch_copresent.update(len(copresence_counts))
248257
for (x, y), count in copresence_counts.items():
249-
yield _SlicedXYKey(slice_key=slice_key, x_path=x_path, x=x, y=y), count
258+
yield (_SlicedXYKey(slice_key=slice_key, x_path=x_path.steps(), x=x,
259+
y=y), count)
250260

251261

252262
def _to_partial_counts(
@@ -283,7 +293,7 @@ def _to_partial_x_counts(
283293
x_path,
284294
boundaries=None,
285295
weight_column_name=example_weight_map.get(x_path)):
286-
yield _SlicedXKey(slice_key, x_path, x), x_count
296+
yield _SlicedXKey(slice_key, x_path.steps(), x), x_count
287297

288298

289299
def _get_unicode_value(value: Union[Text, bytes]) -> Text:
@@ -324,11 +334,12 @@ def _make_dataset_feature_stats_proto(
324334
The populated DatasetFeatureStatistics proto.
325335
"""
326336
key, lift_series_list = lifts
337+
x_path = types.FeaturePath(key.x_path)
327338
stats = statistics_pb2.DatasetFeatureStatistics()
328339
cross_stats = stats.cross_features.add(
329-
path_x=key.x_path.to_proto(), path_y=y_path.to_proto())
340+
path_x=x_path.to_proto(), path_y=y_path.to_proto())
330341
if output_custom_stats:
331-
feature_stats = stats.features.add(path=key.x_path.to_proto())
342+
feature_stats = stats.features.add(path=x_path.to_proto())
332343
for lift_series in sorted(lift_series_list):
333344
lift_series_proto = (
334345
cross_stats.categorical_cross_stats.lift.lift_series.add())
@@ -392,7 +403,8 @@ def _make_dataset_feature_stats_proto(
392403
def _cross_join_y_keys(
393404
join_info: Tuple[types.SliceKey, Dict[Text, Sequence[Any]]]
394405
# TODO(b/147153346) update dict value list element type annotation to:
395-
# Union[_YKey, Tuple[_YType, Tuple[types.FeaturePath, _XType, _CountType]]]
406+
# Union[_YKey, Tuple[_YType,
407+
# Tuple[types.FeaturePathTuple, _XType, _CountType]]]
396408
) -> Iterator[Tuple[_SlicedXYKey, _CountType]]:
397409
slice_key, join_args = join_info
398410
for x_path, x, _ in join_args['x_counts']:

tensorflow_data_validation/statistics/generators/lift_stats_generator_test.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
from absl.testing import absltest
1919
import apache_beam as beam
20-
from apache_beam.options import pipeline_options
2120
import numpy as np
2221
import pandas as pd
2322
import pyarrow as pa
@@ -31,10 +30,6 @@
3130
from tensorflow_metadata.proto.v0 import statistics_pb2
3231

3332

34-
# TODO(b/181911927): Remove this workaround.
35-
pipeline_options.TypeOptions.allow_non_deterministic_key_coders = True
36-
37-
3833
def _get_example_value_presence_as_dataframe(
3934
record_batch: pa.RecordBatch, path: types.FeaturePath,
4035
boundaries: Optional[Sequence[float]],
@@ -221,8 +216,8 @@ def test_to_partial_x_counts_unweighted(self):
221216
], ['x'])
222217
x_path = types.FeaturePath(['x'])
223218
expected_counts = [
224-
(lift_stats_generator._SlicedXYKey('', x_path, x=1, y=None), 2),
225-
(lift_stats_generator._SlicedXYKey('', x_path, x=2, y=None), 1),
219+
(lift_stats_generator._SlicedXYKey('', x_path.steps(), x=1, y=None), 2),
220+
(lift_stats_generator._SlicedXYKey('', x_path.steps(), x=2, y=None), 1),
226221
]
227222
for (expected_key, expected_count), (actual_key, actual_count) in zip(
228223
expected_counts,
@@ -241,8 +236,10 @@ def test_to_partial_x_counts_weighted(self):
241236
], ['x', 'w'])
242237
x_path = types.FeaturePath(['x'])
243238
expected_counts = [
244-
(lift_stats_generator._SlicedXYKey('', x_path, x=1, y=None), 2.5),
245-
(lift_stats_generator._SlicedXYKey('', x_path, x=2, y=None), 0.5),
239+
(lift_stats_generator._SlicedXYKey('', x_path.steps(), x=1,
240+
y=None), 2.5),
241+
(lift_stats_generator._SlicedXYKey('', x_path.steps(), x=2,
242+
y=None), 0.5),
246243
]
247244
for (expected_key, expected_count), (actual_key, actual_count) in zip(
248245
expected_counts,
@@ -263,9 +260,9 @@ def test_to_partial_copresence_counts_unweighted(self):
263260
], ['x', 'y'])
264261
x_path = types.FeaturePath(['x'])
265262
expected_counts = [
266-
(lift_stats_generator._SlicedXYKey('', x_path, x=1, y='a'), 1),
267-
(lift_stats_generator._SlicedXYKey('', x_path, x=1, y='b'), 1),
268-
(lift_stats_generator._SlicedXYKey('', x_path, x=2, y='a'), 1)
263+
(lift_stats_generator._SlicedXYKey('', x_path.steps(), x=1, y='a'), 1),
264+
(lift_stats_generator._SlicedXYKey('', x_path.steps(), x=1, y='b'), 1),
265+
(lift_stats_generator._SlicedXYKey('', x_path.steps(), x=2, y='a'), 1)
269266
]
270267
actual_counts = list(
271268
lift_stats_generator._to_partial_copresence_counts(
@@ -284,9 +281,11 @@ def test_to_partial_copresence_counts_weighted(self):
284281
], ['x', 'y', 'w'])
285282
x_path = types.FeaturePath(['x'])
286283
expected_counts = [
287-
(lift_stats_generator._SlicedXYKey('', x_path, x=1, y='a'), 0.5),
288-
(lift_stats_generator._SlicedXYKey('', x_path, x=1, y='b'), 2.0),
289-
(lift_stats_generator._SlicedXYKey('', x_path, x=2, y='a'), 0.5)
284+
(lift_stats_generator._SlicedXYKey('', x_path.steps(), x=1,
285+
y='a'), 0.5),
286+
(lift_stats_generator._SlicedXYKey('', x_path.steps(), x=1,
287+
y='b'), 2.0),
288+
(lift_stats_generator._SlicedXYKey('', x_path.steps(), x=2, y='a'), 0.5)
290289
]
291290
actual_counts = list(
292291
lift_stats_generator._to_partial_copresence_counts(

tensorflow_data_validation/statistics/stats_impl_test.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from absl.testing import absltest
2222
from absl.testing import parameterized
2323
import apache_beam as beam
24-
from apache_beam.options import pipeline_options
2524
from apache_beam.testing import util
2625
import numpy as np
2726
import pyarrow as pa
@@ -41,10 +40,6 @@
4140
from tensorflow_metadata.proto.v0 import statistics_pb2
4241

4342

44-
# TODO(b/181911927): Remove this workaround.
45-
pipeline_options.TypeOptions.allow_non_deterministic_key_coders = True
46-
47-
4843
# Testing classes for 'custom_feature_generator' testcase.
4944
# They are defined module level in order to allow pickling.
5045
class _BaseCounter(stats_generator.CombinerFeatureStatsGenerator):

0 commit comments

Comments
 (0)