Skip to content

Commit

Permalink
gh-67877: Fix memory leaks in terminated RE matching
Browse files Browse the repository at this point in the history
If SRE(match) function terminates abruptly, either because of a signal
or because memory allocation fails, allocated SRE_REPEAT blocks might
be never released.
  • Loading branch information
serhiy-storchaka committed Nov 14, 2024
1 parent 7577307 commit 2b46b21
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 28 deletions.
21 changes: 21 additions & 0 deletions Lib/test/test_re.py
Original file line number Diff line number Diff line change
Expand Up @@ -2681,6 +2681,27 @@ def test_character_set_none(self):
self.assertIsNone(re.search(p, s))
self.assertIsNone(re.search('(?s:.)' + p, s))

def check_interrupt(self, pattern, string, maxcount):
class Interrupt(Exception):
pass
p = re.compile(pattern)
for n in range(maxcount):
p._fail_after(n, Interrupt)
try:
p.match(string)
return n
except Interrupt:
pass

@unittest.skipUnless(hasattr(re.Pattern, '_fail_after'), 'requires debug build')
def test_memory_leaks(self):
self.check_interrupt(r'(.)*:', 'abc:', 100)
self.check_interrupt(r'([^:])*?:', 'abc:', 100)
self.check_interrupt(r'([^:])*+:', 'abc:', 100)
self.check_interrupt(r'(.){2,4}:', 'abc:', 100)
self.check_interrupt(r'([^:]){2,4}?:', 'abc:', 100)
self.check_interrupt(r'([^:]){2,4}+:', 'abc:', 100)


def get_debug_out(pat):
with captured_stdout() as out:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix memory leaks when :mod:`regular expression <re>` matching terminates
abruptly, either because of a signal or because memory allocation fails.
44 changes: 43 additions & 1 deletion Modules/_sre/clinic/sre.c.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

107 changes: 86 additions & 21 deletions Modules/_sre/sre.c
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,25 @@ _sre_unicode_tolower_impl(PyObject *module, int character)
return sre_lower_unicode(character);
}

LOCAL(void)
state_clean_repeat_data(SRE_STATE* state)
{
SRE_REPEAT *rep = state->repeat;
state->repeat = NULL;
while (rep) {
SRE_REPEAT *prev = rep->prev;
PyMem_Free(rep);
rep = prev;
}
rep = state->repstack;
state->repstack = NULL;
while (rep) {
SRE_REPEAT *next = rep->next;
PyMem_Free(rep);
rep = next;
}
}

LOCAL(void)
state_reset(SRE_STATE* state)
{
Expand All @@ -406,8 +425,7 @@ state_reset(SRE_STATE* state)
state->lastmark = -1;
state->lastindex = -1;

state->repeat = NULL;

state_clean_repeat_data(state);
data_stack_dealloc(state);
}

Expand Down Expand Up @@ -511,6 +529,11 @@ state_init(SRE_STATE* state, PatternObject* pattern, PyObject* string,
state->pos = start;
state->endpos = end;

#ifdef Py_DEBUG
state->fail_after_count = pattern->fail_after_count;
state->fail_after_exc = pattern->fail_after_exc; // borrowed ref
#endif

return string;
err:
/* We add an explicit cast here because MSVC has a bug when
Expand All @@ -524,15 +547,21 @@ state_init(SRE_STATE* state, PatternObject* pattern, PyObject* string,
}

LOCAL(void)
state_fini(SRE_STATE* state)
state_fini(SRE_STATE* state, PatternObject *pattern)
{
if (state->buffer.buf)
PyBuffer_Release(&state->buffer);
Py_XDECREF(state->string);
state_clean_repeat_data(state);
data_stack_dealloc(state);
/* See above PyMem_Free() for why we explicitly cast here. */
PyMem_Free((void*) state->mark);
state->mark = NULL;
#ifdef Py_DEBUG
if (pattern) {
pattern->fail_after_count = -1;
}
#endif
}

/* calculate offset from start of string */
Expand Down Expand Up @@ -619,6 +648,9 @@ pattern_traverse(PatternObject *self, visitproc visit, void *arg)
Py_VISIT(self->groupindex);
Py_VISIT(self->indexgroup);
Py_VISIT(self->pattern);
#ifdef Py_DEBUG
Py_VISIT(self->fail_after_exc);
#endif
return 0;
}

Expand All @@ -628,6 +660,9 @@ pattern_clear(PatternObject *self)
Py_CLEAR(self->groupindex);
Py_CLEAR(self->indexgroup);
Py_CLEAR(self->pattern);
#ifdef Py_DEBUG
Py_CLEAR(self->fail_after_exc);
#endif
return 0;
}

Expand Down Expand Up @@ -690,7 +725,7 @@ _sre_SRE_Pattern_match_impl(PatternObject *self, PyTypeObject *cls,
Py_ssize_t status;
PyObject *match;

if (!state_init(&state, (PatternObject *)self, string, pos, endpos))
if (!state_init(&state, self, string, pos, endpos))
return NULL;

INIT_TRACE(&state);
Expand All @@ -702,12 +737,12 @@ _sre_SRE_Pattern_match_impl(PatternObject *self, PyTypeObject *cls,

TRACE(("|%p|%p|END\n", PatternObject_GetCode(self), state.ptr));
if (PyErr_Occurred()) {
state_fini(&state);
state_fini(&state, self);
return NULL;
}

match = pattern_new_match(module_state, self, &state, status);
state_fini(&state);
state_fini(&state, self);
return match;
}

Expand Down Expand Up @@ -747,12 +782,12 @@ _sre_SRE_Pattern_fullmatch_impl(PatternObject *self, PyTypeObject *cls,

TRACE(("|%p|%p|END\n", PatternObject_GetCode(self), state.ptr));
if (PyErr_Occurred()) {
state_fini(&state);
state_fini(&state, self);
return NULL;
}

match = pattern_new_match(module_state, self, &state, status);
state_fini(&state);
state_fini(&state, self);
return match;
}

Expand Down Expand Up @@ -792,12 +827,12 @@ _sre_SRE_Pattern_search_impl(PatternObject *self, PyTypeObject *cls,
TRACE(("|%p|%p|END\n", PatternObject_GetCode(self), state.ptr));

if (PyErr_Occurred()) {
state_fini(&state);
state_fini(&state, self);
return NULL;
}

match = pattern_new_match(module_state, self, &state, status);
state_fini(&state);
state_fini(&state, self);
return match;
}

Expand Down Expand Up @@ -826,7 +861,7 @@ _sre_SRE_Pattern_findall_impl(PatternObject *self, PyObject *string,

list = PyList_New(0);
if (!list) {
state_fini(&state);
state_fini(&state, self);
return NULL;
}

Expand Down Expand Up @@ -888,12 +923,12 @@ _sre_SRE_Pattern_findall_impl(PatternObject *self, PyObject *string,
state.start = state.ptr;
}

state_fini(&state);
state_fini(&state, self);
return list;

error:
Py_DECREF(list);
state_fini(&state);
state_fini(&state, self);
return NULL;

}
Expand Down Expand Up @@ -989,7 +1024,7 @@ _sre_SRE_Pattern_split_impl(PatternObject *self, PyObject *string,

list = PyList_New(0);
if (!list) {
state_fini(&state);
state_fini(&state, self);
return NULL;
}

Expand Down Expand Up @@ -1053,12 +1088,12 @@ _sre_SRE_Pattern_split_impl(PatternObject *self, PyObject *string,
if (status < 0)
goto error;

state_fini(&state);
state_fini(&state, self);
return list;

error:
Py_DECREF(list);
state_fini(&state);
state_fini(&state, self);
return NULL;

}
Expand Down Expand Up @@ -1185,7 +1220,7 @@ pattern_subx(_sremodulestate* module_state,
list = PyList_New(0);
if (!list) {
Py_DECREF(filter);
state_fini(&state);
state_fini(&state, self);
return NULL;
}

Expand Down Expand Up @@ -1271,7 +1306,7 @@ pattern_subx(_sremodulestate* module_state,
goto error;
}

state_fini(&state);
state_fini(&state, self);

Py_DECREF(filter);

Expand Down Expand Up @@ -1303,7 +1338,7 @@ pattern_subx(_sremodulestate* module_state,

error:
Py_DECREF(list);
state_fini(&state);
state_fini(&state, self);
Py_DECREF(filter);
return NULL;

Expand Down Expand Up @@ -1381,6 +1416,29 @@ _sre_SRE_Pattern___deepcopy__(PatternObject *self, PyObject *memo)
return Py_NewRef(self);
}

#ifdef Py_DEBUG
/*[clinic input]
_sre.SRE_Pattern._fail_after
count: int
exception: object
/
For debugging.
[clinic start generated code]*/

static PyObject *
_sre_SRE_Pattern__fail_after_impl(PatternObject *self, int count,
PyObject *exception)
/*[clinic end generated code: output=9a6bf12135ac50c2 input=ef80a45c66c5499d]*/
{
self->fail_after_count = count;
Py_INCREF(exception);
Py_XSETREF(self->fail_after_exc, exception);
Py_RETURN_NONE;
}
#endif /* Py_DEBUG */

static PyObject *
pattern_repr(PatternObject *obj)
{
Expand Down Expand Up @@ -1506,6 +1564,11 @@ _sre_compile_impl(PyObject *module, PyObject *pattern, int flags,
self->pattern = NULL;
self->groupindex = NULL;
self->indexgroup = NULL;
#ifdef Py_DEBUG
self->fail_after_count = -1;
self->fail_after_exc = NULL;
self->fail_after_exc = Py_NewRef(PyExc_RuntimeError);
#endif

self->codesize = n;

Expand Down Expand Up @@ -2680,7 +2743,7 @@ scanner_dealloc(ScannerObject* self)
PyTypeObject *tp = Py_TYPE(self);

PyObject_GC_UnTrack(self);
state_fini(&self->state);
state_fini(&self->state, self->pattern);
(void)scanner_clear(self);
tp->tp_free(self);
Py_DECREF(tp);
Expand Down Expand Up @@ -2826,7 +2889,8 @@ pattern_scanner(_sremodulestate *module_state,
return NULL;
}

scanner->pattern = Py_NewRef(self);
Py_INCREF(self);
scanner->pattern = self;

PyObject_GC_Track(scanner);
return (PyObject*) scanner;
Expand Down Expand Up @@ -3020,6 +3084,7 @@ static PyMethodDef pattern_methods[] = {
_SRE_SRE_PATTERN_SCANNER_METHODDEF
_SRE_SRE_PATTERN___COPY___METHODDEF
_SRE_SRE_PATTERN___DEEPCOPY___METHODDEF
_SRE_SRE_PATTERN__FAIL_AFTER_METHODDEF
{"__class_getitem__", Py_GenericAlias, METH_O|METH_CLASS,
PyDoc_STR("See PEP 585")},
{NULL, NULL}
Expand Down
Loading

0 comments on commit 2b46b21

Please sign in to comment.