Skip to content

Commit 75303f5

Browse files
authored
added new validators (#59)
1 parent 23830f1 commit 75303f5

File tree

1 file changed

+161
-87
lines changed

1 file changed

+161
-87
lines changed

guardrails/validators.py

Lines changed: 161 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
The name with which a validator is registered is the name that is used
44
in the `RAIL` spec to specify formatters.
55
"""
6-
76
import ast
87
import logging
8+
import os
99
from collections import defaultdict
1010
from dataclasses import dataclass
1111
from typing import Any, Callable, Dict, List, Optional, Union
@@ -28,6 +28,107 @@ class Refrain:
2828
pass
2929

3030

31+
def check_refrain_in_list(schema: List) -> bool:
32+
"""Check if a Refrain object exists in a list.
33+
34+
Args:
35+
schema: A list that can contain lists, dicts or scalars.
36+
37+
Returns:
38+
bool: True if a Refrain object exists in the list.
39+
"""
40+
for item in schema:
41+
if isinstance(item, Refrain):
42+
return True
43+
elif isinstance(item, list):
44+
if check_refrain_in_list(item):
45+
return True
46+
elif isinstance(item, dict):
47+
if check_refrain_in_dict(item):
48+
return True
49+
50+
return False
51+
52+
53+
def check_refrain_in_dict(schema: Dict) -> bool:
54+
"""Check if a Refrain object exists in a dict.
55+
56+
Args:
57+
schema: A dict that can contain lists, dicts or scalars.
58+
59+
Returns:
60+
True if a Refrain object exists in the dict.
61+
"""
62+
63+
for key, value in schema.items():
64+
if isinstance(value, Refrain):
65+
return True
66+
elif isinstance(value, list):
67+
if check_refrain_in_list(value):
68+
return True
69+
elif isinstance(value, dict):
70+
if check_refrain_in_dict(value):
71+
return True
72+
73+
return False
74+
75+
76+
def filter_in_list(schema: List) -> List:
77+
"""Remove out all Filter objects from a list.
78+
79+
Args:
80+
schema: A list that can contain lists, dicts or scalars.
81+
82+
Returns:
83+
A list with all Filter objects removed.
84+
"""
85+
86+
filtered_list = []
87+
88+
for item in schema:
89+
if isinstance(item, Filter):
90+
pass
91+
elif isinstance(item, list):
92+
filtered_item = filter_in_list(item)
93+
if len(filtered_item):
94+
filtered_list.append(filtered_item)
95+
elif isinstance(item, dict):
96+
filtered_dict = filter_in_dict(item)
97+
if len(filtered_dict):
98+
filtered_list.append(filtered_dict)
99+
else:
100+
filtered_list.append(item)
101+
102+
return filtered_list
103+
104+
105+
def filter_in_dict(schema: Dict) -> Dict:
106+
"""Remove out all Filter objects from a dictionary.
107+
108+
Args:
109+
schema: A dictionary that can contain lists, dicts or scalars.
110+
111+
Returns:
112+
A dictionary with all Filter objects removed.
113+
"""
114+
115+
filtered_dict = {}
116+
117+
for key, value in schema.items():
118+
if isinstance(value, Filter):
119+
pass
120+
elif isinstance(value, list):
121+
filtered_item = filter_in_list(value)
122+
if len(filtered_item):
123+
filtered_dict[key] = filtered_item
124+
elif isinstance(value, dict):
125+
filtered_dict[key] = filter_in_dict(value)
126+
else:
127+
filtered_dict[key] = value
128+
129+
return filtered_dict
130+
131+
31132
def register_validator(name: str, data_type: Union[str, List[str]]):
32133
"""Register a validator for a data type."""
33134

@@ -487,102 +588,75 @@ def validate(self, key: str, value: Any, schema: Union[Dict, List]) -> Dict:
487588
return schema
488589

489590

490-
def check_refrain_in_list(schema: List) -> bool:
491-
"""Check if a Refrain object exists in a list.
591+
@register_validator(name="is-profanity-free", data_type="string")
592+
class IsProfanityFree(Validator):
593+
"""Validate that a translated text does not contain profanity language.
492594
493-
Args:
494-
schema: A list that can contain lists, dicts or scalars.
495-
496-
Returns:
497-
bool: True if a Refrain object exists in the list.
498-
"""
499-
for item in schema:
500-
if isinstance(item, Refrain):
501-
return True
502-
elif isinstance(item, list):
503-
if check_refrain_in_list(item):
504-
return True
505-
elif isinstance(item, dict):
506-
if check_refrain_in_dict(item):
507-
return True
508-
509-
return False
595+
This validator uses the `alt-profanity-check` package to check if a string
596+
contains profanity language.
510597
511-
512-
def check_refrain_in_dict(schema: Dict) -> bool:
513-
"""Check if a Refrain object exists in a dict.
514-
515-
Args:
516-
schema: A dict that can contain lists, dicts or scalars.
517-
518-
Returns:
519-
True if a Refrain object exists in the dict.
598+
- Name for `format` attribute: `is-profanity-free`
599+
- Supported data types: `string`
600+
- Programmatic fix: ""
520601
"""
521602

522-
for key, value in schema.items():
523-
if isinstance(value, Refrain):
524-
return True
525-
elif isinstance(value, list):
526-
if check_refrain_in_list(value):
527-
return True
528-
elif isinstance(value, dict):
529-
if check_refrain_in_dict(value):
530-
return True
531-
532-
return False
603+
def validate(self, key, value, schema) -> Dict:
604+
try:
605+
from profanity_check import predict
606+
except ImportError:
607+
raise ImportError(
608+
"`is-profanity-free` validator requires the `alt-profanity-check`"
609+
"package. Please install it with `pip install profanity-check`."
610+
)
533611

612+
prediction = predict([value])
613+
if prediction[0] == 1:
614+
raise EventDetail(
615+
key,
616+
value,
617+
schema,
618+
f"{value} contains profanity. Please return a profanity-free output.",
619+
"",
620+
)
621+
return schema
534622

535-
def filter_in_list(schema: List) -> List:
536-
"""Remove out all Filter objects from a list.
537623

538-
Args:
539-
schema: A list that can contain lists, dicts or scalars.
624+
@register_validator(name="is-high-quality-translation", data_type="string")
625+
class IsHighQualityTranslation(Validator):
626+
"""Using inpiredco.critique to check if a translation is high quality.
540627
541-
Returns:
542-
A list with all Filter objects removed.
628+
- Name for `format` attribute: `is-high-quality-translation`
629+
- Supported data types: `string`
630+
- Programmatic fix: ""
543631
"""
544632

545-
filtered_list = []
546-
547-
for item in schema:
548-
if isinstance(item, Filter):
549-
pass
550-
elif isinstance(item, list):
551-
filtered_item = filter_in_list(item)
552-
if len(filtered_item):
553-
filtered_list.append(filtered_item)
554-
elif isinstance(item, dict):
555-
filtered_dict = filter_in_dict(item)
556-
if len(filtered_dict):
557-
filtered_list.append(filtered_dict)
558-
else:
559-
filtered_list.append(item)
560-
561-
return filtered_list
562-
563-
564-
def filter_in_dict(schema: Dict) -> Dict:
565-
"""Remove out all Filter objects from a dictionary.
566-
567-
Args:
568-
schema: A dictionary that can contain lists, dicts or scalars.
569-
570-
Returns:
571-
A dictionary with all Filter objects removed.
572-
"""
633+
def __init__(self, *args, **kwargs):
634+
super().__init__(*args, **kwargs)
635+
try:
636+
from inspiredco.critique import Critique
573637

574-
filtered_dict = {}
638+
self.critique = Critique(api_key=os.environ["INSPIREDCO_API_KEY"])
575639

576-
for key, value in schema.items():
577-
if isinstance(value, Filter):
578-
pass
579-
elif isinstance(value, list):
580-
filtered_item = filter_in_list(value)
581-
if len(filtered_item):
582-
filtered_dict[key] = filtered_item
583-
elif isinstance(value, dict):
584-
filtered_dict[key] = filter_in_dict(value)
585-
else:
586-
filtered_dict[key] = value
640+
except ImportError:
641+
raise ImportError(
642+
"`is-high-quality-translation` validator requires the `inspiredco`"
643+
"package. Please install it with `pip install inspiredco`."
644+
)
587645

588-
return filtered_dict
646+
def validate(self, key, value, schema) -> Dict:
647+
prediction = self.critique.evaluate(
648+
metric="comet",
649+
config={"model": "unbabel_comet/wmt21-comet-qe-da"},
650+
dataset=[{"source": key, "target": value}],
651+
)
652+
quality = prediction["examples"][0]["value"]
653+
if quality < -0.1:
654+
raise EventDetail(
655+
key,
656+
value,
657+
schema,
658+
f"{value} is a low quality translation."
659+
"Please return a higher quality output.",
660+
"",
661+
)
662+
return schema

0 commit comments

Comments
 (0)