@@ -413,6 +413,59 @@ def test_id_feature_with_num_buckets(
413
413
np .testing .assert_allclose (parsed_feat .values , np .array (expected_values ))
414
414
np .testing .assert_allclose (parsed_feat .lengths , np .array (expected_lengths ))
415
415
416
+ @parameterized .expand (
417
+ [
418
+ ["" , "data/test/id_vocab_list_0" , 4 , [2 , 3 , 1 ], [2 , 0 , 1 ]],
419
+ ["xyz" , "data/test/id_vocab_list_1" , 4 , [2 , 3 , 0 , 1 ], [2 , 1 , 1 ]],
420
+ ["" , "data/test/id_vocab_dict_2" , 3 , [2 , 2 , 1 ], [2 , 0 , 1 ]],
421
+ ["xyz" , "data/test/id_vocab_dict_3" , 3 , [2 , 2 , 0 , 1 ], [2 , 1 , 1 ]],
422
+ ],
423
+ name_func = test_util .parameterized_name_func ,
424
+ )
425
+ def test_id_feature_with_vocab_file (
426
+ self ,
427
+ default_value ,
428
+ vocab_file ,
429
+ expected_num_embeddings ,
430
+ expected_values ,
431
+ expected_lengths ,
432
+ ):
433
+ id_feat_cfg = feature_pb2 .FeatureConfig (
434
+ id_feature = feature_pb2 .IdFeature (
435
+ feature_name = "id_feat" ,
436
+ embedding_dim = 16 ,
437
+ vocab_file = vocab_file ,
438
+ default_bucketize_value = 1 ,
439
+ expression = "user:id_str" ,
440
+ pooling = "mean" ,
441
+ default_value = default_value ,
442
+ )
443
+ )
444
+
445
+ id_feat = id_feature_lib .IdFeature (id_feat_cfg , fg_mode = FgMode .FG_NORMAL )
446
+
447
+ expected_emb_bag_config = EmbeddingBagConfig (
448
+ num_embeddings = expected_num_embeddings ,
449
+ embedding_dim = 16 ,
450
+ name = "id_feat_emb" ,
451
+ feature_names = ["id_feat" ],
452
+ pooling = PoolingType .MEAN ,
453
+ )
454
+ self .assertEqual (repr (id_feat .emb_bag_config ), repr (expected_emb_bag_config ))
455
+ expected_emb_config = EmbeddingConfig (
456
+ num_embeddings = expected_num_embeddings ,
457
+ embedding_dim = 16 ,
458
+ name = "id_feat_emb" ,
459
+ feature_names = ["id_feat" ],
460
+ )
461
+ self .assertEqual (repr (id_feat .emb_config ), repr (expected_emb_config ))
462
+
463
+ input_data = {"id_str" : pa .array (["abc\x1d efg" , "" , "hij" ])}
464
+ parsed_feat = id_feat .parse (input_data )
465
+ self .assertEqual (parsed_feat .name , "id_feat" )
466
+ np .testing .assert_allclose (parsed_feat .values , np .array (expected_values ))
467
+ np .testing .assert_allclose (parsed_feat .lengths , np .array (expected_lengths ))
468
+
416
469
417
470
if __name__ == "__main__" :
418
471
unittest .main ()
0 commit comments