|
64 | 64 | except LookupError: |
65 | 65 | nltk.download("punkt") |
66 | 66 |
|
| 67 | +try: |
| 68 | + import spacy |
| 69 | +except ImportError: |
| 70 | + spacy = None |
| 71 | + |
67 | 72 |
|
68 | 73 | logger = logging.getLogger(__name__) |
69 | 74 |
|
@@ -703,7 +708,7 @@ def __init__( |
703 | 708 | if not _HAS_NUMPY: |
704 | 709 | raise ImportError( |
705 | 710 | f"The {self.__class__.__name__} validator requires the numpy package.\n" |
706 | | - "`pip install numpy` to install it." |
| 711 | + "`poetry add numpy` to install it." |
707 | 712 | ) |
708 | 713 |
|
709 | 714 | self.client = OpenAIClient() |
@@ -775,7 +780,7 @@ def validate(self, value: Any, metadata: Dict) -> ValidationResult: |
775 | 780 | except ImportError: |
776 | 781 | raise ImportError( |
777 | 782 | "`is-profanity-free` validator requires the `alt-profanity-check`" |
778 | | - "package. Please install it with `pip install profanity-check`." |
| 783 | + "package. Please install it with `poetry add profanity-check`." |
779 | 784 | ) |
780 | 785 |
|
781 | 786 | prediction = predict([value]) |
@@ -823,7 +828,7 @@ def __init__(self, *args, **kwargs): |
823 | 828 | except ImportError: |
824 | 829 | raise ImportError( |
825 | 830 | "`is-high-quality-translation` validator requires the `inspiredco`" |
826 | | - "package. Please install it with `pip install inspiredco`." |
| 831 | + "package. Please install it with `poetry add inspiredco`." |
827 | 832 | ) |
828 | 833 |
|
829 | 834 | def validate(self, value: Any, metadata: Dict) -> ValidationResult: |
@@ -1122,7 +1127,7 @@ def validate(self, value: Any, metadata: Dict) -> ValidationResult: |
1122 | 1127 | except ImportError: |
1123 | 1128 | raise ImportError( |
1124 | 1129 | "`thefuzz` library is required for `extractive-summary` validator. " |
1125 | | - "Please install it with `pip install thefuzz`." |
| 1130 | + "Please install it with `poetry add thefuzz`." |
1126 | 1131 | ) |
1127 | 1132 |
|
1128 | 1133 | # Split the value into sentences. |
@@ -1217,7 +1222,7 @@ def validate(self, value: Any, metadata: Dict) -> ValidationResult: |
1217 | 1222 | except ImportError: |
1218 | 1223 | raise ImportError( |
1219 | 1224 | "`thefuzz` library is required for `remove-redundant-sentences` " |
1220 | | - "validator. Please install it with `pip install thefuzz`." |
| 1225 | + "validator. Please install it with `poetry add thefuzz`." |
1221 | 1226 | ) |
1222 | 1227 |
|
1223 | 1228 | # Split the value into sentences. |
@@ -1613,7 +1618,7 @@ def validate_each_sentence( |
1613 | 1618 | if nltk is None: |
1614 | 1619 | raise ImportError( |
1615 | 1620 | "`nltk` library is required for `provenance-v0` validator. " |
1616 | | - "Please install it with `pip install nltk`." |
| 1621 | + "Please install it with `poetry add nltk`." |
1617 | 1622 | ) |
1618 | 1623 | # Split the value into sentences using nltk sentence tokenizer. |
1619 | 1624 | sentences = nltk.sent_tokenize(value) |
@@ -1973,7 +1978,7 @@ def validate_each_sentence( |
1973 | 1978 | if nltk is None: |
1974 | 1979 | raise ImportError( |
1975 | 1980 | "`nltk` library is required for `provenance-v0` validator. " |
1976 | | - "Please install it with `pip install nltk`." |
| 1981 | + "Please install it with `poetry add nltk`." |
1977 | 1982 | ) |
1978 | 1983 | # Split the value into sentences using nltk sentence tokenizer. |
1979 | 1984 | sentences = nltk.sent_tokenize(value) |
@@ -2535,3 +2540,150 @@ def validate(self, value: str, metadata: Dict[str, Any]) -> ValidationResult: |
2535 | 2540 | fix_value=modified_value, |
2536 | 2541 | ) |
2537 | 2542 | return PassResult() |
| 2543 | + |
| 2544 | + |
| 2545 | +@register_validator(name="competitor-check", data_type="string") |
| 2546 | +class CompetitorCheck(Validator): |
| 2547 | + """Validates that LLM-generated text is not naming any competitors from a |
| 2548 | + given list. |
| 2549 | +
|
| 2550 | + In order to use this validator you need to provide an extensive list of the |
| 2551 | + competitors you want to avoid naming including all common variations. |
| 2552 | +
|
| 2553 | + Args: |
| 2554 | + competitors (List[str]): List of competitors you want to avoid naming |
| 2555 | + """ |
| 2556 | + |
| 2557 | + def __init__( |
| 2558 | + self, |
| 2559 | + competitors: List[str], |
| 2560 | + on_fail: Optional[Callable] = None, |
| 2561 | + ): |
| 2562 | + super().__init__(competitors=competitors, on_fail=on_fail) |
| 2563 | + self._competitors = competitors |
| 2564 | + model = "en_core_web_trf" |
| 2565 | + if spacy is None: |
| 2566 | + raise ImportError( |
| 2567 | + "You must install spacy in order to use the CompetitorCheck validator." |
| 2568 | + ) |
| 2569 | + |
| 2570 | + if not spacy.util.is_package(model): |
| 2571 | + logger.info( |
| 2572 | + f"Spacy model {model} not installed. " |
| 2573 | + "Download should start now and take a few minutes." |
| 2574 | + ) |
| 2575 | + spacy.cli.download(model) # type: ignore |
| 2576 | + |
| 2577 | + self.nlp = spacy.load(model) |
| 2578 | + |
| 2579 | + def exact_match(self, text: str, competitors: List[str]) -> List[str]: |
| 2580 | + """Performs exact match to find competitors from a list in a given |
| 2581 | + text. |
| 2582 | +
|
| 2583 | + Args: |
| 2584 | + text (str): The text to search for competitors. |
| 2585 | + competitors (list): A list of competitor entities to match. |
| 2586 | +
|
| 2587 | + Returns: |
| 2588 | + list: A list of matched entities. |
| 2589 | + """ |
| 2590 | + |
| 2591 | + found_entities = [] |
| 2592 | + for entity in competitors: |
| 2593 | + pattern = rf"\b{re.escape(entity)}\b" |
| 2594 | + match = re.search(pattern.lower(), text.lower()) |
| 2595 | + if match: |
| 2596 | + found_entities.append(entity) |
| 2597 | + return found_entities |
| 2598 | + |
| 2599 | + def perform_ner(self, text: str, nlp) -> List[str]: |
| 2600 | + """Performs named entity recognition on text using a provided NLP |
| 2601 | + model. |
| 2602 | +
|
| 2603 | + Args: |
| 2604 | + text (str): The text to perform named entity recognition on. |
| 2605 | + nlp: The NLP model to use for entity recognition. |
| 2606 | +
|
| 2607 | + Returns: |
| 2608 | + entities: A list of entities found. |
| 2609 | + """ |
| 2610 | + |
| 2611 | + doc = nlp(text) |
| 2612 | + entities = [] |
| 2613 | + for ent in doc.ents: |
| 2614 | + entities.append(ent.text) |
| 2615 | + return entities |
| 2616 | + |
| 2617 | + def is_entity_in_list(self, entities: List[str], competitors: List[str]) -> List: |
| 2618 | + """Checks if any entity from a list is present in a given list of |
| 2619 | + competitors. |
| 2620 | +
|
| 2621 | + Args: |
| 2622 | + entities (list): A list of entities to check |
| 2623 | + competitors (list): A list of competitor names to match |
| 2624 | +
|
| 2625 | + Returns: |
| 2626 | + List: List of found competitors |
| 2627 | + """ |
| 2628 | + |
| 2629 | + found_competitors = [] |
| 2630 | + for entity in entities: |
| 2631 | + for item in competitors: |
| 2632 | + pattern = rf"\b{re.escape(item)}\b" |
| 2633 | + match = re.search(pattern.lower(), entity.lower()) |
| 2634 | + if match: |
| 2635 | + found_competitors.append(item) |
| 2636 | + return found_competitors |
| 2637 | + |
| 2638 | + def validate(self, value: str, metadata=Dict) -> ValidationResult: |
| 2639 | + """Checks a text to find competitors' names in it. |
| 2640 | +
|
| 2641 | + While running, store sentences naming competitors and generate a fixed output |
| 2642 | + filtering out all flagged sentences. |
| 2643 | +
|
| 2644 | + Args: |
| 2645 | + value (str): The value to be validated. |
| 2646 | + metadata (Dict, optional): Additional metadata. Defaults to empty dict. |
| 2647 | +
|
| 2648 | + Returns: |
| 2649 | + ValidationResult: The validation result. |
| 2650 | + """ |
| 2651 | + |
| 2652 | + if nltk is None: |
| 2653 | + raise ImportError( |
| 2654 | + "`nltk` library is required for `competitors-check` validator. " |
| 2655 | + "Please install it with `poetry add nltk`." |
| 2656 | + ) |
| 2657 | + sentences = nltk.sent_tokenize(value) |
| 2658 | + flagged_sentences = [] |
| 2659 | + filtered_sentences = [] |
| 2660 | + list_of_competitors_found = [] |
| 2661 | + |
| 2662 | + for sentence in sentences: |
| 2663 | + entities = self.exact_match(sentence, self._competitors) |
| 2664 | + if entities: |
| 2665 | + ner_entities = self.perform_ner(sentence, self.nlp) |
| 2666 | + found_competitors = self.is_entity_in_list(ner_entities, entities) |
| 2667 | + |
| 2668 | + if found_competitors: |
| 2669 | + flagged_sentences.append((found_competitors, sentence)) |
| 2670 | + list_of_competitors_found.append(found_competitors) |
| 2671 | + logger.debug(f"Found: {found_competitors} named in '{sentence}'") |
| 2672 | + else: |
| 2673 | + filtered_sentences.append(sentence) |
| 2674 | + |
| 2675 | + else: |
| 2676 | + filtered_sentences.append(sentence) |
| 2677 | + |
| 2678 | + filtered_output = " ".join(filtered_sentences) |
| 2679 | + |
| 2680 | + if len(flagged_sentences): |
| 2681 | + return FailResult( |
| 2682 | + error_message=( |
| 2683 | + f"Found the following competitors: {list_of_competitors_found}. " |
| 2684 | + "Please avoid naming those competitors next time" |
| 2685 | + ), |
| 2686 | + fix_value=filtered_output, |
| 2687 | + ) |
| 2688 | + else: |
| 2689 | + return PassResult() |
0 commit comments