package iolimit import ( "errors" "fmt" "io" "lindenii.org/go/lgo/intconv" ) // ErrExpectedLengthExceeded reports that a stream // produced bytes beyond the expected length. var ErrExpectedLengthExceeded = errors.New("iolimit: stream exceeded expected length") // ExpectLengthReader wraps src and enforces an expected byte length. // // It returns io.ErrUnexpectedEOF // if src ends before expected bytes are read. // It returns ErrExpectedLengthExceeded // if reads continue beyond the expected boundary // and src still produces bytes. // // This reader does not drain src on close or at the expected boundary. // As a result, // overlength streams are detected only // when a caller reads at or past the boundary. func ExpectLengthReader(src io.Reader, expected uint64) io.Reader { return &expectLengthReader{ src: src, remaining: expected, } } type expectLengthReader struct { src io.Reader remaining uint64 } func (reader *expectLengthReader) Read(dst []byte) (int, error) { if len(dst) == 0 { return 0, nil } if reader.remaining == 0 { var probe [1]byte n, err := reader.src.Read(probe[:]) if n > 0 { return 0, ErrExpectedLengthExceeded } if err == nil { return 0, nil } return 0, err } if uint64(len(dst)) > reader.remaining { limit, err := intconv.Uint64ToInt(reader.remaining) if err != nil { return 0, fmt.Errorf("iolimit: %w", err) } dst = dst[:limit] } n, err := reader.src.Read(dst) if n > 0 { read, convErr := intconv.IntToUint64(n) if convErr != nil { return n, fmt.Errorf("iolimit: %w", convErr) } reader.remaining -= read } if errors.Is(err, io.EOF) { if reader.remaining > 0 { return n, io.ErrUnexpectedEOF } if n > 0 { return n, nil } return 0, io.EOF } return n, err }