@@ -17,10 +17,8 @@ pub const DecompressingHashReader = struct {
1717 input_reader : * std.Io.Reader ,
1818 expected_hash : [32 ]u8 ,
1919 in_buffer : []u8 ,
20- out_buffer : []u8 ,
21- out_pos : usize ,
22- out_end : usize ,
23- finished : bool ,
20+ in_pos : usize ,
21+ in_end : usize ,
2422 hash_verified : bool ,
2523 interface : std.Io.Reader ,
2624
@@ -62,19 +60,16 @@ pub const DecompressingHashReader = struct {
6260 .input_reader = input_reader ,
6361 .expected_hash = expected_hash ,
6462 .in_buffer = in_buffer ,
65- .out_buffer = out_buffer ,
66- .out_pos = 0 ,
67- .out_end = 0 ,
68- .finished = false ,
63+ .in_pos = 0 ,
64+ .in_end = 0 ,
6965 .hash_verified = false ,
7066 .interface = undefined ,
7167 };
7268 result .interface = .{
7369 .vtable = &.{
7470 .stream = stream ,
75- .discard = discard ,
7671 },
77- .buffer = &.{}, // No buffer needed, we have internal buffering
72+ .buffer = out_buffer ,
7873 .seek = 0 ,
7974 .end = 0 ,
8075 };
@@ -84,172 +79,84 @@ pub const DecompressingHashReader = struct {
8479 pub fn deinit (self : * Self ) void {
8580 _ = c .ZSTD_freeDCtx (self .dctx );
8681 self .allocator_ptr .free (self .in_buffer );
87- self .allocator_ptr .free (self .out_buffer );
82+ self .allocator_ptr .free (self .interface . buffer );
8883 }
8984
9085 fn stream (r : * std.Io.Reader , w : * std.Io.Writer , limit : std.Io.Limit ) std.Io.Reader.StreamError ! usize {
91- const self : * Self = @alignCast (@fieldParentPtr ("interface" , r ));
92- const dest = limit .slice (try w .writableSliceGreedy (1 ));
93- const n = self .read (dest ) catch | err | switch (err ) {
94- error .DecompressionFailed , error .HashMismatch = > return std .Io .Reader .StreamError .ReadFailed ,
95- error .UnexpectedEndOfStream = > return std .Io .Reader .StreamError .EndOfStream ,
96- error .OutOfMemory = > return std .Io .Reader .StreamError .ReadFailed ,
97- };
98- if (n == 0 ) {
99- return std .Io .Reader .StreamError .EndOfStream ;
86+ // This implementation just adds the decompressed data to the buffer and returns 0.
87+ // This simplifies the logic a bit which is encouraged by the Zig reader API.
88+ _ = w ;
89+ _ = limit ;
90+ if (r .end == r .seek ) {
91+ r .end = 0 ;
92+ r .seek = 0 ;
10093 }
101- w .advance (n );
102- return n ;
103- }
104-
105- fn discard (r : * std.Io.Reader , limit : std.Io.Limit ) std.Io.Reader.Error ! usize {
10694 const self : * Self = @alignCast (@fieldParentPtr ("interface" , r ));
10795
108- var total : usize = 0 ;
109- var remaining : ? usize = limit .toInt ();
96+ var in_writer = std .Io .Writer .fixed (self .in_buffer [self .in_end .. ]);
97+ var reached_end = false ;
98+ const bytes_read = self .input_reader .stream (& in_writer , std .Io .Limit .limited (self .in_buffer .len )) catch | err | switch (err ) {
99+ error .EndOfStream = > blk : {
100+ reached_end = true ;
101+ break :blk 0 ;
102+ },
103+ error .ReadFailed = > return error .ReadFailed ,
104+ error .WriteFailed = > unreachable , // fixed buffer writer doesn't fail
105+ };
110106
111- // Consume any buffered output data first.
112- if (self .out_pos < self .out_end ) {
113- const available = self .out_end - self .out_pos ;
114- const to_consume = if (remaining ) | rem | @min (available , rem ) else available ;
115- self .out_pos += to_consume ;
116- total += to_consume ;
117- if (remaining ) | * rem | {
118- rem .* -= to_consume ;
119- if (rem .* == 0 ) return total ;
107+ if (reached_end ) {
108+ // verify hash if not already done
109+ if (! self .hash_verified ) {
110+ var actual_hash : [32 ]u8 = undefined ;
111+ self .hasher .final (& actual_hash );
112+ if (std .mem .eql (u8 , & actual_hash , & self .expected_hash )) {
113+ self .hash_verified = true ;
114+ }
120115 }
116+ return error .EndOfStream ;
121117 }
122118
123- var discard_buffer : [4096 ]u8 = undefined ;
119+ // Update hash with compressed data
120+ self .hasher .update (self .in_buffer [self .in_end .. ][0.. bytes_read ]);
121+ self .in_end += bytes_read ;
124122
125- while (true ) {
126- if (remaining ) | rem | {
127- if (rem == 0 ) break ;
128- }
129-
130- const chunk_len = if (remaining ) | rem | @min (discard_buffer .len , rem ) else discard_buffer .len ;
131- const n = self .read (discard_buffer [0.. chunk_len ]) catch | err | switch (err ) {
132- error .DecompressionFailed , error .HashMismatch = > return std .Io .Reader .Error .ReadFailed ,
133- error .UnexpectedEndOfStream = > return std .Io .Reader .Error .EndOfStream ,
134- error .OutOfMemory = > return std .Io .Reader .Error .ReadFailed ,
135- };
123+ // Decompress just to fill the buffer
124+ var in_buf = c.ZSTD_inBuffer { .src = self .in_buffer .ptr , .size = self .in_end , .pos = self .in_pos };
136125
137- if ( n == 0 ) break ;
126+ var out_buf = c.ZSTD_outBuffer { . dst = r . buffer . ptr , . size = r . buffer . len , . pos = r . end } ;
138127
139- total += n ;
140- if (remaining ) | * rem | {
141- rem .* -= n ;
142- if (rem .* == 0 ) break ;
143- }
128+ const result = c .ZSTD_decompressStream (self .dctx , & out_buf , & in_buf );
129+ if (c .ZSTD_isError (result ) != 0 ) {
130+ // this is still a read failed, as we are not writing to the writer but the internal buffer
131+ return error .ReadFailed ;
144132 }
145-
146- return total ;
147- }
148-
149- pub fn read (self : * Self , dest : []u8 ) Error ! usize {
150- if (dest .len == 0 ) return 0 ;
151-
152- var total_read : usize = 0 ;
153-
154- while (total_read < dest .len ) {
155- // If we have data in the output buffer, copy it
156- if (self .out_pos < self .out_end ) {
157- const available = self .out_end - self .out_pos ;
158- const to_copy = @min (available , dest .len - total_read );
159- @memcpy (dest [total_read .. ][0.. to_copy ], self .out_buffer [self .out_pos .. ][0.. to_copy ]);
160- self .out_pos += to_copy ;
161- total_read += to_copy ;
162-
163- if (total_read == dest .len ) {
164- return total_read ;
165- }
166- }
167-
168- // If finished and no more data in buffer, we're done
169- if (self .finished ) {
170- break ;
171- }
172-
173- // Read more compressed data using a fixed writer
174- var in_writer = std .Io .Writer .fixed (self .in_buffer );
175- var reached_end = false ;
176- const bytes_read = self .input_reader .stream (& in_writer , std .Io .Limit .limited (self .in_buffer .len )) catch | err | switch (err ) {
177- error .EndOfStream = > blk : {
178- reached_end = true ;
179- break :blk 0 ;
180- },
181- error .ReadFailed = > return error .UnexpectedEndOfStream ,
182- error .WriteFailed = > unreachable , // fixed buffer writer doesn't fail
183- };
184-
185- if (bytes_read == 0 ) {
186- if (reached_end ) {
187- if (! self .hash_verified ) {
188- var actual_hash : [32 ]u8 = undefined ;
189- self .hasher .final (& actual_hash );
190- if (! std .mem .eql (u8 , & actual_hash , & self .expected_hash )) {
191- return error .HashMismatch ;
192- }
193- self .hash_verified = true ;
194- }
195- self .finished = true ;
196- break ;
197- }
198-
199- if (total_read > 0 ) {
200- break ;
201- }
202- continue ;
203- }
204-
205- // Update hash with compressed data
206- self .hasher .update (self .in_buffer [0.. bytes_read ]);
207-
208- // Decompress
209- var in_buf = c.ZSTD_inBuffer { .src = self .in_buffer .ptr , .size = bytes_read , .pos = 0 };
210-
211- while (in_buf .pos < in_buf .size ) {
212- var out_buf = c.ZSTD_outBuffer { .dst = self .out_buffer .ptr , .size = self .out_buffer .len , .pos = 0 };
213-
214- const result = c .ZSTD_decompressStream (self .dctx , & out_buf , & in_buf );
215- if (c .ZSTD_isError (result ) != 0 ) {
216- return error .DecompressionFailed ;
217- }
218-
219- if (out_buf .pos > 0 ) {
220- self .out_pos = 0 ;
221- self .out_end = out_buf .pos ;
222-
223- // Copy what we can to dest
224- const to_copy = @min (out_buf .pos , dest .len - total_read );
225- @memcpy (dest [total_read .. ][0.. to_copy ], self .out_buffer [0.. to_copy ]);
226- self .out_pos = to_copy ;
227- total_read += to_copy ;
228-
229- if (total_read == dest .len ) {
230- return total_read ;
231- }
232- }
233-
234- // If decompression is complete
235- if (result == 0 ) {
236- break ;
237- }
238- }
133+ if (in_buf .pos < in_buf .size ) {
134+ self .in_pos = in_buf .pos ;
135+ self .in_end = in_buf .size ;
136+ } else {
137+ self .in_pos = 0 ;
138+ self .in_end = 0 ;
239139 }
240140
241- return total_read ;
141+ r .end = out_buf .pos ;
142+
143+ return 0 ;
242144 }
243145
146+ /// Verify that the hash matches. This should be called after reading is complete.
147+ /// If there is remaining data, it will be discarded.
244148 pub fn verifyComplete (self : * Self ) ! void {
245149 // Read any remaining data to ensure we process the entire stream
246- var discard_buffer : [1024 ]u8 = undefined ;
247150 while (true ) {
248- const n = try self .read (& discard_buffer );
249- if (n == 0 ) break ;
151+ _ = self .interface .discard (std .Io .Limit .unlimited ) catch | err | {
152+ switch (err ) {
153+ error .EndOfStream = > break ,
154+ error .ReadFailed = > return error .ReadFailed ,
155+ }
156+ };
250157 }
251158
252- // The hash should have been verified during reading
159+ // The hash should have been verified during stream
253160 if (! self .hash_verified ) {
254161 return error .HashMismatch ;
255162 }
0 commit comments