diff options
| author | 2026-01-30 08:37:45 +0100 | |
|---|---|---|
| committer | 2026-01-30 08:37:45 +0100 | |
| commit | 4c136fc653775d7a5f460ceaa3f7204a64ab4aef (patch) | |
| tree | 46752c5da759ea370030aaa8c6c63b2e2f2d98a9 /internal | |
| parent | tree: Add unit test for TreeEntryNameCompare (diff) | |
| signature | No signature | |
pktline: Add basic pktline support
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/pktline/pktline.go | 163 | ||||
| -rw-r--r-- | internal/pktline/pktline_test.go | 88 |
2 files changed, 251 insertions, 0 deletions
diff --git a/internal/pktline/pktline.go b/internal/pktline/pktline.go new file mode 100644 index 00000000..0b799444 --- /dev/null +++ b/internal/pktline/pktline.go @@ -0,0 +1,163 @@ +package pktline + +import ( + "errors" + "io" +) + +const ( + maxPacketSize = 65520 + maxPacketDataLen = maxPacketSize - 4 +) + +var ( + ErrInvalidHeader = errors.New("pktline: invalid header") + ErrPacketTooLarge = errors.New("pktline: packet too large") + ErrBufferTooSmall = errors.New("pktline: buffer too small") +) + +type Status uint8 + +const ( + StatusEOF Status = iota + StatusData + StatusFlush + StatusDelim + StatusResponseEnd +) + +// ReadLine reads a single pkt-line from r into buf. +// It returns the payload slice, number of payload bytes, and a status. +func ReadLine(r io.Reader, buf []byte) ([]byte, int, Status, error) { + if r == nil { + return nil, 0, StatusEOF, ErrInvalidHeader + } + var header [4]byte + if _, err := io.ReadFull(r, header[:]); err != nil { + if errors.Is(err, io.EOF) { + return nil, 0, StatusEOF, io.EOF + } + if errors.Is(err, io.ErrUnexpectedEOF) { + return nil, 0, StatusEOF, io.ErrUnexpectedEOF + } + return nil, 0, StatusEOF, err + } + + n, err := parseHeader(header[:]) + if err != nil { + return nil, 0, StatusEOF, err + } + switch n { + case 0: + return nil, 0, StatusFlush, nil + case 1: + return nil, 0, StatusDelim, nil + case 2: + return nil, 0, StatusResponseEnd, nil + } + if n < 4 { + return nil, 0, StatusEOF, ErrInvalidHeader + } + n -= 4 + if n > maxPacketDataLen { + return nil, 0, StatusEOF, ErrPacketTooLarge + } + if n > len(buf) { + return nil, 0, StatusEOF, ErrBufferTooSmall + } + if _, err := io.ReadFull(r, buf[:n]); err != nil { + if errors.Is(err, io.ErrUnexpectedEOF) { + return nil, 0, StatusEOF, io.ErrUnexpectedEOF + } + return nil, 0, StatusEOF, err + } + return buf[:n], n, StatusData, nil +} + +// WriteLine writes a single pkt-line with data as its payload. +func WriteLine(w io.Writer, data []byte) error { + if w == nil { + return ErrInvalidHeader + } + if len(data) > maxPacketDataLen { + return ErrPacketTooLarge + } + var header [4]byte + setHeader(header[:], len(data)+4) + if _, err := w.Write(header[:]); err != nil { + return err + } + if len(data) == 0 { + return nil + } + _, err := w.Write(data) + return err +} + +// Flush writes a flush-pkt ("0000"). +func Flush(w io.Writer) error { + return writeLiteral(w, "0000") +} + +// Delim writes a delim-pkt ("0001"). +func Delim(w io.Writer) error { + return writeLiteral(w, "0001") +} + +// ResponseEnd writes a response-end pkt ("0002"). +func ResponseEnd(w io.Writer) error { + return writeLiteral(w, "0002") +} + +func writeLiteral(w io.Writer, s string) error { + if w == nil { + return ErrInvalidHeader + } + _, err := io.WriteString(w, s) + return err +} + +func parseHeader(b []byte) (int, error) { + if len(b) < 4 { + return 0, ErrInvalidHeader + } + v0, ok := hexVal(b[0]) + if !ok { + return 0, ErrInvalidHeader + } + v1, ok := hexVal(b[1]) + if !ok { + return 0, ErrInvalidHeader + } + v2, ok := hexVal(b[2]) + if !ok { + return 0, ErrInvalidHeader + } + v3, ok := hexVal(b[3]) + if !ok { + return 0, ErrInvalidHeader + } + return (v0 << 12) | (v1 << 8) | (v2 << 4) | v3, nil +} + +func setHeader(buf []byte, size int) { + const hex = "0123456789abcdef" + buf[0] = hex[(size>>12)&0x0f] + buf[1] = hex[(size>>8)&0x0f] + buf[2] = hex[(size>>4)&0x0f] + buf[3] = hex[size&0x0f] +} + +// IIRC strconv.ParseUint, encoding/hex.Decode, etc., allocate memory. +func hexVal(b byte) (int, bool) { + switch { + case b >= '0' && b <= '9': + return int(b - '0'), true + case b >= 'a' && b <= 'f': + return int(b-'a') + 10, true + case b >= 'A' && b <= 'F': + return int(b-'A') + 10, true + default: + return 0, false + } +} diff --git a/internal/pktline/pktline_test.go b/internal/pktline/pktline_test.go new file mode 100644 index 00000000..4dae708b --- /dev/null +++ b/internal/pktline/pktline_test.go @@ -0,0 +1,88 @@ +package pktline + +import ( + "bytes" + "errors" + "io" + "testing" +) + +func TestWriteReadLineRoundtrip(t *testing.T) { + var buf bytes.Buffer + payload := []byte("hello\n") + if err := WriteLine(&buf, payload); err != nil { + t.Fatalf("WriteLine: %v", err) + } + + dst := make([]byte, 64) + line, n, status, err := ReadLine(&buf, dst) + if err != nil { + t.Fatalf("ReadLine: %v", err) + } + if status != StatusData { + t.Fatalf("status: got %v, want %v", status, StatusData) + } + if n != len(payload) { + t.Fatalf("n: got %d, want %d", n, len(payload)) + } + if !bytes.Equal(line, payload) { + t.Fatalf("payload: got %q, want %q", line, payload) + } +} + +func TestReadLineSpecialPackets(t *testing.T) { + tests := []struct { + name string + input string + status Status + }{ + {"flush", "0000", StatusFlush}, + {"delim", "0001", StatusDelim}, + {"response_end", "0002", StatusResponseEnd}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := bytes.NewBufferString(tt.input) + dst := make([]byte, 16) + line, n, status, err := ReadLine(r, dst) + if err != nil { + t.Fatalf("ReadLine: %v", err) + } + if status != tt.status { + t.Fatalf("status: got %v, want %v", status, tt.status) + } + if n != 0 || len(line) != 0 { + t.Fatalf("expected empty payload, got %d bytes", n) + } + }) + } +} + +func TestReadLineInvalidHeader(t *testing.T) { + r := bytes.NewBufferString("zzzz") + dst := make([]byte, 16) + _, _, _, err := ReadLine(r, dst) + if !errors.Is(err, ErrInvalidHeader) { + t.Fatalf("expected ErrInvalidHeader, got %v", err) + } +} + +func TestReadLineBufferTooSmall(t *testing.T) { + var buf bytes.Buffer + payload := []byte("abcd") + if err := WriteLine(&buf, payload); err != nil { + t.Fatalf("WriteLine: %v", err) + } + dst := make([]byte, 2) + _, _, _, err := ReadLine(&buf, dst) + if !errors.Is(err, ErrBufferTooSmall) { + t.Fatalf("expected ErrBufferTooSmall, got %v", err) + } +} + +func TestWriteLineTooLarge(t *testing.T) { + payload := make([]byte, maxPacketDataLen+1) + if err := WriteLine(io.Discard, payload); !errors.Is(err, ErrPacketTooLarge) { + t.Fatalf("expected ErrPacketTooLarge, got %v", err) + } +} |
