33The name with which a validator is registered is the name that is used
44in the `RAIL` spec to specify formatters.
55"""
6-
76import ast
87import logging
8+ import os
99from collections import defaultdict
1010from dataclasses import dataclass
1111from typing import Any , Callable , Dict , List , Optional , Union
@@ -28,6 +28,107 @@ class Refrain:
2828 pass
2929
3030
31+ def check_refrain_in_list (schema : List ) -> bool :
32+ """Check if a Refrain object exists in a list.
33+
34+ Args:
35+ schema: A list that can contain lists, dicts or scalars.
36+
37+ Returns:
38+ bool: True if a Refrain object exists in the list.
39+ """
40+ for item in schema :
41+ if isinstance (item , Refrain ):
42+ return True
43+ elif isinstance (item , list ):
44+ if check_refrain_in_list (item ):
45+ return True
46+ elif isinstance (item , dict ):
47+ if check_refrain_in_dict (item ):
48+ return True
49+
50+ return False
51+
52+
53+ def check_refrain_in_dict (schema : Dict ) -> bool :
54+ """Check if a Refrain object exists in a dict.
55+
56+ Args:
57+ schema: A dict that can contain lists, dicts or scalars.
58+
59+ Returns:
60+ True if a Refrain object exists in the dict.
61+ """
62+
63+ for key , value in schema .items ():
64+ if isinstance (value , Refrain ):
65+ return True
66+ elif isinstance (value , list ):
67+ if check_refrain_in_list (value ):
68+ return True
69+ elif isinstance (value , dict ):
70+ if check_refrain_in_dict (value ):
71+ return True
72+
73+ return False
74+
75+
76+ def filter_in_list (schema : List ) -> List :
77+ """Remove out all Filter objects from a list.
78+
79+ Args:
80+ schema: A list that can contain lists, dicts or scalars.
81+
82+ Returns:
83+ A list with all Filter objects removed.
84+ """
85+
86+ filtered_list = []
87+
88+ for item in schema :
89+ if isinstance (item , Filter ):
90+ pass
91+ elif isinstance (item , list ):
92+ filtered_item = filter_in_list (item )
93+ if len (filtered_item ):
94+ filtered_list .append (filtered_item )
95+ elif isinstance (item , dict ):
96+ filtered_dict = filter_in_dict (item )
97+ if len (filtered_dict ):
98+ filtered_list .append (filtered_dict )
99+ else :
100+ filtered_list .append (item )
101+
102+ return filtered_list
103+
104+
105+ def filter_in_dict (schema : Dict ) -> Dict :
106+ """Remove out all Filter objects from a dictionary.
107+
108+ Args:
109+ schema: A dictionary that can contain lists, dicts or scalars.
110+
111+ Returns:
112+ A dictionary with all Filter objects removed.
113+ """
114+
115+ filtered_dict = {}
116+
117+ for key , value in schema .items ():
118+ if isinstance (value , Filter ):
119+ pass
120+ elif isinstance (value , list ):
121+ filtered_item = filter_in_list (value )
122+ if len (filtered_item ):
123+ filtered_dict [key ] = filtered_item
124+ elif isinstance (value , dict ):
125+ filtered_dict [key ] = filter_in_dict (value )
126+ else :
127+ filtered_dict [key ] = value
128+
129+ return filtered_dict
130+
131+
31132def register_validator (name : str , data_type : Union [str , List [str ]]):
32133 """Register a validator for a data type."""
33134
@@ -487,102 +588,75 @@ def validate(self, key: str, value: Any, schema: Union[Dict, List]) -> Dict:
487588 return schema
488589
489590
490- def check_refrain_in_list (schema : List ) -> bool :
491- """Check if a Refrain object exists in a list.
591+ @register_validator (name = "is-profanity-free" , data_type = "string" )
592+ class IsProfanityFree (Validator ):
593+ """Validate that a translated text does not contain profanity language.
492594
493- Args:
494- schema: A list that can contain lists, dicts or scalars.
495-
496- Returns:
497- bool: True if a Refrain object exists in the list.
498- """
499- for item in schema :
500- if isinstance (item , Refrain ):
501- return True
502- elif isinstance (item , list ):
503- if check_refrain_in_list (item ):
504- return True
505- elif isinstance (item , dict ):
506- if check_refrain_in_dict (item ):
507- return True
508-
509- return False
595+ This validator uses the `alt-profanity-check` package to check if a string
596+ contains profanity language.
510597
511-
512- def check_refrain_in_dict (schema : Dict ) -> bool :
513- """Check if a Refrain object exists in a dict.
514-
515- Args:
516- schema: A dict that can contain lists, dicts or scalars.
517-
518- Returns:
519- True if a Refrain object exists in the dict.
598+ - Name for `format` attribute: `is-profanity-free`
599+ - Supported data types: `string`
600+ - Programmatic fix: ""
520601 """
521602
522- for key , value in schema .items ():
523- if isinstance (value , Refrain ):
524- return True
525- elif isinstance (value , list ):
526- if check_refrain_in_list (value ):
527- return True
528- elif isinstance (value , dict ):
529- if check_refrain_in_dict (value ):
530- return True
531-
532- return False
603+ def validate (self , key , value , schema ) -> Dict :
604+ try :
605+ from profanity_check import predict
606+ except ImportError :
607+ raise ImportError (
608+ "`is-profanity-free` validator requires the `alt-profanity-check`"
609+ "package. Please install it with `pip install profanity-check`."
610+ )
533611
612+ prediction = predict ([value ])
613+ if prediction [0 ] == 1 :
614+ raise EventDetail (
615+ key ,
616+ value ,
617+ schema ,
618+ f"{ value } contains profanity. Please return a profanity-free output." ,
619+ "" ,
620+ )
621+ return schema
534622
535- def filter_in_list (schema : List ) -> List :
536- """Remove out all Filter objects from a list.
537623
538- Args:
539- schema: A list that can contain lists, dicts or scalars.
624+ @register_validator (name = "is-high-quality-translation" , data_type = "string" )
625+ class IsHighQualityTranslation (Validator ):
626+ """Using inpiredco.critique to check if a translation is high quality.
540627
541- Returns:
542- A list with all Filter objects removed.
628+ - Name for `format` attribute: `is-high-quality-translation`
629+ - Supported data types: `string`
630+ - Programmatic fix: ""
543631 """
544632
545- filtered_list = []
546-
547- for item in schema :
548- if isinstance (item , Filter ):
549- pass
550- elif isinstance (item , list ):
551- filtered_item = filter_in_list (item )
552- if len (filtered_item ):
553- filtered_list .append (filtered_item )
554- elif isinstance (item , dict ):
555- filtered_dict = filter_in_dict (item )
556- if len (filtered_dict ):
557- filtered_list .append (filtered_dict )
558- else :
559- filtered_list .append (item )
560-
561- return filtered_list
562-
563-
564- def filter_in_dict (schema : Dict ) -> Dict :
565- """Remove out all Filter objects from a dictionary.
566-
567- Args:
568- schema: A dictionary that can contain lists, dicts or scalars.
569-
570- Returns:
571- A dictionary with all Filter objects removed.
572- """
633+ def __init__ (self , * args , ** kwargs ):
634+ super ().__init__ (* args , ** kwargs )
635+ try :
636+ from inspiredco .critique import Critique
573637
574- filtered_dict = {}
638+ self . critique = Critique ( api_key = os . environ [ "INSPIREDCO_API_KEY" ])
575639
576- for key , value in schema .items ():
577- if isinstance (value , Filter ):
578- pass
579- elif isinstance (value , list ):
580- filtered_item = filter_in_list (value )
581- if len (filtered_item ):
582- filtered_dict [key ] = filtered_item
583- elif isinstance (value , dict ):
584- filtered_dict [key ] = filter_in_dict (value )
585- else :
586- filtered_dict [key ] = value
640+ except ImportError :
641+ raise ImportError (
642+ "`is-high-quality-translation` validator requires the `inspiredco`"
643+ "package. Please install it with `pip install inspiredco`."
644+ )
587645
588- return filtered_dict
646+ def validate (self , key , value , schema ) -> Dict :
647+ prediction = self .critique .evaluate (
648+ metric = "comet" ,
649+ config = {"model" : "unbabel_comet/wmt21-comet-qe-da" },
650+ dataset = [{"source" : key , "target" : value }],
651+ )
652+ quality = prediction ["examples" ][0 ]["value" ]
653+ if quality < - 0.1 :
654+ raise EventDetail (
655+ key ,
656+ value ,
657+ schema ,
658+ f"{ value } is a low quality translation."
659+ "Please return a higher quality output." ,
660+ "" ,
661+ )
662+ return schema
0 commit comments