From 83c03b605df05ee4c3b60f1aea0411a0636ca0b5 Mon Sep 17 00:00:00 2001 From: Runxi Yu Date: Sat, 21 Feb 2026 18:11:49 +0800 Subject: iolimit: Add ExpectLengthReader --- internal/iolimit/expect_length_reader.go | 76 +++++++++++++++++++++++++++ internal/iolimit/expect_length_reader_test.go | 70 ++++++++++++++++++++++++ 2 files changed, 146 insertions(+) create mode 100644 internal/iolimit/expect_length_reader.go create mode 100644 internal/iolimit/expect_length_reader_test.go (limited to 'internal') diff --git a/internal/iolimit/expect_length_reader.go b/internal/iolimit/expect_length_reader.go new file mode 100644 index 00000000..477c207f --- /dev/null +++ b/internal/iolimit/expect_length_reader.go @@ -0,0 +1,76 @@ +// Package iolimit provides small internal reader wrappers for length-constrained +// stream I/O. +package iolimit + +import ( + "errors" + "io" +) + +// ErrExpectedLengthExceeded reports that a stream produced bytes beyond the +// expected length. +var ErrExpectedLengthExceeded = errors.New("iolimit: stream exceeded expected length") + +// ExpectLengthReader wraps src and enforces an expected byte length. +// +// It returns io.ErrUnexpectedEOF if src ends before expected bytes are read. +// It returns ErrExpectedLengthExceeded if reads continue beyond the expected +// boundary and src still produces bytes. +// +// This reader does not drain src on close or at the expected boundary. As a +// result, overlength streams are detected only when a caller reads at or past +// the boundary. +func ExpectLengthReader(src io.Reader, expected int64) io.Reader { + return &expectLengthReader{ + src: src, + remaining: expected, + } +} + +type expectLengthReader struct { + src io.Reader + remaining int64 +} + +func (reader *expectLengthReader) Read(dst []byte) (int, error) { + if len(dst) == 0 { + return 0, nil + } + + if reader.remaining == 0 { + var probe [1]byte + n, err := reader.src.Read(probe[:]) + if n > 0 { + return 0, ErrExpectedLengthExceeded + } + if err == nil { + return 0, nil + } + return 0, err + } + + if reader.remaining < 0 { + return 0, ErrExpectedLengthExceeded + } + + if int64(len(dst)) > reader.remaining { + dst = dst[:reader.remaining] + } + + n, err := reader.src.Read(dst) + if n > 0 { + reader.remaining -= int64(n) + } + + if err == io.EOF { + if reader.remaining > 0 { + return n, io.ErrUnexpectedEOF + } + if n > 0 { + return n, nil + } + return 0, io.EOF + } + + return n, err +} diff --git a/internal/iolimit/expect_length_reader_test.go b/internal/iolimit/expect_length_reader_test.go new file mode 100644 index 00000000..503c88ed --- /dev/null +++ b/internal/iolimit/expect_length_reader_test.go @@ -0,0 +1,70 @@ +package iolimit_test + +import ( + "bytes" + "errors" + "io" + "testing" + + "codeberg.org/lindenii/furgit/internal/iolimit" +) + +func TestExpectLengthReaderExact(t *testing.T) { + t.Parallel() + + r := iolimit.ExpectLengthReader(bytes.NewReader([]byte("hello")), 5) + got, err := io.ReadAll(r) + if err != nil { + t.Fatalf("ReadAll error: %v", err) + } + if !bytes.Equal(got, []byte("hello")) { + t.Fatalf("ReadAll = %q, want %q", got, "hello") + } + + buf := make([]byte, 1) + n, err := r.Read(buf) + if n != 0 || !errors.Is(err, io.EOF) { + t.Fatalf("post-boundary Read = (%d,%v), want (0,EOF)", n, err) + } +} + +func TestExpectLengthReaderShort(t *testing.T) { + t.Parallel() + + r := iolimit.ExpectLengthReader(bytes.NewReader([]byte("hey")), 5) + _, err := io.ReadAll(r) + if !errors.Is(err, io.ErrUnexpectedEOF) { + t.Fatalf("ReadAll error = %v, want ErrUnexpectedEOF", err) + } +} + +func TestExpectLengthReaderLongDetectedOnNextRead(t *testing.T) { + t.Parallel() + + r := iolimit.ExpectLengthReader(bytes.NewReader([]byte("hello!")), 5) + buf := make([]byte, 5) + n, err := io.ReadFull(r, buf) + if err != nil { + t.Fatalf("ReadFull error: %v", err) + } + if n != 5 || !bytes.Equal(buf, []byte("hello")) { + t.Fatalf("ReadFull = (%d,%q), want (5,hello)", n, buf) + } + + probe := make([]byte, 1) + n, err = r.Read(probe) + if n != 0 || !errors.Is(err, iolimit.ErrExpectedLengthExceeded) { + t.Fatalf("overflow Read = (%d,%v), want (0,ErrExpectedLengthExceeded)", n, err) + } +} + +func TestExpectLengthReaderEmptyExpected(t *testing.T) { + t.Parallel() + + r := iolimit.ExpectLengthReader(bytes.NewReader(nil), 0) + buf := make([]byte, 1) + n, err := r.Read(buf) + if n != 0 || !errors.Is(err, io.EOF) { + t.Fatalf("Read = (%d,%v), want (0,EOF)", n, err) + } +} -- cgit v1.3.1-10-gc9f91