Skip to content

Commit 72f9ae5

Browse files
committed
Implement bytes.startswith in mypyc
1 parent 4eb6b50 commit 72f9ae5

File tree

6 files changed

+79
-1
lines changed

6 files changed

+79
-1
lines changed

mypyc/lib-rt/CPy.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -784,7 +784,7 @@ PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter);
784784
CPyTagged CPyBytes_Ord(PyObject *obj);
785785
PyObject *CPyBytes_Multiply(PyObject *bytes, CPyTagged count);
786786
PyObject *CPyBytes_Translate(PyObject *bytes, PyObject *table);
787-
787+
int CPyBytes_Startswith(PyObject *self, PyObject *subobj);
788788

789789
int CPyBytes_Compare(PyObject *left, PyObject *right);
790790

mypyc/lib-rt/bytes_ops.c

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,3 +220,42 @@ PyObject *CPyBytes_Translate(PyObject *bytes, PyObject *table) {
220220
}
221221
return PyObject_CallMethodOneArg(bytes, name, table);
222222
}
223+
224+
int CPyBytes_Startswith(PyObject *self, PyObject *subobj) {
225+
if (PyBytes_CheckExact(self) && PyBytes_CheckExact(subobj)) {
226+
if (self == subobj) {
227+
return 1;
228+
}
229+
230+
Py_ssize_t self_len = PyBytes_GET_SIZE(self);
231+
Py_ssize_t subobj_len = PyBytes_GET_SIZE(subobj);
232+
233+
if (subobj_len > self_len) {
234+
return 0;
235+
}
236+
237+
const char *self_buf = PyBytes_AS_STRING(self);
238+
const char *subobj_buf = PyBytes_AS_STRING(subobj);
239+
240+
if (subobj_len == 0) {
241+
return 1;
242+
}
243+
244+
return memcmp(self_buf, subobj_buf, (size_t)subobj_len) == 0 ? 1 : 0;
245+
}
246+
_Py_IDENTIFIER(startswith);
247+
PyObject *name = _PyUnicode_FromId(&PyId_startswith);
248+
if (name == NULL) {
249+
return 2;
250+
}
251+
PyObject *result = PyObject_CallMethodOneArg(self, name, subobj);
252+
if (result == NULL) {
253+
return 2;
254+
}
255+
int ret = PyObject_IsTrue(result);
256+
Py_DECREF(result);
257+
if (ret < 0) {
258+
return 2;
259+
}
260+
return ret;
261+
}

mypyc/primitives/bytes_ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from mypyc.ir.rtypes import (
77
RUnion,
88
bit_rprimitive,
9+
bool_rprimitive,
910
bytes_rprimitive,
1011
c_int_rprimitive,
1112
c_pyssize_t_rprimitive,
@@ -137,6 +138,16 @@
137138
error_kind=ERR_MAGIC,
138139
)
139140

141+
# bytes.startswith(bytes)
142+
method_op(
143+
name="startswith",
144+
arg_types=[bytes_rprimitive, bytes_rprimitive],
145+
return_type=c_int_rprimitive,
146+
c_function_name="CPyBytes_Startswith",
147+
truncated_type=bool_rprimitive,
148+
error_kind=ERR_MAGIC,
149+
)
150+
140151
# Join bytes objects and return a new bytes.
141152
# The first argument is the total number of the following bytes.
142153
bytes_build_op = custom_op(

mypyc/test-data/fixtures/ir.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def __getitem__(self, i: slice) -> bytes: ...
179179
def join(self, x: Iterable[object]) -> bytes: ...
180180
def decode(self, encoding: str=..., errors: str=...) -> str: ...
181181
def translate(self, t: bytes) -> bytes: ...
182+
def startswith(self, t: bytes) -> bool: ...
182183
def __iter__(self) -> Iterator[int]: ...
183184

184185
class bytearray:

mypyc/test-data/irbuild-bytes.test

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,17 @@ def f(b, table):
248248
L0:
249249
r0 = CPyBytes_Translate(b, table)
250250
return r0
251+
252+
[case testBytesStartsWith]
253+
def f(a: bytes, b: bytes) -> bool:
254+
return a.startswith(b)
255+
[out]
256+
def f(a, b):
257+
a, b :: bytes
258+
r0 :: i32
259+
r1 :: bool
260+
L0:
261+
r0 = CPyBytes_Startswith(a, b)
262+
r1 = truncate r0: i32 to builtins.bool
263+
return r1
264+

mypyc/test-data/run-bytes.test

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,19 @@ def test_translate() -> None:
200200
with assertRaises(ValueError, "translation table must be 256 characters long"):
201201
b'test'.translate(bytes(100))
202202

203+
def test_startswith() -> None:
204+
# Test default behavior
205+
test = b'some string'
206+
assert test.startswith(b'some')
207+
assert test.startswith(b'some string')
208+
assert not test.startswith(b'other')
209+
assert not test.startswith(b'some string but longer')
210+
211+
# Test empty cases
212+
assert test.startswith(b'')
213+
assert b''.startswith(b'')
214+
assert not b''.startswith(test)
215+
203216
[case testBytesSlicing]
204217
def test_bytes_slicing() -> None:
205218
b = b'abcdefg'

0 commit comments

Comments
 (0)