Skip to content

Commit 715ecf9

Browse files
brillstfx-copybara
authored andcommitted
internal change
PiperOrigin-RevId: 358584927
1 parent ef66b9d commit 715ecf9

File tree

6 files changed

+136
-22
lines changed

6 files changed

+136
-22
lines changed

tensorflow_data_validation/anomalies/feature_util.cc

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,8 @@ std::vector<Description> UpdateFeatureValueCounts(
509509
}
510510

511511
std::vector<Description> UpdateFeatureShape(
512-
const FeatureStatsView& feature_stats_view, Feature* feature) {
512+
const FeatureStatsView& feature_stats_view,
513+
const bool generate_legacy_feature_spec, Feature* feature) {
513514
if (!feature->has_shape()) {
514515
return {};
515516
}
@@ -533,10 +534,14 @@ std::vector<Description> UpdateFeatureShape(
533534
}
534535

535536
bool has_missing = false;
536-
for (const double num_missing : feature_stats_view.GetNumMissingNested()) {
537-
if (num_missing != 0) {
538-
has_missing = true;
539-
break;
537+
// If Schema.generate_legacy_feature_spec is true, feature absence is allowed.
538+
// See b/180761541.
539+
if (!generate_legacy_feature_spec) {
540+
for (const double num_missing : feature_stats_view.GetNumMissingNested()) {
541+
if (num_missing != 0) {
542+
has_missing = true;
543+
break;
544+
}
540545
}
541546
}
542547

tensorflow_data_validation/anomalies/feature_util.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ std::vector<Description> UpdateFeatureValueCounts(
3737
// to switch to value_counts constraints.
3838
std::vector<Description> UpdateFeatureShape(
3939
const FeatureStatsView& feature_stats_view,
40+
bool generate_legacy_feature_spec,
4041
tensorflow::metadata::v0::Feature* feature);
4142

4243
// If a feature occurs in too few examples, or a feature occurs in too small

tensorflow_data_validation/anomalies/feature_util_test.cc

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1567,6 +1567,7 @@ struct UpdateShapeTestCase {
15671567
// Result of Update().
15681568
Feature expected;
15691569
bool expected_description_empty;
1570+
bool generate_legacy_feature_spec;
15701571
};
15711572

15721573
const std::vector<UpdateShapeTestCase> GetUpdateShapeTestCases() {
@@ -1586,7 +1587,8 @@ const std::vector<UpdateShapeTestCase> GetUpdateShapeTestCases() {
15861587
)"),
15871588
Feature(),
15881589
Feature(),
1589-
true,
1590+
/*expected_description_empty=*/true,
1591+
/*generate_legacy_feature_spec=*/false,
15901592
},
15911593
{
15921594
"validation passes",
@@ -1604,7 +1606,8 @@ const std::vector<UpdateShapeTestCase> GetUpdateShapeTestCases() {
16041606
)"),
16051607
ParseTextProtoOrDie<Feature>(R"(shape { dim { size: 1 } })"),
16061608
ParseTextProtoOrDie<Feature>(R"(shape { dim { size: 1 } })"),
1607-
true,
1609+
/*expected_description_empty=*/true,
1610+
/*generate_legacy_feature_spec=*/false,
16081611
},
16091612
{
16101613
"validation passes: scalar",
@@ -1622,7 +1625,8 @@ const std::vector<UpdateShapeTestCase> GetUpdateShapeTestCases() {
16221625
)"),
16231626
ParseTextProtoOrDie<Feature>(R"(shape {})"),
16241627
ParseTextProtoOrDie<Feature>(R"(shape {})"),
1625-
true,
1628+
/*expected_description_empty=*/true,
1629+
/*generate_legacy_feature_spec=*/false,
16261630
},
16271631
{
16281632
"validation passes: fancy shape",
@@ -1648,7 +1652,8 @@ const std::vector<UpdateShapeTestCase> GetUpdateShapeTestCases() {
16481652
dim { size: 2 },
16491653
dim { size: 9 }
16501654
})"),
1651-
true,
1655+
/*expected_description_empty=*/true,
1656+
/*generate_legacy_feature_spec=*/false,
16521657
},
16531658
{
16541659
"validation passes: fancy shape, nested",
@@ -1689,7 +1694,8 @@ const std::vector<UpdateShapeTestCase> GetUpdateShapeTestCases() {
16891694
dim { size: 2 },
16901695
dim { size: 9 }
16911696
})"),
1692-
true,
1697+
/*expected_description_empty=*/true,
1698+
/*generate_legacy_feature_spec=*/false,
16931699
},
16941700
{
16951701
"failure: num_missing",
@@ -1707,7 +1713,8 @@ const std::vector<UpdateShapeTestCase> GetUpdateShapeTestCases() {
17071713
)"),
17081714
ParseTextProtoOrDie<Feature>(R"(shape { dim { size: 1 } })"),
17091715
Feature(),
1710-
false,
1716+
/*expected_description_empty=*/false,
1717+
/*generate_legacy_feature_spec=*/false,
17111718
},
17121719
{
17131720
"failure: num_missing (nested)",
@@ -1737,7 +1744,8 @@ const std::vector<UpdateShapeTestCase> GetUpdateShapeTestCases() {
17371744
)"),
17381745
ParseTextProtoOrDie<Feature>(R"(shape { dim { size: 1 } })"),
17391746
Feature(),
1740-
false,
1747+
/*expected_description_empty=*/false,
1748+
/*generate_legacy_feature_spec=*/false,
17411749
},
17421750
{
17431751
"failure: num_value (nested)",
@@ -1767,7 +1775,8 @@ const std::vector<UpdateShapeTestCase> GetUpdateShapeTestCases() {
17671775
)"),
17681776
ParseTextProtoOrDie<Feature>(R"(shape { dim { size: 1 } })"),
17691777
Feature(),
1770-
false,
1778+
/*expected_description_empty=*/false,
1779+
/*generate_legacy_feature_spec=*/false,
17711780
},
17721781
{
17731782
"failure: shape not compatible",
@@ -1787,8 +1796,29 @@ const std::vector<UpdateShapeTestCase> GetUpdateShapeTestCases() {
17871796
dim { size: 3 }
17881797
})"),
17891798
Feature(),
1790-
false,
1791-
}};
1799+
/*expected_description_empty=*/false,
1800+
/*generate_legacy_feature_spec=*/false,
1801+
},
1802+
{
1803+
"success: num_missing but generate_legacy_feature_spec",
1804+
ParseTextProtoOrDie<FeatureNameStatistics>(R"(
1805+
name: 'f1'
1806+
type: INT
1807+
num_stats: {
1808+
common_stats: {
1809+
num_missing: 1
1810+
num_non_missing: 10
1811+
min_num_values: 1
1812+
max_num_values: 1
1813+
}
1814+
}
1815+
)"),
1816+
ParseTextProtoOrDie<Feature>(R"(shape { dim { size: 1 } })"),
1817+
ParseTextProtoOrDie<Feature>(R"(shape { dim { size: 1 } })"),
1818+
/*expected_description_empty=*/true,
1819+
/*generate_legacy_feature_spec=*/true,
1820+
},
1821+
};
17921822
}
17931823

17941824
TEST(FeatureTypeTest, UpdateShapeTest) {
@@ -1801,8 +1831,10 @@ TEST(FeatureTypeTest, UpdateShapeTest) {
18011831
by_weight);
18021832
Feature updated = test.original;
18031833
auto descriptions =
1804-
UpdateFeatureShape(dataset.feature_stats_view(), &updated);
1805-
EXPECT_EQ(test.expected_description_empty, descriptions.empty());
1834+
UpdateFeatureShape(dataset.feature_stats_view(),
1835+
test.generate_legacy_feature_spec, &updated);
1836+
EXPECT_EQ(test.expected_description_empty, descriptions.empty())
1837+
<< "Test: " << test.name;
18061838
EXPECT_THAT(updated, EqualsProto(test.expected))
18071839
<< "Test:" << test.name << "(by_weight: " << by_weight
18081840
<< ") Reason: " << DescriptionsToString(descriptions);

tensorflow_data_validation/anomalies/schema.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -987,7 +987,8 @@ void Schema::UpdateFeatureInternal(
987987
}
988988

989989
if (feature->has_shape()) {
990-
add_to_descriptions(UpdateFeatureShape(view, feature));
990+
add_to_descriptions(
991+
UpdateFeatureShape(view, generate_legacy_feature_spec(), feature));
991992
}
992993

993994
if (feature->has_presence()) {
@@ -1264,5 +1265,14 @@ std::vector<Description> Schema::UpdateDatasetConstraints(
12641265
return descriptions;
12651266
}
12661267

1268+
bool Schema::generate_legacy_feature_spec() const {
1269+
// This field is not available in the OSS TFMD schema, so we use proto
1270+
// reflection to get its value to avoid compilation errors.
1271+
const auto* field_desc =
1272+
schema_.GetDescriptor()->FindFieldByName("generate_legacy_feature_spec");
1273+
if (!field_desc) return false;
1274+
return schema_.GetReflection()->GetBool(schema_, field_desc);
1275+
}
1276+
12671277
} // namespace data_validation
12681278
} // namespace tensorflow

tensorflow_data_validation/anomalies/schema.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,8 @@ class Schema {
264264
::tensorflow::metadata::v0::DatasetConstraints*
265265
GetExistingDatasetConstraints();
266266

267+
bool generate_legacy_feature_spec() const;
268+
267269
// Note: do not manually add string_domains or features.
268270
// Call GetNewEnum() or GetNewFeature().
269271
tensorflow::metadata::v0::Schema schema_;

tensorflow_data_validation/anomalies/schema_test.cc

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2729,10 +2729,10 @@ TEST(SchemaTest, UpdateFeatureShape) {
27292729
type: INT
27302730
num_stats: {
27312731
common_stats: {
2732-
num_missing: 1
2733-
num_non_missing: 9
2732+
num_missing: 0
2733+
num_non_missing: 10
27342734
min_num_values: 1
2735-
max_num_values: 1
2735+
max_num_values: 2
27362736
}
27372737
}
27382738
})");
@@ -2752,10 +2752,74 @@ TEST(SchemaTest, UpdateFeatureShape) {
27522752
feature {
27532753
name: "f1"
27542754
type: INT
2755-
presence { min_fraction: 0.9 min_count: 1 }
2755+
presence { min_fraction: 1 min_count: 1 }
27562756
})"));
27572757
}
27582758

2759+
// b/180761541
2760+
TEST(SchemaTest, UpdateFeatureShapeInferLegacyFeatureSpecWithNumMissing) {
2761+
const auto statistics =
2762+
ParseTextProtoOrDie<DatasetFeatureStatistics>(R"(
2763+
num_examples: 10
2764+
features: {
2765+
name: "f1"
2766+
type: INT
2767+
num_stats: {
2768+
common_stats: {
2769+
num_missing: 1
2770+
num_non_missing: 9
2771+
min_num_values: 1
2772+
max_num_values: 1
2773+
}
2774+
}
2775+
})");
2776+
auto schema_proto = ParseTextProtoOrDie<tensorflow::metadata::v0::Schema>(R"(
2777+
feature {
2778+
name: "f1"
2779+
type: INT
2780+
shape { dim { size: 1 } }
2781+
presence { min_fraction: 1 min_count: 1 }
2782+
}
2783+
)");
2784+
auto* field_desc = schema_proto.GetDescriptor()->FindFieldByName(
2785+
"generate_legacy_feature_spec");
2786+
if (!field_desc) {
2787+
// Skip the test because the schema does not have the legacy field (OSS).
2788+
return;
2789+
}
2790+
2791+
// The default value of that field is true.
2792+
ASSERT_TRUE(schema_proto.GetReflection()->GetBool(schema_proto, field_desc));
2793+
{
2794+
Schema schema;
2795+
TF_ASSERT_OK(schema.Init(schema_proto));
2796+
TF_ASSERT_OK(schema.Update(DatasetStatsView(statistics, false),
2797+
FeatureStatisticsToProtoConfig()));
2798+
EXPECT_THAT(schema.GetSchema(), EqualsProto(R"(
2799+
feature {
2800+
name: "f1"
2801+
type: INT
2802+
shape { dim { size: 1 } }
2803+
presence { min_fraction: 0.9 min_count: 1 }
2804+
})"));
2805+
}
2806+
2807+
schema_proto.GetReflection()->SetBool(&schema_proto, field_desc, false);
2808+
{
2809+
Schema schema;
2810+
TF_ASSERT_OK(schema.Init(schema_proto));
2811+
TF_ASSERT_OK(schema.Update(DatasetStatsView(statistics, false),
2812+
FeatureStatisticsToProtoConfig()));
2813+
EXPECT_THAT(schema.GetSchema(), EqualsProto(R"(
2814+
generate_legacy_feature_spec: false
2815+
feature {
2816+
name: "f1"
2817+
type: INT
2818+
presence { min_fraction: 0.9 min_count: 1 }
2819+
})"));
2820+
}
2821+
}
2822+
27592823
// Construct a schema from a proto field, and then write it to a
27602824
// DescriptorProto.
27612825
struct ValidTest {

0 commit comments

Comments
 (0)