Skip to content

Commit ff395e8

Browse files
ShreyaRzsimjee
andauthored
add competitor check validation (#474)
* add competitor check validation * lint * lint fix continued * update deps * typing * 3.8 typing * comp check tests, notebook * do not run comp check nb in workflow --------- Co-authored-by: zsimjee <[email protected]>
1 parent 9cdd821 commit ff395e8

File tree

10 files changed

+588
-128
lines changed

10 files changed

+588
-128
lines changed

.github/workflows/scripts/run_notebooks.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ cd docs/examples
99
# Function to process a notebook
1010
process_notebook() {
1111
notebook="$1"
12-
if [ "$notebook" != "valid_chess_moves.ipynb" ] && [ "$notebook" != "translation_with_quality_check.ipynb" ]; then
12+
if [ "$notebook" != "valid_chess_moves.ipynb" ] && [ "$notebook" != "translation_with_quality_check.ipynb" ] && [ "$notebook" != "competitors_check.ipynb" ]; then
1313
echo "Processing $notebook..."
1414
poetry run jupyter nbconvert --to notebook --execute "$notebook"
1515
if [ $? -ne 0 ]; then

docs/examples/competitors_check.ipynb

Lines changed: 240 additions & 0 deletions
Large diffs are not rendered by default.

guardrails/utils/openai_utils/v0.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def construct_nonchat_response(
8484
) -> LLMResponse:
8585
"""Construct an LLMResponse from an OpenAI response.
8686
87-
Splits execution based on whether the `stream` parameter
88-
is set in the kwargs.
87+
Splits execution based on whether the `stream` parameter is set
88+
in the kwargs.
8989
"""
9090
if stream:
9191
# If stream is defined and set to True,
@@ -152,8 +152,8 @@ def construct_chat_response(
152152
) -> LLMResponse:
153153
"""Construct an LLMResponse from an OpenAI response.
154154
155-
Splits execution based on whether the `stream` parameter
156-
is set in the kwargs.
155+
Splits execution based on whether the `stream` parameter is set
156+
in the kwargs.
157157
"""
158158
if stream:
159159
# If stream is defined and set to True,
@@ -296,8 +296,8 @@ async def construct_chat_response(
296296
) -> LLMResponse:
297297
"""Construct an LLMResponse from an OpenAI response.
298298
299-
Splits execution based on whether the `stream` parameter
300-
is set in the kwargs.
299+
Splits execution based on whether the `stream` parameter is set
300+
in the kwargs.
301301
"""
302302
if stream:
303303
# If stream is defined and set to True,

guardrails/utils/openai_utils/v1.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def construct_nonchat_response(
7676
) -> LLMResponse:
7777
"""Construct an LLMResponse from an OpenAI response.
7878
79-
Splits execution based on whether the `stream` parameter
80-
is set in the kwargs.
79+
Splits execution based on whether the `stream` parameter is set
80+
in the kwargs.
8181
"""
8282
if stream:
8383
# If stream is defined and set to True,
@@ -140,8 +140,8 @@ def construct_chat_response(
140140
) -> LLMResponse:
141141
"""Construct an LLMResponse from an OpenAI response.
142142
143-
Splits execution based on whether the `stream` parameter
144-
is set in the kwargs.
143+
Splits execution based on whether the `stream` parameter is set
144+
in the kwargs.
145145
"""
146146
if stream:
147147
# If stream is defined and set to True,
@@ -298,8 +298,8 @@ async def construct_chat_response(
298298
) -> LLMResponse:
299299
"""Construct an LLMResponse from an OpenAI response.
300300
301-
Splits execution based on whether the `stream` parameter
302-
is set in the kwargs.
301+
Splits execution based on whether the `stream` parameter is set
302+
in the kwargs.
303303
"""
304304
if stream:
305305
# If stream is defined and set to True,

guardrails/validators.py

Lines changed: 159 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@
6464
except LookupError:
6565
nltk.download("punkt")
6666

67+
try:
68+
import spacy
69+
except ImportError:
70+
spacy = None
71+
6772

6873
logger = logging.getLogger(__name__)
6974

@@ -703,7 +708,7 @@ def __init__(
703708
if not _HAS_NUMPY:
704709
raise ImportError(
705710
f"The {self.__class__.__name__} validator requires the numpy package.\n"
706-
"`pip install numpy` to install it."
711+
"`poetry add numpy` to install it."
707712
)
708713

709714
self.client = OpenAIClient()
@@ -775,7 +780,7 @@ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
775780
except ImportError:
776781
raise ImportError(
777782
"`is-profanity-free` validator requires the `alt-profanity-check`"
778-
"package. Please install it with `pip install profanity-check`."
783+
"package. Please install it with `poetry add profanity-check`."
779784
)
780785

781786
prediction = predict([value])
@@ -823,7 +828,7 @@ def __init__(self, *args, **kwargs):
823828
except ImportError:
824829
raise ImportError(
825830
"`is-high-quality-translation` validator requires the `inspiredco`"
826-
"package. Please install it with `pip install inspiredco`."
831+
"package. Please install it with `poetry add inspiredco`."
827832
)
828833

829834
def validate(self, value: Any, metadata: Dict) -> ValidationResult:
@@ -1122,7 +1127,7 @@ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
11221127
except ImportError:
11231128
raise ImportError(
11241129
"`thefuzz` library is required for `extractive-summary` validator. "
1125-
"Please install it with `pip install thefuzz`."
1130+
"Please install it with `poetry add thefuzz`."
11261131
)
11271132

11281133
# Split the value into sentences.
@@ -1217,7 +1222,7 @@ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
12171222
except ImportError:
12181223
raise ImportError(
12191224
"`thefuzz` library is required for `remove-redundant-sentences` "
1220-
"validator. Please install it with `pip install thefuzz`."
1225+
"validator. Please install it with `poetry add thefuzz`."
12211226
)
12221227

12231228
# Split the value into sentences.
@@ -1613,7 +1618,7 @@ def validate_each_sentence(
16131618
if nltk is None:
16141619
raise ImportError(
16151620
"`nltk` library is required for `provenance-v0` validator. "
1616-
"Please install it with `pip install nltk`."
1621+
"Please install it with `poetry add nltk`."
16171622
)
16181623
# Split the value into sentences using nltk sentence tokenizer.
16191624
sentences = nltk.sent_tokenize(value)
@@ -1973,7 +1978,7 @@ def validate_each_sentence(
19731978
if nltk is None:
19741979
raise ImportError(
19751980
"`nltk` library is required for `provenance-v0` validator. "
1976-
"Please install it with `pip install nltk`."
1981+
"Please install it with `poetry add nltk`."
19771982
)
19781983
# Split the value into sentences using nltk sentence tokenizer.
19791984
sentences = nltk.sent_tokenize(value)
@@ -2535,3 +2540,150 @@ def validate(self, value: str, metadata: Dict[str, Any]) -> ValidationResult:
25352540
fix_value=modified_value,
25362541
)
25372542
return PassResult()
2543+
2544+
2545+
@register_validator(name="competitor-check", data_type="string")
2546+
class CompetitorCheck(Validator):
2547+
"""Validates that LLM-generated text is not naming any competitors from a
2548+
given list.
2549+
2550+
In order to use this validator you need to provide an extensive list of the
2551+
competitors you want to avoid naming including all common variations.
2552+
2553+
Args:
2554+
competitors (List[str]): List of competitors you want to avoid naming
2555+
"""
2556+
2557+
def __init__(
2558+
self,
2559+
competitors: List[str],
2560+
on_fail: Optional[Callable] = None,
2561+
):
2562+
super().__init__(competitors=competitors, on_fail=on_fail)
2563+
self._competitors = competitors
2564+
model = "en_core_web_trf"
2565+
if spacy is None:
2566+
raise ImportError(
2567+
"You must install spacy in order to use the CompetitorCheck validator."
2568+
)
2569+
2570+
if not spacy.util.is_package(model):
2571+
logger.info(
2572+
f"Spacy model {model} not installed. "
2573+
"Download should start now and take a few minutes."
2574+
)
2575+
spacy.cli.download(model) # type: ignore
2576+
2577+
self.nlp = spacy.load(model)
2578+
2579+
def exact_match(self, text: str, competitors: List[str]) -> List[str]:
2580+
"""Performs exact match to find competitors from a list in a given
2581+
text.
2582+
2583+
Args:
2584+
text (str): The text to search for competitors.
2585+
competitors (list): A list of competitor entities to match.
2586+
2587+
Returns:
2588+
list: A list of matched entities.
2589+
"""
2590+
2591+
found_entities = []
2592+
for entity in competitors:
2593+
pattern = rf"\b{re.escape(entity)}\b"
2594+
match = re.search(pattern.lower(), text.lower())
2595+
if match:
2596+
found_entities.append(entity)
2597+
return found_entities
2598+
2599+
def perform_ner(self, text: str, nlp) -> List[str]:
2600+
"""Performs named entity recognition on text using a provided NLP
2601+
model.
2602+
2603+
Args:
2604+
text (str): The text to perform named entity recognition on.
2605+
nlp: The NLP model to use for entity recognition.
2606+
2607+
Returns:
2608+
entities: A list of entities found.
2609+
"""
2610+
2611+
doc = nlp(text)
2612+
entities = []
2613+
for ent in doc.ents:
2614+
entities.append(ent.text)
2615+
return entities
2616+
2617+
def is_entity_in_list(self, entities: List[str], competitors: List[str]) -> List:
2618+
"""Checks if any entity from a list is present in a given list of
2619+
competitors.
2620+
2621+
Args:
2622+
entities (list): A list of entities to check
2623+
competitors (list): A list of competitor names to match
2624+
2625+
Returns:
2626+
List: List of found competitors
2627+
"""
2628+
2629+
found_competitors = []
2630+
for entity in entities:
2631+
for item in competitors:
2632+
pattern = rf"\b{re.escape(item)}\b"
2633+
match = re.search(pattern.lower(), entity.lower())
2634+
if match:
2635+
found_competitors.append(item)
2636+
return found_competitors
2637+
2638+
def validate(self, value: str, metadata=Dict) -> ValidationResult:
2639+
"""Checks a text to find competitors' names in it.
2640+
2641+
While running, store sentences naming competitors and generate a fixed output
2642+
filtering out all flagged sentences.
2643+
2644+
Args:
2645+
value (str): The value to be validated.
2646+
metadata (Dict, optional): Additional metadata. Defaults to empty dict.
2647+
2648+
Returns:
2649+
ValidationResult: The validation result.
2650+
"""
2651+
2652+
if nltk is None:
2653+
raise ImportError(
2654+
"`nltk` library is required for `competitors-check` validator. "
2655+
"Please install it with `poetry add nltk`."
2656+
)
2657+
sentences = nltk.sent_tokenize(value)
2658+
flagged_sentences = []
2659+
filtered_sentences = []
2660+
list_of_competitors_found = []
2661+
2662+
for sentence in sentences:
2663+
entities = self.exact_match(sentence, self._competitors)
2664+
if entities:
2665+
ner_entities = self.perform_ner(sentence, self.nlp)
2666+
found_competitors = self.is_entity_in_list(ner_entities, entities)
2667+
2668+
if found_competitors:
2669+
flagged_sentences.append((found_competitors, sentence))
2670+
list_of_competitors_found.append(found_competitors)
2671+
logger.debug(f"Found: {found_competitors} named in '{sentence}'")
2672+
else:
2673+
filtered_sentences.append(sentence)
2674+
2675+
else:
2676+
filtered_sentences.append(sentence)
2677+
2678+
filtered_output = " ".join(filtered_sentences)
2679+
2680+
if len(flagged_sentences):
2681+
return FailResult(
2682+
error_message=(
2683+
f"Found the following competitors: {list_of_competitors_found}. "
2684+
"Please avoid naming those competitors next time"
2685+
),
2686+
fix_value=filtered_output,
2687+
)
2688+
else:
2689+
return PassResult()

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ nav:
5757
- 'Check key info present in generated summary': examples/text_summarization_quality.ipynb
5858
- 'Detect and limit hallucinations in generated text': examples/provenance.ipynb
5959
- 'Check whether a value is similar to a set of other values': examples/value_within_distribution.ipynb
60+
- 'Check if a competitor is named': examples/competitors_check.ipynb
6061
- 'Integrations':
6162
- 'Azure OpenAI': integrations/azure_openai.ipynb
6263
- 'OpenAI Functions': integrations/openai_functions.ipynb

0 commit comments

Comments
 (0)