1- from typing import TYPE_CHECKING , List , Optional , Tuple
1+ from typing import TYPE_CHECKING , Any , List , Optional , Tuple
22
33if TYPE_CHECKING :
44 from ..classes .generated import PackedBitVector
55
66
7- def reshape (data : list , shape : Optional [Tuple [int , ...]] = None ) -> list :
7+ def reshape (data : list , shape : Optional [Tuple [int , ...]] = None ) -> List [ Any ] :
88 if shape is None :
99 return data
1010 if len (shape ) == 1 :
1111 m = shape [0 ]
1212 return [data [i : i + m ] for i in range (0 , len (data ), m )]
1313 elif len (shape ) == 2 :
1414 m , n = shape
15- return [[[ data [i + j : i + j + n ] for j in range (0 , m * n , n )] ] for i in range (0 , len (data ), m * n )]
15+ return [[data [i + j : i + j + n ] for j in range (0 , m * n , n )] for i in range (0 , len (data ), m * n )]
1616 else :
1717 raise ValueError ("Invalid shape" )
1818
@@ -22,7 +22,7 @@ def unpack_ints(
2222 start : int = 0 ,
2323 count : Optional [int ] = None ,
2424 shape : Optional [Tuple [int , ...]] = None ,
25- ) -> List [int ]:
25+ ) -> List [Any ]:
2626 assert packed .m_BitSize is not None
2727
2828 m_BitSize = packed .m_BitSize
@@ -70,13 +70,18 @@ def unpack_floats(
7070 start : int = 0 ,
7171 count : Optional [int ] = None ,
7272 shape : Optional [Tuple [int , ...]] = None ,
73- ) -> List [float ]:
73+ ) -> List [Any ]:
7474 assert packed .m_BitSize is not None and packed .m_Range is not None and packed .m_Start is not None
7575
76- # read as int and cast up to double to prevent loss of precision
77- quantized_f64 = unpack_ints (packed , start , count )
78- scale = packed .m_Range / ((1 << packed .m_BitSize ) - 1 )
79- quantized = [x * scale + packed .m_Start for x in quantized_f64 ]
76+ # avoid zero division of scale
77+ if packed .m_BitSize == 0 :
78+ quantized = [packed .m_Start ] * (packed .m_NumItems if count is None else count )
79+ else :
80+ # read as int and cast up to double to prevent loss of precision
81+ quantized_f64 = unpack_ints (packed , start , count )
82+ scale = packed .m_Range / ((1 << packed .m_BitSize ) - 1 )
83+ quantized = [x * scale + packed .m_Start for x in quantized_f64 ]
84+
8085 return reshape (quantized , shape )
8186
8287
0 commit comments