11from __future__ import annotations
22import ctypes , functools , os , pathlib , re , sys , sysconfig
33from tinygrad .helpers import ceildiv , getenv , unwrap , DEBUG , OSX , WIN
4- from _ctypes import Array as _CArray , _SimpleCData , _Pointer
54from typing import TYPE_CHECKING , get_type_hints , get_args , get_origin , overload , Annotated , Any , Generic , Iterable , ParamSpec , TypeVar
65
76def _do_ioctl (__idir , __base , __nr , __struct , __fd , * args , __payload = None , ** kwargs ):
@@ -34,22 +33,22 @@ def del_an(ty):
3433 from _ctypes import _CData
3534 class Array (Generic [T , U ], _CData ):
3635 @overload
37- def __getitem__ (self : Array [_SimpleCData [V ], Any ], key : int ) -> V : ...
36+ def __getitem__ (self : Array [ctypes . _SimpleCData [V ], Any ], key : int ) -> V : ...
3837 @overload
3938 def __getitem__ (self : Array [T , Any ], key : slice ) -> list [T ]: ...
4039 @overload
4140 def __getitem__ (self : Array [T , Any ], key : int ) -> T : ...
4241 def __getitem__ (self , key ) -> Any : ...
4342 @overload
44- def __setitem__ (self : Array [_SimpleCData [V ], Any ], key : int , val : V ): ...
43+ def __setitem__ (self : Array [ctypes . _SimpleCData [V ], Any ], key : int , val : V ): ...
4544 @overload
4645 def __setitem__ (self : Array [T , Any ], key : int , val : T ): ...
4746 @overload
4847 def __setitem__ (self : Array [T , Any ], key : slice , val : Iterable [T ]): ...
4948 def __setitem__ (self , key , val ): ...
50- class POINTER (Generic [T ], _Pointer ): ...
49+ class POINTER (Generic [T ], ctypes . _Pointer ): ...
5150 class CFUNCTYPE (Generic [T , P ], _CFunctionType ): ...
52- class Enum (_SimpleCData ):
51+ class Enum (ctypes . _SimpleCData ):
5352 @classmethod
5453 def get (cls , val :int , default = "unknown" ) -> str : ...
5554 @classmethod
@@ -80,14 +79,9 @@ def define(cls, name:str, val:int) -> int:
8079 return val
8180 def pointer (obj ): return ctypes .pointer (obj )
8281
83- def i2b (i :int , sz :int ) -> bytes : return i .to_bytes (sz , sys .byteorder )
84- def b2i (b :bytes ) -> int : return int .from_bytes (b , sys .byteorder )
85- def mv (st ) -> memoryview : return memoryview (st ).cast ('B' )
86-
8782class Struct (ctypes .Structure ):
8883 def __init__ (self , * args , ** kwargs ):
8984 ctypes .Structure .__init__ (self )
90- self ._objects_ = {}
9185 for f ,v in [* zip ((rf [0 ] for rf in self ._real_fields_ ), args ), * kwargs .items ()]: setattr (self , f , v )
9286
9387def record (cls ) -> type [Struct ]:
@@ -98,38 +92,38 @@ def record(cls) -> type[Struct]:
9892def init_records () -> None :
9993 for cls , struct , ns in _pending_records :
10094 setattr (struct , '_real_fields_' , [])
101- for nm , t in get_type_hints (cls , globalns = ns , include_extras = True ).items ():
102- if t .__origin__ in (bool , bytes , str , int , float ): setattr (struct , nm , Field (* (f := t .__metadata__ )))
103- else : setattr (struct , nm , Field (* (f := (del_an (t .__origin__ ), * t .__metadata__ ))))
104- struct ._real_fields_ .append ((nm ,) + f ) # type: ignore
95+ for i , (nm , t ) in enumerate (get_type_hints (cls , globalns = ns , include_extras = True ).items ()):
96+ struct ._real_fields_ .append ((nm , * (f := (del_an (t .__origin__ ), * t .__metadata__ ) if isinstance (t .__metadata__ [0 ], int ) else t .__metadata__ ))) # type: ignore
97+ setattr (struct , nm , Field (nm , i , * f ))
10598 _pending_records .clear ()
10699
107- class Field (property ):
108- def __init__ (self , typ , off :int , bit_width = None , bit_off = 0 ):
109- if bit_width is not None :
110- sl , set_mask = slice (off ,off + (sz := ceildiv (bit_width + bit_off , 8 ))), ~ ((mask := (1 << bit_width ) - 1 ) << bit_off )
100+ class Field :
101+ def __init__ (self , nm , idx , typ , off , bit_width = None , bit_off = 0 ):
102+ self .nm , self .idx , self .typ , self .off , self .bit_width , self .bit_off = nm , idx , typ , off , bit_width , bit_off
103+
104+ # lazily resolve field descriptors
105+ def _resolve (self , cls ):
106+ if self .bit_width : # handle bitfields ourselves
107+ sl , set_mask = slice (self .off , self .off + (sz := ceildiv (self .bit_width + self .bit_off , 8 ))), ~ ((mask := (1 << self .bit_width ) - 1 ) << self .bit_off )
108+ def b2i (obj ): return int .from_bytes (memoryview (obj ).cast ("B" )[sl ], sys .byteorder )
109+ def bset (obj , v ): memoryview (obj ).cast ("B" )[sl ] = ((b2i (obj ) & set_mask ) | v << self .bit_off ).to_bytes (sz , sys .byteorder )
111110 # FIXME: signedness
112- super ().__init__ (lambda self : (b2i (mv (self )[sl ]) >> bit_off ) & mask ,
113- lambda self ,v : mv (self ).__setitem__ (sl , i2b ((b2i (mv (self )[sl ]) & set_mask ) | (v << bit_off ), sz )))
114- else :
115- sl = slice (off , off + ctypes .sizeof (typ ))
116- def set_with_objs (f ):
117- def wrapper (self , v ):
118- if hasattr (v , '_objects' ) and hasattr (self , '_objects_' ): self ._objects_ [off ] = {'_self_' : v , ** (v ._objects or {})}
119- mv (self ).__setitem__ (sl , bytes (v if isinstance (v , typ ) else f (v )))
120- return wrapper
121- if issubclass (typ , _CArray ):
122- getter = (lambda self : typ .from_buffer (mv (self )[sl ]).value ) if typ ._type_ is ctypes .c_char else (lambda self : typ .from_buffer (mv (self )[sl ]))
123- super ().__init__ (getter , set_with_objs (lambda v : typ (* v )))
124- else : super ().__init__ (lambda self : v .value if isinstance (v := typ .from_buffer (mv (self )[sl ]), _SimpleCData ) else v , set_with_objs (typ ))
125- self .offset = off
111+ cf = property (lambda obj : b2i (obj ) >> self .bit_off & mask , bset )
112+ # pull the CField descriptor from a dummy class, zero length arrays are so ctypes manages references to child objects for us
113+ else : cf = type (self .nm , (ctypes .Structure ,), {"_layout_" : "ms" , "_pack_" : 1 , "_fields_" : [(str (i ), ctypes .c_byte * 0 ) for i in range (self .idx )] +
114+ [("_" , ctypes .c_byte * self .off ), ("v" , self .typ )]}).v # type: ignore
115+ setattr (cls , self .nm , cf )
116+ return cf
117+
118+ def __get__ (self , obj , objtype = None ): return self ._resolve (objtype ).__get__ (obj , objtype ) if objtype else self
119+ def __set__ (self , obj , value ): self ._resolve (obj .__class__ ).__set__ (obj , value )
126120
127121@functools .cache
128122def init_c_struct_t (sz :int , fields : tuple [tuple , ...]):
129123 CStruct = type ("CStruct" , (Struct ,), {'_fields_' : [('_mem_' , ctypes .c_byte * sz )], '_real_fields_' : []})
130- for nm ,ty ,* args in fields :
131- setattr ( CStruct , nm , Field ( * (f := (del_an (ty ), * args ))))
132- CStruct . _real_fields_ . append (( nm ,) + f ) # type: ignore
124+ for i ,( nm ,ty ,* args ) in enumerate ( fields ) :
125+ CStruct . _real_fields_ . append (( nm , * (f := (del_an (ty ), * args )))) # type: ignore
126+ setattr ( CStruct , nm , Field ( nm , i , * f ))
133127 return CStruct
134128def init_c_var (ty , creat_cb ): return (creat_cb (v := del_an (ty )()), v )[1 ]
135129
0 commit comments