From 0b126e43eada56e86108795bc0342e69b03f6466 Mon Sep 17 00:00:00 2001 From: John Vandenberg Date: Tue, 22 Sep 2020 17:34:03 +0700 Subject: [PATCH] Custom sampler --- requirements-dev.txt | 1 + sre_yield/__init__.py | 10 ++++++- sre_yield/tests/test_sre_yield.py | 44 +++++++++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 0479c95..86fe0d9 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,3 +7,4 @@ tox==3.14.1 twine==3.1.1 volatile==2.1.0 wheel==0.33.6 +sre-tools==0.0.1 diff --git a/sre_yield/__init__.py b/sre_yield/__init__.py index b1539d1..06598b9 100644 --- a/sre_yield/__init__.py +++ b/sre_yield/__init__.py @@ -325,6 +325,12 @@ def __repr__(self): ) +class SamplingRepetitiveSequence(SlicedSequence): + def __init__(self, content, lowest=1, highest=1): + real = RepetitiveSequence(content, lowest, highest) + super().__init__(real, slicer=slice(0, 2)) + + class SaveCaptureGroup(WrappedSequence): def __init__(self, parsed, key): self.key = key @@ -353,6 +359,8 @@ def get_item(self, i, d=None): class RegexMembershipSequence(WrappedSequence): """Creates a sequence from the regex, knows how to test membership.""" + _RepetitiveSequence = RepetitiveSequence + def empty_list(self, *_): return [] @@ -372,7 +380,7 @@ def max_repeat_values(self, min_count, max_count, items): """Sequential expansion of the count to be combinatorics.""" max_count = min(max_count, self.max_count) max_count = max(max_count, min_count) - return RepetitiveSequence(self.sub_values(items), min_count, max_count) + return self._RepetitiveSequence(self.sub_values(items), min_count, max_count) def in_values(self, items): # Special case which distinguishes branch from charset operator diff --git a/sre_yield/tests/test_sre_yield.py b/sre_yield/tests/test_sre_yield.py index bb649d1..3786105 100644 --- a/sre_yield/tests/test_sre_yield.py +++ b/sre_yield/tests/test_sre_yield.py @@ -17,6 +17,7 @@ import io import re +import sre_constants import sre_parse import sys import unittest @@ -208,6 +209,49 @@ def testSlicingMatches(self): self.assertEqual(["a-a", "b-b"], [x.group(0) for x in parsed[:2]]) self.assertEqual(["a", "b"], [x.group(1) for x in parsed[:2]]) + def testCustomRepeater(self): + sample_codes = ["au", "us", "eu"] + + class DomainExpander(sre_yield.SlicedSequence): + # Expand any \w\w pattern to be a country domain code + def __init__(self, content, lowest=1, highest=1): + if content.__len__() == len( + sre_yield.CATEGORIES[sre_constants.CATEGORY_WORD] + ): + if ( + "".join(iter(content)) + == sre_yield.CATEGORIES[sre_constants.CATEGORY_WORD] + ): + if lowest == 2 and highest == 2: + super().__init__(sample_codes) + return + + real = sre_yield.RepetitiveSequence(content, lowest, highest) + super().__init__(real) + + class CustomStrings(sre_yield.RegexMembershipSequence): + _RepetitiveSequence = DomainExpander + + def __init__( + self, + pattern, + flags=0, + charset=sre_yield.CHARSET, + max_count=None, + relaxed=False, + simplify=False, + ): + if simplify: + from sre_tools.simplify import simplify_regex + + pattern = simplify_regex(pattern) + super().__init__(pattern, flags, charset, max_count, relaxed) + + self.assertEqual(list(CustomStrings(r"\w{2}")), sample_codes) + self.assertEqual( + list(CustomStrings(r"\w\w", simplify=True)), sample_codes + ) + def testSlicingMatchesMultichar(self): parsed = sre_yield.AllMatches("z([ab]{2})") self.assertEqual(4, len(parsed))