diff options
| author | 2026-03-05 18:07:04 +0800 | |
|---|---|---|
| committer | 2026-03-05 18:38:33 +0800 | |
| commit | c8f00194c617796e2b83f715b4d2ece80a34a716 (patch) | |
| tree | 52837aea0f3056611275ebc91d84860d65de3561 /internal/compress | |
| parent | internal/compress/flate: Add InputConsumed (diff) | |
| signature | No signature | |
internal/compress/zlib: Use flate's compression consumed counter
Diffstat (limited to 'internal/compress')
| -rw-r--r-- | internal/compress/zlib/reader.go | 38 | ||||
| -rw-r--r-- | internal/compress/zlib/reader_reset.go | 20 |
2 files changed, 25 insertions, 33 deletions
diff --git a/internal/compress/zlib/reader.go b/internal/compress/zlib/reader.go index 2d009887..75ef864c 100644 --- a/internal/compress/zlib/reader.go +++ b/internal/compress/zlib/reader.go @@ -71,36 +71,14 @@ var readerPool = sync.Pool{ type Reader struct { r flate.Reader decompressor io.ReadCloser + progress flate.InputProgress digest hash.Hash32 - counter *countingFlateReader + headerRead uint64 + trailerRead uint64 err error scratch [4]byte } -// countingFlateReader wraps flate input and tracks consumed bytes. -type countingFlateReader struct { - inner flate.Reader - read uint64 -} - -// Read implements io.Reader. -func (reader *countingFlateReader) Read(dst []byte) (int, error) { - n, err := reader.inner.Read(dst) - reader.read += uint64(n) - - return n, err -} - -// ReadByte implements io.ByteReader. -func (reader *countingFlateReader) ReadByte() (byte, error) { - b, err := reader.inner.ReadByte() - if err == nil { - reader.read++ - } - - return b, err -} - // NewReader creates a new ReadCloser. // Reads from the returned ReadCloser read and decompress data from r. // If r does not implement [io.ByteReader], the decompressor may read more @@ -152,7 +130,8 @@ func (z *Reader) Read(p []byte) (int, error) { } // Finished file; check checksum. - _, err = io.ReadFull(z.r, z.scratch[0:4]) + readN, err := io.ReadFull(z.r, z.scratch[0:4]) + z.trailerRead += uint64(readN) if err != nil { if errors.Is(err, io.EOF) { err = io.ErrUnexpectedEOF @@ -178,11 +157,12 @@ func (z *Reader) Read(p []byte) (int, error) { // This count includes the zlib header, deflate payload, and zlib checksum // trailer bytes read by the reader. func (z *Reader) InputConsumed() uint64 { - if z.counter == nil { - return 0 + out := z.headerRead + z.trailerRead + if z.progress != nil { + out += uint64(z.progress.InputConsumed()) } - return z.counter.read + return out } // Close does not close the wrapped [io.Reader] originally passed to [NewReader]. diff --git a/internal/compress/zlib/reader_reset.go b/internal/compress/zlib/reader_reset.go index f374111c..19f06cf0 100644 --- a/internal/compress/zlib/reader_reset.go +++ b/internal/compress/zlib/reader_reset.go @@ -25,11 +25,12 @@ func (z *Reader) reset(r io.Reader, dict []byte) error { input = bufio.NewReader(r) } - z.counter = &countingFlateReader{inner: input} - z.r = z.counter + z.r = input // Read the header (RFC 1950 section 2.2.). - _, z.err = io.ReadFull(z.r, z.scratch[0:2]) + readN, err := io.ReadFull(z.r, z.scratch[0:2]) + z.headerRead += uint64(readN) + z.err = err if z.err != nil { if errors.Is(z.err, io.EOF) { z.err = io.ErrUnexpectedEOF @@ -47,7 +48,8 @@ func (z *Reader) reset(r io.Reader, dict []byte) error { haveDict := z.scratch[1]&0x20 != 0 if haveDict { - _, z.err = io.ReadFull(z.r, z.scratch[0:4]) + readN, z.err = io.ReadFull(z.r, z.scratch[0:4]) + z.headerRead += uint64(readN) if z.err != nil { if errors.Is(z.err, io.EOF) { z.err = io.ErrUnexpectedEOF @@ -74,6 +76,11 @@ func (z *Reader) reset(r io.Reader, dict []byte) error { if z.err != nil { return z.err } + progress, ok := z.decompressor.(flate.InputProgress) + if !ok { + panic("zlib: pooled decompressor does not implement flate.InputProgress") + } + z.progress = progress z.digest = adler32.New() @@ -85,6 +92,11 @@ func (z *Reader) reset(r io.Reader, dict []byte) error { } else { z.decompressor = flate.NewReader(z.r) } + progress, ok := z.decompressor.(flate.InputProgress) + if !ok { + panic("zlib: decompressor does not implement flate.InputProgress") + } + z.progress = progress z.digest = adler32.New() |
