zig icon indicating copy to clipboard operation
zig copied to clipboard

change std.io.BufferedReader and std.io.BufferedWriter to accept the buffer as a runtime parameter, and use AnyReader / AnyWriter

Open andrewrk opened this issue 1 year ago • 1 comments

Problem statement:

  1. BufferedReader/BufferedWriter is instantiated many times, causing code bloat.
  2. BufferedReader/BufferedWriter as a parameter requires use of anytype, and it's infectious.

Proposed solution (starting point):

--- a/lib/std/io/buffered_writer.zig
+++ b/lib/std/io/buffered_writer.zig
@@ -3,10 +3,9 @@ const std = @import("../std.zig");
 const io = std.io;
 const mem = std.mem;
 
-pub fn BufferedWriter(comptime buffer_size: usize, comptime WriterType: type) type {
-    return struct {
-        unbuffered_writer: WriterType,
-        buf: [buffer_size]u8 = undefined,
+const BufferedWriter = struct {
+    unbuffered_writer: AnyWriter,
+    user_provided_buffer: []u8,
     end: usize = 0,
 
     pub const Error = WriterType.Error;
@@ -36,8 +35,10 @@ pub fn BufferedWriter(comptime buffer_size: usize, comptime WriterType: type) ty
         return bytes.len;
     }
 };
-}
 
-pub fn bufferedWriter(underlying_stream: anytype) BufferedWriter(4096, @TypeOf(underlying_stream)) {
-    return .{ .unbuffered_writer = underlying_stream };
+pub fn bufferedWriter(buffer: []u8, unbuffered_writer: AnyWriter) BufferedWriter {
+    return .{
+        .buffer = buffer,
+        .unbuffered_writer = unbuffered_writer,
+    };
 }

Downsides:

  • async functions don't really work across function pointer boundaries
  • opportunities for the programmer to bungle the lifetime of the buffer, causing UAF

Related:

  • #4501

andrewrk avatar Mar 14 '24 02:03 andrewrk

Related sketch:

unbuffered_reader: io.AnyReader,
buffer: []u8,
start: usize,
end: usize,

pub fn init(unbuffered_reader: io.AnyReader, buffer: []u8) BufferedReader {
    return .{
        .unbuffered_reader = unbuffered_reader,
        .buffer = buffer,
        .start = 0,
        .end = 0,
    };
}

/// If the amount requested can be fulfilled from already-buffered data, no
/// underlying read occurs, however, if the amount requested exceeds the amount
/// of buffered data, an underlying read occurs.
///
/// Calling this function will cause at most one underlying read call.
pub fn readv(br: *BufferedReader, vecs: []const io.Vec) anyerror!usize {
    var total_read_len: usize = 0;
    var vec_i: usize = 0;
    while (vec_i < vecs.len) : (vec_i += 1) {
        var vec = vecs[vec_i];
        while (vec.len > 0) {
            if (br.end == br.start) {
                // Caller wants more data but we have none buffered.
                // If the caller has only one vector remaining, then we'll pass
                // it along with the main buffer to underlying read.
                // Otherwise we'll pass the rest of the caller's vectors directly
                // to the underlying reader with one tweak: if the last vector
                // is smaller than the main buffer, then we swap it with the main
                // buffer instead.
                const remaining_vecs = vecs[vec_i..];
                switch (remaining_vecs.len) {
                    0 => unreachable, // vec.len > 0 above
                    1 => {
                        var my_vecs: [2]io.Vec = .{
                            vec,
                            .{
                                .ptr = br.buffer.ptr,
                                .len = br.buffer.len,
                            },
                        };
                        const n = try br.unbuffered_reader.readv(&my_vecs);
                        if (n <= vec.len) {
                            total_read_len += n;
                            return total_read_len;
                        }
                        br.start = 0;
                        br.end = n - vec.len;
                        total_read_len += vec.len;
                        return total_read_len;
                    },
                    else => {
                        const last = remaining_vecs[remaining_vecs.len - 1];
                        if (last.len < br.buffer.len) {
                            const first = remaining_vecs[0];
                            defer {
                                remaining_vecs[0] = first;
                                remaining_vecs[remaining_vecs.len - 1] = last;
                            }
                            remaining_vecs[0] = vec;
                            remaining_vecs[remaining_vecs.len - 1] = .{
                                .ptr = br.buffer.ptr,
                                .len = br.buffer.len,
                            };
                            var n = try br.unbuffered_reader.readv(remaining_vecs);
                            for (remaining_vecs[0 .. remaining_vecs.len - 1]) |v| {
                                if (n >= v.len) {
                                    n -= v.len;
                                    total_read_len += v.len;
                                    continue;
                                }
                                total_read_len += n;
                                return total_read_len;
                            }
                            const copy_len = @min(last.len, n);
                            @memcpy(last.ptr[0..copy_len], br.buffer[0..n]);
                            total_read_len += copy_len;
                            br.start = copy_len;
                            br.end = n;
                            return total_read_len;
                        }
                        total_read_len += try br.unbuffered_reader.readv(remaining_vecs);
                        return total_read_len;
                    },
                }
            }
            const copy_len = @min(vec.len, br.end - br.start);
            @memcpy(vec.ptr[0..copy_len], br.buffer[br.start..][0..copy_len]);
            vec.ptr += copy_len;
            vec.len -= copy_len;
            total_read_len += copy_len;
            br.start += copy_len;
        }
    }
    return total_read_len;
}

pub fn reader(br: *BufferedReader) io.AnyReader {
    return .{
        .context = br,
        .readv = readv,
    };
}

const std = @import("../std.zig");
const io = std.io;
const BufferedReader = @This();

andrewrk avatar Mar 20 '24 23:03 andrewrk