aboutsummaryrefslogtreecommitdiff
path: root/internal/compress/zlib/reader_reset.go
blob: fe675c73483b2688b658e348d3ae22efa0a7085a (about) (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package zlib

import (
	"bufio"
	"encoding/binary"
	"errors"
	"io"

	"codeberg.org/lindenii/furgit/internal/adler32"
	"codeberg.org/lindenii/furgit/internal/compress/flate"
	"codeberg.org/lindenii/furgit/internal/intconv"
)

// reset resets receiver to read a new zlib stream.
func (z *Reader) reset(r io.Reader, dict []byte) error {
	*z = Reader{decompressor: z.decompressor}

	var input flate.Reader
	if fr, ok := r.(flate.Reader); ok {
		input = fr
	} else {
		input = bufio.NewReader(r)
	}

	z.r = input

	// Read the header (RFC 1950 section 2.2.).
	readN, err := io.ReadFull(z.r, z.scratch[0:2])

	readNUint64, convErr := intconv.IntToUint64(readN)
	if convErr != nil {
		z.err = convErr

		return z.err
	}

	z.headerRead += readNUint64

	z.err = err
	if z.err != nil {
		if errors.Is(z.err, io.EOF) {
			z.err = io.ErrUnexpectedEOF
		}

		return z.err
	}

	h := binary.BigEndian.Uint16(z.scratch[:2])
	if (z.scratch[0]&0x0f != zlibDeflate) || (z.scratch[0]>>4 > zlibMaxWindow) || (h%31 != 0) {
		z.err = ErrHeader

		return z.err
	}

	haveDict := z.scratch[1]&0x20 != 0
	if haveDict { //nolint:nestif
		readN, z.err = io.ReadFull(z.r, z.scratch[0:4])

		readNUint64, err := intconv.IntToUint64(readN)
		if err != nil {
			z.err = err

			return z.err
		}

		z.headerRead += readNUint64
		if z.err != nil {
			if errors.Is(z.err, io.EOF) {
				z.err = io.ErrUnexpectedEOF
			}

			return z.err
		}

		checksum := binary.BigEndian.Uint32(z.scratch[:4])
		if checksum != adler32.Checksum(dict) {
			z.err = ErrDictionary

			return z.err
		}
	}

	if z.decompressor != nil {
		resetter, ok := z.decompressor.(flate.Resetter)
		if !ok {
			panic("zlib: pooled decompressor does not implement flate.Resetter")
		}

		z.err = resetter.Reset(z.r, dict)
		if z.err != nil {
			return z.err
		}

		progress, ok := z.decompressor.(flate.InputProgress)
		if !ok {
			panic("zlib: pooled decompressor does not implement flate.InputProgress")
		}

		z.progress = progress

		z.digest = adler32.New()

		return nil
	}

	if haveDict {
		z.decompressor = flate.NewReaderDict(z.r, dict)
	} else {
		z.decompressor = flate.NewReader(z.r)
	}

	progress, ok := z.decompressor.(flate.InputProgress)
	if !ok {
		panic("zlib: decompressor does not implement flate.InputProgress")
	}

	z.progress = progress

	z.digest = adler32.New()

	return nil
}