diff options
| -rw-r--r-- | object/store/memory/writer.go | 8 | ||||
| -rw-r--r-- | object/store/memory/writer_test.go | 183 | ||||
| -rw-r--r-- | object/store/writer.go | 4 |
3 files changed, 191 insertions, 4 deletions
diff --git a/object/store/memory/writer.go b/object/store/memory/writer.go index b661df0a..185b082b 100644 --- a/object/store/memory/writer.go +++ b/object/store/memory/writer.go @@ -1,12 +1,12 @@ package memory import ( - "errors" "fmt" "io" "lindenii.org/go/furgit/object/header" "lindenii.org/go/furgit/object/id" + "lindenii.org/go/furgit/object/store" "lindenii.org/go/furgit/object/typ" "lindenii.org/go/lgo/intconv" ) @@ -31,7 +31,7 @@ func (memory *Memory) WriteBytesFull(raw []byte) (id.ObjectID, error) { content := raw[consumed:] if uint64(len(content)) != size { - return id.ObjectID{}, errors.New("object/store/memory: object header size/content mismatch") + return id.ObjectID{}, fmt.Errorf("%w: header size/content mismatch", store.ErrInvalidObject) } return memory.WriteBytesContent(ty, content) @@ -51,9 +51,9 @@ func (memory *Memory) WriteReaderContent(ty typ.Type, size uint64, src io.Reader switch { case uint64(len(content)) > size: - return id.ObjectID{}, errors.New("object/store/memory: object content longer than declared size") + return id.ObjectID{}, fmt.Errorf("%w: content longer than declared size", store.ErrInvalidObject) case uint64(len(content)) < size: - return id.ObjectID{}, errors.New("object/store/memory: object content shorter than declared size") + return id.ObjectID{}, fmt.Errorf("%w: content shorter than declared size", store.ErrInvalidObject) } return memory.WriteBytesContent(ty, content) diff --git a/object/store/memory/writer_test.go b/object/store/memory/writer_test.go new file mode 100644 index 00000000..e68e1671 --- /dev/null +++ b/object/store/memory/writer_test.go @@ -0,0 +1,183 @@ +package memory_test + +import ( + "bytes" + "testing" + + "lindenii.org/go/furgit/object/header" + "lindenii.org/go/furgit/object/id" + "lindenii.org/go/furgit/object/store/memory" + "lindenii.org/go/furgit/object/typ" +) + +func TestWriteReaderContent(t *testing.T) { + t.Parallel() + + for _, objectFormat := range id.SupportedObjectFormats() { + t.Run(objectFormat.String(), func(t *testing.T) { + t.Parallel() + + store := memory.New(objectFormat) + content := []byte("memory-content\n") + raw := append(header.Append(nil, typ.TypeBlob, uint64(len(content))), content...) + + gotID, err := store.WriteReaderContent(typ.TypeBlob, uint64(len(content)), bytes.NewReader(content)) + if err != nil { + t.Fatalf("WriteReaderContent: %v", err) + } + + wantID := objectFormat.Sum(raw) + if gotID != wantID { + t.Fatalf("WriteReaderContent id = %s, want %s", gotID, wantID) + } + + gotType, gotContent, err := store.ReadBytesContent(gotID) + if err != nil { + t.Fatalf("ReadBytesContent: %v", err) + } + + if gotType != typ.TypeBlob { + t.Fatalf("ReadBytesContent type = %v, want %v", gotType, typ.TypeBlob) + } + + if !bytes.Equal(gotContent, content) { + t.Fatalf("ReadBytesContent content = %q, want %q", gotContent, content) + } + }) + } +} + +func TestWriteReaderFull(t *testing.T) { + t.Parallel() + + for _, objectFormat := range id.SupportedObjectFormats() { + t.Run(objectFormat.String(), func(t *testing.T) { + t.Parallel() + + store := memory.New(objectFormat) + content := []byte("memory-full\n") + raw := append(header.Append(nil, typ.TypeBlob, uint64(len(content))), content...) + + gotID, err := store.WriteReaderFull(bytes.NewReader(raw)) + if err != nil { + t.Fatalf("WriteReaderFull: %v", err) + } + + wantID := objectFormat.Sum(raw) + if gotID != wantID { + t.Fatalf("WriteReaderFull id = %s, want %s", gotID, wantID) + } + + gotRaw, err := store.ReadBytesFull(gotID) + if err != nil { + t.Fatalf("ReadBytesFull: %v", err) + } + + if !bytes.Equal(gotRaw, raw) { + t.Fatalf("ReadBytesFull = %q, want %q", gotRaw, raw) + } + }) + } +} + +func TestWriteBytes(t *testing.T) { + t.Parallel() + + for _, objectFormat := range id.SupportedObjectFormats() { + t.Run(objectFormat.String(), func(t *testing.T) { + t.Parallel() + + store := memory.New(objectFormat) + content := []byte("memory-bytes\n") + raw := append(header.Append(nil, typ.TypeBlob, uint64(len(content))), content...) + + gotID, err := store.WriteBytesContent(typ.TypeBlob, content) + if err != nil { + t.Fatalf("WriteBytesContent: %v", err) + } + + wantID := objectFormat.Sum(raw) + if gotID != wantID { + t.Fatalf("WriteBytesContent id = %s, want %s", gotID, wantID) + } + + gotID2, err := store.WriteBytesFull(raw) + if err != nil { + t.Fatalf("WriteBytesFull: %v", err) + } + + if gotID2 != wantID { + t.Fatalf("WriteBytesFull id = %s, want %s", gotID2, wantID) + } + }) + } +} + +func TestWriteValidationErrors(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + run func(store *memory.Memory) error + }{ + { + name: "content overflow", + run: func(store *memory.Memory) error { + _, err := store.WriteReaderContent(typ.TypeBlob, 1, bytes.NewReader([]byte("hello"))) + + return err //nolint:wrapcheck + }, + }, + { + name: "content short", + run: func(store *memory.Memory) error { + _, err := store.WriteReaderContent(typ.TypeBlob, 5, bytes.NewReader([]byte("x"))) + + return err //nolint:wrapcheck + }, + }, + { + name: "full malformed header", + run: func(store *memory.Memory) error { + _, err := store.WriteReaderFull(bytes.NewReader([]byte("not-a-header"))) + + return err //nolint:wrapcheck + }, + }, + { + name: "full size mismatch", + run: func(store *memory.Memory) error { + _, err := store.WriteReaderFull(bytes.NewReader([]byte("blob 1\x00hello"))) + + return err //nolint:wrapcheck + }, + }, + { + name: "bytes malformed header", + run: func(store *memory.Memory) error { + _, err := store.WriteBytesFull([]byte("not-a-header")) + + return err //nolint:wrapcheck + }, + }, + } + + for _, objectFormat := range id.SupportedObjectFormats() { + t.Run(objectFormat.String(), func(t *testing.T) { + t.Parallel() + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + store := memory.New(objectFormat) + + err := tc.run(store) + if err == nil { + t.Fatalf("expected error") + } + }) + } + }) + } +} diff --git a/object/store/writer.go b/object/store/writer.go index 289e6df4..e7efaab4 100644 --- a/object/store/writer.go +++ b/object/store/writer.go @@ -1,6 +1,7 @@ package store import ( + "errors" "io" "lindenii.org/go/furgit/common/iowrap" @@ -8,6 +9,9 @@ import ( "lindenii.org/go/furgit/object/typ" ) +// ErrInvalidObject indicates a malformed object passed to a write. +var ErrInvalidObject = errors.New("object/store: invalid object") + // ObjectWriter writes individual Git objects. type ObjectWriter interface { // WriteBytesFull writes one full serialized object byte slice as "type size\x00content". |
