Skip to content

Commit d08c76d

Browse files
authored
c.Struct cleanup (tinygrad#15640)
1 parent 742b389 commit d08c76d

File tree

2 files changed

+74
-35
lines changed

2 files changed

+74
-35
lines changed

test/null/test_autogen.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,51 @@ def compile(self, src):
1212
subprocess.check_output(('clang', '-x', 'c', '-fPIC', '-shared', '-', '-o', f.name), input=src.encode())
1313
return DLL("test", f.name)
1414

15+
def test_struct_array_init(self):
16+
@record
17+
class Foo:
18+
SIZE = 12
19+
a: Annotated[ctypes.c_int * 3, 0]
20+
init_records()
21+
22+
f = Foo((1,2,3))
23+
assert f.a[0] == 1
24+
assert f.a[1] == 2
25+
assert f.a[2] == 3
26+
f = Foo((ctypes.c_int * 3)(1,2,3))
27+
assert f.a[0] == 1
28+
assert f.a[1] == 2
29+
assert f.a[2] == 3
30+
31+
def test_field_ranges(self):
32+
@record
33+
class Foo:
34+
SIZE = 2
35+
s: Annotated[ctypes.c_int8, 0]
36+
u: Annotated[ctypes.c_uint8, 1]
37+
init_records()
38+
39+
f = Foo()
40+
f.s = -1
41+
f.u = -1
42+
assert f.s == -1
43+
assert f.u == 255
44+
45+
# this syntax is inherited from ctypes, but it seems a bit nonsensical?
46+
def test_voidp_none(self):
47+
@record
48+
class Foo:
49+
SIZE = 8
50+
p: Annotated[ctypes.c_void_p, 0]
51+
init_records()
52+
53+
f = Foo(None)
54+
assert f.p is None
55+
f.p = ctypes.c_void_p(0xDEADBEEF)
56+
assert f.p == 0xDEADBEEF
57+
f.p = None
58+
assert f.p is None
59+
1560
def test_packed_struct(self):
1661
@record
1762
class Baz:

tinygrad/runtime/support/c.py

Lines changed: 29 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22
import ctypes, functools, os, pathlib, re, sys, sysconfig
33
from tinygrad.helpers import ceildiv, getenv, unwrap, DEBUG, OSX, WIN
4-
from _ctypes import Array as _CArray, _SimpleCData, _Pointer
54
from typing import TYPE_CHECKING, get_type_hints, get_args, get_origin, overload, Annotated, Any, Generic, Iterable, ParamSpec, TypeVar
65

76
def _do_ioctl(__idir, __base, __nr, __struct, __fd, *args, __payload=None, **kwargs):
@@ -34,22 +33,22 @@ def del_an(ty):
3433
from _ctypes import _CData
3534
class Array(Generic[T, U], _CData):
3635
@overload
37-
def __getitem__(self: Array[_SimpleCData[V], Any], key: int) -> V: ...
36+
def __getitem__(self: Array[ctypes._SimpleCData[V], Any], key: int) -> V: ...
3837
@overload
3938
def __getitem__(self: Array[T, Any], key: slice) -> list[T]: ...
4039
@overload
4140
def __getitem__(self: Array[T, Any], key: int) -> T: ...
4241
def __getitem__(self, key) -> Any: ...
4342
@overload
44-
def __setitem__(self: Array[_SimpleCData[V], Any], key: int, val: V): ...
43+
def __setitem__(self: Array[ctypes._SimpleCData[V], Any], key: int, val: V): ...
4544
@overload
4645
def __setitem__(self: Array[T, Any], key: int, val: T): ...
4746
@overload
4847
def __setitem__(self: Array[T, Any], key: slice, val: Iterable[T]): ...
4948
def __setitem__(self, key, val): ...
50-
class POINTER(Generic[T], _Pointer): ...
49+
class POINTER(Generic[T], ctypes._Pointer): ...
5150
class CFUNCTYPE(Generic[T, P], _CFunctionType): ...
52-
class Enum(_SimpleCData):
51+
class Enum(ctypes._SimpleCData):
5352
@classmethod
5453
def get(cls, val:int, default="unknown") -> str: ...
5554
@classmethod
@@ -80,14 +79,9 @@ def define(cls, name:str, val:int) -> int:
8079
return val
8180
def pointer(obj): return ctypes.pointer(obj)
8281

83-
def i2b(i:int, sz:int) -> bytes: return i.to_bytes(sz, sys.byteorder)
84-
def b2i(b:bytes) -> int: return int.from_bytes(b, sys.byteorder)
85-
def mv(st) -> memoryview: return memoryview(st).cast('B')
86-
8782
class Struct(ctypes.Structure):
8883
def __init__(self, *args, **kwargs):
8984
ctypes.Structure.__init__(self)
90-
self._objects_ = {}
9185
for f,v in [*zip((rf[0] for rf in self._real_fields_), args), *kwargs.items()]: setattr(self, f, v)
9286

9387
def record(cls) -> type[Struct]:
@@ -98,38 +92,38 @@ def record(cls) -> type[Struct]:
9892
def init_records() -> None:
9993
for cls, struct, ns in _pending_records:
10094
setattr(struct, '_real_fields_', [])
101-
for nm, t in get_type_hints(cls, globalns=ns, include_extras=True).items():
102-
if t.__origin__ in (bool, bytes, str, int, float): setattr(struct, nm, Field(*(f:=t.__metadata__)))
103-
else: setattr(struct, nm, Field(*(f:=(del_an(t.__origin__), *t.__metadata__))))
104-
struct._real_fields_.append((nm,) + f) # type: ignore
95+
for i, (nm, t) in enumerate(get_type_hints(cls, globalns=ns, include_extras=True).items()):
96+
struct._real_fields_.append((nm, *(f:=(del_an(t.__origin__), *t.__metadata__) if isinstance(t.__metadata__[0], int) else t.__metadata__))) # type: ignore
97+
setattr(struct, nm, Field(nm, i, *f))
10598
_pending_records.clear()
10699

107-
class Field(property):
108-
def __init__(self, typ, off:int, bit_width=None, bit_off=0):
109-
if bit_width is not None:
110-
sl, set_mask = slice(off,off+(sz:=ceildiv(bit_width+bit_off, 8))), ~((mask:=(1 << bit_width) - 1) << bit_off)
100+
class Field:
101+
def __init__(self, nm, idx, typ, off, bit_width=None, bit_off=0):
102+
self.nm, self.idx, self.typ, self.off, self.bit_width, self.bit_off = nm, idx, typ, off, bit_width, bit_off
103+
104+
# lazily resolve field descriptors
105+
def _resolve(self, cls):
106+
if self.bit_width: # handle bitfields ourselves
107+
sl, set_mask = slice(self.off, self.off+(sz:=ceildiv(self.bit_width+self.bit_off, 8))), ~((mask:=(1 << self.bit_width) - 1) << self.bit_off)
108+
def b2i(obj): return int.from_bytes(memoryview(obj).cast("B")[sl], sys.byteorder)
109+
def bset(obj, v): memoryview(obj).cast("B")[sl] = ((b2i(obj) & set_mask) | v << self.bit_off).to_bytes(sz, sys.byteorder)
111110
# FIXME: signedness
112-
super().__init__(lambda self: (b2i(mv(self)[sl]) >> bit_off) & mask,
113-
lambda self,v: mv(self).__setitem__(sl, i2b((b2i(mv(self)[sl]) & set_mask) | (v << bit_off), sz)))
114-
else:
115-
sl = slice(off, off + ctypes.sizeof(typ))
116-
def set_with_objs(f):
117-
def wrapper(self, v):
118-
if hasattr(v, '_objects') and hasattr(self, '_objects_'): self._objects_[off] = {'_self_': v, **(v._objects or {})}
119-
mv(self).__setitem__(sl, bytes(v if isinstance(v, typ) else f(v)))
120-
return wrapper
121-
if issubclass(typ, _CArray):
122-
getter = (lambda self: typ.from_buffer(mv(self)[sl]).value) if typ._type_ is ctypes.c_char else (lambda self: typ.from_buffer(mv(self)[sl]))
123-
super().__init__(getter, set_with_objs(lambda v: typ(*v)))
124-
else: super().__init__(lambda self: v.value if isinstance(v:=typ.from_buffer(mv(self)[sl]), _SimpleCData) else v, set_with_objs(typ))
125-
self.offset = off
111+
cf = property(lambda obj: b2i(obj) >> self.bit_off & mask, bset)
112+
# pull the CField descriptor from a dummy class, zero length arrays are so ctypes manages references to child objects for us
113+
else: cf = type(self.nm, (ctypes.Structure,), {"_layout_": "ms", "_pack_": 1, "_fields_": [(str(i), ctypes.c_byte * 0) for i in range(self.idx)] +
114+
[("_", ctypes.c_byte * self.off), ("v", self.typ)]}).v # type: ignore
115+
setattr(cls, self.nm, cf)
116+
return cf
117+
118+
def __get__(self, obj, objtype=None): return self._resolve(objtype).__get__(obj, objtype) if objtype else self
119+
def __set__(self, obj, value): self._resolve(obj.__class__).__set__(obj, value)
126120

127121
@functools.cache
128122
def init_c_struct_t(sz:int, fields: tuple[tuple, ...]):
129123
CStruct = type("CStruct", (Struct,), {'_fields_': [('_mem_', ctypes.c_byte * sz)], '_real_fields_': []})
130-
for nm,ty,*args in fields:
131-
setattr(CStruct, nm, Field(*(f:=(del_an(ty), *args))))
132-
CStruct._real_fields_.append((nm,) + f) # type: ignore
124+
for i,(nm,ty,*args) in enumerate(fields):
125+
CStruct._real_fields_.append((nm, *(f:=(del_an(ty), *args)))) # type: ignore
126+
setattr(CStruct, nm, Field(nm, i, *f))
133127
return CStruct
134128
def init_c_var(ty, creat_cb): return (creat_cb(v:=del_an(ty)()), v)[1]
135129

0 commit comments

Comments
 (0)