aboutsummaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorGravatar Runxi Yu2026-03-05 17:26:39 +0800
committerGravatar Runxi Yu2026-03-05 17:54:14 +0800
commit7f84e2e246aa9a9b5adb28ccd86fc61635d5c0ea (patch)
tree4b2db9ae70667ddb8e941219099eb1dd6053955f /internal
parenttestgit: Add pack object reader and many object maker (diff)
signatureNo signature
internal/zlib: Add counting flate reader
Diffstat (limited to 'internal')
-rw-r--r--internal/zlib/reader.go55
-rw-r--r--internal/zlib/reader_reset.go12
2 files changed, 56 insertions, 11 deletions
diff --git a/internal/zlib/reader.go b/internal/zlib/reader.go
index 24db9875..30e98cdb 100644
--- a/internal/zlib/reader.go
+++ b/internal/zlib/reader.go
@@ -59,36 +59,64 @@ var (
var readerPool = sync.Pool{
New: func() any {
- r := new(reader)
+ r := new(Reader)
return r
},
}
-type reader struct {
+// Reader reads and verifies one zlib stream.
+//
+// Reader implements io.ReadCloser.
+type Reader struct {
r flate.Reader
decompressor io.ReadCloser
digest hash.Hash32
+ counter *countingFlateReader
err error
scratch [4]byte
}
+// countingFlateReader wraps flate input and tracks consumed bytes.
+type countingFlateReader struct {
+ inner flate.Reader
+ read uint64
+}
+
+// Read implements io.Reader.
+func (reader *countingFlateReader) Read(dst []byte) (int, error) {
+ n, err := reader.inner.Read(dst)
+ reader.read += uint64(n)
+
+ return n, err
+}
+
+// ReadByte implements io.ByteReader.
+func (reader *countingFlateReader) ReadByte() (byte, error) {
+ b, err := reader.inner.ReadByte()
+ if err == nil {
+ reader.read++
+ }
+
+ return b, err
+}
+
// NewReader creates a new ReadCloser.
// Reads from the returned ReadCloser read and decompress data from r.
// If r does not implement [io.ByteReader], the decompressor may read more
// data than necessary from r.
// It is the caller's responsibility to call Close on the ReadCloser when done.
-func NewReader(r io.Reader) (io.ReadCloser, error) {
+func NewReader(r io.Reader) (*Reader, error) {
return NewReaderDict(r, nil)
}
// NewReaderDict is like [NewReader] but uses a preset dictionary.
// NewReaderDict ignores the dictionary if the compressed data does not refer to it.
// If the compressed data refers to a different dictionary, NewReaderDict returns [ErrDictionary].
-func NewReaderDict(r io.Reader, dict []byte) (io.ReadCloser, error) {
+func NewReaderDict(r io.Reader, dict []byte) (*Reader, error) {
v := readerPool.Get()
- z, ok := v.(*reader)
+ z, ok := v.(*Reader)
if !ok {
panic("zlib: pool returned unexpected type")
}
@@ -101,7 +129,8 @@ func NewReaderDict(r io.Reader, dict []byte) (io.ReadCloser, error) {
return z, nil
}
-func (z *reader) Read(p []byte) (int, error) {
+// Read decompresses bytes from receiver into p.
+func (z *Reader) Read(p []byte) (int, error) {
if z.err != nil {
return 0, z.err
}
@@ -144,10 +173,22 @@ func (z *reader) Read(p []byte) (int, error) {
return n, io.EOF
}
+// InputConsumed returns compressed bytes consumed from stream input.
+//
+// This count includes the zlib header, deflate payload, and zlib checksum
+// trailer bytes read by the reader.
+func (z *Reader) InputConsumed() uint64 {
+ if z.counter == nil {
+ return 0
+ }
+
+ return z.counter.read
+}
+
// Close does not close the wrapped [io.Reader] originally passed to [NewReader].
// In order for the ZLIB checksum to be verified, the reader must be
// fully consumed until the [io.EOF].
-func (z *reader) Close() error {
+func (z *Reader) Close() error {
if z.err != nil && !errors.Is(z.err, io.EOF) {
return z.err
}
diff --git a/internal/zlib/reader_reset.go b/internal/zlib/reader_reset.go
index a39337f7..6f15b681 100644
--- a/internal/zlib/reader_reset.go
+++ b/internal/zlib/reader_reset.go
@@ -15,13 +15,17 @@ import (
"github.com/klauspost/compress/flate"
)
-func (z *reader) Reset(r io.Reader, dict []byte) error {
- *z = reader{decompressor: z.decompressor}
+// Reset resets receiver to read a new zlib stream.
+func (z *Reader) Reset(r io.Reader, dict []byte) error {
+ *z = Reader{decompressor: z.decompressor}
+ var input flate.Reader
if fr, ok := r.(flate.Reader); ok {
- z.r = fr
+ input = fr
} else {
- z.r = bufio.NewReader(r)
+ input = bufio.NewReader(r)
}
+ z.counter = &countingFlateReader{inner: input}
+ z.r = z.counter
// Read the header (RFC 1950 section 2.2.).
_, z.err = io.ReadFull(z.r, z.scratch[0:2])