Skip to content

Commit 0b126e4

Browse files
committed
Custom sampler
1 parent aec5852 commit 0b126e4

File tree

3 files changed

+54
-1
lines changed

3 files changed

+54
-1
lines changed

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ tox==3.14.1
77
twine==3.1.1
88
volatile==2.1.0
99
wheel==0.33.6
10+
sre-tools==0.0.1

sre_yield/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,12 @@ def __repr__(self):
325325
)
326326

327327

328+
class SamplingRepetitiveSequence(SlicedSequence):
329+
def __init__(self, content, lowest=1, highest=1):
330+
real = RepetitiveSequence(content, lowest, highest)
331+
super().__init__(real, slicer=slice(0, 2))
332+
333+
328334
class SaveCaptureGroup(WrappedSequence):
329335
def __init__(self, parsed, key):
330336
self.key = key
@@ -353,6 +359,8 @@ def get_item(self, i, d=None):
353359
class RegexMembershipSequence(WrappedSequence):
354360
"""Creates a sequence from the regex, knows how to test membership."""
355361

362+
_RepetitiveSequence = RepetitiveSequence
363+
356364
def empty_list(self, *_):
357365
return []
358366

@@ -372,7 +380,7 @@ def max_repeat_values(self, min_count, max_count, items):
372380
"""Sequential expansion of the count to be combinatorics."""
373381
max_count = min(max_count, self.max_count)
374382
max_count = max(max_count, min_count)
375-
return RepetitiveSequence(self.sub_values(items), min_count, max_count)
383+
return self._RepetitiveSequence(self.sub_values(items), min_count, max_count)
376384

377385
def in_values(self, items):
378386
# Special case which distinguishes branch from charset operator

sre_yield/tests/test_sre_yield.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import io
1919
import re
20+
import sre_constants
2021
import sre_parse
2122
import sys
2223
import unittest
@@ -208,6 +209,49 @@ def testSlicingMatches(self):
208209
self.assertEqual(["a-a", "b-b"], [x.group(0) for x in parsed[:2]])
209210
self.assertEqual(["a", "b"], [x.group(1) for x in parsed[:2]])
210211

212+
def testCustomRepeater(self):
213+
sample_codes = ["au", "us", "eu"]
214+
215+
class DomainExpander(sre_yield.SlicedSequence):
216+
# Expand any \w\w pattern to be a country domain code
217+
def __init__(self, content, lowest=1, highest=1):
218+
if content.__len__() == len(
219+
sre_yield.CATEGORIES[sre_constants.CATEGORY_WORD]
220+
):
221+
if (
222+
"".join(iter(content))
223+
== sre_yield.CATEGORIES[sre_constants.CATEGORY_WORD]
224+
):
225+
if lowest == 2 and highest == 2:
226+
super().__init__(sample_codes)
227+
return
228+
229+
real = sre_yield.RepetitiveSequence(content, lowest, highest)
230+
super().__init__(real)
231+
232+
class CustomStrings(sre_yield.RegexMembershipSequence):
233+
_RepetitiveSequence = DomainExpander
234+
235+
def __init__(
236+
self,
237+
pattern,
238+
flags=0,
239+
charset=sre_yield.CHARSET,
240+
max_count=None,
241+
relaxed=False,
242+
simplify=False,
243+
):
244+
if simplify:
245+
from sre_tools.simplify import simplify_regex
246+
247+
pattern = simplify_regex(pattern)
248+
super().__init__(pattern, flags, charset, max_count, relaxed)
249+
250+
self.assertEqual(list(CustomStrings(r"\w{2}")), sample_codes)
251+
self.assertEqual(
252+
list(CustomStrings(r"\w\w", simplify=True)), sample_codes
253+
)
254+
211255
def testSlicingMatchesMultichar(self):
212256
parsed = sre_yield.AllMatches("z([ab]{2})")
213257
self.assertEqual(4, len(parsed))

0 commit comments

Comments
 (0)