Skip to content

Commit 6159ea3

Browse files
authored
Simplify DecompressingHashReader
1 parent 1c83cf2 commit 6159ea3

File tree

5 files changed

+74
-183
lines changed

5 files changed

+74
-183
lines changed

build.zig.zon

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
.{
22
.name = .roc,
33
.version = "0.0.0",
4-
.minimum_zig_version = "0.15.1",
4+
.minimum_zig_version = "0.15.2",
55
.dependencies = .{
66
.afl_kit = .{
77
.url = "git+https://github.com/bhansconnect/zig-afl-kit?ref=main#b863c41ca47ed05729e0b509fb1926c111aa2800",

src/build/builtin_compiler/main.zig

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,6 @@ fn compileModule(
166166
module_env.* = try ModuleEnv.init(gpa, source);
167167
errdefer module_env.deinit();
168168

169-
module_env.common.source = source;
170169
module_env.module_name = module_name;
171170
try module_env.common.calcLineStarts(gpa);
172171

src/bundle/bundle.zig

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -648,11 +648,11 @@ pub fn unbundleStream(
648648
}
649649

650650
// Ensure all data was read and hash was verified
651-
decompress_reader.verifyComplete() catch |err| switch (err) {
652-
error.HashMismatch => return error.HashMismatch,
653-
error.UnexpectedEndOfStream => return error.UnexpectedEndOfStream,
654-
error.DecompressionFailed => return error.DecompressionFailed,
655-
error.OutOfMemory => return error.OutOfMemory,
651+
decompress_reader.verifyComplete() catch |err| {
652+
switch (err) {
653+
error.ReadFailed => return error.DecompressionFailed,
654+
error.HashMismatch => return error.HashMismatch,
655+
}
656656
};
657657
}
658658

src/bundle/streaming_reader.zig

Lines changed: 60 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,8 @@ pub const DecompressingHashReader = struct {
1717
input_reader: *std.Io.Reader,
1818
expected_hash: [32]u8,
1919
in_buffer: []u8,
20-
out_buffer: []u8,
21-
out_pos: usize,
22-
out_end: usize,
23-
finished: bool,
20+
in_pos: usize,
21+
in_end: usize,
2422
hash_verified: bool,
2523
interface: std.Io.Reader,
2624

@@ -62,19 +60,16 @@ pub const DecompressingHashReader = struct {
6260
.input_reader = input_reader,
6361
.expected_hash = expected_hash,
6462
.in_buffer = in_buffer,
65-
.out_buffer = out_buffer,
66-
.out_pos = 0,
67-
.out_end = 0,
68-
.finished = false,
63+
.in_pos = 0,
64+
.in_end = 0,
6965
.hash_verified = false,
7066
.interface = undefined,
7167
};
7268
result.interface = .{
7369
.vtable = &.{
7470
.stream = stream,
75-
.discard = discard,
7671
},
77-
.buffer = &.{}, // No buffer needed, we have internal buffering
72+
.buffer = out_buffer,
7873
.seek = 0,
7974
.end = 0,
8075
};
@@ -84,172 +79,84 @@ pub const DecompressingHashReader = struct {
8479
pub fn deinit(self: *Self) void {
8580
_ = c.ZSTD_freeDCtx(self.dctx);
8681
self.allocator_ptr.free(self.in_buffer);
87-
self.allocator_ptr.free(self.out_buffer);
82+
self.allocator_ptr.free(self.interface.buffer);
8883
}
8984

9085
fn stream(r: *std.Io.Reader, w: *std.Io.Writer, limit: std.Io.Limit) std.Io.Reader.StreamError!usize {
91-
const self: *Self = @alignCast(@fieldParentPtr("interface", r));
92-
const dest = limit.slice(try w.writableSliceGreedy(1));
93-
const n = self.read(dest) catch |err| switch (err) {
94-
error.DecompressionFailed, error.HashMismatch => return std.Io.Reader.StreamError.ReadFailed,
95-
error.UnexpectedEndOfStream => return std.Io.Reader.StreamError.EndOfStream,
96-
error.OutOfMemory => return std.Io.Reader.StreamError.ReadFailed,
97-
};
98-
if (n == 0) {
99-
return std.Io.Reader.StreamError.EndOfStream;
86+
// This implementation just adds the decompressed data to the buffer and returns 0.
87+
// This simplifies the logic a bit which is encouraged by the Zig reader API.
88+
_ = w;
89+
_ = limit;
90+
if (r.end == r.seek) {
91+
r.end = 0;
92+
r.seek = 0;
10093
}
101-
w.advance(n);
102-
return n;
103-
}
104-
105-
fn discard(r: *std.Io.Reader, limit: std.Io.Limit) std.Io.Reader.Error!usize {
10694
const self: *Self = @alignCast(@fieldParentPtr("interface", r));
10795

108-
var total: usize = 0;
109-
var remaining: ?usize = limit.toInt();
96+
var in_writer = std.Io.Writer.fixed(self.in_buffer[self.in_end..]);
97+
var reached_end = false;
98+
const bytes_read = self.input_reader.stream(&in_writer, std.Io.Limit.limited(self.in_buffer.len)) catch |err| switch (err) {
99+
error.EndOfStream => blk: {
100+
reached_end = true;
101+
break :blk 0;
102+
},
103+
error.ReadFailed => return error.ReadFailed,
104+
error.WriteFailed => unreachable, // fixed buffer writer doesn't fail
105+
};
110106

111-
// Consume any buffered output data first.
112-
if (self.out_pos < self.out_end) {
113-
const available = self.out_end - self.out_pos;
114-
const to_consume = if (remaining) |rem| @min(available, rem) else available;
115-
self.out_pos += to_consume;
116-
total += to_consume;
117-
if (remaining) |*rem| {
118-
rem.* -= to_consume;
119-
if (rem.* == 0) return total;
107+
if (reached_end) {
108+
// verify hash if not already done
109+
if (!self.hash_verified) {
110+
var actual_hash: [32]u8 = undefined;
111+
self.hasher.final(&actual_hash);
112+
if (std.mem.eql(u8, &actual_hash, &self.expected_hash)) {
113+
self.hash_verified = true;
114+
}
120115
}
116+
return error.EndOfStream;
121117
}
122118

123-
var discard_buffer: [4096]u8 = undefined;
119+
// Update hash with compressed data
120+
self.hasher.update(self.in_buffer[self.in_end..][0..bytes_read]);
121+
self.in_end += bytes_read;
124122

125-
while (true) {
126-
if (remaining) |rem| {
127-
if (rem == 0) break;
128-
}
129-
130-
const chunk_len = if (remaining) |rem| @min(discard_buffer.len, rem) else discard_buffer.len;
131-
const n = self.read(discard_buffer[0..chunk_len]) catch |err| switch (err) {
132-
error.DecompressionFailed, error.HashMismatch => return std.Io.Reader.Error.ReadFailed,
133-
error.UnexpectedEndOfStream => return std.Io.Reader.Error.EndOfStream,
134-
error.OutOfMemory => return std.Io.Reader.Error.ReadFailed,
135-
};
123+
// Decompress just to fill the buffer
124+
var in_buf = c.ZSTD_inBuffer{ .src = self.in_buffer.ptr, .size = self.in_end, .pos = self.in_pos };
136125

137-
if (n == 0) break;
126+
var out_buf = c.ZSTD_outBuffer{ .dst = r.buffer.ptr, .size = r.buffer.len, .pos = r.end };
138127

139-
total += n;
140-
if (remaining) |*rem| {
141-
rem.* -= n;
142-
if (rem.* == 0) break;
143-
}
128+
const result = c.ZSTD_decompressStream(self.dctx, &out_buf, &in_buf);
129+
if (c.ZSTD_isError(result) != 0) {
130+
// this is still a read failed, as we are not writing to the writer but the internal buffer
131+
return error.ReadFailed;
144132
}
145-
146-
return total;
147-
}
148-
149-
pub fn read(self: *Self, dest: []u8) Error!usize {
150-
if (dest.len == 0) return 0;
151-
152-
var total_read: usize = 0;
153-
154-
while (total_read < dest.len) {
155-
// If we have data in the output buffer, copy it
156-
if (self.out_pos < self.out_end) {
157-
const available = self.out_end - self.out_pos;
158-
const to_copy = @min(available, dest.len - total_read);
159-
@memcpy(dest[total_read..][0..to_copy], self.out_buffer[self.out_pos..][0..to_copy]);
160-
self.out_pos += to_copy;
161-
total_read += to_copy;
162-
163-
if (total_read == dest.len) {
164-
return total_read;
165-
}
166-
}
167-
168-
// If finished and no more data in buffer, we're done
169-
if (self.finished) {
170-
break;
171-
}
172-
173-
// Read more compressed data using a fixed writer
174-
var in_writer = std.Io.Writer.fixed(self.in_buffer);
175-
var reached_end = false;
176-
const bytes_read = self.input_reader.stream(&in_writer, std.Io.Limit.limited(self.in_buffer.len)) catch |err| switch (err) {
177-
error.EndOfStream => blk: {
178-
reached_end = true;
179-
break :blk 0;
180-
},
181-
error.ReadFailed => return error.UnexpectedEndOfStream,
182-
error.WriteFailed => unreachable, // fixed buffer writer doesn't fail
183-
};
184-
185-
if (bytes_read == 0) {
186-
if (reached_end) {
187-
if (!self.hash_verified) {
188-
var actual_hash: [32]u8 = undefined;
189-
self.hasher.final(&actual_hash);
190-
if (!std.mem.eql(u8, &actual_hash, &self.expected_hash)) {
191-
return error.HashMismatch;
192-
}
193-
self.hash_verified = true;
194-
}
195-
self.finished = true;
196-
break;
197-
}
198-
199-
if (total_read > 0) {
200-
break;
201-
}
202-
continue;
203-
}
204-
205-
// Update hash with compressed data
206-
self.hasher.update(self.in_buffer[0..bytes_read]);
207-
208-
// Decompress
209-
var in_buf = c.ZSTD_inBuffer{ .src = self.in_buffer.ptr, .size = bytes_read, .pos = 0 };
210-
211-
while (in_buf.pos < in_buf.size) {
212-
var out_buf = c.ZSTD_outBuffer{ .dst = self.out_buffer.ptr, .size = self.out_buffer.len, .pos = 0 };
213-
214-
const result = c.ZSTD_decompressStream(self.dctx, &out_buf, &in_buf);
215-
if (c.ZSTD_isError(result) != 0) {
216-
return error.DecompressionFailed;
217-
}
218-
219-
if (out_buf.pos > 0) {
220-
self.out_pos = 0;
221-
self.out_end = out_buf.pos;
222-
223-
// Copy what we can to dest
224-
const to_copy = @min(out_buf.pos, dest.len - total_read);
225-
@memcpy(dest[total_read..][0..to_copy], self.out_buffer[0..to_copy]);
226-
self.out_pos = to_copy;
227-
total_read += to_copy;
228-
229-
if (total_read == dest.len) {
230-
return total_read;
231-
}
232-
}
233-
234-
// If decompression is complete
235-
if (result == 0) {
236-
break;
237-
}
238-
}
133+
if (in_buf.pos < in_buf.size) {
134+
self.in_pos = in_buf.pos;
135+
self.in_end = in_buf.size;
136+
} else {
137+
self.in_pos = 0;
138+
self.in_end = 0;
239139
}
240140

241-
return total_read;
141+
r.end = out_buf.pos;
142+
143+
return 0;
242144
}
243145

146+
/// Verify that the hash matches. This should be called after reading is complete.
147+
/// If there is remaining data, it will be discarded.
244148
pub fn verifyComplete(self: *Self) !void {
245149
// Read any remaining data to ensure we process the entire stream
246-
var discard_buffer: [1024]u8 = undefined;
247150
while (true) {
248-
const n = try self.read(&discard_buffer);
249-
if (n == 0) break;
151+
_ = self.interface.discard(std.Io.Limit.unlimited) catch |err| {
152+
switch (err) {
153+
error.EndOfStream => break,
154+
error.ReadFailed => return error.ReadFailed,
155+
}
156+
};
250157
}
251158

252-
// The hash should have been verified during reading
159+
// The hash should have been verified during stream
253160
if (!self.hash_verified) {
254161
return error.HashMismatch;
255162
}

src/bundle/test_streaming.zig

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -162,17 +162,8 @@ test "streaming read with hash mismatch" {
162162
);
163163
defer reader.deinit();
164164

165-
var buffer: [1024]u8 = undefined;
166-
while (true) {
167-
const n = reader.read(&buffer) catch |err| {
168-
try std.testing.expectEqual(err, error.HashMismatch);
169-
return;
170-
};
171-
if (n == 0) break;
172-
}
173-
174-
// Should have gotten hash mismatch error
175-
try std.testing.expect(false);
165+
// verifyComplete discards remaining data and checks hash
166+
try std.testing.expectEqual(error.HashMismatch, reader.verifyComplete());
176167
}
177168

178169
test "different compression levels" {
@@ -218,17 +209,12 @@ test "different compression levels" {
218209
);
219210
defer reader.deinit();
220211

221-
var decompressed = std.array_list.Managed(u8).init(allocator);
222-
defer decompressed.deinit();
212+
var decompressed_writer: std.Io.Writer.Allocating = .init(allocator);
213+
defer decompressed_writer.deinit();
223214

224-
var buffer: [1024]u8 = undefined;
225-
while (true) {
226-
const n = try reader.read(&buffer);
227-
if (n == 0) break;
228-
try decompressed.appendSlice(buffer[0..n]);
229-
}
215+
_ = try reader.interface.streamRemaining(&decompressed_writer.writer);
230216

231-
try std.testing.expectEqualStrings(test_data, decompressed.items);
217+
try std.testing.expectEqualStrings(test_data, decompressed_writer.written());
232218
}
233219

234220
// Higher compression levels should generally produce smaller output
@@ -280,8 +266,7 @@ test "large file streaming extraction" {
280266
defer allocator.free(filename);
281267

282268
// Just verify we successfully bundled a large file
283-
var bundle_list = bundle_writer.toArrayList();
284-
defer bundle_list.deinit(allocator);
285-
try std.testing.expect(bundle_list.items.len > 512); // Should include header and compressed data
269+
const bundle_list = bundle_writer.written();
270+
try std.testing.expect(bundle_list.len > 512); // Should include header and compressed data
286271
// Note: Full round-trip testing with unbundle is done in integration tests
287272
}

0 commit comments

Comments
 (0)