diff options
Diffstat (limited to 'objectid')
| -rw-r--r-- | objectid/objectid.go | 186 | ||||
| -rw-r--r-- | objectid/objectid_test.go | 144 |
2 files changed, 330 insertions, 0 deletions
diff --git a/objectid/objectid.go b/objectid/objectid.go new file mode 100644 index 00000000..f97fd197 --- /dev/null +++ b/objectid/objectid.go @@ -0,0 +1,186 @@ +// Package objectid provides object ID and algorithm primitives for Git objects. +package objectid + +import ( + "crypto/sha1" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "hash" +) + +var ( + // ErrInvalidAlgorithm indicates an unsupported object ID algorithm. + ErrInvalidAlgorithm = errors.New("objectid: invalid algorithm") + // ErrInvalidObjectID indicates malformed object ID data. + ErrInvalidObjectID = errors.New("objectid: invalid object id") +) + +// maxObjectIDSize MUST be >= the largest supported algorithm size. +const maxObjectIDSize = sha256.Size + +// Algorithm identifies the hash algorithm used for Git object IDs. +type Algorithm uint8 + +const ( + AlgorithmUnknown Algorithm = iota + AlgorithmSHA1 + AlgorithmSHA256 +) + +type algorithmDetails struct { + name string + size int + sum func([]byte) ObjectID + new func() hash.Hash +} + +var algorithmTable = [...]algorithmDetails{ + AlgorithmUnknown: {}, + AlgorithmSHA1: { + name: "sha1", + size: sha1.Size, + sum: func(data []byte) ObjectID { + sum := sha1.Sum(data) + var id ObjectID + copy(id.data[:], sum[:]) + id.algo = AlgorithmSHA1 + return id + }, + new: func() hash.Hash { + return sha1.New() + }, + }, + AlgorithmSHA256: { + name: "sha256", + size: sha256.Size, + sum: func(data []byte) ObjectID { + sum := sha256.Sum256(data) + var id ObjectID + copy(id.data[:], sum[:]) + id.algo = AlgorithmSHA256 + return id + }, + new: func() hash.Hash { + return sha256.New() + }, + }, +} + +var algorithmByName = map[string]Algorithm{} + +func init() { + for algo, info := range algorithmTable { + if info.name == "" { + continue + } + algorithmByName[info.name] = Algorithm(algo) + } +} + +func (algo Algorithm) info() algorithmDetails { + return algorithmTable[algo] +} + +// ParseAlgorithm parses a canonical algorithm name (e.g. "sha1", "sha256"). +func ParseAlgorithm(s string) (Algorithm, bool) { + algo, ok := algorithmByName[s] + return algo, ok +} + +// Size returns the hash size in bytes. +func (algo Algorithm) Size() int { + return algo.info().size +} + +// String returns the canonical algorithm name. +func (algo Algorithm) String() string { + inf := algo.info() + if inf.name == "" { + return "unknown" + } + return inf.name +} + +// HexLen returns the encoded hexadecimal length. +func (algo Algorithm) HexLen() int { + return algo.Size() * 2 +} + +// Sum computes an object ID from raw data using the selected algorithm. +func (algo Algorithm) Sum(data []byte) ObjectID { + return algo.info().sum(data) +} + +// New returns a new hash.Hash for this algorithm. +func (algo Algorithm) New() (hash.Hash, error) { + newFn := algo.info().new + if newFn == nil { + return nil, ErrInvalidAlgorithm + } + return newFn(), nil +} + +// ObjectID represents a Git object ID. +type ObjectID struct { + algo Algorithm + data [maxObjectIDSize]byte +} + +// Algorithm returns the object ID's hash algorithm. +func (id ObjectID) Algorithm() Algorithm { + return id.algo +} + +// Size returns the object ID size in bytes. +func (id ObjectID) Size() int { + return id.algo.Size() +} + +// String returns the canonical hex representation. +func (id ObjectID) String() string { + size := id.Size() + return hex.EncodeToString(id.data[:size]) +} + +// Bytes returns a copy of the object ID bytes. +func (id ObjectID) Bytes() []byte { + size := id.Size() + return append([]byte(nil), id.data[:size]...) +} + +// ParseHex parses an object ID from hex for the specified algorithm. +func ParseHex(algo Algorithm, s string) (ObjectID, error) { + var id ObjectID + if algo.Size() == 0 { + return id, ErrInvalidAlgorithm + } + if len(s)%2 != 0 { + return id, fmt.Errorf("%w: odd hex length %d", ErrInvalidObjectID, len(s)) + } + if len(s) != algo.HexLen() { + return id, fmt.Errorf("%w: got %d chars, expected %d", ErrInvalidObjectID, len(s), algo.HexLen()) + } + decoded, err := hex.DecodeString(s) + if err != nil { + return id, fmt.Errorf("%w: decode: %v", ErrInvalidObjectID, err) + } + copy(id.data[:], decoded) + id.algo = algo + return id, nil +} + +// FromBytes builds an object ID from raw bytes for the specified algorithm. +func FromBytes(algo Algorithm, b []byte) (ObjectID, error) { + var id ObjectID + if algo.Size() == 0 { + return id, ErrInvalidAlgorithm + } + if len(b) != algo.Size() { + return id, fmt.Errorf("%w: got %d bytes, expected %d", ErrInvalidObjectID, len(b), algo.Size()) + } + copy(id.data[:], b) + id.algo = algo + return id, nil +} diff --git a/objectid/objectid_test.go b/objectid/objectid_test.go new file mode 100644 index 00000000..2598a4ed --- /dev/null +++ b/objectid/objectid_test.go @@ -0,0 +1,144 @@ +package objectid + +import ( + "bytes" + "testing" +) + +func TestParseAlgorithm(t *testing.T) { + t.Parallel() + + algo, ok := ParseAlgorithm("sha1") + if !ok || algo != AlgorithmSHA1 { + t.Fatalf("ParseAlgorithm(sha1) = (%v,%v)", algo, ok) + } + + algo, ok = ParseAlgorithm("sha256") + if !ok || algo != AlgorithmSHA256 { + t.Fatalf("ParseAlgorithm(sha256) = (%v,%v)", algo, ok) + } + + if _, ok := ParseAlgorithm("md5"); ok { + t.Fatalf("ParseAlgorithm(md5) should fail") + } +} + +func TestParseHexRoundtrip(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + algo Algorithm + hex string + }{ + { + name: "sha1", + algo: AlgorithmSHA1, + hex: "0123456789abcdef0123456789abcdef01234567", + }, + { + name: "sha256", + algo: AlgorithmSHA256, + hex: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + id, err := ParseHex(tt.algo, tt.hex) + if err != nil { + t.Fatalf("ParseHex failed: %v", err) + } + if got := id.String(); got != tt.hex { + t.Fatalf("String() = %q, want %q", got, tt.hex) + } + if got := id.Size(); got != tt.algo.Size() { + t.Fatalf("Size() = %d, want %d", got, tt.algo.Size()) + } + + raw := id.Bytes() + if len(raw) != tt.algo.Size() { + t.Fatalf("Bytes len = %d, want %d", len(raw), tt.algo.Size()) + } + + id2, err := FromBytes(tt.algo, raw) + if err != nil { + t.Fatalf("FromBytes failed: %v", err) + } + if id2.String() != tt.hex { + t.Fatalf("FromBytes roundtrip = %q, want %q", id2.String(), tt.hex) + } + }) + } +} + +func TestParseHexErrors(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + algo Algorithm + hex string + }{ + {"unknown algo", AlgorithmUnknown, "00"}, + {"odd len", AlgorithmSHA1, "0"}, + {"wrong len", AlgorithmSHA1, "0123"}, + {"invalid hex", AlgorithmSHA1, "zz23456789abcdef0123456789abcdef01234567"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if _, err := ParseHex(tt.algo, tt.hex); err == nil { + t.Fatalf("expected ParseHex error") + } + }) + } +} + +func TestFromBytesErrors(t *testing.T) { + t.Parallel() + + if _, err := FromBytes(AlgorithmUnknown, []byte{1, 2}); err == nil { + t.Fatalf("expected FromBytes unknown algo error") + } + if _, err := FromBytes(AlgorithmSHA1, []byte{1, 2}); err == nil { + t.Fatalf("expected FromBytes wrong size error") + } +} + +func TestBytesReturnsCopy(t *testing.T) { + t.Parallel() + + id, err := ParseHex(AlgorithmSHA1, "0123456789abcdef0123456789abcdef01234567") + if err != nil { + t.Fatalf("ParseHex failed: %v", err) + } + + b1 := id.Bytes() + b2 := id.Bytes() + if !bytes.Equal(b1, b2) { + t.Fatalf("Bytes mismatch") + } + b1[0] ^= 0xff + if bytes.Equal(b1, b2) { + t.Fatalf("Bytes should return independent copies") + } +} + +func TestAlgorithmSum(t *testing.T) { + t.Parallel() + + id1 := AlgorithmSHA1.Sum([]byte("hello")) + if id1.Algorithm() != AlgorithmSHA1 || id1.Size() != AlgorithmSHA1.Size() { + t.Fatalf("sha1 sum produced invalid object id") + } + + id2 := AlgorithmSHA256.Sum([]byte("hello")) + if id2.Algorithm() != AlgorithmSHA256 || id2.Size() != AlgorithmSHA256.Size() { + t.Fatalf("sha256 sum produced invalid object id") + } + + if id1.String() == id2.String() { + t.Fatalf("sha1 and sha256 should differ") + } +} |
