aboutsummaryrefslogtreecommitdiff
path: root/reachability
diff options
context:
space:
mode:
authorGravatar Runxi Yu2026-03-04 08:26:56 +0800
committerGravatar Runxi Yu2026-03-04 08:59:53 +0800
commitab7501be34032fb9e5c48726a68ae90a917af9eb (patch)
tree20d005647569befea8133e953c3270e8fd2a2a5b /reachability
parent*: gofumpt (diff)
signatureNo signature
*: Lint
Diffstat (limited to 'reachability')
-rw-r--r--reachability/errors.go2
-rw-r--r--reachability/helpers.go7
-rw-r--r--reachability/integration_test.go89
-rw-r--r--reachability/peel.go8
-rw-r--r--reachability/reachability.go13
-rw-r--r--reachability/unit_test.go61
-rw-r--r--reachability/walk.go33
7 files changed, 191 insertions, 22 deletions
diff --git a/reachability/errors.go b/reachability/errors.go
index e52bf0a4..7d0d9a18 100644
--- a/reachability/errors.go
+++ b/reachability/errors.go
@@ -29,9 +29,11 @@ func (e *ErrObjectType) Error() string {
if !gotOK {
gotName = fmt.Sprintf("type(%d)", e.Got)
}
+
wantName, wantOK := objecttype.Name(e.Want)
if !wantOK {
wantName = fmt.Sprintf("type(%d)", e.Want)
}
+
return fmt.Sprintf("reachability: object %s has type %s, want %s", e.OID, gotName, wantName)
}
diff --git a/reachability/helpers.go b/reachability/helpers.go
index 1368a3f1..41a2f80b 100644
--- a/reachability/helpers.go
+++ b/reachability/helpers.go
@@ -22,7 +22,9 @@ func containsOID(set map[objectid.ObjectID]struct{}, id objectid.ObjectID) bool
if len(set) == 0 {
return false
}
+
_, ok := set[id]
+
return ok
}
@@ -39,8 +41,10 @@ func (r *Reachability) readHeaderType(id objectid.ObjectID) (objecttype.Type, er
if errors.Is(err, objectstore.ErrObjectNotFound) {
return objecttype.TypeInvalid, &ErrObjectMissing{OID: id}
}
+
return objecttype.TypeInvalid, err
}
+
return ty, nil
}
@@ -49,6 +53,7 @@ func (walk *Walk) readBytesContent(id objectid.ObjectID) ([]byte, error) {
if err != nil {
return nil, err
}
+
return content, nil
}
@@ -58,7 +63,9 @@ func (r *Reachability) readBytesContent(id objectid.ObjectID) ([]byte, error) {
if errors.Is(err, objectstore.ErrObjectNotFound) {
return nil, &ErrObjectMissing{OID: id}
}
+
return nil, err
}
+
return content, nil
}
diff --git a/reachability/integration_test.go b/reachability/integration_test.go
index 10668006..079ce5fc 100644
--- a/reachability/integration_test.go
+++ b/reachability/integration_test.go
@@ -47,8 +47,11 @@ func TestWalkCommitsMatchesGitRevList(t *testing.T) {
nil,
map[objectid.ObjectID]struct{}{merge: {}},
)
+
got := oidSetFromSeq(walk.Seq())
- if err := walk.Err(); err != nil {
+
+ err := walk.Err()
+ if err != nil {
t.Fatalf("walk.Err(): %v", err)
}
@@ -62,12 +65,17 @@ func TestWalkCommitsMatchesGitRevList(t *testing.T) {
nil,
map[objectid.ObjectID]struct{}{tag2: {}},
)
+
peelGot := oidSetFromSeq(peelWalk.Seq())
- if err := peelWalk.Err(); err != nil {
+
+ err = peelWalk.Err()
+ if err != nil {
t.Fatalf("peelWalk.Err(): %v", err)
}
+
wantWithTags := maps.Clone(want)
wantWithTags[tag1] = struct{}{}
+
wantWithTags[tag2] = struct{}{}
if !maps.Equal(peelGot, wantWithTags) {
t.Fatalf("tag-root commit walk mismatch:\n got=%v\nwant=%v", sortedOIDStrings(peelGot), sortedOIDStrings(wantWithTags))
@@ -104,8 +112,11 @@ func TestWalkObjectsMatchesGitRevListObjects(t *testing.T) {
nil,
map[objectid.ObjectID]struct{}{head: {}},
)
+
got := oidSetFromSeq(walk.Seq())
- if err := walk.Err(); err != nil {
+
+ err := walk.Err()
+ if err != nil {
t.Fatalf("walk.Err(): %v", err)
}
@@ -119,10 +130,14 @@ func TestWalkObjectsMatchesGitRevListObjects(t *testing.T) {
nil,
map[objectid.ObjectID]struct{}{tag: {}},
)
+
peelGot := oidSetFromSeq(peelWalk.Seq())
- if err := peelWalk.Err(); err != nil {
+
+ err = peelWalk.Err()
+ if err != nil {
t.Fatalf("peelWalk.Err(): %v", err)
}
+
wantFromTag := gitRevListSet(t, testRepo, true, []objectid.ObjectID{tag}, nil)
if !maps.Equal(peelGot, wantFromTag) {
t.Fatalf("tag-root object walk mismatch:\n got=%v\nwant=%v", sortedOIDStrings(peelGot), sortedOIDStrings(wantFromTag))
@@ -133,11 +148,16 @@ func TestWalkObjectsMatchesGitRevListObjects(t *testing.T) {
map[objectid.ObjectID]struct{}{base: {}},
map[objectid.ObjectID]struct{}{head: {}},
)
+
withHave := oidSetFromSeq(walkWithHave.Seq())
- if err := walkWithHave.Err(); err != nil {
+
+ err = walkWithHave.Err()
+ if err != nil {
t.Fatalf("walkWithHave.Err(): %v", err)
}
- if _, ok := withHave[base]; ok {
+
+ _, ok := withHave[base]
+ if ok {
t.Fatalf("walk output unexpectedly contains have commit %s", base)
}
})
@@ -170,7 +190,9 @@ func TestIsAncestorMatchesGitMergeBase(t *testing.T) {
if err != nil {
t.Fatalf("IsAncestor(c1, tag): %v", err)
}
- if want := gitMergeBaseIsAncestor(t, testRepo, c1, c2); got != want {
+
+ want := gitMergeBaseIsAncestor(t, testRepo, c1, c2)
+ if got != want {
t.Fatalf("IsAncestor(c1, tag)=%v, want %v", got, want)
}
@@ -178,7 +200,9 @@ func TestIsAncestorMatchesGitMergeBase(t *testing.T) {
if err != nil {
t.Fatalf("IsAncestor(c3, c2): %v", err)
}
- if want := gitMergeBaseIsAncestor(t, testRepo, c3, c2); got != want {
+
+ want = gitMergeBaseIsAncestor(t, testRepo, c3, c2)
+ if got != want {
t.Fatalf("IsAncestor(c3, c2)=%v, want %v", got, want)
}
})
@@ -195,12 +219,15 @@ func TestCheckConnectedMissingObject(t *testing.T) {
})
_, treeID, commitID := testRepo.MakeCommit(t, "missing")
- if err := os.Remove(looseObjectPath(testRepo.Dir(), treeID)); err != nil {
+
+ err := os.Remove(looseObjectPath(testRepo.Dir(), treeID))
+ if err != nil {
t.Fatalf("remove tree object: %v", err)
}
r := openReachabilityFromTestRepo(t, testRepo)
- err := r.CheckConnected(
+
+ err = r.CheckConnected(
reachability.DomainObjects,
nil,
map[objectid.ObjectID]struct{}{commitID: {}},
@@ -208,10 +235,12 @@ func TestCheckConnectedMissingObject(t *testing.T) {
if err == nil {
t.Fatal("expected error")
}
+
var missing *reachability.ErrObjectMissing
if !errors.As(err, &missing) {
t.Fatalf("expected ErrObjectMissing, got %T (%v)", err, err)
}
+
if missing.OID != treeID {
t.Fatalf("missing oid = %s, want %s", missing.OID, treeID)
}
@@ -246,14 +275,21 @@ func TestWalkOnPackedOnlyRepo(t *testing.T) {
nil,
map[objectid.ObjectID]struct{}{c2: {}},
)
+
got := oidSetFromSeq(walk.Seq())
- if err := walk.Err(); err != nil {
+
+ err := walk.Err()
+ if err != nil {
t.Fatalf("walk.Err(): %v", err)
}
- if _, ok := got[c2]; !ok {
+
+ _, ok := got[c2]
+ if !ok {
t.Fatalf("walk output missing HEAD commit %s", c2)
}
- if _, ok := got[c1]; !ok {
+
+ _, ok = got[c1]
+ if !ok {
t.Fatalf("walk output missing parent commit %s", c1)
}
})
@@ -261,16 +297,19 @@ func TestWalkOnPackedOnlyRepo(t *testing.T) {
func openReachabilityFromTestRepo(t *testing.T, testRepo *testgit.TestRepo) *reachability.Reachability {
t.Helper()
+
root, err := os.OpenRoot(testRepo.Dir())
if err != nil {
t.Fatalf("os.OpenRoot: %v", err)
}
+
t.Cleanup(func() { _ = root.Close() })
repo, err := repository.Open(root)
if err != nil {
t.Fatalf("repository.Open: %v", err)
}
+
t.Cleanup(func() { _ = repo.Close() })
return reachability.New(repo.Objects())
@@ -278,10 +317,13 @@ func openReachabilityFromTestRepo(t *testing.T, testRepo *testgit.TestRepo) *rea
func oidSetFromSeq(seq func(func(objectid.ObjectID) bool)) map[objectid.ObjectID]struct{} {
out := make(map[objectid.ObjectID]struct{})
+
seq(func(id objectid.ObjectID) bool {
out[id] = struct{}{}
+
return true
})
+
return out
}
@@ -298,9 +340,11 @@ func gitRevListSet(
if includeObjects {
args = append(args, "--objects")
}
+
for _, want := range wants {
args = append(args, want.String())
}
+
if len(haves) > 0 {
args = append(args, "--not")
for _, have := range haves {
@@ -310,21 +354,28 @@ func gitRevListSet(
out := testRepo.Run(t, args...)
set := make(map[objectid.ObjectID]struct{})
+
for line := range strings.SplitSeq(strings.TrimSpace(out), "\n") {
line = strings.TrimSpace(line)
if line == "" {
continue
}
+
tok := line
- if i := strings.IndexByte(tok, ' '); i >= 0 {
+
+ i := strings.IndexByte(tok, ' ')
+ if i >= 0 {
tok = tok[:i]
}
+
id, err := objectid.ParseHex(testRepo.Algorithm(), tok)
if err != nil {
t.Fatalf("parse rev-list oid %q: %v", tok, err)
}
+
set[id] = struct{}{}
}
+
return set
}
@@ -332,6 +383,7 @@ func gitMergeBaseIsAncestor(t *testing.T, testRepo *testgit.TestRepo, a, b objec
t.Helper()
// testgit.Run fatals on non-zero status, so we compare merge-base output.
mb := testRepo.Run(t, "merge-base", a.String(), b.String())
+
return mb == a.String()
}
@@ -340,33 +392,40 @@ func sortedOIDStrings(set map[objectid.ObjectID]struct{}) []string {
for id := range set {
out = append(out, id.String())
}
+
slices.Sort(out)
+
return out
}
func looseObjectPath(repoDir string, id objectid.ObjectID) string {
hex := id.String()
+
return filepath.Join(repoDir, "objects", hex[:2], hex[2:])
}
func assertPackedOnly(t *testing.T, repoDir string) {
t.Helper()
+
objectsDir := filepath.Join(repoDir, "objects")
entries, err := os.ReadDir(objectsDir)
if err != nil {
t.Fatalf("ReadDir(objects): %v", err)
}
+
for _, entry := range entries {
name := entry.Name()
if name == "pack" || name == "info" {
continue
}
+
if len(name) == 2 && isHexDirName(name) {
subEntries, err := os.ReadDir(filepath.Join(objectsDir, name))
if err != nil {
t.Fatalf("ReadDir(objects/%s): %v", name, err)
}
+
if len(subEntries) != 0 {
t.Fatalf("found loose objects in %s", filepath.Join(objectsDir, name))
}
@@ -378,11 +437,13 @@ func isHexDirName(name string) bool {
if len(name) != 2 {
return false
}
+
for i := range 2 {
c := name[i]
if (c < '0' || c > '9') && (c < 'a' || c > 'f') {
return false
}
}
+
return true
}
diff --git a/reachability/peel.go b/reachability/peel.go
index 9df9bb4d..7b9e7bf1 100644
--- a/reachability/peel.go
+++ b/reachability/peel.go
@@ -7,18 +7,22 @@ import (
)
func (r *Reachability) peelRootToDomain(id objectid.ObjectID, domain Domain) (objectid.ObjectID, error) {
- if err := validateDomain(domain); err != nil {
+ err := validateDomain(domain)
+ if err != nil {
return objectid.ObjectID{}, err
}
+
for {
ty, err := r.readHeaderType(id)
if err != nil {
return objectid.ObjectID{}, err
}
+
if ty != objecttype.TypeTag {
if domain == DomainCommits && ty != objecttype.TypeCommit {
return objectid.ObjectID{}, &ErrObjectType{OID: id, Got: ty, Want: objecttype.TypeCommit}
}
+
return id, nil
}
@@ -26,10 +30,12 @@ func (r *Reachability) peelRootToDomain(id objectid.ObjectID, domain Domain) (ob
if err != nil {
return objectid.ObjectID{}, err
}
+
tag, err := object.ParseTag(content, id.Algorithm())
if err != nil {
return objectid.ObjectID{}, err
}
+
id = tag.Target
}
}
diff --git a/reachability/reachability.go b/reachability/reachability.go
index 0bec055f..93bc840b 100644
--- a/reachability/reachability.go
+++ b/reachability/reachability.go
@@ -26,10 +26,12 @@ func (r *Reachability) IsAncestor(ancestor, descendant objectid.ObjectID) (bool,
if err != nil {
return false, err
}
+
descendantCommit, err := r.peelRootToDomain(descendant, DomainCommits)
if err != nil {
return false, err
}
+
if ancestorCommit == descendantCommit {
return true, nil
}
@@ -40,9 +42,12 @@ func (r *Reachability) IsAncestor(ancestor, descendant objectid.ObjectID) (bool,
return true, nil
}
}
- if err := walk.Err(); err != nil {
+
+ err = walk.Err()
+ if err != nil {
return false, err
}
+
return false, nil
}
@@ -53,6 +58,7 @@ func (r *Reachability) CheckConnected(domain Domain, haves, wants map[objectid.O
walk := r.Walk(domain, haves, wants)
for range walk.Seq() {
}
+
return walk.Err()
}
@@ -64,8 +70,11 @@ func (r *Reachability) Walk(domain Domain, haves, wants map[objectid.ObjectID]st
haves: haves,
wants: wants,
}
- if err := validateDomain(domain); err != nil {
+
+ err := validateDomain(domain)
+ if err != nil {
walk.err = err
}
+
return walk
}
diff --git a/reachability/unit_test.go b/reachability/unit_test.go
index ec177938..8f19cfcd 100644
--- a/reachability/unit_test.go
+++ b/reachability/unit_test.go
@@ -42,13 +42,16 @@ func (store *memStore) ReadBytesFull(id objectid.ObjectID) ([]byte, error) {
if !ok {
return nil, objectstore.ErrObjectNotFound
}
+
header, ok := objectheader.Encode(obj.ty, int64(len(obj.content)))
if !ok {
panic("failed to encode object header")
}
+
raw := make([]byte, len(header)+len(obj.content))
copy(raw, header)
copy(raw[len(header):], obj.content)
+
return raw, nil
}
@@ -57,7 +60,9 @@ func (store *memStore) ReadBytesContent(id objectid.ObjectID) (objecttype.Type,
if !ok {
return objecttype.TypeInvalid, nil, objectstore.ErrObjectNotFound
}
+
store.readBytesByObjectID[id]++
+
return obj.ty, append([]byte(nil), obj.content...), nil
}
@@ -66,6 +71,7 @@ func (store *memStore) ReadReaderFull(id objectid.ObjectID) (io.ReadCloser, erro
if err != nil {
return nil, err
}
+
return io.NopCloser(bytes.NewReader(raw)), nil
}
@@ -74,6 +80,7 @@ func (store *memStore) ReadReaderContent(id objectid.ObjectID) (objecttype.Type,
if err != nil {
return objecttype.TypeInvalid, 0, nil, err
}
+
return ty, int64(len(content)), io.NopCloser(bytes.NewReader(content)), nil
}
@@ -82,6 +89,7 @@ func (store *memStore) ReadSize(id objectid.ObjectID) (int64, error) {
if err != nil {
return 0, err
}
+
return size, nil
}
@@ -90,6 +98,7 @@ func (store *memStore) ReadHeader(id objectid.ObjectID) (objecttype.Type, int64,
if !ok {
return objecttype.TypeInvalid, 0, objectstore.ErrObjectNotFound
}
+
return obj.ty, int64(len(obj.content)), nil
}
@@ -102,7 +111,9 @@ func commitBody(tree objectid.ObjectID, parents ...objectid.ObjectID) []byte {
for _, parent := range parents {
buf = append(buf, fmt.Appendf(nil, "parent %s\n", parent.String())...)
}
+
buf = append(buf, []byte("\nmsg\n")...)
+
return buf
}
@@ -111,15 +122,19 @@ func tagBody(target objectid.ObjectID, targetType objecttype.Type) []byte {
if !ok {
panic("invalid tag target type")
}
+
return fmt.Appendf(nil, "object %s\ntype %s\ntag t\n\nmsg\n", target.String(), targetName)
}
func collectSeq(seq func(func(objectid.ObjectID) bool)) []objectid.ObjectID {
var out []objectid.ObjectID
+
seq(func(id objectid.ObjectID) bool {
out = append(out, id)
+
return true
})
+
return out
}
@@ -128,6 +143,7 @@ func toSet(ids []objectid.ObjectID) map[objectid.ObjectID]struct{} {
for _, id := range ids {
set[id] = struct{}{}
}
+
return set
}
@@ -149,12 +165,16 @@ func TestWalkDomainCommitsIncludesTagNodes(t *testing.T) {
r := reachability.New(store)
walk := r.Walk(reachability.DomainCommits, nil, map[objectid.ObjectID]struct{}{tag2: {}})
+
got := collectSeq(walk.Seq())
- if err := walk.Err(); err != nil {
+
+ err := walk.Err()
+ if err != nil {
t.Fatalf("walk.Err(): %v", err)
}
gotSet := toSet(got)
+
wantSet := map[objectid.ObjectID]struct{}{tag2: {}, tag1: {}, commit2: {}, commit1: {}}
if !maps.Equal(gotSet, wantSet) {
t.Fatalf("walk output mismatch: got %v, want %v", slices.Collect(maps.Keys(gotSet)), slices.Collect(maps.Keys(wantSet)))
@@ -177,10 +197,14 @@ func TestWalkExcludesHavesCompletely(t *testing.T) {
r := reachability.New(store)
walk := r.Walk(reachability.DomainCommits, map[objectid.ObjectID]struct{}{commit: {}}, map[objectid.ObjectID]struct{}{commit: {}})
+
got := collectSeq(walk.Seq())
- if err := walk.Err(); err != nil {
+
+ err := walk.Err()
+ if err != nil {
t.Fatalf("walk.Err(): %v", err)
}
+
if len(got) != 0 {
t.Fatalf("expected empty output, got %v", got)
}
@@ -203,14 +227,17 @@ func TestWalkDomainCommitsRejectsNonCommitRootAfterPeel(t *testing.T) {
r := reachability.New(store)
walk := r.Walk(reachability.DomainCommits, nil, map[objectid.ObjectID]struct{}{tag: {}})
_ = collectSeq(walk.Seq())
+
err := walk.Err()
if err == nil {
t.Fatal("expected error")
}
+
var typeErr *reachability.ErrObjectType
if !errors.As(err, &typeErr) {
t.Fatalf("expected ErrObjectType, got %T (%v)", err, err)
}
+
if typeErr.Got != objecttype.TypeTree || typeErr.Want != objecttype.TypeCommit {
t.Fatalf("unexpected type error: %+v", typeErr)
}
@@ -239,12 +266,16 @@ func TestWalkDomainCommitsHaveTagStopsTraversal(t *testing.T) {
map[objectid.ObjectID]struct{}{tag1: {}},
map[objectid.ObjectID]struct{}{tag2: {}},
)
+
got := collectSeq(walk.Seq())
- if err := walk.Err(); err != nil {
+
+ err := walk.Err()
+ if err != nil {
t.Fatalf("walk.Err(): %v", err)
}
gotSet := toSet(got)
+
wantSet := map[objectid.ObjectID]struct{}{tag2: {}}
if !maps.Equal(gotSet, wantSet) {
t.Fatalf("walk output mismatch: got %v, want %v", slices.Collect(maps.Keys(gotSet)), slices.Collect(maps.Keys(wantSet)))
@@ -276,16 +307,21 @@ func TestWalkDomainObjectsRecursesTreesAndSkipsBlobContentReads(t *testing.T) {
r := reachability.New(store)
walk := r.Walk(reachability.DomainObjects, nil, map[objectid.ObjectID]struct{}{commit: {}})
+
got := collectSeq(walk.Seq())
- if err := walk.Err(); err != nil {
+
+ err := walk.Err()
+ if err != nil {
t.Fatalf("walk.Err(): %v", err)
}
gotSet := toSet(got)
+
wantSet := map[objectid.ObjectID]struct{}{commit: {}, rootTree: {}, subtree: {}, blob1: {}, blob2: {}}
if !maps.Equal(gotSet, wantSet) {
t.Fatalf("walk output mismatch: got %v, want %v", slices.Collect(maps.Keys(gotSet)), slices.Collect(maps.Keys(wantSet)))
}
+
if store.readBytesByObjectID[blob1] != 0 || store.readBytesByObjectID[blob2] != 0 {
t.Fatalf("blob contents should not be read; counts: blob1=%d blob2=%d", store.readBytesByObjectID[blob1], store.readBytesByObjectID[blob2])
}
@@ -307,14 +343,17 @@ func TestCheckConnectedReturnsConcreteMissingObject(t *testing.T) {
commit := store.addObject(objecttype.TypeCommit, commitBody(tree, missingParent))
r := reachability.New(store)
+
err := r.CheckConnected(reachability.DomainCommits, nil, map[objectid.ObjectID]struct{}{commit: {}})
if err == nil {
t.Fatal("expected error")
}
+
var missing *reachability.ErrObjectMissing
if !errors.As(err, &missing) {
t.Fatalf("expected ErrObjectMissing, got %T (%v)", err, err)
}
+
if missing.OID != missingParent {
t.Fatalf("unexpected missing oid: got %s want %s", missing.OID, missingParent)
}
@@ -327,8 +366,11 @@ func TestWalkInvalidDomainReturnsPlainError(t *testing.T) {
testgit.ForEachAlgorithm(t, func(t *testing.T, algo objectid.Algorithm) { //nolint:thelper
r := reachability.New(newMemStore(algo))
walk := r.Walk(reachability.Domain(99), nil, nil)
+
_ = collectSeq(walk.Seq())
- if err := walk.Err(); err == nil {
+
+ err := walk.Err()
+ if err == nil {
t.Fatal("expected error")
}
})
@@ -357,10 +399,12 @@ func TestIsAncestor(t *testing.T) {
tag := store.addObject(objecttype.TypeTag, tagBody(c2, objecttype.TypeCommit))
r := reachability.New(store)
+
ok, err := r.IsAncestor(c1, tag)
if err != nil {
t.Fatalf("IsAncestor(c1, tag): %v", err)
}
+
if !ok {
t.Fatal("expected c1 to be ancestor of tag->c2")
}
@@ -369,6 +413,7 @@ func TestIsAncestor(t *testing.T) {
if err != nil {
t.Fatalf("IsAncestor(c3, c2): %v", err)
}
+
if ok {
t.Fatal("did not expect c3 to be ancestor of c2")
}
@@ -390,10 +435,12 @@ func TestIsAncestorRejectsNonCommitAfterPeel(t *testing.T) {
tagToTree := store.addObject(objecttype.TypeTag, tagBody(tree, objecttype.TypeTree))
r := reachability.New(store)
+
_, err := r.IsAncestor(commit, tagToTree)
if err == nil {
t.Fatal("expected error")
}
+
var typeErr *reachability.ErrObjectType
if !errors.As(err, &typeErr) {
t.Fatalf("expected ErrObjectType, got %T (%v)", err, err)
@@ -403,10 +450,12 @@ func TestIsAncestorRejectsNonCommitAfterPeel(t *testing.T) {
func mustSerializeTree(tb testing.TB, tree *object.Tree) []byte {
tb.Helper()
+
body, err := tree.SerializeWithoutHeader()
if err != nil {
tb.Fatalf("SerializeWithoutHeader: %v", err)
}
+
return body
}
@@ -415,8 +464,10 @@ func (store *memStore) addObject(ty objecttype.Type, body []byte) objectid.Objec
if !ok {
panic("failed to encode object header")
}
+
raw := append(append([]byte(nil), header...), body...)
id := store.algo.Sum(raw)
store.objects[id] = storeObject{ty: ty, content: append([]byte(nil), body...)}
+
return id
}
diff --git a/reachability/walk.go b/reachability/walk.go
index 2b592a45..89de19e8 100644
--- a/reachability/walk.go
+++ b/reachability/walk.go
@@ -26,18 +26,24 @@ func (walk *Walk) Seq() iter.Seq[objectid.ObjectID] {
if walk.seqUsed {
return func(yield func(objectid.ObjectID) bool) {
_ = yield
+
if walk.err == nil {
walk.err = errors.New("reachability: walk sequence already consumed")
}
}
}
+
walk.seqUsed = true
+
return func(yield func(objectid.ObjectID) bool) {
if walk.err != nil {
return
}
+
stack := walk.initialStack()
+
var err error
+
visited := make(map[objectid.ObjectID]struct{}, len(stack))
for len(stack) > 0 {
item := stack[len(stack)-1]
@@ -46,20 +52,26 @@ func (walk *Walk) Seq() iter.Seq[objectid.ObjectID] {
if containsOID(walk.haves, item.id) {
continue
}
+
if _, ok := visited[item.id]; ok {
continue
}
+
visited[item.id] = struct{}{}
var next []walkItem
+
next, err = walk.expand(item)
if err != nil {
walk.err = err
+
return
}
+
if !yield(item.id) {
return
}
+
stack = append(stack, next...)
}
}
@@ -79,10 +91,12 @@ func (walk *Walk) initialStack() []walkItem {
if len(walk.wants) == 0 {
return nil
}
+
stack := make([]walkItem, 0, len(walk.wants))
for want := range walk.wants {
stack = append(stack, walkItem{id: want, want: objecttype.TypeInvalid})
}
+
return stack
}
@@ -90,6 +104,7 @@ func (walk *Walk) expand(item walkItem) ([]walkItem, error) {
if walk.domain == DomainCommits {
return walk.expandCommits(item)
}
+
return walk.expandObjects(item)
}
@@ -98,35 +113,42 @@ func (walk *Walk) expandCommits(item walkItem) ([]walkItem, error) {
if err != nil {
return nil, err
}
+
switch ty {
case objecttype.TypeCommit:
content, err := walk.readBytesContent(item.id)
if err != nil {
return nil, err
}
+
commit, err := object.ParseCommit(content, item.id.Algorithm())
if err != nil {
return nil, err
}
+
next := make([]walkItem, 0, len(commit.Parents))
for _, parent := range commit.Parents {
next = append(next, walkItem{id: parent, want: objecttype.TypeInvalid})
}
+
return next, nil
case objecttype.TypeTag:
content, err := walk.readBytesContent(item.id)
if err != nil {
return nil, err
}
+
tag, err := object.ParseTag(content, item.id.Algorithm())
if err != nil {
return nil, err
}
+
return []walkItem{{id: tag.Target, want: objecttype.TypeInvalid}}, nil
case objecttype.TypeTree, objecttype.TypeBlob, objecttype.TypeInvalid,
objecttype.TypeFuture, objecttype.TypeOfsDelta, objecttype.TypeRefDelta:
return nil, &ErrObjectType{OID: item.id, Got: ty, Want: objecttype.TypeCommit}
}
+
return nil, fmt.Errorf("reachability: unreachable object type %d", ty)
}
@@ -135,6 +157,7 @@ func (walk *Walk) expandObjects(item walkItem) ([]walkItem, error) {
if err != nil {
return nil, err
}
+
if item.want != objecttype.TypeInvalid && ty != item.want {
return nil, &ErrObjectType{OID: item.id, Got: ty, Want: item.want}
}
@@ -147,25 +170,31 @@ func (walk *Walk) expandObjects(item walkItem) ([]walkItem, error) {
if err != nil {
return nil, err
}
+
commit, err := object.ParseCommit(content, item.id.Algorithm())
if err != nil {
return nil, err
}
+
next := make([]walkItem, 0, len(commit.Parents)+1)
+
next = append(next, walkItem{id: commit.Tree, want: objecttype.TypeTree})
for _, parent := range commit.Parents {
next = append(next, walkItem{id: parent, want: objecttype.TypeCommit})
}
+
return next, nil
case objecttype.TypeTree:
content, err := walk.readBytesContent(item.id)
if err != nil {
return nil, err
}
+
tree, err := object.ParseTree(content, item.id.Algorithm())
if err != nil {
return nil, err
}
+
next := make([]walkItem, 0, len(tree.Entries))
for _, entry := range tree.Entries {
switch entry.Mode {
@@ -177,19 +206,23 @@ func (walk *Walk) expandObjects(item walkItem) ([]walkItem, error) {
next = append(next, walkItem{id: entry.ID, want: objecttype.TypeBlob})
}
}
+
return next, nil
case objecttype.TypeTag:
content, err := walk.readBytesContent(item.id)
if err != nil {
return nil, err
}
+
tag, err := object.ParseTag(content, item.id.Algorithm())
if err != nil {
return nil, err
}
+
return []walkItem{{id: tag.Target, want: tag.TargetType}}, nil
case objecttype.TypeInvalid, objecttype.TypeFuture, objecttype.TypeOfsDelta, objecttype.TypeRefDelta:
return nil, &ErrObjectType{OID: item.id, Got: ty, Want: item.want}
}
+
return nil, fmt.Errorf("reachability: unreachable object type %d", ty)
}