aboutsummaryrefslogtreecommitdiff
path: root/reachability
diff options
context:
space:
mode:
Diffstat (limited to 'reachability')
-rw-r--r--reachability/reachability.go382
-rw-r--r--reachability/reachability_integration_test.go388
-rw-r--r--reachability/reachability_unit_test.go422
3 files changed, 1192 insertions, 0 deletions
diff --git a/reachability/reachability.go b/reachability/reachability.go
new file mode 100644
index 00000000..5ab41944
--- /dev/null
+++ b/reachability/reachability.go
@@ -0,0 +1,382 @@
+package reachability
+
+import (
+ "errors"
+ "fmt"
+ "iter"
+
+ "codeberg.org/lindenii/furgit/object"
+ "codeberg.org/lindenii/furgit/objectid"
+ "codeberg.org/lindenii/furgit/objectstore"
+ "codeberg.org/lindenii/furgit/objecttype"
+)
+
+// Domain specifies which graph edges are traversed.
+type Domain uint8
+
+const (
+ // DomainCommits traverses commit-parent edges and annotated-tag target edges.
+ DomainCommits Domain = iota
+ // DomainObjects traverses full commit/tree/blob objects.
+ DomainObjects
+)
+
+// Reachability provides graph traversal over objects in one object store.
+//
+// It is not safe for concurrent use.
+type Reachability struct {
+ Store objectstore.Store
+}
+
+// New builds a Reachability over one object store.
+func New(store objectstore.Store) *Reachability {
+ return &Reachability{Store: store}
+}
+
+// IsAncestor reports whether ancestor is reachable from descendant via commit
+// parent edges.
+//
+// Both inputs are peeled through annotated tags before commit traversal.
+func (r *Reachability) IsAncestor(ancestor, descendant objectid.ObjectID) (bool, error) {
+ ancestorCommit, err := r.peelRootToDomain(ancestor, DomainCommits)
+ if err != nil {
+ return false, err
+ }
+ descendantCommit, err := r.peelRootToDomain(descendant, DomainCommits)
+ if err != nil {
+ return false, err
+ }
+ if ancestorCommit == descendantCommit {
+ return true, nil
+ }
+
+ walk := r.Walk(DomainCommits, nil, map[objectid.ObjectID]struct{}{descendantCommit: {}})
+ for id := range walk.Seq() {
+ if id == ancestorCommit {
+ return true, nil
+ }
+ }
+ if err := walk.Err(); err != nil {
+ return false, err
+ }
+ return false, nil
+}
+
+// CheckConnected verifies that all objects reachable from wants (under the
+// selected domain) can be fully traversed without missing-object/type/parse
+// errors, excluding subgraphs rooted at haves.
+func (r *Reachability) CheckConnected(domain Domain, haves, wants map[objectid.ObjectID]struct{}) error {
+ walk := r.Walk(domain, haves, wants)
+ for range walk.Seq() {
+ }
+ return walk.Err()
+}
+
+// Walk creates one single-use traversal over the selected domain.
+func (r *Reachability) Walk(domain Domain, haves, wants map[objectid.ObjectID]struct{}) *Walk {
+ walk := &Walk{
+ reachability: r,
+ domain: domain,
+ haves: haves,
+ wants: wants,
+ }
+ if err := validateDomain(domain); err != nil {
+ walk.err = err
+ }
+ return walk
+}
+
+// ErrObjectMissing indicates that a referenced object is absent from the store.
+type ErrObjectMissing struct {
+ OID objectid.ObjectID
+}
+
+func (e *ErrObjectMissing) Error() string {
+ return fmt.Sprintf("reachability: missing object %s", e.OID)
+}
+
+// ErrObjectType indicates that a referenced object has a different type than
+// what traversal expected on that edge.
+type ErrObjectType struct {
+ OID objectid.ObjectID
+ Got objecttype.Type
+ Want objecttype.Type
+}
+
+func (e *ErrObjectType) Error() string {
+ gotName, gotOK := objecttype.Name(e.Got)
+ if !gotOK {
+ gotName = fmt.Sprintf("type(%d)", e.Got)
+ }
+ wantName, wantOK := objecttype.Name(e.Want)
+ if !wantOK {
+ wantName = fmt.Sprintf("type(%d)", e.Want)
+ }
+ return fmt.Sprintf("reachability: object %s has type %s, want %s", e.OID, gotName, wantName)
+}
+
+// Walk is one single-use iterator-style traversal.
+type Walk struct {
+ reachability *Reachability
+ domain Domain
+ haves map[objectid.ObjectID]struct{}
+ wants map[objectid.ObjectID]struct{}
+
+ seqUsed bool
+ err error
+}
+
+// Seq returns the traversal sequence. It is single-use.
+func (walk *Walk) Seq() iter.Seq[objectid.ObjectID] {
+ if walk.seqUsed {
+ return func(yield func(objectid.ObjectID) bool) {
+ _ = yield
+ if walk.err == nil {
+ walk.err = errors.New("reachability: walk sequence already consumed")
+ }
+ }
+ }
+ walk.seqUsed = true
+ return func(yield func(objectid.ObjectID) bool) {
+ if walk.err != nil {
+ return
+ }
+ stack := walk.initialStack()
+ var err error
+ visited := make(map[objectid.ObjectID]struct{}, len(stack))
+ for len(stack) > 0 {
+ item := stack[len(stack)-1]
+ stack = stack[:len(stack)-1]
+
+ if containsOID(walk.haves, item.id) {
+ continue
+ }
+ if _, ok := visited[item.id]; ok {
+ continue
+ }
+ visited[item.id] = struct{}{}
+
+ var next []walkItem
+ next, err = walk.expand(item)
+ if err != nil {
+ walk.err = err
+ return
+ }
+ if !yield(item.id) {
+ return
+ }
+ stack = append(stack, next...)
+ }
+ }
+}
+
+// Err returns the terminal error, if any, once Seq has been consumed.
+func (walk *Walk) Err() error {
+ return walk.err
+}
+
+type walkItem struct {
+ id objectid.ObjectID
+ want objecttype.Type
+}
+
+func (walk *Walk) initialStack() []walkItem {
+ if len(walk.wants) == 0 {
+ return nil
+ }
+ stack := make([]walkItem, 0, len(walk.wants))
+ for want := range walk.wants {
+ stack = append(stack, walkItem{id: want, want: objecttype.TypeInvalid})
+ }
+ return stack
+}
+
+func (walk *Walk) expand(item walkItem) ([]walkItem, error) {
+ if walk.domain == DomainCommits {
+ return walk.expandCommits(item)
+ }
+ return walk.expandObjects(item)
+}
+
+func (walk *Walk) expandCommits(item walkItem) ([]walkItem, error) {
+ ty, err := walk.readHeaderType(item.id)
+ if err != nil {
+ return nil, err
+ }
+ switch ty {
+ case objecttype.TypeCommit:
+ content, err := walk.readBytesContent(item.id)
+ if err != nil {
+ return nil, err
+ }
+ commit, err := object.ParseCommit(content, item.id.Algorithm())
+ if err != nil {
+ return nil, err
+ }
+ next := make([]walkItem, 0, len(commit.Parents))
+ for _, parent := range commit.Parents {
+ next = append(next, walkItem{id: parent, want: objecttype.TypeInvalid})
+ }
+ return next, nil
+ case objecttype.TypeTag:
+ content, err := walk.readBytesContent(item.id)
+ if err != nil {
+ return nil, err
+ }
+ tag, err := object.ParseTag(content, item.id.Algorithm())
+ if err != nil {
+ return nil, err
+ }
+ return []walkItem{{id: tag.Target, want: objecttype.TypeInvalid}}, nil
+ case objecttype.TypeTree, objecttype.TypeBlob, objecttype.TypeInvalid,
+ objecttype.TypeFuture, objecttype.TypeOfsDelta, objecttype.TypeRefDelta:
+ return nil, &ErrObjectType{OID: item.id, Got: ty, Want: objecttype.TypeCommit}
+ }
+ return nil, fmt.Errorf("reachability: unreachable object type %d", ty)
+}
+
+func (walk *Walk) expandObjects(item walkItem) ([]walkItem, error) {
+ ty, err := walk.readHeaderType(item.id)
+ if err != nil {
+ return nil, err
+ }
+ if item.want != objecttype.TypeInvalid && ty != item.want {
+ return nil, &ErrObjectType{OID: item.id, Got: ty, Want: item.want}
+ }
+
+ switch ty {
+ case objecttype.TypeBlob:
+ return nil, nil
+ case objecttype.TypeCommit:
+ content, err := walk.readBytesContent(item.id)
+ if err != nil {
+ return nil, err
+ }
+ commit, err := object.ParseCommit(content, item.id.Algorithm())
+ if err != nil {
+ return nil, err
+ }
+ next := make([]walkItem, 0, len(commit.Parents)+1)
+ next = append(next, walkItem{id: commit.Tree, want: objecttype.TypeTree})
+ for _, parent := range commit.Parents {
+ next = append(next, walkItem{id: parent, want: objecttype.TypeCommit})
+ }
+ return next, nil
+ case objecttype.TypeTree:
+ content, err := walk.readBytesContent(item.id)
+ if err != nil {
+ return nil, err
+ }
+ tree, err := object.ParseTree(content, item.id.Algorithm())
+ if err != nil {
+ return nil, err
+ }
+ next := make([]walkItem, 0, len(tree.Entries))
+ for _, entry := range tree.Entries {
+ switch entry.Mode {
+ case object.FileModeGitlink:
+ continue
+ case object.FileModeDir:
+ next = append(next, walkItem{id: entry.ID, want: objecttype.TypeTree})
+ case object.FileModeRegular, object.FileModeExecutable, object.FileModeSymlink:
+ next = append(next, walkItem{id: entry.ID, want: objecttype.TypeBlob})
+ }
+ }
+ return next, nil
+ case objecttype.TypeTag:
+ content, err := walk.readBytesContent(item.id)
+ if err != nil {
+ return nil, err
+ }
+ tag, err := object.ParseTag(content, item.id.Algorithm())
+ if err != nil {
+ return nil, err
+ }
+ return []walkItem{{id: tag.Target, want: tag.TargetType}}, nil
+ case objecttype.TypeInvalid, objecttype.TypeFuture, objecttype.TypeOfsDelta, objecttype.TypeRefDelta:
+ return nil, &ErrObjectType{OID: item.id, Got: ty, Want: item.want}
+ }
+ return nil, fmt.Errorf("reachability: unreachable object type %d", ty)
+}
+
+func (r *Reachability) peelRootToDomain(id objectid.ObjectID, domain Domain) (objectid.ObjectID, error) {
+ if err := validateDomain(domain); err != nil {
+ return objectid.ObjectID{}, err
+ }
+ for {
+ ty, err := r.readHeaderType(id)
+ if err != nil {
+ return objectid.ObjectID{}, err
+ }
+ if ty != objecttype.TypeTag {
+ if domain == DomainCommits && ty != objecttype.TypeCommit {
+ return objectid.ObjectID{}, &ErrObjectType{OID: id, Got: ty, Want: objecttype.TypeCommit}
+ }
+ return id, nil
+ }
+
+ content, err := r.readBytesContent(id)
+ if err != nil {
+ return objectid.ObjectID{}, err
+ }
+ tag, err := object.ParseTag(content, id.Algorithm())
+ if err != nil {
+ return objectid.ObjectID{}, err
+ }
+ id = tag.Target
+ }
+}
+
+func validateDomain(domain Domain) error {
+ switch domain {
+ case DomainCommits, DomainObjects:
+ return nil
+ default:
+ return fmt.Errorf("reachability: invalid domain %d", domain)
+ }
+}
+
+func containsOID(set map[objectid.ObjectID]struct{}, id objectid.ObjectID) bool {
+ if len(set) == 0 {
+ return false
+ }
+ _, ok := set[id]
+ return ok
+}
+
+// The following helpers exist because we don't have unified error handling across the entire project.
+// This will be fixed later.
+
+func (walk *Walk) readHeaderType(id objectid.ObjectID) (objecttype.Type, error) {
+ return walk.reachability.readHeaderType(id)
+}
+
+func (r *Reachability) readHeaderType(id objectid.ObjectID) (objecttype.Type, error) {
+ ty, _, err := r.Store.ReadHeader(id)
+ if err != nil {
+ if errors.Is(err, objectstore.ErrObjectNotFound) {
+ return objecttype.TypeInvalid, &ErrObjectMissing{OID: id}
+ }
+ return objecttype.TypeInvalid, err
+ }
+ return ty, nil
+}
+
+func (walk *Walk) readBytesContent(id objectid.ObjectID) ([]byte, error) {
+ content, err := walk.reachability.readBytesContent(id)
+ if err != nil {
+ return nil, err
+ }
+ return content, nil
+}
+
+func (r *Reachability) readBytesContent(id objectid.ObjectID) ([]byte, error) {
+ _, content, err := r.Store.ReadBytesContent(id)
+ if err != nil {
+ if errors.Is(err, objectstore.ErrObjectNotFound) {
+ return nil, &ErrObjectMissing{OID: id}
+ }
+ return nil, err
+ }
+ return content, nil
+}
diff --git a/reachability/reachability_integration_test.go b/reachability/reachability_integration_test.go
new file mode 100644
index 00000000..10668006
--- /dev/null
+++ b/reachability/reachability_integration_test.go
@@ -0,0 +1,388 @@
+package reachability_test
+
+import (
+ "errors"
+ "fmt"
+ "maps"
+ "os"
+ "path/filepath"
+ "slices"
+ "strings"
+ "testing"
+
+ "codeberg.org/lindenii/furgit/internal/testgit"
+ "codeberg.org/lindenii/furgit/objectid"
+ "codeberg.org/lindenii/furgit/reachability"
+ "codeberg.org/lindenii/furgit/repository"
+)
+
+func TestWalkCommitsMatchesGitRevList(t *testing.T) {
+ t.Parallel()
+
+ testgit.ForEachAlgorithm(t, func(t *testing.T, algo objectid.Algorithm) { //nolint:thelper
+ testRepo := testgit.NewRepo(t, testgit.RepoOptions{
+ ObjectFormat: algo,
+ Bare: true,
+ RefFormat: "files",
+ })
+
+ _, tree1 := testRepo.MakeSingleFileTree(t, "base.txt", []byte("base\n"))
+ base := testRepo.CommitTree(t, tree1, "base")
+
+ _, tree2 := testRepo.MakeSingleFileTree(t, "left.txt", []byte("left\n"))
+ left := testRepo.CommitTree(t, tree2, "left", base)
+
+ _, tree3 := testRepo.MakeSingleFileTree(t, "right.txt", []byte("right\n"))
+ right := testRepo.CommitTree(t, tree3, "right", base)
+
+ _, tree4 := testRepo.MakeSingleFileTree(t, "merge.txt", []byte("merge\n"))
+ merge := testRepo.CommitTree(t, tree4, "merge", left, right)
+
+ tag1 := testRepo.TagAnnotated(t, "v1", merge, "v1")
+ tag2 := testRepo.TagAnnotated(t, "v2", tag1, "v2")
+
+ r := openReachabilityFromTestRepo(t, testRepo)
+ walk := r.Walk(
+ reachability.DomainCommits,
+ nil,
+ map[objectid.ObjectID]struct{}{merge: {}},
+ )
+ got := oidSetFromSeq(walk.Seq())
+ if err := walk.Err(); err != nil {
+ t.Fatalf("walk.Err(): %v", err)
+ }
+
+ want := gitRevListSet(t, testRepo, false, []objectid.ObjectID{merge}, nil)
+ if !maps.Equal(got, want) {
+ t.Fatalf("commit walk mismatch:\n got=%v\nwant=%v", sortedOIDStrings(got), sortedOIDStrings(want))
+ }
+
+ peelWalk := r.Walk(
+ reachability.DomainCommits,
+ nil,
+ map[objectid.ObjectID]struct{}{tag2: {}},
+ )
+ peelGot := oidSetFromSeq(peelWalk.Seq())
+ if err := peelWalk.Err(); err != nil {
+ t.Fatalf("peelWalk.Err(): %v", err)
+ }
+ wantWithTags := maps.Clone(want)
+ wantWithTags[tag1] = struct{}{}
+ wantWithTags[tag2] = struct{}{}
+ if !maps.Equal(peelGot, wantWithTags) {
+ t.Fatalf("tag-root commit walk mismatch:\n got=%v\nwant=%v", sortedOIDStrings(peelGot), sortedOIDStrings(wantWithTags))
+ }
+ })
+}
+
+func TestWalkObjectsMatchesGitRevListObjects(t *testing.T) {
+ t.Parallel()
+
+ testgit.ForEachAlgorithm(t, func(t *testing.T, algo objectid.Algorithm) { //nolint:thelper
+ testRepo := testgit.NewRepo(t, testgit.RepoOptions{
+ ObjectFormat: algo,
+ Bare: true,
+ RefFormat: "files",
+ })
+
+ aBlob := testRepo.HashObject(t, "blob", []byte("a\n"))
+ bBlob := testRepo.HashObject(t, "blob", []byte("b\n"))
+ nestedTree := testRepo.Mktree(t, fmt.Sprintf("100644 blob %s\tb.txt\n", bBlob))
+ rootTree := testRepo.Mktree(t,
+ fmt.Sprintf("100644 blob %s\ta.txt\n040000 tree %s\tdir\n", aBlob, nestedTree),
+ )
+ base := testRepo.CommitTree(t, rootTree, "base")
+
+ cBlob := testRepo.HashObject(t, "blob", []byte("c\n"))
+ tree2 := testRepo.Mktree(t, fmt.Sprintf("100644 blob %s\tc.txt\n", cBlob))
+ head := testRepo.CommitTree(t, tree2, "head", base)
+ tag := testRepo.TagAnnotated(t, "objtag", head, "objtag")
+
+ r := openReachabilityFromTestRepo(t, testRepo)
+ walk := r.Walk(
+ reachability.DomainObjects,
+ nil,
+ map[objectid.ObjectID]struct{}{head: {}},
+ )
+ got := oidSetFromSeq(walk.Seq())
+ if err := walk.Err(); err != nil {
+ t.Fatalf("walk.Err(): %v", err)
+ }
+
+ want := gitRevListSet(t, testRepo, true, []objectid.ObjectID{head}, nil)
+ if !maps.Equal(got, want) {
+ t.Fatalf("object walk mismatch:\n got=%v\nwant=%v", sortedOIDStrings(got), sortedOIDStrings(want))
+ }
+
+ peelWalk := r.Walk(
+ reachability.DomainObjects,
+ nil,
+ map[objectid.ObjectID]struct{}{tag: {}},
+ )
+ peelGot := oidSetFromSeq(peelWalk.Seq())
+ if err := peelWalk.Err(); err != nil {
+ t.Fatalf("peelWalk.Err(): %v", err)
+ }
+ wantFromTag := gitRevListSet(t, testRepo, true, []objectid.ObjectID{tag}, nil)
+ if !maps.Equal(peelGot, wantFromTag) {
+ t.Fatalf("tag-root object walk mismatch:\n got=%v\nwant=%v", sortedOIDStrings(peelGot), sortedOIDStrings(wantFromTag))
+ }
+
+ walkWithHave := r.Walk(
+ reachability.DomainObjects,
+ map[objectid.ObjectID]struct{}{base: {}},
+ map[objectid.ObjectID]struct{}{head: {}},
+ )
+ withHave := oidSetFromSeq(walkWithHave.Seq())
+ if err := walkWithHave.Err(); err != nil {
+ t.Fatalf("walkWithHave.Err(): %v", err)
+ }
+ if _, ok := withHave[base]; ok {
+ t.Fatalf("walk output unexpectedly contains have commit %s", base)
+ }
+ })
+}
+
+func TestIsAncestorMatchesGitMergeBase(t *testing.T) {
+ t.Parallel()
+
+ testgit.ForEachAlgorithm(t, func(t *testing.T, algo objectid.Algorithm) { //nolint:thelper
+ testRepo := testgit.NewRepo(t, testgit.RepoOptions{
+ ObjectFormat: algo,
+ Bare: true,
+ RefFormat: "files",
+ })
+
+ _, tree1 := testRepo.MakeSingleFileTree(t, "one.txt", []byte("one\n"))
+ c1 := testRepo.CommitTree(t, tree1, "c1")
+
+ _, tree2 := testRepo.MakeSingleFileTree(t, "two.txt", []byte("two\n"))
+ c2 := testRepo.CommitTree(t, tree2, "c2", c1)
+
+ _, tree3 := testRepo.MakeSingleFileTree(t, "three.txt", []byte("three\n"))
+ c3 := testRepo.CommitTree(t, tree3, "c3", c2)
+
+ tag := testRepo.TagAnnotated(t, "tip", c2, "tip")
+
+ r := openReachabilityFromTestRepo(t, testRepo)
+
+ got, err := r.IsAncestor(c1, tag)
+ if err != nil {
+ t.Fatalf("IsAncestor(c1, tag): %v", err)
+ }
+ if want := gitMergeBaseIsAncestor(t, testRepo, c1, c2); got != want {
+ t.Fatalf("IsAncestor(c1, tag)=%v, want %v", got, want)
+ }
+
+ got, err = r.IsAncestor(c3, c2)
+ if err != nil {
+ t.Fatalf("IsAncestor(c3, c2): %v", err)
+ }
+ if want := gitMergeBaseIsAncestor(t, testRepo, c3, c2); got != want {
+ t.Fatalf("IsAncestor(c3, c2)=%v, want %v", got, want)
+ }
+ })
+}
+
+func TestCheckConnectedMissingObject(t *testing.T) {
+ t.Parallel()
+
+ testgit.ForEachAlgorithm(t, func(t *testing.T, algo objectid.Algorithm) { //nolint:thelper
+ testRepo := testgit.NewRepo(t, testgit.RepoOptions{
+ ObjectFormat: algo,
+ Bare: true,
+ RefFormat: "files",
+ })
+
+ _, treeID, commitID := testRepo.MakeCommit(t, "missing")
+ if err := os.Remove(looseObjectPath(testRepo.Dir(), treeID)); err != nil {
+ t.Fatalf("remove tree object: %v", err)
+ }
+
+ r := openReachabilityFromTestRepo(t, testRepo)
+ err := r.CheckConnected(
+ reachability.DomainObjects,
+ nil,
+ map[objectid.ObjectID]struct{}{commitID: {}},
+ )
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ var missing *reachability.ErrObjectMissing
+ if !errors.As(err, &missing) {
+ t.Fatalf("expected ErrObjectMissing, got %T (%v)", err, err)
+ }
+ if missing.OID != treeID {
+ t.Fatalf("missing oid = %s, want %s", missing.OID, treeID)
+ }
+ })
+}
+
+func TestWalkOnPackedOnlyRepo(t *testing.T) {
+ t.Parallel()
+
+ testgit.ForEachAlgorithm(t, func(t *testing.T, algo objectid.Algorithm) { //nolint:thelper
+ testRepo := testgit.NewRepo(t, testgit.RepoOptions{
+ ObjectFormat: algo,
+ Bare: true,
+ RefFormat: "files",
+ })
+
+ _, tree1 := testRepo.MakeSingleFileTree(t, "one.txt", []byte("one\n"))
+ c1 := testRepo.CommitTree(t, tree1, "one")
+ _, tree2 := testRepo.MakeSingleFileTree(t, "two.txt", []byte("two\n"))
+ c2 := testRepo.CommitTree(t, tree2, "two", c1)
+ testRepo.UpdateRef(t, "refs/heads/main", c2)
+ testRepo.SymbolicRef(t, "HEAD", "refs/heads/main")
+
+ testRepo.Repack(t, "-ad")
+ testRepo.Run(t, "prune-packed")
+
+ assertPackedOnly(t, testRepo.Dir())
+
+ r := openReachabilityFromTestRepo(t, testRepo)
+ walk := r.Walk(
+ reachability.DomainCommits,
+ nil,
+ map[objectid.ObjectID]struct{}{c2: {}},
+ )
+ got := oidSetFromSeq(walk.Seq())
+ if err := walk.Err(); err != nil {
+ t.Fatalf("walk.Err(): %v", err)
+ }
+ if _, ok := got[c2]; !ok {
+ t.Fatalf("walk output missing HEAD commit %s", c2)
+ }
+ if _, ok := got[c1]; !ok {
+ t.Fatalf("walk output missing parent commit %s", c1)
+ }
+ })
+}
+
+func openReachabilityFromTestRepo(t *testing.T, testRepo *testgit.TestRepo) *reachability.Reachability {
+ t.Helper()
+ root, err := os.OpenRoot(testRepo.Dir())
+ if err != nil {
+ t.Fatalf("os.OpenRoot: %v", err)
+ }
+ t.Cleanup(func() { _ = root.Close() })
+
+ repo, err := repository.Open(root)
+ if err != nil {
+ t.Fatalf("repository.Open: %v", err)
+ }
+ t.Cleanup(func() { _ = repo.Close() })
+
+ return reachability.New(repo.Objects())
+}
+
+func oidSetFromSeq(seq func(func(objectid.ObjectID) bool)) map[objectid.ObjectID]struct{} {
+ out := make(map[objectid.ObjectID]struct{})
+ seq(func(id objectid.ObjectID) bool {
+ out[id] = struct{}{}
+ return true
+ })
+ return out
+}
+
+func gitRevListSet(
+ t *testing.T,
+ testRepo *testgit.TestRepo,
+ includeObjects bool,
+ wants []objectid.ObjectID,
+ haves []objectid.ObjectID,
+) map[objectid.ObjectID]struct{} {
+ t.Helper()
+
+ args := []string{"rev-list"}
+ if includeObjects {
+ args = append(args, "--objects")
+ }
+ for _, want := range wants {
+ args = append(args, want.String())
+ }
+ if len(haves) > 0 {
+ args = append(args, "--not")
+ for _, have := range haves {
+ args = append(args, have.String())
+ }
+ }
+
+ out := testRepo.Run(t, args...)
+ set := make(map[objectid.ObjectID]struct{})
+ for line := range strings.SplitSeq(strings.TrimSpace(out), "\n") {
+ line = strings.TrimSpace(line)
+ if line == "" {
+ continue
+ }
+ tok := line
+ if i := strings.IndexByte(tok, ' '); i >= 0 {
+ tok = tok[:i]
+ }
+ id, err := objectid.ParseHex(testRepo.Algorithm(), tok)
+ if err != nil {
+ t.Fatalf("parse rev-list oid %q: %v", tok, err)
+ }
+ set[id] = struct{}{}
+ }
+ return set
+}
+
+func gitMergeBaseIsAncestor(t *testing.T, testRepo *testgit.TestRepo, a, b objectid.ObjectID) bool {
+ t.Helper()
+ // testgit.Run fatals on non-zero status, so we compare merge-base output.
+ mb := testRepo.Run(t, "merge-base", a.String(), b.String())
+ return mb == a.String()
+}
+
+func sortedOIDStrings(set map[objectid.ObjectID]struct{}) []string {
+ out := make([]string, 0, len(set))
+ for id := range set {
+ out = append(out, id.String())
+ }
+ slices.Sort(out)
+ return out
+}
+
+func looseObjectPath(repoDir string, id objectid.ObjectID) string {
+ hex := id.String()
+ return filepath.Join(repoDir, "objects", hex[:2], hex[2:])
+}
+
+func assertPackedOnly(t *testing.T, repoDir string) {
+ t.Helper()
+ objectsDir := filepath.Join(repoDir, "objects")
+
+ entries, err := os.ReadDir(objectsDir)
+ if err != nil {
+ t.Fatalf("ReadDir(objects): %v", err)
+ }
+ for _, entry := range entries {
+ name := entry.Name()
+ if name == "pack" || name == "info" {
+ continue
+ }
+ if len(name) == 2 && isHexDirName(name) {
+ subEntries, err := os.ReadDir(filepath.Join(objectsDir, name))
+ if err != nil {
+ t.Fatalf("ReadDir(objects/%s): %v", name, err)
+ }
+ if len(subEntries) != 0 {
+ t.Fatalf("found loose objects in %s", filepath.Join(objectsDir, name))
+ }
+ }
+ }
+}
+
+func isHexDirName(name string) bool {
+ if len(name) != 2 {
+ return false
+ }
+ for i := range 2 {
+ c := name[i]
+ if (c < '0' || c > '9') && (c < 'a' || c > 'f') {
+ return false
+ }
+ }
+ return true
+}
diff --git a/reachability/reachability_unit_test.go b/reachability/reachability_unit_test.go
new file mode 100644
index 00000000..ec177938
--- /dev/null
+++ b/reachability/reachability_unit_test.go
@@ -0,0 +1,422 @@
+package reachability_test
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "maps"
+ "slices"
+ "testing"
+
+ "codeberg.org/lindenii/furgit/internal/testgit"
+ "codeberg.org/lindenii/furgit/object"
+ "codeberg.org/lindenii/furgit/objectheader"
+ "codeberg.org/lindenii/furgit/objectid"
+ "codeberg.org/lindenii/furgit/objectstore"
+ "codeberg.org/lindenii/furgit/objecttype"
+ "codeberg.org/lindenii/furgit/reachability"
+)
+
+type storeObject struct {
+ ty objecttype.Type
+ content []byte
+}
+
+type memStore struct {
+ algo objectid.Algorithm
+ objects map[objectid.ObjectID]storeObject
+ readBytesByObjectID map[objectid.ObjectID]int
+}
+
+func newMemStore(algo objectid.Algorithm) *memStore {
+ return &memStore{
+ algo: algo,
+ objects: make(map[objectid.ObjectID]storeObject),
+ readBytesByObjectID: make(map[objectid.ObjectID]int),
+ }
+}
+
+func (store *memStore) ReadBytesFull(id objectid.ObjectID) ([]byte, error) {
+ obj, ok := store.objects[id]
+ if !ok {
+ return nil, objectstore.ErrObjectNotFound
+ }
+ header, ok := objectheader.Encode(obj.ty, int64(len(obj.content)))
+ if !ok {
+ panic("failed to encode object header")
+ }
+ raw := make([]byte, len(header)+len(obj.content))
+ copy(raw, header)
+ copy(raw[len(header):], obj.content)
+ return raw, nil
+}
+
+func (store *memStore) ReadBytesContent(id objectid.ObjectID) (objecttype.Type, []byte, error) {
+ obj, ok := store.objects[id]
+ if !ok {
+ return objecttype.TypeInvalid, nil, objectstore.ErrObjectNotFound
+ }
+ store.readBytesByObjectID[id]++
+ return obj.ty, append([]byte(nil), obj.content...), nil
+}
+
+func (store *memStore) ReadReaderFull(id objectid.ObjectID) (io.ReadCloser, error) {
+ raw, err := store.ReadBytesFull(id)
+ if err != nil {
+ return nil, err
+ }
+ return io.NopCloser(bytes.NewReader(raw)), nil
+}
+
+func (store *memStore) ReadReaderContent(id objectid.ObjectID) (objecttype.Type, int64, io.ReadCloser, error) {
+ ty, content, err := store.ReadBytesContent(id)
+ if err != nil {
+ return objecttype.TypeInvalid, 0, nil, err
+ }
+ return ty, int64(len(content)), io.NopCloser(bytes.NewReader(content)), nil
+}
+
+func (store *memStore) ReadSize(id objectid.ObjectID) (int64, error) {
+ _, size, err := store.ReadHeader(id)
+ if err != nil {
+ return 0, err
+ }
+ return size, nil
+}
+
+func (store *memStore) ReadHeader(id objectid.ObjectID) (objecttype.Type, int64, error) {
+ obj, ok := store.objects[id]
+ if !ok {
+ return objecttype.TypeInvalid, 0, objectstore.ErrObjectNotFound
+ }
+ return obj.ty, int64(len(obj.content)), nil
+}
+
+func (store *memStore) Close() error {
+ return nil
+}
+
+func commitBody(tree objectid.ObjectID, parents ...objectid.ObjectID) []byte {
+ buf := fmt.Appendf(nil, "tree %s\n", tree.String())
+ for _, parent := range parents {
+ buf = append(buf, fmt.Appendf(nil, "parent %s\n", parent.String())...)
+ }
+ buf = append(buf, []byte("\nmsg\n")...)
+ return buf
+}
+
+func tagBody(target objectid.ObjectID, targetType objecttype.Type) []byte {
+ targetName, ok := objecttype.Name(targetType)
+ if !ok {
+ panic("invalid tag target type")
+ }
+ return fmt.Appendf(nil, "object %s\ntype %s\ntag t\n\nmsg\n", target.String(), targetName)
+}
+
+func collectSeq(seq func(func(objectid.ObjectID) bool)) []objectid.ObjectID {
+ var out []objectid.ObjectID
+ seq(func(id objectid.ObjectID) bool {
+ out = append(out, id)
+ return true
+ })
+ return out
+}
+
+func toSet(ids []objectid.ObjectID) map[objectid.ObjectID]struct{} {
+ set := make(map[objectid.ObjectID]struct{}, len(ids))
+ for _, id := range ids {
+ set[id] = struct{}{}
+ }
+ return set
+}
+
+func TestWalkDomainCommitsIncludesTagNodes(t *testing.T) {
+ t.Parallel()
+
+ testgit.ForEachAlgorithm(t, func(t *testing.T, algo objectid.Algorithm) { //nolint:thelper
+ store := newMemStore(algo)
+ blob := store.addObject(objecttype.TypeBlob, []byte("blob\n"))
+ tree := store.addObject(objecttype.TypeTree, mustSerializeTree(t, &object.Tree{Entries: []object.TreeEntry{{
+ Mode: object.FileModeRegular,
+ Name: []byte("f"),
+ ID: blob,
+ }}}))
+ commit1 := store.addObject(objecttype.TypeCommit, commitBody(tree))
+ commit2 := store.addObject(objecttype.TypeCommit, commitBody(tree, commit1))
+ tag1 := store.addObject(objecttype.TypeTag, tagBody(commit2, objecttype.TypeCommit))
+ tag2 := store.addObject(objecttype.TypeTag, tagBody(tag1, objecttype.TypeTag))
+
+ r := reachability.New(store)
+ walk := r.Walk(reachability.DomainCommits, nil, map[objectid.ObjectID]struct{}{tag2: {}})
+ got := collectSeq(walk.Seq())
+ if err := walk.Err(); err != nil {
+ t.Fatalf("walk.Err(): %v", err)
+ }
+
+ gotSet := toSet(got)
+ wantSet := map[objectid.ObjectID]struct{}{tag2: {}, tag1: {}, commit2: {}, commit1: {}}
+ if !maps.Equal(gotSet, wantSet) {
+ t.Fatalf("walk output mismatch: got %v, want %v", slices.Collect(maps.Keys(gotSet)), slices.Collect(maps.Keys(wantSet)))
+ }
+ })
+}
+
+func TestWalkExcludesHavesCompletely(t *testing.T) {
+ t.Parallel()
+
+ testgit.ForEachAlgorithm(t, func(t *testing.T, algo objectid.Algorithm) { //nolint:thelper
+ store := newMemStore(algo)
+ blob := store.addObject(objecttype.TypeBlob, []byte("blob\n"))
+ tree := store.addObject(objecttype.TypeTree, mustSerializeTree(t, &object.Tree{Entries: []object.TreeEntry{{
+ Mode: object.FileModeRegular,
+ Name: []byte("f"),
+ ID: blob,
+ }}}))
+ commit := store.addObject(objecttype.TypeCommit, commitBody(tree))
+
+ r := reachability.New(store)
+ walk := r.Walk(reachability.DomainCommits, map[objectid.ObjectID]struct{}{commit: {}}, map[objectid.ObjectID]struct{}{commit: {}})
+ got := collectSeq(walk.Seq())
+ if err := walk.Err(); err != nil {
+ t.Fatalf("walk.Err(): %v", err)
+ }
+ if len(got) != 0 {
+ t.Fatalf("expected empty output, got %v", got)
+ }
+ })
+}
+
+func TestWalkDomainCommitsRejectsNonCommitRootAfterPeel(t *testing.T) {
+ t.Parallel()
+
+ testgit.ForEachAlgorithm(t, func(t *testing.T, algo objectid.Algorithm) { //nolint:thelper
+ store := newMemStore(algo)
+ blob := store.addObject(objecttype.TypeBlob, []byte("blob\n"))
+ tree := store.addObject(objecttype.TypeTree, mustSerializeTree(t, &object.Tree{Entries: []object.TreeEntry{{
+ Mode: object.FileModeRegular,
+ Name: []byte("f"),
+ ID: blob,
+ }}}))
+ tag := store.addObject(objecttype.TypeTag, tagBody(tree, objecttype.TypeTree))
+
+ r := reachability.New(store)
+ walk := r.Walk(reachability.DomainCommits, nil, map[objectid.ObjectID]struct{}{tag: {}})
+ _ = collectSeq(walk.Seq())
+ err := walk.Err()
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ var typeErr *reachability.ErrObjectType
+ if !errors.As(err, &typeErr) {
+ t.Fatalf("expected ErrObjectType, got %T (%v)", err, err)
+ }
+ if typeErr.Got != objecttype.TypeTree || typeErr.Want != objecttype.TypeCommit {
+ t.Fatalf("unexpected type error: %+v", typeErr)
+ }
+ })
+}
+
+func TestWalkDomainCommitsHaveTagStopsTraversal(t *testing.T) {
+ t.Parallel()
+
+ testgit.ForEachAlgorithm(t, func(t *testing.T, algo objectid.Algorithm) { //nolint:thelper
+ store := newMemStore(algo)
+ blob := store.addObject(objecttype.TypeBlob, []byte("blob\n"))
+ tree := store.addObject(objecttype.TypeTree, mustSerializeTree(t, &object.Tree{Entries: []object.TreeEntry{{
+ Mode: object.FileModeRegular,
+ Name: []byte("f"),
+ ID: blob,
+ }}}))
+ commit1 := store.addObject(objecttype.TypeCommit, commitBody(tree))
+ commit2 := store.addObject(objecttype.TypeCommit, commitBody(tree, commit1))
+ tag1 := store.addObject(objecttype.TypeTag, tagBody(commit2, objecttype.TypeCommit))
+ tag2 := store.addObject(objecttype.TypeTag, tagBody(tag1, objecttype.TypeTag))
+
+ r := reachability.New(store)
+ walk := r.Walk(
+ reachability.DomainCommits,
+ map[objectid.ObjectID]struct{}{tag1: {}},
+ map[objectid.ObjectID]struct{}{tag2: {}},
+ )
+ got := collectSeq(walk.Seq())
+ if err := walk.Err(); err != nil {
+ t.Fatalf("walk.Err(): %v", err)
+ }
+
+ gotSet := toSet(got)
+ wantSet := map[objectid.ObjectID]struct{}{tag2: {}}
+ if !maps.Equal(gotSet, wantSet) {
+ t.Fatalf("walk output mismatch: got %v, want %v", slices.Collect(maps.Keys(gotSet)), slices.Collect(maps.Keys(wantSet)))
+ }
+ })
+}
+
+func TestWalkDomainObjectsRecursesTreesAndSkipsBlobContentReads(t *testing.T) {
+ t.Parallel()
+
+ testgit.ForEachAlgorithm(t, func(t *testing.T, algo objectid.Algorithm) { //nolint:thelper
+ store := newMemStore(algo)
+
+ blob1 := store.addObject(objecttype.TypeBlob, []byte("b1\n"))
+ blob2 := store.addObject(objecttype.TypeBlob, []byte("b2\n"))
+ gitlinkTarget := store.algo.Sum([]byte("external-submodule"))
+
+ subtree := store.addObject(objecttype.TypeTree, mustSerializeTree(t, &object.Tree{Entries: []object.TreeEntry{{
+ Mode: object.FileModeRegular,
+ Name: []byte("nested"),
+ ID: blob2,
+ }}}))
+ rootTree := store.addObject(objecttype.TypeTree, mustSerializeTree(t, &object.Tree{Entries: []object.TreeEntry{
+ {Mode: object.FileModeRegular, Name: []byte("a"), ID: blob1},
+ {Mode: object.FileModeDir, Name: []byte("dir"), ID: subtree},
+ {Mode: object.FileModeGitlink, Name: []byte("submodule"), ID: gitlinkTarget},
+ }}))
+ commit := store.addObject(objecttype.TypeCommit, commitBody(rootTree))
+
+ r := reachability.New(store)
+ walk := r.Walk(reachability.DomainObjects, nil, map[objectid.ObjectID]struct{}{commit: {}})
+ got := collectSeq(walk.Seq())
+ if err := walk.Err(); err != nil {
+ t.Fatalf("walk.Err(): %v", err)
+ }
+
+ gotSet := toSet(got)
+ wantSet := map[objectid.ObjectID]struct{}{commit: {}, rootTree: {}, subtree: {}, blob1: {}, blob2: {}}
+ if !maps.Equal(gotSet, wantSet) {
+ t.Fatalf("walk output mismatch: got %v, want %v", slices.Collect(maps.Keys(gotSet)), slices.Collect(maps.Keys(wantSet)))
+ }
+ if store.readBytesByObjectID[blob1] != 0 || store.readBytesByObjectID[blob2] != 0 {
+ t.Fatalf("blob contents should not be read; counts: blob1=%d blob2=%d", store.readBytesByObjectID[blob1], store.readBytesByObjectID[blob2])
+ }
+ })
+}
+
+func TestCheckConnectedReturnsConcreteMissingObject(t *testing.T) {
+ t.Parallel()
+
+ testgit.ForEachAlgorithm(t, func(t *testing.T, algo objectid.Algorithm) { //nolint:thelper
+ store := newMemStore(algo)
+ blob := store.addObject(objecttype.TypeBlob, []byte("blob\n"))
+ tree := store.addObject(objecttype.TypeTree, mustSerializeTree(t, &object.Tree{Entries: []object.TreeEntry{{
+ Mode: object.FileModeRegular,
+ Name: []byte("f"),
+ ID: blob,
+ }}}))
+ missingParent := store.algo.Sum([]byte("missing-parent"))
+ commit := store.addObject(objecttype.TypeCommit, commitBody(tree, missingParent))
+
+ r := reachability.New(store)
+ err := r.CheckConnected(reachability.DomainCommits, nil, map[objectid.ObjectID]struct{}{commit: {}})
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ var missing *reachability.ErrObjectMissing
+ if !errors.As(err, &missing) {
+ t.Fatalf("expected ErrObjectMissing, got %T (%v)", err, err)
+ }
+ if missing.OID != missingParent {
+ t.Fatalf("unexpected missing oid: got %s want %s", missing.OID, missingParent)
+ }
+ })
+}
+
+func TestWalkInvalidDomainReturnsPlainError(t *testing.T) {
+ t.Parallel()
+
+ testgit.ForEachAlgorithm(t, func(t *testing.T, algo objectid.Algorithm) { //nolint:thelper
+ r := reachability.New(newMemStore(algo))
+ walk := r.Walk(reachability.Domain(99), nil, nil)
+ _ = collectSeq(walk.Seq())
+ if err := walk.Err(); err == nil {
+ t.Fatal("expected error")
+ }
+ })
+}
+
+func TestIsAncestor(t *testing.T) {
+ t.Parallel()
+
+ testgit.ForEachAlgorithm(t, func(t *testing.T, algo objectid.Algorithm) { //nolint:thelper
+ store := newMemStore(algo)
+ blob := store.addObject(objecttype.TypeBlob, []byte("blob\n"))
+ tree := store.addObject(objecttype.TypeTree, mustSerializeTree(t, &object.Tree{Entries: []object.TreeEntry{{
+ Mode: object.FileModeRegular,
+ Name: []byte("f"),
+ ID: blob,
+ }}}))
+ c1 := store.addObject(objecttype.TypeCommit, commitBody(tree))
+ c2 := store.addObject(objecttype.TypeCommit, commitBody(tree, c1))
+ otherBlob := store.addObject(objecttype.TypeBlob, []byte("other-blob\n"))
+ otherTree := store.addObject(objecttype.TypeTree, mustSerializeTree(t, &object.Tree{Entries: []object.TreeEntry{{
+ Mode: object.FileModeRegular,
+ Name: []byte("g"),
+ ID: otherBlob,
+ }}}))
+ c3 := store.addObject(objecttype.TypeCommit, commitBody(otherTree))
+ tag := store.addObject(objecttype.TypeTag, tagBody(c2, objecttype.TypeCommit))
+
+ r := reachability.New(store)
+ ok, err := r.IsAncestor(c1, tag)
+ if err != nil {
+ t.Fatalf("IsAncestor(c1, tag): %v", err)
+ }
+ if !ok {
+ t.Fatal("expected c1 to be ancestor of tag->c2")
+ }
+
+ ok, err = r.IsAncestor(c3, c2)
+ if err != nil {
+ t.Fatalf("IsAncestor(c3, c2): %v", err)
+ }
+ if ok {
+ t.Fatal("did not expect c3 to be ancestor of c2")
+ }
+ })
+}
+
+func TestIsAncestorRejectsNonCommitAfterPeel(t *testing.T) {
+ t.Parallel()
+
+ testgit.ForEachAlgorithm(t, func(t *testing.T, algo objectid.Algorithm) { //nolint:thelper
+ store := newMemStore(algo)
+ blob := store.addObject(objecttype.TypeBlob, []byte("blob\n"))
+ tree := store.addObject(objecttype.TypeTree, mustSerializeTree(t, &object.Tree{Entries: []object.TreeEntry{{
+ Mode: object.FileModeRegular,
+ Name: []byte("f"),
+ ID: blob,
+ }}}))
+ commit := store.addObject(objecttype.TypeCommit, commitBody(tree))
+ tagToTree := store.addObject(objecttype.TypeTag, tagBody(tree, objecttype.TypeTree))
+
+ r := reachability.New(store)
+ _, err := r.IsAncestor(commit, tagToTree)
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ var typeErr *reachability.ErrObjectType
+ if !errors.As(err, &typeErr) {
+ t.Fatalf("expected ErrObjectType, got %T (%v)", err, err)
+ }
+ })
+}
+
+func mustSerializeTree(tb testing.TB, tree *object.Tree) []byte {
+ tb.Helper()
+ body, err := tree.SerializeWithoutHeader()
+ if err != nil {
+ tb.Fatalf("SerializeWithoutHeader: %v", err)
+ }
+ return body
+}
+
+func (store *memStore) addObject(ty objecttype.Type, body []byte) objectid.ObjectID {
+ header, ok := objectheader.Encode(ty, int64(len(body)))
+ if !ok {
+ panic("failed to encode object header")
+ }
+ raw := append(append([]byte(nil), header...), body...)
+ id := store.algo.Sum(raw)
+ store.objects[id] = storeObject{ty: ty, content: append([]byte(nil), body...)}
+ return id
+}