From 7f84e2e246aa9a9b5adb28ccd86fc61635d5c0ea Mon Sep 17 00:00:00 2001 From: Runxi Yu Date: Thu, 5 Mar 2026 17:26:39 +0800 Subject: internal/zlib: Add counting flate reader --- internal/zlib/reader.go | 55 +++++++++++++++++++++++++++++++++++++------ internal/zlib/reader_reset.go | 12 ++++++---- 2 files changed, 56 insertions(+), 11 deletions(-) (limited to 'internal') diff --git a/internal/zlib/reader.go b/internal/zlib/reader.go index 24db9875..30e98cdb 100644 --- a/internal/zlib/reader.go +++ b/internal/zlib/reader.go @@ -59,36 +59,64 @@ var ( var readerPool = sync.Pool{ New: func() any { - r := new(reader) + r := new(Reader) return r }, } -type reader struct { +// Reader reads and verifies one zlib stream. +// +// Reader implements io.ReadCloser. +type Reader struct { r flate.Reader decompressor io.ReadCloser digest hash.Hash32 + counter *countingFlateReader err error scratch [4]byte } +// countingFlateReader wraps flate input and tracks consumed bytes. +type countingFlateReader struct { + inner flate.Reader + read uint64 +} + +// Read implements io.Reader. +func (reader *countingFlateReader) Read(dst []byte) (int, error) { + n, err := reader.inner.Read(dst) + reader.read += uint64(n) + + return n, err +} + +// ReadByte implements io.ByteReader. +func (reader *countingFlateReader) ReadByte() (byte, error) { + b, err := reader.inner.ReadByte() + if err == nil { + reader.read++ + } + + return b, err +} + // NewReader creates a new ReadCloser. // Reads from the returned ReadCloser read and decompress data from r. // If r does not implement [io.ByteReader], the decompressor may read more // data than necessary from r. // It is the caller's responsibility to call Close on the ReadCloser when done. -func NewReader(r io.Reader) (io.ReadCloser, error) { +func NewReader(r io.Reader) (*Reader, error) { return NewReaderDict(r, nil) } // NewReaderDict is like [NewReader] but uses a preset dictionary. // NewReaderDict ignores the dictionary if the compressed data does not refer to it. // If the compressed data refers to a different dictionary, NewReaderDict returns [ErrDictionary]. -func NewReaderDict(r io.Reader, dict []byte) (io.ReadCloser, error) { +func NewReaderDict(r io.Reader, dict []byte) (*Reader, error) { v := readerPool.Get() - z, ok := v.(*reader) + z, ok := v.(*Reader) if !ok { panic("zlib: pool returned unexpected type") } @@ -101,7 +129,8 @@ func NewReaderDict(r io.Reader, dict []byte) (io.ReadCloser, error) { return z, nil } -func (z *reader) Read(p []byte) (int, error) { +// Read decompresses bytes from receiver into p. +func (z *Reader) Read(p []byte) (int, error) { if z.err != nil { return 0, z.err } @@ -144,10 +173,22 @@ func (z *reader) Read(p []byte) (int, error) { return n, io.EOF } +// InputConsumed returns compressed bytes consumed from stream input. +// +// This count includes the zlib header, deflate payload, and zlib checksum +// trailer bytes read by the reader. +func (z *Reader) InputConsumed() uint64 { + if z.counter == nil { + return 0 + } + + return z.counter.read +} + // Close does not close the wrapped [io.Reader] originally passed to [NewReader]. // In order for the ZLIB checksum to be verified, the reader must be // fully consumed until the [io.EOF]. -func (z *reader) Close() error { +func (z *Reader) Close() error { if z.err != nil && !errors.Is(z.err, io.EOF) { return z.err } diff --git a/internal/zlib/reader_reset.go b/internal/zlib/reader_reset.go index a39337f7..6f15b681 100644 --- a/internal/zlib/reader_reset.go +++ b/internal/zlib/reader_reset.go @@ -15,13 +15,17 @@ import ( "github.com/klauspost/compress/flate" ) -func (z *reader) Reset(r io.Reader, dict []byte) error { - *z = reader{decompressor: z.decompressor} +// Reset resets receiver to read a new zlib stream. +func (z *Reader) Reset(r io.Reader, dict []byte) error { + *z = Reader{decompressor: z.decompressor} + var input flate.Reader if fr, ok := r.(flate.Reader); ok { - z.r = fr + input = fr } else { - z.r = bufio.NewReader(r) + input = bufio.NewReader(r) } + z.counter = &countingFlateReader{inner: input} + z.r = z.counter // Read the header (RFC 1950 section 2.2.). _, z.err = io.ReadFull(z.r, z.scratch[0:2]) -- cgit v1.3.1-10-gc9f91