Skip to content

Commit d89ce22

Browse files
committed
Custom sampler
1 parent aec5852 commit d89ce22

File tree

3 files changed

+53
-1
lines changed

3 files changed

+53
-1
lines changed

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ tox==3.14.1
77
twine==3.1.1
88
volatile==2.1.0
99
wheel==0.33.6
10+
sre-tools==0.0.1

sre_yield/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,12 @@ def __repr__(self):
325325
)
326326

327327

328+
class SamplingRepetitiveSequence(SlicedSequence):
329+
def __init__(self, content, lowest=1, highest=1):
330+
real = RepetitiveSequence(content, lowest, highest)
331+
super().__init__(real, slicer=slice(0, 2))
332+
333+
328334
class SaveCaptureGroup(WrappedSequence):
329335
def __init__(self, parsed, key):
330336
self.key = key
@@ -353,6 +359,8 @@ def get_item(self, i, d=None):
353359
class RegexMembershipSequence(WrappedSequence):
354360
"""Creates a sequence from the regex, knows how to test membership."""
355361

362+
_RepetitiveSequence = RepetitiveSequence
363+
356364
def empty_list(self, *_):
357365
return []
358366

@@ -372,7 +380,7 @@ def max_repeat_values(self, min_count, max_count, items):
372380
"""Sequential expansion of the count to be combinatorics."""
373381
max_count = min(max_count, self.max_count)
374382
max_count = max(max_count, min_count)
375-
return RepetitiveSequence(self.sub_values(items), min_count, max_count)
383+
return self._RepetitiveSequence(self.sub_values(items), min_count, max_count)
376384

377385
def in_values(self, items):
378386
# Special case which distinguishes branch from charset operator

sre_yield/tests/test_sre_yield.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import io
1919
import re
20+
import sre_constants
2021
import sre_parse
2122
import sys
2223
import unittest
@@ -208,6 +209,48 @@ def testSlicingMatches(self):
208209
self.assertEqual(["a-a", "b-b"], [x.group(0) for x in parsed[:2]])
209210
self.assertEqual(["a", "b"], [x.group(1) for x in parsed[:2]])
210211

212+
def testCustomRepeater(self):
213+
sample_codes = ["au", "us", "eu"]
214+
class DomainExpander(sre_yield.SlicedSequence):
215+
# Expand any \w\w pattern to be a country domain code
216+
def __init__(self, content, lowest=1, highest=1):
217+
if content.__len__() == len(
218+
sre_yield.CATEGORIES[sre_constants.CATEGORY_WORD]
219+
):
220+
if (
221+
"".join(iter(content))
222+
== sre_yield.CATEGORIES[sre_constants.CATEGORY_WORD]
223+
):
224+
if lowest == 2 and highest == 2:
225+
super().__init__(sample_codes)
226+
return
227+
228+
real = sre_yield.RepetitiveSequence(content, lowest, highest)
229+
super().__init__(real)
230+
231+
class CustomStrings(sre_yield.RegexMembershipSequence):
232+
_RepetitiveSequence = DomainExpander
233+
234+
def __init__(
235+
self,
236+
pattern,
237+
flags=0,
238+
charset=sre_yield.CHARSET,
239+
max_count=None,
240+
relaxed=False,
241+
simplify=False,
242+
):
243+
if simplify:
244+
from sre_tools.simplify import simplify_regex
245+
246+
pattern = simplify_regex(pattern)
247+
super().__init__(pattern, flags, charset, max_count, relaxed)
248+
249+
self.assertEqual(list(CustomStrings(r"\w{2}")), sample_codes)
250+
self.assertEqual(
251+
list(CustomStrings(r"\w\w", simplify=True)), sample_codes
252+
)
253+
211254
def testSlicingMatchesMultichar(self):
212255
parsed = sre_yield.AllMatches("z([ab]{2})")
213256
self.assertEqual(4, len(parsed))

0 commit comments

Comments
 (0)