diff options
| author | 2026-03-04 08:26:56 +0800 | |
|---|---|---|
| committer | 2026-03-04 08:59:53 +0800 | |
| commit | ab7501be34032fb9e5c48726a68ae90a917af9eb (patch) | |
| tree | 20d005647569befea8133e953c3270e8fd2a2a5b | |
| parent | *: gofumpt (diff) | |
| signature | No signature | |
*: Lint
129 files changed, 2049 insertions, 214 deletions
diff --git a/.golangci.yaml b/.golangci.yaml index 6e9611f1..9abc2f16 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -10,26 +10,22 @@ linters: - lll # poor standard - ireturn # not an issue - perfsprint # silly fmt.Errorf vs errors.New suggestion - - noinlineerr # not an issue - gosmopolitan # completely normal to have CJK and such in tests - gochecknoglobals # unlikely to be introduce accidentally and are usually intentional - nonamedreturns # named returns are often good for clarity - errname # ErrXXX is better than XXXError + - wsl # outdated, use wsl_v5 instead + - varnamelen # it's rather reasonable to have counters like i, even when it spans quite a bit + - gocyclo # cyclomatic metrics aren't that good + - cyclop # cyclomatic metrics aren't that good + - godox # TODO/etc comments are allowed in our codebase + - funlen # long functions are fine + - wrapcheck # rules around interface-return methods are a bit silly + - exhaustruct # tmp: should fix... but too annoying at the moment - - wsl_v5 # tmp - - wsl # tmp - err113 # tmp: will enable when we properly use defined errors - - gochecknoinits # tmp - - nlreturn # tmp - - cyclop # tmp - - gocognit # tmp - - varnamelen # tmp - - funlen # tmp - - godox # tmp - - nestif # tmp - - maintidx # tmp - - gocyclo # tmp - - wrapcheck # unsure + - gocognit # tmp: should consider sometime + settings: gosec: excludes: diff --git a/cmd/show-object/main.go b/cmd/show-object/main.go index 6d27ffad..b3f5a9cd 100644 --- a/cmd/show-object/main.go +++ b/cmd/show-object/main.go @@ -16,13 +16,15 @@ import ( func main() { repoPath := flag.String("r", "", "path to git dir (.git or bare repo root)") name := flag.String("h", "", "reference name or object id") + flag.Parse() if *repoPath == "" || *name == "" { log.Fatal("must provide -r <repo> and -h <ref-or-object-id>") } - if err := run(repoPath, name); err != nil { + err := run(repoPath, name) + if err != nil { log.Fatalf("run: %v", err) } } @@ -32,6 +34,7 @@ func run(repoPath, name *string) error { if err != nil { return fmt.Errorf("open repo root: %w", err) } + defer func() { _ = root.Close() }() repo, err := repository.Open(root) @@ -42,17 +45,21 @@ func run(repoPath, name *string) error { id, err := resolveInput(repo, *name) if err != nil { _ = repo.Close() + return fmt.Errorf("resolve %q: %w", *name, err) } stored, err := repo.ReadStored(id) if err != nil { _ = repo.Close() + return fmt.Errorf("read object %s: %w", id, err) } printStored(stored) - if err := repo.Close(); err != nil { + + err = repo.Close() + if err != nil { return fmt.Errorf("close repository: %w", err) } @@ -60,13 +67,16 @@ func run(repoPath, name *string) error { } func resolveInput(repo *repository.Repository, input string) (objectid.ObjectID, error) { - if id, err := objectid.ParseHex(repo.Algorithm(), strings.TrimSpace(input)); err == nil { + id, err := objectid.ParseHex(repo.Algorithm(), strings.TrimSpace(input)) + if err == nil { return id, nil } + resolved, err := repo.Refs().ResolveFully(input) if err != nil { return objectid.ObjectID{}, err } + return resolved.ID, nil } @@ -75,10 +85,12 @@ func printStored(stored objectstored.StoredObject) { id := stored.ID() ty := stored.Object().ObjectType() + tyName, ok := objecttype.Name(ty) if !ok { tyName = fmt.Sprintf("type %d", ty) } + fmt.Fprintf(&b, "id: %s\n", id) fmt.Fprintf(&b, "type: %s\n", tyName) @@ -90,29 +102,36 @@ func printStored(stored objectstored.StoredObject) { case *objectstored.StoredTree: tree := stored.Tree() fmt.Fprintf(&b, "entries: %d\n", len(tree.Entries)) + for _, entry := range tree.Entries { fmt.Fprintf(&b, "%06o %s\t%s\n", entry.Mode, entry.ID, entry.Name) } case *objectstored.StoredCommit: commit := stored.Commit() fmt.Fprintf(&b, "tree: %s\n", commit.Tree) + for _, parent := range commit.Parents { fmt.Fprintf(&b, "parent: %s\n", parent) } + fmt.Fprintf(&b, "author: %s <%s>\n", commit.Author.Name, commit.Author.Email) fmt.Fprintf(&b, "committer: %s <%s>\n", commit.Committer.Name, commit.Committer.Email) fmt.Fprintf(&b, "message:\n%s\n", string(commit.Message)) case *objectstored.StoredTag: tag := stored.Tag() + targetTy, ok := objecttype.Name(tag.TargetType) if !ok { targetTy = fmt.Sprintf("type %d", tag.TargetType) } + fmt.Fprintf(&b, "target: %s (%s)\n", tag.Target, targetTy) fmt.Fprintf(&b, "name: %s\n", tag.Name) + if tag.Tagger != nil { fmt.Fprintf(&b, "tagger: %s <%s>\n", tag.Tagger.Name, tag.Tagger.Email) } + fmt.Fprintf(&b, "message:\n%s\n", string(tag.Message)) default: fmt.Fprintf(&b, "%#v\n", stored.Object()) diff --git a/config/config.go b/config/config.go index a4853990..b761dce5 100644 --- a/config/config.go +++ b/config/config.go @@ -121,6 +121,7 @@ func ParseConfig(r io.Reader) (*Config, error) { reader: bufio.NewReader(r), lineNum: 1, } + return parser.parse() } @@ -128,6 +129,7 @@ func ParseConfig(r io.Reader) (*Config, error) { // and key. func (c *Config) Lookup(section, subsection, key string) LookupResult { section = strings.ToLower(section) + key = strings.ToLower(key) for _, entry := range c.entries { if strings.EqualFold(entry.Section, section) && @@ -139,6 +141,7 @@ func (c *Config) Lookup(section, subsection, key string) LookupResult { } } } + return LookupResult{Kind: ValueMissing} } @@ -147,7 +150,9 @@ func (c *Config) Lookup(section, subsection, key string) LookupResult { func (c *Config) LookupAll(section, subsection, key string) []LookupResult { section = strings.ToLower(section) key = strings.ToLower(key) + var values []LookupResult + for _, entry := range c.entries { if strings.EqualFold(entry.Section, section) && entry.Subsection == subsection && @@ -158,6 +163,7 @@ func (c *Config) LookupAll(section, subsection, key string) []LookupResult { }) } } + return values } @@ -166,6 +172,7 @@ func (c *Config) LookupAll(section, subsection, key string) []LookupResult { func (c *Config) Entries() []ConfigEntry { result := make([]ConfigEntry, len(c.entries)) copy(result, c.entries) + return result } @@ -181,7 +188,8 @@ type configParser struct { func (p *configParser) parse() (*Config, error) { cfg := &Config{} - if err := p.skipBOM(); err != nil { + err := p.skipBOM() + if err != nil { return nil, err } @@ -190,6 +198,7 @@ func (p *configParser) parse() (*Config, error) { if errors.Is(err, io.EOF) { break } + if err != nil { return nil, err } @@ -201,26 +210,33 @@ func (p *configParser) parse() (*Config, error) { // Comments if ch == '#' || ch == ';' { - if err := p.skipToEOL(); err != nil && !errors.Is(err, io.EOF) { + err := p.skipToEOL() + if err != nil && !errors.Is(err, io.EOF) { return nil, err } + continue } // Section header if ch == '[' { - if err := p.parseSection(); err != nil { + err := p.parseSection() + if err != nil { return nil, fmt.Errorf("furgit: config: line %d: %w", p.lineNum, err) } + continue } // Key-value pair if isLetter(ch) { p.unreadChar(ch) - if err := p.parseKeyValue(cfg); err != nil { + + err := p.parseKeyValue(cfg) + if err != nil { return nil, fmt.Errorf("furgit: config: line %d: %w", p.lineNum, err) } + continue } @@ -233,6 +249,7 @@ func (p *configParser) parse() (*Config, error) { func (p *configParser) nextChar() (byte, error) { if p.hasPeeked { p.hasPeeked = false + return p.peeked, nil } @@ -260,6 +277,7 @@ func (p *configParser) nextChar() (byte, error) { func (p *configParser) unreadChar(ch byte) { p.peeked = ch + p.hasPeeked = true if ch == '\n' && p.lineNum > 1 { p.lineNum-- @@ -271,36 +289,48 @@ func (p *configParser) skipBOM() error { if errors.Is(err, io.EOF) { return nil } + if err != nil { return err } + if first != 0xef { _ = p.reader.UnreadByte() + return nil } + second, err := p.reader.ReadByte() if err != nil { if errors.Is(err, io.EOF) { _ = p.reader.UnreadByte() + return nil } + return err } + third, err := p.reader.ReadByte() if err != nil { if errors.Is(err, io.EOF) { _ = p.reader.UnreadByte() _ = p.reader.UnreadByte() + return nil } + return err } + if second == 0xbb && third == 0xbf { return nil } + _ = p.reader.UnreadByte() _ = p.reader.UnreadByte() _ = p.reader.UnreadByte() + return nil } @@ -310,6 +340,7 @@ func (p *configParser) skipToEOL() error { if err != nil { return err } + if ch == '\n' { return nil } @@ -330,8 +361,10 @@ func (p *configParser) parseSection() error { if !isValidSection(section) { return fmt.Errorf("invalid section name: %q", section) } + p.currentSection = strings.ToLower(section) p.currentSubsec = "" + return nil } @@ -353,15 +386,18 @@ func (p *configParser) parseExtendedSection(sectionName *bytes.Buffer) error { if err != nil { return errors.New("unexpected EOF in section header") } + if !isWhitespace(ch) { if ch != '"' { return errors.New("expected quote after section name") } + break } } var subsec bytes.Buffer + for { ch, err := p.nextChar() if err != nil { @@ -381,9 +417,11 @@ func (p *configParser) parseExtendedSection(sectionName *bytes.Buffer) error { if err != nil { return errors.New("unexpected EOF after backslash in subsection") } + if next == '\n' { return errors.New("newline after backslash in subsection") } + subsec.WriteByte(next) } else { subsec.WriteByte(ch) @@ -394,6 +432,7 @@ func (p *configParser) parseExtendedSection(sectionName *bytes.Buffer) error { if err != nil { return errors.New("unexpected EOF after subsection") } + if ch != ']' { return fmt.Errorf("expected ']' after subsection, got %q", ch) } @@ -405,6 +444,7 @@ func (p *configParser) parseExtendedSection(sectionName *bytes.Buffer) error { p.currentSection = strings.ToLower(section) p.currentSubsec = subsec.String() + return nil } @@ -414,17 +454,20 @@ func (p *configParser) parseKeyValue(cfg *Config) error { } var key bytes.Buffer + for { ch, err := p.nextChar() if errors.Is(err, io.EOF) { break } + if err != nil { return err } if ch == '=' || ch == '\n' || isSpace(ch) { p.unreadChar(ch) + break } @@ -439,6 +482,7 @@ func (p *configParser) parseKeyValue(cfg *Config) error { if len(keyStr) == 0 { return errors.New("empty key name") } + if !isLetter(keyStr[0]) { return errors.New("key must start with a letter") } @@ -453,8 +497,10 @@ func (p *configParser) parseKeyValue(cfg *Config) error { Kind: ValueValueless, Value: "", }) + return nil } + if err != nil { return err } @@ -467,13 +513,16 @@ func (p *configParser) parseKeyValue(cfg *Config) error { Kind: ValueValueless, Value: "", }) + return nil } if ch == '#' || ch == ';' { - if err := p.skipToEOL(); err != nil && !errors.Is(err, io.EOF) { + err := p.skipToEOL() + if err != nil && !errors.Is(err, io.EOF) { return err } + cfg.entries = append(cfg.entries, ConfigEntry{ Section: p.currentSection, Subsection: p.currentSubsec, @@ -481,6 +530,7 @@ func (p *configParser) parseKeyValue(cfg *Config) error { Kind: ValueValueless, Value: "", }) + return nil } @@ -510,9 +560,12 @@ func (p *configParser) parseKeyValue(cfg *Config) error { } func (p *configParser) parseValue() (string, error) { - var value bytes.Buffer - var inQuote bool - var inComment bool + var ( + value bytes.Buffer + inQuote bool + inComment bool + ) + trimLen := 0 for { @@ -521,11 +574,14 @@ func (p *configParser) parseValue() (string, error) { if inQuote { return "", errors.New("unexpected EOF in quoted value") } + if trimLen > 0 { return truncateAtNUL(value.String()[:trimLen]), nil } + return truncateAtNUL(value.String()), nil } + if err != nil { return "", err } @@ -534,9 +590,11 @@ func (p *configParser) parseValue() (string, error) { if inQuote { return "", errors.New("newline in quoted value") } + if trimLen > 0 { return truncateAtNUL(value.String()[:trimLen]), nil } + return truncateAtNUL(value.String()), nil } @@ -548,14 +606,17 @@ func (p *configParser) parseValue() (string, error) { if trimLen == 0 && value.Len() > 0 { trimLen = value.Len() } + if value.Len() > 0 { value.WriteByte(ch) } + continue } if !inQuote && (ch == '#' || ch == ';') { inComment = true + continue } @@ -568,6 +629,7 @@ func (p *configParser) parseValue() (string, error) { if errors.Is(err, io.EOF) { return "", errors.New("unexpected EOF after backslash") } + if err != nil { return "", err } @@ -586,11 +648,13 @@ func (p *configParser) parseValue() (string, error) { default: return "", fmt.Errorf("invalid escape sequence: \\%c", next) } + continue } if ch == '"' { inQuote = !inQuote + continue } @@ -602,12 +666,14 @@ func isValidSection(s string) bool { if len(s) == 0 { return false } + for i := range len(s) { ch := s[i] if !isLetter(ch) && !isDigit(ch) && ch != '-' && ch != '.' { return false } } + return true } @@ -632,6 +698,7 @@ func parseBool(value string) (bool, error) { if err != nil { return false, fmt.Errorf("invalid boolean value %q", value) } + return n != 0, nil } @@ -640,6 +707,7 @@ func parseInt32(value string) (int32, error) { if err != nil { return 0, err } + return intconv.Int64ToInt32(n64) } @@ -648,6 +716,7 @@ func parseInt(value string) (int, error) { if err != nil { return 0, err } + return int(n64), nil } @@ -667,6 +736,7 @@ func parseInt64WithMax(value string, maxValue int64) (int64, error) { numPart := trimmed factor := int64(1) + if last := trimmed[len(trimmed)-1]; last == 'k' || last == 'K' || last == 'm' || last == 'M' || last == 'g' || last == 'G' { switch toLower(last) { case 'k': @@ -676,8 +746,10 @@ func parseInt64WithMax(value string, maxValue int64) (int64, error) { case 'g': factor = 1024 * 1024 * 1024 } + numPart = trimmed[:len(trimmed)-1] } + if numPart == "" { return 0, errors.New("missing integer value") } @@ -689,14 +761,17 @@ func parseInt64WithMax(value string, maxValue int64) (int64, error) { intMax := maxValue intMin := -maxValue - 1 + if n > 0 && n > intMax/factor { return 0, errors.New("integer overflow") } + if n < 0 && n < intMin/factor { return 0, errors.New("integer overflow") } n *= factor + return n, nil } @@ -706,6 +781,7 @@ func truncateAtNUL(value string) string { return value[:i] } } + return value } @@ -729,5 +805,6 @@ func toLower(ch byte) byte { if ch >= 'A' && ch <= 'Z' { return ch + ('a' - 'A') } + return ch } diff --git a/config/config_test.go b/config/config_test.go index 416222e7..a87b2d7a 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -15,15 +15,18 @@ import ( func openConfig(t *testing.T, testRepo *testgit.TestRepo) *os.File { t.Helper() + cfgFile, err := os.Open(filepath.Join(testRepo.Dir(), "config")) if err != nil { t.Fatalf("failed to open config: %v", err) } + return cfgFile } func gitConfigGet(t *testing.T, testRepo *testgit.TestRepo, key string) string { t.Helper() + return testRepo.Run(t, "config", "--get", key) } @@ -31,11 +34,13 @@ func gitConfigGetE(testRepo *testgit.TestRepo, key string) (string, error) { //nolint:noctx cmd := exec.Command("git", "config", "--get", key) //#nosec G204 cmd.Dir = testRepo.Dir() + cmd.Env = append(os.Environ(), "GIT_CONFIG_GLOBAL=/dev/null", "GIT_CONFIG_SYSTEM=/dev/null", ) out, err := cmd.CombinedOutput() + return strings.TrimSpace(string(out)), err } @@ -44,15 +49,18 @@ func lookupValue(cfg *config.Config, section, subsection, key string) string { if result.Kind == config.ValueMissing { return "" } + return result.Value } func lookupAllValues(cfg *config.Config, section, subsection, key string) []string { results := cfg.LookupAll(section, subsection, key) + values := make([]string, 0, len(results)) for _, result := range results { values = append(values, result.Value) } + return values } @@ -66,6 +74,7 @@ func TestConfigAgainstGit(t *testing.T) { testRepo.Run(t, "config", "user.email", "jane@example.org") cfgFile := openConfig(t, testRepo) + defer func() { _ = cfgFile.Close() }() cfg, err := config.ParseConfig(cfgFile) @@ -76,12 +85,15 @@ func TestConfigAgainstGit(t *testing.T) { if got := lookupValue(cfg, "core", "", "bare"); got != "true" { t.Errorf("core.bare: got %q, want %q", got, "true") } + if got := lookupValue(cfg, "core", "", "filemode"); got != "false" { t.Errorf("core.filemode: got %q, want %q", got, "false") } + if got := lookupValue(cfg, "user", "", "name"); got != "Jane Doe" { t.Errorf("user.name: got %q, want %q", got, "Jane Doe") } + if got := lookupValue(cfg, "user", "", "email"); got != "jane@example.org" { t.Errorf("user.email: got %q, want %q", got, "jane@example.org") } @@ -96,6 +108,7 @@ func TestConfigSubsectionAgainstGit(t *testing.T) { testRepo.Run(t, "config", "remote.origin.fetch", "+refs/heads/*:refs/remotes/origin/*") cfgFile := openConfig(t, testRepo) + defer func() { _ = cfgFile.Close() }() cfg, err := config.ParseConfig(cfgFile) @@ -106,6 +119,7 @@ func TestConfigSubsectionAgainstGit(t *testing.T) { if got := lookupValue(cfg, "remote", "origin", "url"); got != "https://example.org/repo.git" { t.Errorf("remote.origin.url: got %q, want %q", got, "https://example.org/repo.git") } + if got := lookupValue(cfg, "remote", "origin", "fetch"); got != "+refs/heads/*:refs/remotes/origin/*" { t.Errorf("remote.origin.fetch: got %q, want %q", got, "+refs/heads/*:refs/remotes/origin/*") } @@ -121,6 +135,7 @@ func TestConfigMultiValueAgainstGit(t *testing.T) { testRepo.Run(t, "config", "--add", "remote.origin.fetch", "+refs/tags/*:refs/tags/*") cfgFile := openConfig(t, testRepo) + defer func() { _ = cfgFile.Close() }() cfg, err := config.ParseConfig(cfgFile) @@ -157,6 +172,7 @@ func TestConfigCaseInsensitiveAgainstGit(t *testing.T) { gitVerifyFilemode := gitConfigGet(t, testRepo, "core.filemode") cfgFile := openConfig(t, testRepo) + defer func() { _ = cfgFile.Close() }() cfg, err := config.ParseConfig(cfgFile) @@ -167,9 +183,11 @@ func TestConfigCaseInsensitiveAgainstGit(t *testing.T) { if got := lookupValue(cfg, "core", "", "bare"); got != gitVerifyBare { t.Errorf("core.bare: got %q, want %q (from git)", got, gitVerifyBare) } + if got := lookupValue(cfg, "CORE", "", "BARE"); got != gitVerifyBare { t.Errorf("CORE.BARE: got %q, want %q (from git)", got, gitVerifyBare) } + if got := lookupValue(cfg, "core", "", "filemode"); got != gitVerifyFilemode { t.Errorf("core.filemode: got %q, want %q (from git)", got, gitVerifyFilemode) } @@ -186,6 +204,7 @@ func TestConfigBooleanAgainstGit(t *testing.T) { testRepo.Run(t, "config", "test.flag4", "no") cfgFile := openConfig(t, testRepo) + defer func() { _ = cfgFile.Close() }() cfg, err := config.ParseConfig(cfgFile) @@ -213,7 +232,9 @@ func TestConfigBooleanAgainstGit(t *testing.T) { func TestConfigLookupKindsAndBool(t *testing.T) { t.Parallel() + cfgText := "[test]\nnovalue\nempty =\ntruthy = yes\nnumeric = -2\nleadspace = \" 1\"\nleadtab = \"\t-2\"\nksuffix = 1k\nhex = 0x10\nmaxi32 = 2147483647\ntoobig = 2147483648\ntoosmall = -2147483649\nbadnum = \" 2x\"\n" + cfg, err := config.ParseConfig(strings.NewReader(cfgText)) if err != nil { t.Fatalf("ParseConfig failed: %v", err) @@ -223,6 +244,7 @@ func TestConfigLookupKindsAndBool(t *testing.T) { if novalue.Kind != config.ValueValueless { t.Fatalf("novalue kind: got %v, want %v", novalue.Kind, config.ValueValueless) } + novalueBool, err := novalue.Bool() if err != nil || !novalueBool { t.Fatalf("novalue bool: got (%v, %v), want (true, nil)", novalueBool, err) @@ -232,6 +254,7 @@ func TestConfigLookupKindsAndBool(t *testing.T) { if empty.Kind != config.ValueString || empty.Value != "" { t.Fatalf("empty: got (%v, %q), want (%v, %q)", empty.Kind, empty.Value, config.ValueString, "") } + emptyBool, err := empty.Bool() if err != nil || emptyBool { t.Fatalf("empty bool: got (%v, %v), want (false, nil)", emptyBool, err) @@ -241,39 +264,52 @@ func TestConfigLookupKindsAndBool(t *testing.T) { if err != nil || !truthyBool { t.Fatalf("truthy bool: got (%v, %v), want (true, nil)", truthyBool, err) } + numericBool, err := cfg.Lookup("test", "", "numeric").Bool() if err != nil || !numericBool { t.Fatalf("numeric bool: got (%v, %v), want (true, nil)", numericBool, err) } + leadspaceBool, err := cfg.Lookup("test", "", "leadspace").Bool() if err != nil || !leadspaceBool { t.Fatalf("leadspace bool: got (%v, %v), want (true, nil)", leadspaceBool, err) } + leadtabBool, err := cfg.Lookup("test", "", "leadtab").Bool() if err != nil || !leadtabBool { t.Fatalf("leadtab bool: got (%v, %v), want (true, nil)", leadtabBool, err) } + ksuffixBool, err := cfg.Lookup("test", "", "ksuffix").Bool() if err != nil || !ksuffixBool { t.Fatalf("ksuffix bool: got (%v, %v), want (true, nil)", ksuffixBool, err) } + maxi32Bool, err := cfg.Lookup("test", "", "maxi32").Bool() if err != nil || !maxi32Bool { t.Fatalf("maxi32 bool: got (%v, %v), want (true, nil)", maxi32Bool, err) } - if _, err := cfg.Lookup("test", "", "toobig").Bool(); err == nil { + + _, err = cfg.Lookup("test", "", "toobig").Bool() + if err == nil { t.Fatal("toobig bool: expected error") } - if _, err := cfg.Lookup("test", "", "toosmall").Bool(); err == nil { + + _, err = cfg.Lookup("test", "", "toosmall").Bool() + if err == nil { t.Fatal("toosmall bool: expected error") } - if _, err := cfg.Lookup("test", "", "badnum").Bool(); err == nil { + + _, err = cfg.Lookup("test", "", "badnum").Bool() + if err == nil { t.Fatal("badnum bool: expected error") } - if _, err := novalue.String(); err == nil { + _, err = novalue.String() + if err == nil { t.Fatal("novalue string: expected error") } + emptyString, err := empty.String() if err != nil || emptyString != "" { t.Fatalf("empty string: got (%q, %v), want (%q, nil)", emptyString, err, "") @@ -283,15 +319,19 @@ func TestConfigLookupKindsAndBool(t *testing.T) { if err != nil || numericInt != -2 { t.Fatalf("numeric int: got (%v, %v), want (-2, nil)", numericInt, err) } + ksuffixInt, err := cfg.Lookup("test", "", "ksuffix").Int() if err != nil || ksuffixInt != 1024 { t.Fatalf("ksuffix int: got (%v, %v), want (1024, nil)", ksuffixInt, err) } + hexInt64, err := cfg.Lookup("test", "", "hex").Int64() if err != nil || hexInt64 != 16 { t.Fatalf("hex int64: got (%v, %v), want (16, nil)", hexInt64, err) } - if _, err := cfg.Lookup("test", "", "badnum").Int(); err == nil { + + _, err = cfg.Lookup("test", "", "badnum").Int() + if err == nil { t.Fatal("badnum int: expected error") } @@ -299,13 +339,19 @@ func TestConfigLookupKindsAndBool(t *testing.T) { if missing.Kind != config.ValueMissing { t.Fatalf("missing kind: got %v, want %v", missing.Kind, config.ValueMissing) } - if _, err := missing.Bool(); err == nil { + + _, err = missing.Bool() + if err == nil { t.Fatal("missing bool: expected error") } - if _, err := missing.Int(); err == nil { + + _, err = missing.Int() + if err == nil { t.Fatal("missing int: expected error") } - if _, err := missing.String(); err == nil { + + _, err = missing.String() + if err == nil { t.Fatal("missing string: expected error") } } @@ -320,6 +366,7 @@ func TestConfigComplexValuesAgainstGit(t *testing.T) { testRepo.Run(t, "config", "test.number", "12345") cfgFile := openConfig(t, testRepo) + defer func() { _ = cfgFile.Close() }() cfg, err := config.ParseConfig(cfgFile) @@ -346,6 +393,7 @@ func TestConfigEntriesAgainstGit(t *testing.T) { testRepo.Run(t, "config", "user.name", "Test User") cfgFile := openConfig(t, testRepo) + defer func() { _ = cfgFile.Close() }() cfg, err := config.ParseConfig(cfgFile) @@ -359,11 +407,13 @@ func TestConfigEntriesAgainstGit(t *testing.T) { } found := make(map[string]bool) + for _, entry := range entries { key := entry.Section + "." + entry.Key if entry.Subsection != "" { key = entry.Section + "." + entry.Subsection + "." + entry.Key } + found[key] = true gitValue := gitConfigGet(t, testRepo, key) @@ -376,6 +426,7 @@ func TestConfigEntriesAgainstGit(t *testing.T) { func TestConfigErrorCases(t *testing.T) { t.Parallel() + tests := []struct { name string config string @@ -405,7 +456,9 @@ func TestConfigErrorCases(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() + r := strings.NewReader(tt.config) + _, err := config.ParseConfig(r) if err == nil { t.Errorf("expected error for %s", tt.name) @@ -420,7 +473,9 @@ func TestConfigEOFAfterKeyAgainstGit(t *testing.T) { cfgPath := filepath.Join(testRepo.Dir(), "config") cfgData := []byte("[Core]BAre") - if err := os.WriteFile(cfgPath, cfgData, 0o600); err != nil { + + err := os.WriteFile(cfgPath, cfgData, 0o600) + if err != nil { t.Fatalf("failed to write config: %v", err) } @@ -430,6 +485,7 @@ func TestConfigEOFAfterKeyAgainstGit(t *testing.T) { if (gitErr == nil) != (furErr == nil) { t.Fatalf("git: %v\nfur: %v", gitErr, furErr) } + if furErr != nil { return } @@ -445,7 +501,9 @@ func TestConfigNULValueAgainstGit(t *testing.T) { cfgPath := filepath.Join(testRepo.Dir(), "config") cfgData := []byte("[Core]BAre=\x00") - if err := os.WriteFile(cfgPath, cfgData, 0o600); err != nil { + + err := os.WriteFile(cfgPath, cfgData, 0o600) + if err != nil { t.Fatalf("failed to write config: %v", err) } @@ -455,6 +513,7 @@ func TestConfigNULValueAgainstGit(t *testing.T) { if (gitErr == nil) != (furErr == nil) { t.Fatalf("git: %v\nfur: %v", gitErr, furErr) } + if furErr != nil { return } @@ -470,7 +529,9 @@ func TestConfigCarriageReturnSeparatorAgainstGit(t *testing.T) { cfgPath := filepath.Join(testRepo.Dir(), "config") cfgData := []byte("[Core \"sub\"]\rBAre") - if err := os.WriteFile(cfgPath, cfgData, 0o600); err != nil { + + err := os.WriteFile(cfgPath, cfgData, 0o600) + if err != nil { t.Fatalf("failed to write config: %v", err) } @@ -480,6 +541,7 @@ func TestConfigCarriageReturnSeparatorAgainstGit(t *testing.T) { if (gitErr == nil) != (furErr == nil) { t.Fatalf("git: %v\nfur: %v", gitErr, furErr) } + if furErr != nil { return } @@ -498,11 +560,13 @@ func FuzzConfig(f *testing.F) { cfgPath := filepath.Join(testRepo.Dir(), "config") f.Fuzz(func(t *testing.T, cfgData []byte, gitKey string) { - if err := os.WriteFile(cfgPath, cfgData, 0o600); err != nil { + err := os.WriteFile(cfgPath, cfgData, 0o600) + if err != nil { t.Fatalf("failed to write config: %v", err) } gitValue, gitErr := gitConfigGetE(testRepo, gitKey) + furConfig, furErr := config.ParseConfig(bytes.NewReader(cfgData)) if furErr == nil && furConfig == nil { t.Fatalf("ParseConfig returned nil config with nil error") @@ -513,12 +577,16 @@ func FuzzConfig(f *testing.F) { if furErr == nil { return } + t.Fatalf("git: %v\nfur: %v", gitErr, furErr) } + if furErr == nil { parts := strings.SplitN(gitKey, ".", 3) furSection := parts[0] + var furSubsection, furKey string + switch len(parts) { case 1: case 2: diff --git a/diff/lines/diff.go b/diff/lines/diff.go index bdcb4d93..ca34f371 100644 --- a/diff/lines/diff.go +++ b/diff/lines/diff.go @@ -5,7 +5,7 @@ import "bytes" // Diff performs a line-based diff. // Lines are bytes up to and including '\n' (final line may lack '\n'). -func Diff(oldB, newB []byte) ([]Chunk, error) { +func Diff(oldB, newB []byte) ([]Chunk, error) { //nolint:maintidx type lineRef struct { base []byte start int @@ -16,17 +16,22 @@ func Diff(oldB, newB []byte) ([]Chunk, error) { if len(b) == 0 { return nil } + var res []lineRef + start := 0 + for i := range b { if b[i] == '\n' { res = append(res, lineRef{base: b, start: start, end: i + 1}) start = i + 1 } } + if start < len(b) { res = append(res, lineRef{base: b, start: start, end: len(b)}) } + return res } @@ -34,6 +39,7 @@ func Diff(oldB, newB []byte) ([]Chunk, error) { newLines := split(newB) n := len(oldLines) + m := len(newLines) if n == 0 && m == 0 { return nil, nil @@ -42,25 +48,32 @@ func Diff(oldB, newB []byte) ([]Chunk, error) { idOf := make(map[string]int) nextID := 0 oldIDs := make([]int, n) + for i, ln := range oldLines { key := string(ln.base[ln.start:ln.end]) + id, ok := idOf[key] if !ok { id = nextID idOf[key] = id nextID++ } + oldIDs[i] = id } + newIDs := make([]int, m) + for i, ln := range newLines { key := string(ln.base[ln.start:ln.end]) + id, ok := idOf[key] if !ok { id = nextID idOf[key] = id nextID++ } + newIDs[i] = id } @@ -74,11 +87,13 @@ func Diff(oldB, newB []byte) ([]Chunk, error) { } x0 := 0 + y0 := 0 for x0 < n && y0 < m && oldIDs[x0] == newIDs[y0] { x0++ y0++ } + Vprev[offset+0] = x0 trace = append(trace, append([]int(nil), Vprev...)) @@ -97,17 +112,20 @@ func Diff(oldB, newB []byte) ([]Chunk, error) { } else { x = Vprev[offset+(k-1)] + 1 } + y := x - k for x < n && y < m && oldIDs[x] == newIDs[y] { x++ y++ } + V[offset+k] = x if x >= n && y >= m { trace = append(trace, V) found = true + break } } @@ -122,9 +140,11 @@ func Diff(oldB, newB []byte) ([]Chunk, error) { kind ChunkKind lineref lineRef } + revEdits := make([]edit, 0, n+m) x := n + y := m for D := len(trace) - 1; D >= 0; D-- { k := x - y @@ -134,6 +154,7 @@ func Diff(oldB, newB []byte) ([]Chunk, error) { prevX int prevY int ) + if D > 0 { prevV := trace[D-1] if k == -D || (k != D && prevV[offset+(k-1)] < prevV[offset+(k+1)]) { @@ -141,6 +162,7 @@ func Diff(oldB, newB []byte) ([]Chunk, error) { } else { prevK = k - 1 } + prevX = prevV[offset+prevK] prevY = prevX - prevK } @@ -148,6 +170,7 @@ func Diff(oldB, newB []byte) ([]Chunk, error) { for x > prevX && y > prevY { x-- y-- + revEdits = append(revEdits, edit{kind: ChunkKindUnchanged, lineref: oldLines[x]}) } @@ -169,11 +192,13 @@ func Diff(oldB, newB []byte) ([]Chunk, error) { } var out []Chunk + type meta struct { base []byte start int end int } + var metas []meta for _, e := range revEdits { @@ -184,6 +209,7 @@ func Diff(oldB, newB []byte) ([]Chunk, error) { if len(out) == 0 || out[len(out)-1].Kind != e.kind { out = append(out, Chunk{Kind: e.kind, Data: curBase[curStart:curEnd]}) metas = append(metas, meta{base: curBase, start: curStart, end: curEnd}) + continue } @@ -193,6 +219,7 @@ func Diff(oldB, newB []byte) ([]Chunk, error) { if bytes.Equal(lastMeta.base, curBase) && lastMeta.end == curStart { metas[lastIdx].end = curEnd out[lastIdx].Data = curBase[metas[lastIdx].start:metas[lastIdx].end] + continue } diff --git a/diff/lines/diff_test.go b/diff/lines/diff_test.go index 7ff2c386..c5d5be9f 100644 --- a/diff/lines/diff_test.go +++ b/diff/lines/diff_test.go @@ -9,7 +9,7 @@ import ( "codeberg.org/lindenii/furgit/diff/lines" ) -func TestDiff(t *testing.T) { +func TestDiff(t *testing.T) { //nolint:maintidx t.Parallel() tests := []struct { @@ -291,6 +291,7 @@ func TestDiff(t *testing.T) { if chunks[i].Kind != tt.expected[i].Kind { t.Fatalf("chunk %d kind mismatch: got %v, want %v; chunks: %s", i, chunks[i].Kind, tt.expected[i].Kind, formatChunks(chunks)) } + if !bytes.Equal(chunks[i].Data, tt.expected[i].Data) { t.Fatalf("chunk %d data mismatch: got %q, want %q; chunks: %s", i, string(chunks[i].Data), string(tt.expected[i].Data), formatChunks(chunks)) } @@ -302,15 +303,19 @@ func TestDiff(t *testing.T) { func formatChunks(chunks []lines.Chunk) string { var b strings.Builder b.WriteByte('[') + for i, chunk := range chunks { if i > 0 { b.WriteString(", ") } + b.WriteString(chunkKindName(chunk.Kind)) b.WriteByte(':') b.WriteString(strconv.Quote(string(chunk.Data))) } + b.WriteByte(']') + return b.String() } diff --git a/diff/trees/diff.go b/diff/trees/diff.go index 836b71cc..9583c939 100644 --- a/diff/trees/diff.go +++ b/diff/trees/diff.go @@ -12,9 +12,12 @@ import ( // reaches directory entries. func Diff(a, b *object.Tree, readTree func(objectid.ObjectID) (*object.Tree, error)) ([]Entry, error) { var out []Entry - if err := diffRecursive(a, b, nil, readTree, &out); err != nil { + + err := diffRecursive(a, b, nil, readTree, &out) + if err != nil { return nil, err } + return out, nil } @@ -27,17 +30,23 @@ func diffRecursive(a, b *object.Tree, prefix []byte, readTree func(objectid.Obje for i := range b.Entries { entry := &b.Entries[i] full := joinPath(prefix, entry.Name) + *out = append(*out, Entry{Path: full, Kind: EntryKindAdded, Old: nil, New: entry}) - if entry.Mode == object.FileModeDir { - sub, err := readTree(entry.ID) - if err != nil { - return err - } - if err := diffRecursive(nil, sub, full, readTree, out); err != nil { - return err - } + if entry.Mode != object.FileModeDir { + continue + } + + sub, err := readTree(entry.ID) + if err != nil { + return err + } + + err = diffRecursive(nil, sub, full, readTree, out) + if err != nil { + return err } } + return nil } @@ -45,25 +54,33 @@ func diffRecursive(a, b *object.Tree, prefix []byte, readTree func(objectid.Obje for i := range a.Entries { entry := &a.Entries[i] full := joinPath(prefix, entry.Name) + *out = append(*out, Entry{Path: full, Kind: EntryKindDeleted, Old: entry, New: nil}) - if entry.Mode == object.FileModeDir { - sub, err := readTree(entry.ID) - if err != nil { - return err - } - if err := diffRecursive(sub, nil, full, readTree, out); err != nil { - return err - } + if entry.Mode != object.FileModeDir { + continue + } + + sub, err := readTree(entry.ID) + if err != nil { + return err + } + + err = diffRecursive(sub, nil, full, readTree, out) + if err != nil { + return err } } + return nil } i := 0 + j := 0 for i < len(a.Entries) && j < len(b.Entries) { left := &a.Entries[i] right := &b.Entries[j] + cmp := object.TreeEntryNameCompare( left.Name, left.Mode, @@ -73,49 +90,63 @@ func diffRecursive(a, b *object.Tree, prefix []byte, readTree func(objectid.Obje switch { case cmp < 0: full := joinPath(prefix, left.Name) + *out = append(*out, Entry{Path: full, Kind: EntryKindDeleted, Old: left, New: nil}) if left.Mode == object.FileModeDir { sub, err := readTree(left.ID) if err != nil { return err } - if err := diffRecursive(sub, nil, full, readTree, out); err != nil { + + err = diffRecursive(sub, nil, full, readTree, out) + if err != nil { return err } } + i++ case cmp > 0: full := joinPath(prefix, right.Name) + *out = append(*out, Entry{Path: full, Kind: EntryKindAdded, Old: nil, New: right}) if right.Mode == object.FileModeDir { sub, err := readTree(right.ID) if err != nil { return err } - if err := diffRecursive(nil, sub, full, readTree, out); err != nil { + + err = diffRecursive(nil, sub, full, readTree, out) + if err != nil { return err } } + j++ default: full := joinPath(prefix, left.Name) + modified := left.Mode != right.Mode || left.ID != right.ID if modified { *out = append(*out, Entry{Path: full, Kind: EntryKindModified, Old: left, New: right}) } + if left.Mode == object.FileModeDir && right.Mode == object.FileModeDir && left.ID != right.ID { leftSub, err := readTree(left.ID) if err != nil { return err } + rightSub, err := readTree(right.ID) if err != nil { return err } - if err := diffRecursive(leftSub, rightSub, full, readTree, out); err != nil { + + err = diffRecursive(leftSub, rightSub, full, readTree, out) + if err != nil { return err } } + i++ j++ } @@ -124,13 +155,16 @@ func diffRecursive(a, b *object.Tree, prefix []byte, readTree func(objectid.Obje for ; i < len(a.Entries); i++ { left := &a.Entries[i] full := joinPath(prefix, left.Name) + *out = append(*out, Entry{Path: full, Kind: EntryKindDeleted, Old: left, New: nil}) if left.Mode == object.FileModeDir { sub, err := readTree(left.ID) if err != nil { return err } - if err := diffRecursive(sub, nil, full, readTree, out); err != nil { + + err = diffRecursive(sub, nil, full, readTree, out) + if err != nil { return err } } @@ -139,13 +173,16 @@ func diffRecursive(a, b *object.Tree, prefix []byte, readTree func(objectid.Obje for ; j < len(b.Entries); j++ { right := &b.Entries[j] full := joinPath(prefix, right.Name) + *out = append(*out, Entry{Path: full, Kind: EntryKindAdded, Old: nil, New: right}) if right.Mode == object.FileModeDir { sub, err := readTree(right.ID) if err != nil { return err } - if err := diffRecursive(nil, sub, full, readTree, out); err != nil { + + err = diffRecursive(nil, sub, full, readTree, out) + if err != nil { return err } } diff --git a/diff/trees/diff_test.go b/diff/trees/diff_test.go index 2fb8540f..1664bdf8 100644 --- a/diff/trees/diff_test.go +++ b/diff/trees/diff_test.go @@ -157,88 +157,112 @@ type diffExpectation struct { func writeTestFile(t *testing.T, path, data string) { t.Helper() - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + + err := os.MkdirAll(filepath.Dir(path), 0o755) + if err != nil { t.Fatalf("create directory for %s: %v", path, err) } - if err := os.WriteFile(path, []byte(data), 0o644); err != nil { + + err = os.WriteFile(path, []byte(data), 0o644) + if err != nil { t.Fatalf("write %s: %v", path, err) } } func openLooseStore(t *testing.T, objectsPath string, algo objectid.Algorithm) *loose.Store { t.Helper() + root, err := os.OpenRoot(objectsPath) if err != nil { t.Fatalf("OpenRoot(%q): %v", objectsPath, err) } + t.Cleanup(func() { _ = root.Close() }) + store, err := loose.New(root, algo) if err != nil { t.Fatalf("loose.New: %v", err) } + t.Cleanup(func() { _ = store.Close() }) + return store } func makeReadTree(t *testing.T, store *loose.Store, algo objectid.Algorithm) func(objectid.ObjectID) (*object.Tree, error) { t.Helper() + return func(id objectid.ObjectID) (*object.Tree, error) { ty, content, err := store.ReadBytesContent(id) if err != nil { return nil, err } + if ty != objecttype.TypeTree { return nil, errors.New("diff/trees test: object is not a tree") } + return object.ParseTree(content, algo) } } func mustReadTree(t *testing.T, readTree func(objectid.ObjectID) (*object.Tree, error), id objectid.ObjectID) *object.Tree { t.Helper() + tree, err := readTree(id) if err != nil { t.Fatalf("read tree %s: %v", id, err) } + return tree } func parseID(t *testing.T, algo objectid.Algorithm, hex string) objectid.ObjectID { t.Helper() + id, err := objectid.ParseHex(algo, hex) if err != nil { t.Fatalf("parse object id %q: %v", hex, err) } + return id } func checkDiffs(t *testing.T, diffs []trees.Entry, expected map[string]diffExpectation) { t.Helper() + got := make(map[string]trees.Entry, len(diffs)) for _, diff := range diffs { path := string(diff.Path) if _, exists := got[path]; exists { t.Fatalf("duplicate diff path %q", path) } + got[path] = diff } + if len(got) != len(expected) { t.Fatalf("diff count = %d, want %d", len(got), len(expected)) } + for path, want := range expected { diff, ok := got[path] if !ok { t.Fatalf("missing diff for %q", path) } + if diff.Kind != want.kind { t.Errorf("%s kind = %v, want %v", path, diff.Kind, want.kind) } + if (diff.Old == nil) != want.oldNil { t.Errorf("%s old nil = %v, want %v", path, diff.Old == nil, want.oldNil) } + if (diff.New == nil) != want.newNil { t.Errorf("%s new nil = %v, want %v", path, diff.New == nil, want.newNil) } + if diff.Kind == trees.EntryKindModified && diff.Old != nil && diff.New != nil && diff.Old.ID == diff.New.ID { t.Errorf("%s modified entry should change IDs", path) } diff --git a/diff/trees/path.go b/diff/trees/path.go index 0ced379a..e40f3de5 100644 --- a/diff/trees/path.go +++ b/diff/trees/path.go @@ -4,11 +4,14 @@ func joinPath(prefix, name []byte) []byte { if len(prefix) == 0 { out := make([]byte, len(name)) copy(out, name) + return out } + out := make([]byte, len(prefix)+1+len(name)) copy(out, prefix) out[len(prefix)] = '/' copy(out[len(prefix)+1:], name) + return out } diff --git a/format/delta/apply/apply.go b/format/delta/apply/apply.go index f9f2fbaf..cd53f837 100644 --- a/format/delta/apply/apply.go +++ b/format/delta/apply/apply.go @@ -6,101 +6,129 @@ import "fmt" // Apply applies one Git delta instruction stream to base. func Apply(base, delta []byte) ([]byte, error) { pos := 0 + srcSize, err := readVarint(delta, &pos) if err != nil { return nil, err } + dstSize, err := readVarint(delta, &pos) if err != nil { return nil, err } + if srcSize != len(base) { return nil, fmt.Errorf("format/delta/apply: delta source size mismatch: got %d want %d", srcSize, len(base)) } out := make([]byte, dstSize) outPos := 0 + for pos < len(delta) { op := delta[pos] pos++ + + //nolint:nestif if op&0x80 != 0 { off := 0 + if op&0x01 != 0 { if pos >= len(delta) { return nil, fmt.Errorf("format/delta/apply: malformed delta copy offset") } + off |= int(delta[pos]) pos++ } + if op&0x02 != 0 { if pos >= len(delta) { return nil, fmt.Errorf("format/delta/apply: malformed delta copy offset") } + off |= int(delta[pos]) << 8 pos++ } + if op&0x04 != 0 { if pos >= len(delta) { return nil, fmt.Errorf("format/delta/apply: malformed delta copy offset") } + off |= int(delta[pos]) << 16 pos++ } + if op&0x08 != 0 { if pos >= len(delta) { return nil, fmt.Errorf("format/delta/apply: malformed delta copy offset") } + off |= int(delta[pos]) << 24 pos++ } n := 0 + if op&0x10 != 0 { if pos >= len(delta) { return nil, fmt.Errorf("format/delta/apply: malformed delta copy size") } + n |= int(delta[pos]) pos++ } + if op&0x20 != 0 { if pos >= len(delta) { return nil, fmt.Errorf("format/delta/apply: malformed delta copy size") } + n |= int(delta[pos]) << 8 pos++ } + if op&0x40 != 0 { if pos >= len(delta) { return nil, fmt.Errorf("format/delta/apply: malformed delta copy size") } + n |= int(delta[pos]) << 16 pos++ } + if n == 0 { n = 0x10000 } + if off < 0 || n < 0 || off+n > len(base) || outPos+n > len(out) { return nil, fmt.Errorf("format/delta/apply: delta copy out of bounds") } + copy(out[outPos:outPos+n], base[off:off+n]) outPos += n + continue } if op == 0 { return nil, fmt.Errorf("format/delta/apply: invalid delta opcode 0") } + n := int(op) if pos+n > len(delta) || outPos+n > len(out) { return nil, fmt.Errorf("format/delta/apply: delta insert out of bounds") } + copy(out[outPos:outPos+n], delta[pos:pos+n]) outPos += n pos += n } + if outPos != len(out) { return nil, fmt.Errorf("format/delta/apply: delta output size mismatch: got %d want %d", outPos, len(out)) } + return out, nil } @@ -108,20 +136,25 @@ func Apply(base, delta []byte) ([]byte, error) { func readVarint(buf []byte, pos *int) (int, error) { value := 0 shift := uint(0) + for { if *pos >= len(buf) { return 0, fmt.Errorf("format/delta/apply: malformed delta varint") } + b := buf[*pos] *pos++ + value |= int(b&0x7f) << shift if b&0x80 == 0 { break } + shift += 7 if shift > 63 { return 0, fmt.Errorf("format/delta/apply: delta varint overflow") } } + return value, nil } diff --git a/format/delta/apply/header.go b/format/delta/apply/header.go index 996b006b..dbd29550 100644 --- a/format/delta/apply/header.go +++ b/format/delta/apply/header.go @@ -14,10 +14,12 @@ func ReadHeaderSizes(reader io.ByteReader) (int, int, error) { if err != nil { return 0, 0, err } + dstSize, err := readVarintFromByteReader(reader) if err != nil { return 0, 0, err } + return srcSize, dstSize, nil } @@ -25,15 +27,18 @@ func ReadHeaderSizes(reader io.ByteReader) (int, int, error) { func readVarintFromByteReader(reader io.ByteReader) (int, error) { value := 0 shift := uint(0) + for { b, err := reader.ReadByte() if err != nil { return 0, fmt.Errorf("format/delta/apply: malformed delta varint: %w", err) } + value |= int(b&0x7f) << shift if b&0x80 == 0 { return value, nil } + shift += 7 if shift > 63 { return 0, fmt.Errorf("format/delta/apply: delta varint overflow") diff --git a/format/pack/entry.go b/format/pack/entry.go index b95ad0ac..93d232a2 100644 --- a/format/pack/entry.go +++ b/format/pack/entry.go @@ -31,19 +31,23 @@ func ParseEntryHeader(data []byte) (EntryHeader, error) { } shift := uint(4) + b := first for b&0x80 != 0 { if header.HeaderSize >= len(data) { return zero, fmt.Errorf("format/pack: truncated entry header") } + b = data[header.HeaderSize] header.HeaderSize++ header.Size |= int64(b&0x7f) << shift shift += 7 } + if header.Size < 0 { return zero, fmt.Errorf("format/pack: negative entry size") } + return header, nil } @@ -73,6 +77,7 @@ func ParseEntry(data []byte, hashSize int) (Entry, error) { if err != nil { return zero, err } + entry := Entry{ Type: header.Type, Size: header.Size, @@ -86,10 +91,12 @@ func ParseEntry(data []byte, hashSize int) (Entry, error) { if hashSize <= 0 { return zero, fmt.Errorf("format/pack: invalid hash size %d", hashSize) } + end := entry.DataOffset + hashSize if end > len(data) { return zero, fmt.Errorf("format/pack: truncated ref-delta base id") } + entry.RefBaseID = data[entry.DataOffset:end] entry.DataOffset = end case objecttype.TypeOfsDelta: @@ -97,6 +104,7 @@ func ParseEntry(data []byte, hashSize int) (Entry, error) { if err != nil { return zero, err } + entry.OfsBaseDistance = dist entry.DataOffset += consumed case objecttype.TypeInvalid, objecttype.TypeFuture: @@ -108,5 +116,6 @@ func ParseEntry(data []byte, hashSize int) (Entry, error) { if entry.DataOffset > len(data) { return zero, fmt.Errorf("format/pack: entry data offset out of bounds") } + return entry, nil } diff --git a/format/pack/pack.go b/format/pack/pack.go index 45fe6a1c..e87e3360 100644 --- a/format/pack/pack.go +++ b/format/pack/pack.go @@ -33,16 +33,20 @@ func ParseOfsDeltaDistance(buf []byte) (uint64, int, error) { if len(buf) == 0 { return 0, 0, fmt.Errorf("format/pack: malformed ofs-delta distance") } + b := buf[0] dist := uint64(b & 0x7f) + consumed := 1 for b&0x80 != 0 { if consumed >= len(buf) { return 0, 0, fmt.Errorf("format/pack: malformed ofs-delta distance") } + b = buf[consumed] consumed++ dist = ((dist + 1) << 7) + uint64(b&0x7f) } + return dist, consumed, nil } diff --git a/internal/adler32/adler32_amd64.go b/internal/adler32/adler32_amd64.go index cb67f21c..7dfab299 100644 --- a/internal/adler32/adler32_amd64.go +++ b/internal/adler32/adler32_amd64.go @@ -27,8 +27,10 @@ func New() hash.Hash32 { if !hasAVX2 { return adler32.New() } + d := new(digest) d.Reset() + return d } @@ -36,6 +38,7 @@ func (d *digest) MarshalBinary() ([]byte, error) { b := make([]byte, 0, marshaledSize) b = append(b, magic...) b = binary.BigEndian.AppendUint32(b, uint32(*d)) + return b, nil } @@ -43,10 +46,13 @@ func (d *digest) UnmarshalBinary(b []byte) error { if len(b) < len(magic) || string(b[:len(magic)]) != magic { return errors.New("hash/adler32: invalid hash state identifier") } + if len(b) != marshaledSize { return errors.New("hash/adler32: invalid hash state size") } + *d = digest(binary.BigEndian.Uint32(b[len(magic):])) + return nil } @@ -62,6 +68,7 @@ func (d *digest) Write(data []byte) (nn int, err error) { h := update(uint32(*d), data) *d = digest(h) } + return len(data), nil } @@ -76,5 +83,6 @@ func Checksum(data []byte) uint32 { if hasAVX2 && len(data) >= 64 { return adler32_avx2(1, data) } + return adler32.Checksum(data) } diff --git a/internal/adler32/adler32_generic.go b/internal/adler32/adler32_generic.go index 0908d8f7..56e3ff8b 100644 --- a/internal/adler32/adler32_generic.go +++ b/internal/adler32/adler32_generic.go @@ -16,11 +16,13 @@ const ( // Add p to the running checksum d. func update(d uint32, p []byte) uint32 { s1, s2 := d&0xffff, d>>16 + for len(p) > 0 { var q []byte if len(p) > nmax { p, q = p[:nmax], p[nmax:] } + for len(p) >= 4 { s1 += uint32(p[0]) s2 += s1 @@ -32,13 +34,16 @@ func update(d uint32, p []byte) uint32 { s2 += s1 p = p[4:] } + for _, x := range p { s1 += uint32(x) s2 += s1 } + s1 %= mod s2 %= mod p = q } + return s2<<16 | s1 } diff --git a/internal/adler32/bench_test.go b/internal/adler32/bench_test.go index 6c6f75ea..d2aebe8f 100644 --- a/internal/adler32/bench_test.go +++ b/internal/adler32/bench_test.go @@ -10,7 +10,7 @@ const benchmarkSize = 64 * 1024 var data = make([]byte, benchmarkSize) -func init() { +func init() { //nolint:gochecknoinits for i := range benchmarkSize { data[i] = byte(i % 256) } @@ -18,6 +18,7 @@ func init() { func BenchmarkChecksum(b *testing.B) { b.ReportAllocs() + for b.Loop() { adler32.Checksum(data) } diff --git a/internal/bufpool/buffers.go b/internal/bufpool/buffers.go index a5c27b67..91e30a31 100644 --- a/internal/bufpool/buffers.go +++ b/internal/bufpool/buffers.go @@ -62,9 +62,11 @@ var bufferPools = func() []sync.Pool { capCopy := classCap pools[i].New = func() any { buf := make([]byte, 0, capCopy) + return &buf } } + return pools }() @@ -80,9 +82,11 @@ func Borrow(capHint int) Buffer { if capHint < DefaultBufferCap { capHint = DefaultBufferCap } + classIdx, classCap, pooled := classFor(capHint) if !pooled { newBuf := make([]byte, 0, capHint) + return Buffer{buf: newBuf, pool: unpooled} } //nolint:forcetypeassert @@ -90,7 +94,9 @@ func Borrow(capHint int) Buffer { if cap(*buf) < classCap { *buf = make([]byte, 0, classCap) } + slice := (*buf)[:0] + return Buffer{buf: slice, pool: poolIndex(classIdx)} //#nosec G115 } @@ -110,6 +116,7 @@ func (buf *Buffer) Resize(n int) { if n < 0 { n = 0 } + buf.ensureCapacity(n) buf.buf = buf.buf[:n] } @@ -122,6 +129,7 @@ func (buf *Buffer) Append(src []byte) { if len(src) == 0 { return } + start := len(buf.buf) buf.ensureCapacity(start + len(src)) buf.buf = buf.buf[:start+len(src)] @@ -144,6 +152,7 @@ func (buf *Buffer) Release() { if buf.buf == nil { return } + buf.returnToPool() buf.buf = nil buf.pool = unpooled @@ -157,20 +166,26 @@ func (buf *Buffer) ensureCapacity(needed int) { if cap(buf.buf) >= needed { return } + classIdx, classCap, pooled := classFor(needed) + var newBuf []byte + if pooled { //nolint:forcetypeassert raw := bufferPools[classIdx].Get().(*[]byte) if cap(*raw) < classCap { *raw = make([]byte, 0, classCap) } + newBuf = (*raw)[:len(buf.buf)] } else { newBuf = make([]byte, len(buf.buf), classCap) } + copy(newBuf, buf.buf) buf.returnToPool() + buf.buf = newBuf if pooled { buf.pool = poolIndex(classIdx) //#nosec G115 @@ -185,6 +200,7 @@ func classFor(size int) (idx, classCap int, ok bool) { return i, class, true } } + return -1, size, false } @@ -192,6 +208,7 @@ func (buf *Buffer) returnToPool() { if buf.pool == unpooled { return } + tmp := buf.buf[:0] bufferPools[int(buf.pool)].Put(&tmp) } diff --git a/internal/bufpool/buffers_test.go b/internal/bufpool/buffers_test.go index 70861d33..224fa98c 100644 --- a/internal/bufpool/buffers_test.go +++ b/internal/bufpool/buffers_test.go @@ -15,19 +15,23 @@ func TestBorrowBufferResizeAndAppend(t *testing.T) { b.Append([]byte("alpha")) b.Append([]byte("beta")) + if got := string(b.Bytes()); got != "alphabeta" { t.Fatalf("unexpected contents: %q", got) } b.Resize(3) + if got := string(b.Bytes()); got != "alp" { t.Fatalf("resize shrink mismatch: %q", got) } b.Resize(8) + if len(b.Bytes()) != 8 { t.Fatalf("expected len 8 after grow, got %d", len(b.Bytes())) } + if prefix := string(b.Bytes()[:3]); prefix != "alp" { t.Fatalf("prefix lost after grow: %q", prefix) } @@ -39,6 +43,7 @@ func TestBorrowBufferRelease(t *testing.T) { b := Borrow(DefaultBufferCap / 2) b.Append([]byte("data")) b.Release() + if b.buf != nil { t.Fatal("expected buffer cleared after release") } @@ -59,9 +64,11 @@ func TestBorrowUsesLargerPools(t *testing.T) { if b.pool != poolIndex(classIdx) { t.Fatalf("expected pooled buffer in class %d, got %d", classIdx, b.pool) } + if cap(b.buf) != classCap { t.Fatalf("expected capacity %d, got %d", classCap, cap(b.buf)) } + b.Release() b2 := Borrow(request) @@ -70,6 +77,7 @@ func TestBorrowUsesLargerPools(t *testing.T) { if b2.pool != poolIndex(classIdx) { t.Fatalf("expected pooled buffer in class %d on reuse, got %d", classIdx, b2.pool) } + if cap(b2.buf) != classCap { t.Fatalf("expected capacity %d on reuse, got %d", classCap, cap(b2.buf)) } @@ -82,6 +90,7 @@ func TestGrowingBufferStaysPooled(t *testing.T) { defer b.Release() b.Append(make([]byte, DefaultBufferCap*3)) + if b.pool == unpooled { t.Fatal("buffer should stay pooled after growth within limit") } diff --git a/internal/intconv/intconv.go b/internal/intconv/intconv.go index 8bc77d8e..67f99a14 100644 --- a/internal/intconv/intconv.go +++ b/internal/intconv/intconv.go @@ -11,6 +11,7 @@ func Uint64ToInt(v uint64) (int, error) { if v > uint64(math.MaxInt) { return 0, fmt.Errorf("intconv: uint64 %d overflows int", v) } + return int(v), nil } @@ -19,6 +20,7 @@ func UintptrToInt(v uintptr) (int, error) { if v > uintptr(math.MaxInt) { return 0, fmt.Errorf("intconv: uintptr %d overflows int", v) } + return int(v), nil } @@ -27,6 +29,7 @@ func IntToUint64(v int) (uint64, error) { if v < 0 { return 0, fmt.Errorf("intconv: int %d is negative", v) } + return uint64(v), nil } @@ -35,5 +38,6 @@ func Int64ToInt32(v int64) (int32, error) { if v < math.MinInt32 || v > math.MaxInt32 { return 0, fmt.Errorf("intconv: int64 %d overflows int32", v) } + return int32(v), nil } diff --git a/internal/iolimit/expect_length_reader.go b/internal/iolimit/expect_length_reader.go index 477c207f..288e0e62 100644 --- a/internal/iolimit/expect_length_reader.go +++ b/internal/iolimit/expect_length_reader.go @@ -39,13 +39,16 @@ func (reader *expectLengthReader) Read(dst []byte) (int, error) { if reader.remaining == 0 { var probe [1]byte + n, err := reader.src.Read(probe[:]) if n > 0 { return 0, ErrExpectedLengthExceeded } + if err == nil { return 0, nil } + return 0, err } @@ -66,9 +69,11 @@ func (reader *expectLengthReader) Read(dst []byte) (int, error) { if reader.remaining > 0 { return n, io.ErrUnexpectedEOF } + if n > 0 { return n, nil } + return 0, io.EOF } diff --git a/internal/iolimit/expect_length_reader_test.go b/internal/iolimit/expect_length_reader_test.go index 503c88ed..e2cfeab0 100644 --- a/internal/iolimit/expect_length_reader_test.go +++ b/internal/iolimit/expect_length_reader_test.go @@ -13,15 +13,18 @@ func TestExpectLengthReaderExact(t *testing.T) { t.Parallel() r := iolimit.ExpectLengthReader(bytes.NewReader([]byte("hello")), 5) + got, err := io.ReadAll(r) if err != nil { t.Fatalf("ReadAll error: %v", err) } + if !bytes.Equal(got, []byte("hello")) { t.Fatalf("ReadAll = %q, want %q", got, "hello") } buf := make([]byte, 1) + n, err := r.Read(buf) if n != 0 || !errors.Is(err, io.EOF) { t.Fatalf("post-boundary Read = (%d,%v), want (0,EOF)", n, err) @@ -32,6 +35,7 @@ func TestExpectLengthReaderShort(t *testing.T) { t.Parallel() r := iolimit.ExpectLengthReader(bytes.NewReader([]byte("hey")), 5) + _, err := io.ReadAll(r) if !errors.Is(err, io.ErrUnexpectedEOF) { t.Fatalf("ReadAll error = %v, want ErrUnexpectedEOF", err) @@ -43,15 +47,18 @@ func TestExpectLengthReaderLongDetectedOnNextRead(t *testing.T) { r := iolimit.ExpectLengthReader(bytes.NewReader([]byte("hello!")), 5) buf := make([]byte, 5) + n, err := io.ReadFull(r, buf) if err != nil { t.Fatalf("ReadFull error: %v", err) } + if n != 5 || !bytes.Equal(buf, []byte("hello")) { t.Fatalf("ReadFull = (%d,%q), want (5,hello)", n, buf) } probe := make([]byte, 1) + n, err = r.Read(probe) if n != 0 || !errors.Is(err, iolimit.ErrExpectedLengthExceeded) { t.Fatalf("overflow Read = (%d,%v), want (0,ErrExpectedLengthExceeded)", n, err) @@ -63,6 +70,7 @@ func TestExpectLengthReaderEmptyExpected(t *testing.T) { r := iolimit.ExpectLengthReader(bytes.NewReader(nil), 0) buf := make([]byte, 1) + n, err := r.Read(buf) if n != 0 || !errors.Is(err, io.EOF) { t.Fatalf("Read = (%d,%v), want (0,EOF)", n, err) diff --git a/internal/lru/lru.go b/internal/lru/lru.go index 585aaa3f..fcbab646 100644 --- a/internal/lru/lru.go +++ b/internal/lru/lru.go @@ -39,9 +39,11 @@ func New[K comparable, V any](maxWeight int64, weightFn WeightFunc[K, V], onEvic if maxWeight < 0 { panic("lru: negative max weight") } + if weightFn == nil { panic("lru: nil weight function") } + return &Cache[K, V]{ maxWeight: maxWeight, weightFn: weightFn, @@ -61,6 +63,7 @@ func (cache *Cache[K, V]) Add(key K, value V) bool { if w < 0 { panic("lru: negative entry weight") } + if w > cache.maxWeight { return false } @@ -79,6 +82,7 @@ func (cache *Cache[K, V]) Add(key K, value V) bool { cache.weight += w cache.evictOverBudget() + return true } @@ -87,8 +91,10 @@ func (cache *Cache[K, V]) Get(key K) (V, bool) { elem, ok := cache.items[key] if !ok { var zero V + return zero, false } + cache.lru.MoveToBack(elem) //nolint:forcetypeassert return elem.Value.(*entry[K, V]).value, true @@ -99,6 +105,7 @@ func (cache *Cache[K, V]) Peek(key K) (V, bool) { elem, ok := cache.items[key] if !ok { var zero V + return zero, false } //nolint:forcetypeassert @@ -110,9 +117,12 @@ func (cache *Cache[K, V]) Remove(key K) (V, bool) { elem, ok := cache.items[key] if !ok { var zero V + return zero, false } + ent := cache.removeElem(elem) + return ent.value, true } @@ -148,6 +158,7 @@ func (cache *Cache[K, V]) SetMaxWeight(maxWeight int64) { if maxWeight < 0 { panic("lru: negative max weight") } + cache.maxWeight = maxWeight cache.evictOverBudget() } @@ -158,6 +169,7 @@ func (cache *Cache[K, V]) evictOverBudget() { if elem == nil { return } + cache.removeElem(elem) } } @@ -167,9 +179,11 @@ func (cache *Cache[K, V]) removeElem(elem *list.Element) *entry[K, V] { ent := elem.Value.(*entry[K, V]) cache.lru.Remove(elem) delete(cache.items, ent.key) + cache.weight -= ent.weight if cache.onEvict != nil { cache.onEvict(ent.key, ent.value) } + return ent } diff --git a/internal/lru/lru_test.go b/internal/lru/lru_test.go index adfec403..006a32b8 100644 --- a/internal/lru/lru_test.go +++ b/internal/lru/lru_test.go @@ -27,9 +27,11 @@ func TestCacheEvictsLRUAndGetUpdatesRecency(t *testing.T) { if _, ok := cache.Peek("a"); ok { t.Fatalf("expected a to be evicted") } + if _, ok := cache.Peek("b"); !ok { t.Fatalf("expected b to be present") } + if _, ok := cache.Peek("c"); !ok { t.Fatalf("expected c to be present") } @@ -37,14 +39,17 @@ func TestCacheEvictsLRUAndGetUpdatesRecency(t *testing.T) { if _, ok := cache.Get("b"); !ok { t.Fatalf("Get(b) should hit") } + cache.Add("d", testValue{weight: 4, label: "d"}) if _, ok := cache.Peek("c"); ok { t.Fatalf("expected c to be evicted after b was touched") } + if _, ok := cache.Peek("b"); !ok { t.Fatalf("expected b to remain present") } + if _, ok := cache.Peek("d"); !ok { t.Fatalf("expected d to be present") } @@ -60,11 +65,13 @@ func TestCachePeekDoesNotUpdateRecency(t *testing.T) { if _, ok := cache.Peek("a"); !ok { t.Fatalf("Peek(a) should hit") } + cache.Add("c", testValue{weight: 2, label: "c"}) if _, ok := cache.Peek("a"); ok { t.Fatalf("expected a to be evicted; Peek must not update recency") } + if _, ok := cache.Peek("b"); !ok { t.Fatalf("expected b to remain present") } @@ -74,6 +81,7 @@ func TestCacheReplaceAndResize(t *testing.T) { t.Parallel() var evicted []string + cache := lru.New[string, testValue](10, weightFn, func(key string, value testValue) { evicted = append(evicted, key+":"+value.label) }) @@ -85,17 +93,21 @@ func TestCacheReplaceAndResize(t *testing.T) { if cache.Weight() != 10 { t.Fatalf("Weight() = %d, want 10", cache.Weight()) } + if got, ok := cache.Peek("a"); !ok || got.label != "new" { t.Fatalf("Peek(a) = (%+v,%v), want new,true", got, ok) } + if !slices.Equal(evicted, []string{"a:old"}) { t.Fatalf("evicted = %v, want [a:old]", evicted) } cache.SetMaxWeight(8) + if _, ok := cache.Peek("b"); ok { t.Fatalf("expected b to be evicted after shrinking max weight") } + if !slices.Equal(evicted, []string{"a:old", "b:b"}) { t.Fatalf("evicted = %v, want [a:old b:b]", evicted) } @@ -105,6 +117,7 @@ func TestCacheRejectsOversizedWithoutMutation(t *testing.T) { t.Parallel() var evicted []string + cache := lru.New[string, testValue](5, weightFn, func(key string, value testValue) { evicted = append(evicted, key) }) @@ -113,12 +126,15 @@ func TestCacheRejectsOversizedWithoutMutation(t *testing.T) { if ok := cache.Add("b", testValue{weight: 6, label: "b"}); ok { t.Fatalf("Add oversized should return false") } + if got, ok := cache.Peek("a"); !ok || got.label != "a" { t.Fatalf("cache should remain unchanged after oversized add") } + if cache.Weight() != 3 { t.Fatalf("Weight() = %d, want 3", cache.Weight()) } + if len(evicted) != 0 { t.Fatalf("evicted = %v, want none", evicted) } @@ -126,9 +142,11 @@ func TestCacheRejectsOversizedWithoutMutation(t *testing.T) { if ok := cache.Add("a", testValue{weight: 6, label: "new"}); ok { t.Fatalf("oversized replace should return false") } + if got, ok := cache.Peek("a"); !ok || got.label != "a" { t.Fatalf("existing key should remain unchanged after oversized replace") } + if len(evicted) != 0 { t.Fatalf("evicted = %v, want none", evicted) } @@ -138,6 +156,7 @@ func TestCacheRemoveAndClear(t *testing.T) { t.Parallel() var evicted []string + cache := lru.New[string, testValue](10, weightFn, func(key string, value testValue) { evicted = append(evicted, key) }) @@ -150,11 +169,13 @@ func TestCacheRemoveAndClear(t *testing.T) { if !ok || removed.label != "b" { t.Fatalf("Remove(b) = (%+v,%v), want b,true", removed, ok) } + if cache.Len() != 2 || cache.Weight() != 6 { t.Fatalf("post-remove Len/Weight = %d/%d, want 2/6", cache.Len(), cache.Weight()) } cache.Clear() + if cache.Len() != 0 || cache.Weight() != 0 { t.Fatalf("post-clear Len/Weight = %d/%d, want 0/0", cache.Len(), cache.Weight()) } @@ -170,45 +191,55 @@ func TestCachePanicsForInvalidConfiguration(t *testing.T) { t.Run("negative max", func(t *testing.T) { t.Parallel() + defer func() { if recover() == nil { t.Fatalf("expected panic") } }() + _ = lru.New[string, testValue](-1, weightFn, nil) }) t.Run("nil weight function", func(t *testing.T) { t.Parallel() + defer func() { if recover() == nil { t.Fatalf("expected panic") } }() + _ = lru.New[string, testValue](1, nil, nil) }) t.Run("negative entry weight", func(t *testing.T) { t.Parallel() + cache := lru.New[string, testValue](10, func(_ string, _ testValue) int64 { return -1 }, nil) + defer func() { if recover() == nil { t.Fatalf("expected panic") } }() + cache.Add("x", testValue{weight: 1, label: "x"}) }) t.Run("set negative max", func(t *testing.T) { t.Parallel() + cache := lru.New[string, testValue](10, weightFn, nil) + defer func() { if recover() == nil { t.Fatalf("expected panic") } }() + cache.SetMaxWeight(-1) }) } diff --git a/internal/testgit/algorithms.go b/internal/testgit/algorithms.go index 81af4f75..5534aad0 100644 --- a/internal/testgit/algorithms.go +++ b/internal/testgit/algorithms.go @@ -9,6 +9,7 @@ import ( // ForEachAlgorithm runs a subtest for every supported algorithm. func ForEachAlgorithm(t *testing.T, fn func(t *testing.T, algo objectid.Algorithm)) { t.Helper() + for _, algo := range objectid.SupportedAlgorithms() { t.Run(algo.String(), func(t *testing.T) { fn(t, algo) diff --git a/internal/testgit/repo_cat_file.go b/internal/testgit/repo_cat_file.go index 9cc56db6..1325cf6f 100644 --- a/internal/testgit/repo_cat_file.go +++ b/internal/testgit/repo_cat_file.go @@ -9,5 +9,6 @@ import ( // CatFile returns raw output from git cat-file. func (testRepo *TestRepo) CatFile(tb testing.TB, mode string, id objectid.ObjectID) []byte { tb.Helper() + return testRepo.RunBytes(tb, "cat-file", mode, id.String()) } diff --git a/internal/testgit/repo_commit_tree.go b/internal/testgit/repo_commit_tree.go index 763474c2..5eee21ba 100644 --- a/internal/testgit/repo_commit_tree.go +++ b/internal/testgit/repo_commit_tree.go @@ -9,16 +9,21 @@ import ( // CommitTree creates a commit from a tree and message, optionally with parents. func (testRepo *TestRepo) CommitTree(tb testing.TB, tree objectid.ObjectID, message string, parents ...objectid.ObjectID) objectid.ObjectID { tb.Helper() + args := make([]string, 0, 2+2*len(parents)+2) + args = append(args, "commit-tree", tree.String()) for _, p := range parents { args = append(args, "-p", p.String()) } + args = append(args, "-m", message) hex := testRepo.Run(tb, args...) + id, err := objectid.ParseHex(testRepo.algo, hex) if err != nil { tb.Fatalf("parse commit-tree output %q: %v", hex, err) } + return id } diff --git a/internal/testgit/repo_hash_object.go b/internal/testgit/repo_hash_object.go index 10a05381..bc2def72 100644 --- a/internal/testgit/repo_hash_object.go +++ b/internal/testgit/repo_hash_object.go @@ -10,9 +10,11 @@ import ( func (testRepo *TestRepo) HashObject(tb testing.TB, objType string, body []byte) objectid.ObjectID { tb.Helper() hex := testRepo.RunInput(tb, body, "hash-object", "-t", objType, "-w", "--stdin") + id, err := objectid.ParseHex(testRepo.algo, hex) if err != nil { tb.Fatalf("parse git hash-object output %q: %v", hex, err) } + return id } diff --git a/internal/testgit/repo_make_commit.go b/internal/testgit/repo_make_commit.go index a569dfb1..c8bdc428 100644 --- a/internal/testgit/repo_make_commit.go +++ b/internal/testgit/repo_make_commit.go @@ -11,5 +11,6 @@ func (testRepo *TestRepo) MakeCommit(tb testing.TB, message string) (objectid.Ob tb.Helper() blobID, treeID := testRepo.MakeSingleFileTree(tb, "file.txt", []byte("commit-body\n")) commitID := testRepo.CommitTree(tb, treeID, message) + return blobID, treeID, commitID } diff --git a/internal/testgit/repo_make_single_file_tree.go b/internal/testgit/repo_make_single_file_tree.go index 7c53c658..e7a235a7 100644 --- a/internal/testgit/repo_make_single_file_tree.go +++ b/internal/testgit/repo_make_single_file_tree.go @@ -13,5 +13,6 @@ func (testRepo *TestRepo) MakeSingleFileTree(tb testing.TB, fileName string, fil blobID := testRepo.HashObject(tb, "blob", fileContent) treeInput := fmt.Sprintf("100644 blob %s\t%s\n", blobID.String(), fileName) treeID := testRepo.Mktree(tb, treeInput) + return blobID, treeID } diff --git a/internal/testgit/repo_mktree.go b/internal/testgit/repo_mktree.go index 34e6388d..565a0083 100644 --- a/internal/testgit/repo_mktree.go +++ b/internal/testgit/repo_mktree.go @@ -10,9 +10,11 @@ import ( func (testRepo *TestRepo) Mktree(tb testing.TB, input string) objectid.ObjectID { tb.Helper() hex := testRepo.RunInput(tb, []byte(input), "mktree") + id, err := objectid.ParseHex(testRepo.algo, hex) if err != nil { tb.Fatalf("parse mktree output %q: %v", hex, err) } + return id } diff --git a/internal/testgit/repo_new.go b/internal/testgit/repo_new.go index 8120a9a2..8a71e406 100644 --- a/internal/testgit/repo_new.go +++ b/internal/testgit/repo_new.go @@ -21,6 +21,7 @@ type RepoOptions struct { // NewRepo creates a temporary repository initialized with the requested options. func NewRepo(tb testing.TB, opts RepoOptions) *TestRepo { tb.Helper() + algo := opts.ObjectFormat if algo.Size() == 0 { tb.Fatalf("invalid algorithm: %v", algo) @@ -47,10 +48,13 @@ func NewRepo(tb testing.TB, opts RepoOptions) *TestRepo { if opts.Bare { args = append(args, "--bare") } + if opts.RefFormat != "" { args = append(args, "--ref-format="+opts.RefFormat) } + args = append(args, dir) testRepo.runBytes(tb, nil, "", args...) + return testRepo } diff --git a/internal/testgit/repo_refs.go b/internal/testgit/repo_refs.go index eb09a78b..66e08561 100644 --- a/internal/testgit/repo_refs.go +++ b/internal/testgit/repo_refs.go @@ -28,6 +28,7 @@ func (testRepo *TestRepo) SymbolicRef(tb testing.TB, name, target string) { // PackRefs runs git pack-refs with args. func (testRepo *TestRepo) PackRefs(tb testing.TB, args ...string) { tb.Helper() + cmd := append([]string{"pack-refs"}, args...) testRepo.Run(tb, cmd...) } @@ -35,10 +36,13 @@ func (testRepo *TestRepo) PackRefs(tb testing.TB, args ...string) { // ShowRef returns lines from git show-ref output. func (testRepo *TestRepo) ShowRef(tb testing.TB, args ...string) []string { tb.Helper() + cmd := append([]string{"show-ref"}, args...) + out := testRepo.Run(tb, cmd...) if strings.TrimSpace(out) == "" { return nil } + return strings.Split(strings.TrimSpace(out), "\n") } diff --git a/internal/testgit/repo_repack.go b/internal/testgit/repo_repack.go index 29fa8a4f..7773ac13 100644 --- a/internal/testgit/repo_repack.go +++ b/internal/testgit/repo_repack.go @@ -5,6 +5,7 @@ import "testing" // Repack runs "git repack" with args in the repository. func (testRepo *TestRepo) Repack(tb testing.TB, args ...string) { tb.Helper() + cmdArgs := make([]string, 0, len(args)+1) cmdArgs = append(cmdArgs, "repack") cmdArgs = append(cmdArgs, args...) diff --git a/internal/testgit/repo_rev_parse.go b/internal/testgit/repo_rev_parse.go index bebdfa8e..3bee6108 100644 --- a/internal/testgit/repo_rev_parse.go +++ b/internal/testgit/repo_rev_parse.go @@ -10,9 +10,11 @@ import ( func (testRepo *TestRepo) RevParse(tb testing.TB, spec string) objectid.ObjectID { tb.Helper() hex := testRepo.Run(tb, "rev-parse", spec) + id, err := objectid.ParseHex(testRepo.algo, hex) if err != nil { tb.Fatalf("parse rev-parse output %q: %v", hex, err) } + return id } diff --git a/internal/testgit/repo_run.go b/internal/testgit/repo_run.go index 8022835e..162a0d72 100644 --- a/internal/testgit/repo_run.go +++ b/internal/testgit/repo_run.go @@ -11,12 +11,14 @@ import ( func (testRepo *TestRepo) Run(tb testing.TB, args ...string) string { tb.Helper() out := testRepo.runBytes(tb, nil, testRepo.dir, args...) + return strings.TrimSpace(string(out)) } // RunBytes executes git and returns raw output bytes. func (testRepo *TestRepo) RunBytes(tb testing.TB, args ...string) []byte { tb.Helper() + return testRepo.runBytes(tb, nil, testRepo.dir, args...) } @@ -24,12 +26,14 @@ func (testRepo *TestRepo) RunBytes(tb testing.TB, args ...string) []byte { func (testRepo *TestRepo) RunInput(tb testing.TB, stdin []byte, args ...string) string { tb.Helper() out := testRepo.runBytes(tb, stdin, testRepo.dir, args...) + return strings.TrimSpace(string(out)) } // RunInputBytes executes git with stdin and returns raw output bytes. func (testRepo *TestRepo) RunInputBytes(tb testing.TB, stdin []byte, args ...string) []byte { tb.Helper() + return testRepo.runBytes(tb, stdin, testRepo.dir, args...) } @@ -38,13 +42,16 @@ func (testRepo *TestRepo) runBytes(tb testing.TB, stdin []byte, dir string, args //nolint:noctx cmd := exec.Command("git", args...) //#nosec G204 cmd.Dir = dir + cmd.Env = testRepo.env if stdin != nil { cmd.Stdin = bytes.NewReader(stdin) } + out, err := cmd.CombinedOutput() if err != nil { tb.Fatalf("git %v failed: %v\n%s", args, err, out) } + return out } diff --git a/internal/testgit/repo_tag_annotated.go b/internal/testgit/repo_tag_annotated.go index a3ffafa6..7e9bfbf5 100644 --- a/internal/testgit/repo_tag_annotated.go +++ b/internal/testgit/repo_tag_annotated.go @@ -11,5 +11,6 @@ import ( func (testRepo *TestRepo) TagAnnotated(tb testing.TB, name string, target objectid.ObjectID, message string) objectid.ObjectID { tb.Helper() testRepo.Run(tb, "tag", "-a", name, target.String(), "-m", message) + return testRepo.RevParse(tb, fmt.Sprintf("refs/tags/%s", name)) } diff --git a/internal/zlib/reader.go b/internal/zlib/reader.go index 5d6dcd88..e4babb9e 100644 --- a/internal/zlib/reader.go +++ b/internal/zlib/reader.go @@ -63,6 +63,7 @@ var ( var readerPool = sync.Pool{ New: func() any { r := new(reader) + return r }, } @@ -89,14 +90,17 @@ func NewReader(r io.Reader) (io.ReadCloser, error) { // If the compressed data refers to a different dictionary, NewReaderDict returns [ErrDictionary]. func NewReaderDict(r io.Reader, dict []byte) (io.ReadCloser, error) { v := readerPool.Get() + z, ok := v.(*reader) if !ok { panic("zlib: pool returned unexpected type") } + err := z.Reset(r, dict) if err != nil { return nil, err } + return z, nil } @@ -106,30 +110,40 @@ func (z *reader) Read(p []byte) (int, error) { } var n int + n, z.err = z.decompressor.Read(p) - if _, err := z.digest.Write(p[0:n]); err != nil { + + _, err := z.digest.Write(p[0:n]) + if err != nil { z.err = err + return n, z.err } + if !errors.Is(z.err, io.EOF) { // In the normal case we return here. return n, z.err } // Finished file; check checksum. - if _, err := io.ReadFull(z.r, z.scratch[0:4]); err != nil { - if err == io.EOF { + _, err = io.ReadFull(z.r, z.scratch[0:4]) + if err != nil { + if errors.Is(err, io.EOF) { err = io.ErrUnexpectedEOF } + z.err = err + return n, z.err } // ZLIB (RFC 1950) is big-endian, unlike GZIP (RFC 1952). checksum := binary.BigEndian.Uint32(z.scratch[:4]) if checksum != z.digest.Sum32() { z.err = ErrChecksum + return n, z.err } + return n, io.EOF } @@ -140,12 +154,14 @@ func (z *reader) Close() error { if z.err != nil && !errors.Is(z.err, io.EOF) { return z.err } + z.err = z.decompressor.Close() if z.err != nil { return z.err } readerPool.Put(z) + return nil } @@ -163,13 +179,17 @@ func (z *reader) Reset(r io.Reader, dict []byte) error { if errors.Is(z.err, io.EOF) { z.err = io.ErrUnexpectedEOF } + return z.err } + h := binary.BigEndian.Uint16(z.scratch[:2]) if (z.scratch[0]&0x0f != zlibDeflate) || (z.scratch[0]>>4 > zlibMaxWindow) || (h%31 != 0) { z.err = ErrHeader + return z.err } + haveDict := z.scratch[1]&0x20 != 0 if haveDict { _, z.err = io.ReadFull(z.r, z.scratch[0:4]) @@ -177,31 +197,41 @@ func (z *reader) Reset(r io.Reader, dict []byte) error { if errors.Is(z.err, io.EOF) { z.err = io.ErrUnexpectedEOF } + return z.err } + checksum := binary.BigEndian.Uint32(z.scratch[:4]) if checksum != adler32.Checksum(dict) { z.err = ErrDictionary + return z.err } } - if z.decompressor == nil { - if haveDict { - z.decompressor = flate.NewReaderDict(z.r, dict) - } else { - z.decompressor = flate.NewReader(z.r) - } - } else { + if z.decompressor != nil { resetter, ok := z.decompressor.(flate.Resetter) if !ok { panic("zlib: pooled decompressor does not implement flate.Resetter") } + z.err = resetter.Reset(z.r, dict) if z.err != nil { return z.err } + + z.digest = adler32.New() + + return nil + } + + if haveDict { + z.decompressor = flate.NewReaderDict(z.r, dict) + } else { + z.decompressor = flate.NewReader(z.r) } + z.digest = adler32.New() + return nil } diff --git a/internal/zlib/writer.go b/internal/zlib/writer.go index 75a8ec1d..bfc52889 100644 --- a/internal/zlib/writer.go +++ b/internal/zlib/writer.go @@ -52,6 +52,7 @@ var writerPool = sync.Pool{ // Writes may be buffered and not flushed until Close. func NewWriter(w io.Writer) *Writer { z, _ := NewWriterLevelDict(w, DefaultCompression, nil) + return z } @@ -74,7 +75,9 @@ func NewWriterLevelDict(w io.Writer, level int, dict []byte) (*Writer, error) { if level < HuffmanOnly || level > BestCompression { return nil, fmt.Errorf("zlib: invalid compression level: %d", level) } + v := writerPool.Get() + z, ok := v.(*Writer) if !ok { panic("zlib: pool returned unexpected type") @@ -86,6 +89,7 @@ func NewWriterLevelDict(w io.Writer, level int, dict []byte) (*Writer, error) { if !reuseCompressor { z.compressor = nil } + if z.digest != nil { z.digest.Reset() } @@ -100,6 +104,7 @@ func NewWriterLevelDict(w io.Writer, level int, dict []byte) (*Writer, error) { if z.compressor != nil { z.compressor.Reset(w) } + return z, nil } @@ -112,9 +117,11 @@ func (z *Writer) Reset(w io.Writer) { if z.compressor != nil { z.compressor.Reset(w) } + if z.digest != nil { z.digest.Reset() } + z.err = nil z.scratch = [4]byte{} z.wroteHeader = false @@ -127,21 +134,29 @@ func (z *Writer) Write(p []byte) (n int, err error) { if !z.wroteHeader { z.err = z.writeHeader() } + if z.err != nil { return 0, z.err } + if len(p) == 0 { return 0, nil } + n, err = z.compressor.Write(p) if err != nil { z.err = err + return n, err } - if _, err = z.digest.Write(p); err != nil { + + _, err = z.digest.Write(p) + if err != nil { z.err = err + return 0, z.err } + return n, err } @@ -150,10 +165,13 @@ func (z *Writer) Flush() error { if !z.wroteHeader { z.err = z.writeHeader() } + if z.err != nil { return z.err } + z.err = z.compressor.Flush() + return z.err } @@ -163,22 +181,27 @@ func (z *Writer) Close() error { if !z.wroteHeader { z.err = z.writeHeader() } + if z.err != nil { return z.err } + z.err = z.compressor.Close() if z.err != nil { return z.err } + checksum := z.digest.Sum32() // ZLIB (RFC 1950) is big-endian, unlike GZIP (RFC 1952). binary.BigEndian.PutUint32(z.scratch[:], checksum) + _, z.err = z.w.Write(z.scratch[0:4]) if z.err != nil { return z.err } writerPool.Put(z) + return nil } @@ -205,20 +228,28 @@ func (z *Writer) writeHeader() (err error) { default: panic("unreachable") } + if z.dict != nil { z.scratch[1] |= 1 << 5 } + z.scratch[1] += uint8(31 - binary.BigEndian.Uint16(z.scratch[:2])%31) //#nosec G115 - if _, err = z.w.Write(z.scratch[0:2]); err != nil { + + _, err = z.w.Write(z.scratch[0:2]) + if err != nil { return err } + if z.dict != nil { // The next four bytes are the Adler-32 checksum of the dictionary. binary.BigEndian.PutUint32(z.scratch[:], adler32.Checksum(z.dict)) - if _, err = z.w.Write(z.scratch[0:4]); err != nil { + + _, err = z.w.Write(z.scratch[0:4]) + if err != nil { return err } } + if z.compressor == nil { // Initialize deflater unless the Writer is being reused // after a Reset call. @@ -226,7 +257,9 @@ func (z *Writer) writeHeader() (err error) { if err != nil { return err } + z.digest = adler32.New() } + return nil } diff --git a/object/blob.go b/object/blob.go index 8f094405..9c507e1f 100644 --- a/object/blob.go +++ b/object/blob.go @@ -10,5 +10,6 @@ type Blob struct { // ObjectType returns TypeBlob. func (blob *Blob) ObjectType() objecttype.Type { _ = blob + return objecttype.TypeBlob } diff --git a/object/blob_parse_test.go b/object/blob_parse_test.go index 7b242ef7..1cf3990f 100644 --- a/object/blob_parse_test.go +++ b/object/blob_parse_test.go @@ -17,10 +17,12 @@ func TestBlobParseFromGit(t *testing.T) { blobID := testRepo.HashObject(t, "blob", body) rawBody := testRepo.CatFile(t, "blob", blobID) + blob, err := object.ParseBlob(rawBody) if err != nil { t.Fatalf("ParseBlob: %v", err) } + if !bytes.Equal(blob.Data, body) { t.Fatalf("blob body mismatch") } diff --git a/object/blob_serialize.go b/object/blob_serialize.go index 70354ddc..e9c0ac5e 100644 --- a/object/blob_serialize.go +++ b/object/blob_serialize.go @@ -18,12 +18,15 @@ func (blob *Blob) SerializeWithHeader() ([]byte, error) { if err != nil { return nil, err } + header, ok := objectheader.Encode(objecttype.TypeBlob, int64(len(body))) if !ok { return nil, errors.New("object: blob: failed to encode object header") } + raw := make([]byte, len(header)+len(body)) copy(raw, header) copy(raw[len(header):], body) + return raw, nil } diff --git a/object/blob_serialize_test.go b/object/blob_serialize_test.go index c49815da..69dbe849 100644 --- a/object/blob_serialize_test.go +++ b/object/blob_serialize_test.go @@ -16,10 +16,12 @@ func TestBlobSerialize(t *testing.T) { wantID := testRepo.HashObject(t, "blob", body) blob := &object.Blob{Data: body} + rawObj, err := blob.SerializeWithHeader() if err != nil { t.Fatalf("SerializeWithHeader: %v", err) } + gotID := algo.Sum(rawObj) if gotID != wantID { t.Fatalf("object id mismatch: got %s want %s", gotID, wantID) diff --git a/object/commit.go b/object/commit.go index bd48bb44..34e89033 100644 --- a/object/commit.go +++ b/object/commit.go @@ -19,5 +19,6 @@ type Commit struct { // ObjectType returns TypeCommit. func (commit *Commit) ObjectType() objecttype.Type { _ = commit + return objecttype.TypeCommit } diff --git a/object/commit_parse.go b/object/commit_parse.go index ae1b2559..31e215de 100644 --- a/object/commit_parse.go +++ b/object/commit_parse.go @@ -11,14 +11,17 @@ import ( // ParseCommit decodes a commit object body. func ParseCommit(body []byte, algo objectid.Algorithm) (*Commit, error) { c := new(Commit) + i := 0 for i < len(body) { rel := bytes.IndexByte(body[i:], '\n') if rel < 0 { return nil, errors.New("object: commit: missing newline") } + line := body[i : i+rel] i += rel + 1 + if len(line) == 0 { break } @@ -34,24 +37,28 @@ func ParseCommit(body []byte, algo objectid.Algorithm) (*Commit, error) { if err != nil { return nil, fmt.Errorf("object: commit: tree: %w", err) } + c.Tree = id case "parent": id, err := objectid.ParseHex(algo, string(value)) if err != nil { return nil, fmt.Errorf("object: commit: parent: %w", err) } + c.Parents = append(c.Parents, id) case "author": idt, err := ParseSignature(value) if err != nil { return nil, fmt.Errorf("object: commit: author: %w", err) } + c.Author = *idt case "committer": idt, err := ParseSignature(value) if err != nil { return nil, fmt.Errorf("object: commit: committer: %w", err) } + c.Committer = *idt case "change-id": c.ChangeID = string(value) @@ -61,9 +68,11 @@ func ParseCommit(body []byte, algo objectid.Algorithm) (*Commit, error) { if nextRel < 0 { return nil, errors.New("object: commit: unterminated gpgsig") } + if body[i] != ' ' { break } + i += nextRel + 1 } default: @@ -77,6 +86,8 @@ func ParseCommit(body []byte, algo objectid.Algorithm) (*Commit, error) { if i > len(body) { return nil, errors.New("object: commit: parser position out of bounds") } + c.Message = append([]byte(nil), body[i:]...) + return c, nil } diff --git a/object/commit_parse_test.go b/object/commit_parse_test.go index a29ab1fa..4dc1dea1 100644 --- a/object/commit_parse_test.go +++ b/object/commit_parse_test.go @@ -17,22 +17,28 @@ func TestCommitParseFromGit(t *testing.T) { _, treeID, commitID := testRepo.MakeCommit(t, "subject\n\nbody") rawBody := testRepo.CatFile(t, "commit", commitID) + commit, err := object.ParseCommit(rawBody, algo) if err != nil { t.Fatalf("ParseCommit: %v", err) } + if commit.Tree != treeID { t.Fatalf("tree id mismatch: got %s want %s", commit.Tree, treeID) } + if len(commit.Parents) != 0 { t.Fatalf("parent count = %d, want 0", len(commit.Parents)) } + if !bytes.Equal(commit.Author.Name, []byte("Test Author")) { t.Fatalf("author name = %q, want %q", commit.Author.Name, "Test Author") } + if !bytes.Equal(commit.Committer.Name, []byte("Test Committer")) { t.Fatalf("committer name = %q, want %q", commit.Committer.Name, "Test Committer") } + if !bytes.Contains(commit.Message, []byte("subject")) { t.Fatalf("commit message missing subject: %q", commit.Message) } @@ -61,18 +67,23 @@ func TestCommitParseMultipleParents(t *testing.T) { if err != nil { t.Fatalf("ParseCommit(merge): %v", err) } + if commit.Tree != treeID { t.Fatalf("merge tree = %s, want %s", commit.Tree, treeID) } + if len(commit.Parents) != 2 { t.Fatalf("merge parent count = %d, want 2", len(commit.Parents)) } + if commit.Parents[0] != parent1 { t.Fatalf("merge parent[0] = %s, want %s", commit.Parents[0], parent1) } + if commit.Parents[1] != parent2 { t.Fatalf("merge parent[1] = %s, want %s", commit.Parents[1], parent2) } + if !bytes.Equal(commit.Message, []byte("Merge commit\n")) { t.Fatalf("merge message = %q, want %q", commit.Message, "Merge commit\n") } diff --git a/object/commit_serialize.go b/object/commit_serialize.go index ec28aded..eef45ef4 100644 --- a/object/commit_serialize.go +++ b/object/commit_serialize.go @@ -16,7 +16,9 @@ func (commit *Commit) SerializeWithoutHeader() ([]byte, error) { if commit.Tree.Size() == 0 { return nil, errors.New("object: commit: missing tree id") } + fmt.Fprintf(&buf, "tree %s\n", commit.Tree.String()) + for _, parent := range commit.Parents { fmt.Fprintf(&buf, "parent %s\n", parent.String()) } @@ -25,6 +27,7 @@ func (commit *Commit) SerializeWithoutHeader() ([]byte, error) { if err != nil { return nil, err } + buf.WriteString("author ") buf.Write(authorBytes) buf.WriteByte('\n') @@ -33,6 +36,7 @@ func (commit *Commit) SerializeWithoutHeader() ([]byte, error) { if err != nil { return nil, err } + buf.WriteString("committer ") buf.Write(committerBytes) buf.WriteByte('\n') @@ -42,10 +46,12 @@ func (commit *Commit) SerializeWithoutHeader() ([]byte, error) { buf.WriteString(commit.ChangeID) buf.WriteByte('\n') } + for _, h := range commit.ExtraHeaders { if h.Key == "" { return nil, errors.New("object: commit: extra header has empty key") } + buf.WriteString(h.Key) buf.WriteByte(' ') buf.Write(h.Value) @@ -54,6 +60,7 @@ func (commit *Commit) SerializeWithoutHeader() ([]byte, error) { buf.WriteByte('\n') buf.Write(commit.Message) + return buf.Bytes(), nil } @@ -63,12 +70,15 @@ func (commit *Commit) SerializeWithHeader() ([]byte, error) { if err != nil { return nil, err } + header, ok := objectheader.Encode(objecttype.TypeCommit, int64(len(body))) if !ok { return nil, errors.New("object: commit: failed to encode object header") } + raw := make([]byte, len(header)+len(body)) copy(raw, header) copy(raw[len(header):], body) + return raw, nil } diff --git a/object/commit_serialize_test.go b/object/commit_serialize_test.go index 4f9856b0..70b3fc92 100644 --- a/object/commit_serialize_test.go +++ b/object/commit_serialize_test.go @@ -15,6 +15,7 @@ func TestCommitSerialize(t *testing.T) { _, _, commitID := testRepo.MakeCommit(t, "subject\n\nbody") rawBody := testRepo.CatFile(t, "commit", commitID) + commit, err := object.ParseCommit(rawBody, algo) if err != nil { t.Fatalf("ParseCommit: %v", err) @@ -24,6 +25,7 @@ func TestCommitSerialize(t *testing.T) { if err != nil { t.Fatalf("SerializeWithHeader: %v", err) } + gotID := algo.Sum(rawObj) if gotID != commitID { t.Fatalf("commit id mismatch: got %s want %s", gotID, commitID) diff --git a/object/ident.go b/object/ident.go index 1ea55cc2..049b0c01 100644 --- a/object/ident.go +++ b/object/ident.go @@ -26,10 +26,12 @@ func ParseSignature(line []byte) (*Signature, error) { if lt < 0 { return nil, errors.New("object: signature: missing opening <") } + gtRel := bytes.IndexByte(line[lt+1:], '>') if gtRel < 0 { return nil, errors.New("object: signature: missing closing >") } + gt := lt + 1 + gtRel nameBytes := append([]byte(nil), bytes.TrimRight(line[:lt], " ")...) @@ -39,11 +41,14 @@ func ParseSignature(line []byte) (*Signature, error) { if len(rest) == 0 || rest[0] != ' ' { return nil, errors.New("object: signature: missing timestamp separator") } + rest = rest[1:] + before, after, ok := bytes.Cut(rest, []byte{' '}) if !ok { return nil, errors.New("object: signature: missing timezone separator") } + when, err := strconv.ParseInt(string(before), 10, 64) if err != nil { return nil, fmt.Errorf("object: signature: invalid timestamp: %w", err) @@ -53,7 +58,9 @@ func ParseSignature(line []byte) (*Signature, error) { if len(tz) < 5 { return nil, errors.New("object: signature: invalid timezone encoding") } + sign := 1 + switch tz[0] { case '-': sign = -1 @@ -66,24 +73,31 @@ func ParseSignature(line []byte) (*Signature, error) { if err != nil { return nil, fmt.Errorf("object: signature: invalid timezone hours: %w", err) } + mm, err := strconv.Atoi(string(tz[3:5])) if err != nil { return nil, fmt.Errorf("object: signature: invalid timezone minutes: %w", err) } + if hh < 0 || hh > 23 { return nil, errors.New("object: signature: invalid timezone hours range") } + if mm < 0 || mm > 59 { return nil, errors.New("object: signature: invalid timezone minutes range") } + total := int64(hh)*60 + int64(mm) + offset, err := intconv.Int64ToInt32(total) if err != nil { return nil, errors.New("object: signature: timezone overflow") } + if sign < 0 { offset = -offset } + return &Signature{ Name: nameBytes, Email: emailBytes, @@ -104,19 +118,23 @@ func (signature Signature) Serialize() ([]byte, error) { b.WriteByte(' ') offset := signature.OffsetMinutes + sign := '+' if offset < 0 { sign = '-' offset = -offset } + hh := offset / 60 mm := offset % 60 fmt.Fprintf(&b, "%c%02d%02d", sign, hh, mm) + return []byte(b.String()), nil } // When returns a time.Time with the signature's timezone offset. func (signature Signature) When() time.Time { loc := time.FixedZone("git", int(signature.OffsetMinutes)*60) + return time.Unix(signature.WhenUnix, 0).In(loc) } diff --git a/object/tag.go b/object/tag.go index 9a621ec9..0da3e4a8 100644 --- a/object/tag.go +++ b/object/tag.go @@ -17,5 +17,6 @@ type Tag struct { // ObjectType returns TypeTag. func (tag *Tag) ObjectType() objecttype.Type { _ = tag + return objecttype.TypeTag } diff --git a/object/tag_parse.go b/object/tag_parse.go index ea194085..c2fee81a 100644 --- a/object/tag_parse.go +++ b/object/tag_parse.go @@ -13,6 +13,7 @@ import ( func ParseTag(body []byte, algo objectid.Algorithm) (*Tag, error) { t := new(Tag) i := 0 + var haveTarget, haveType bool for i < len(body) { @@ -20,8 +21,10 @@ func ParseTag(body []byte, algo objectid.Algorithm) (*Tag, error) { if rel < 0 { return nil, errors.New("object: tag: missing newline") } + line := body[i : i+rel] i += rel + 1 + if len(line) == 0 { break } @@ -37,6 +40,7 @@ func ParseTag(body []byte, algo objectid.Algorithm) (*Tag, error) { if err != nil { return nil, fmt.Errorf("object: tag: object: %w", err) } + t.Target = id haveTarget = true case "type": @@ -44,6 +48,7 @@ func ParseTag(body []byte, algo objectid.Algorithm) (*Tag, error) { if !ok { return nil, errors.New("object: tag: unknown target type") } + t.TargetType = ty haveType = true case "tag": @@ -53,6 +58,7 @@ func ParseTag(body []byte, algo objectid.Algorithm) (*Tag, error) { if err != nil { return nil, fmt.Errorf("object: tag: tagger: %w", err) } + t.Tagger = idt case "gpgsig", "gpgsig-sha256": for i < len(body) { @@ -60,9 +66,11 @@ func ParseTag(body []byte, algo objectid.Algorithm) (*Tag, error) { if nextRel < 0 { return nil, errors.New("object: tag: unterminated gpgsig") } + if body[i] != ' ' { break } + i += nextRel + 1 } default: @@ -73,6 +81,8 @@ func ParseTag(body []byte, algo objectid.Algorithm) (*Tag, error) { if !haveTarget || !haveType { return nil, errors.New("object: tag: missing required headers") } + t.Message = append([]byte(nil), body[i:]...) + return t, nil } diff --git a/object/tag_parse_test.go b/object/tag_parse_test.go index 7ddb60e9..456d2f63 100644 --- a/object/tag_parse_test.go +++ b/object/tag_parse_test.go @@ -18,22 +18,28 @@ func TestTagParseFromGit(t *testing.T) { tagID := testRepo.TagAnnotated(t, "v1", commitID, "tag message") rawBody := testRepo.CatFile(t, "tag", tagID) + tag, err := object.ParseTag(rawBody, algo) if err != nil { t.Fatalf("ParseTag: %v", err) } + if tag.Target != commitID { t.Fatalf("tag target mismatch: got %s want %s", tag.Target, commitID) } + if tag.TargetType != objecttype.TypeCommit { t.Fatalf("tag target type = %v, want %v", tag.TargetType, objecttype.TypeCommit) } + if !bytes.Equal(tag.Name, []byte("v1")) { t.Fatalf("tag name = %q, want %q", tag.Name, "v1") } + if tag.Tagger == nil { t.Fatalf("expected tagger") } + if !bytes.Contains(tag.Message, []byte("tag message")) { t.Fatalf("tag message mismatch: %q", tag.Message) } diff --git a/object/tag_serialize.go b/object/tag_serialize.go index 9ccf0bd0..1e016cdb 100644 --- a/object/tag_serialize.go +++ b/object/tag_serialize.go @@ -22,6 +22,7 @@ func (tag *Tag) SerializeWithoutHeader() ([]byte, error) { if !ok { return nil, fmt.Errorf("object: tag: invalid target type %d", tag.TargetType) } + buf.WriteString("type ") buf.WriteString(tyName) buf.WriteByte('\n') @@ -35,6 +36,7 @@ func (tag *Tag) SerializeWithoutHeader() ([]byte, error) { if err != nil { return nil, err } + buf.WriteString("tagger ") buf.Write(taggerBytes) buf.WriteByte('\n') @@ -42,6 +44,7 @@ func (tag *Tag) SerializeWithoutHeader() ([]byte, error) { buf.WriteByte('\n') buf.Write(tag.Message) + return buf.Bytes(), nil } @@ -51,12 +54,15 @@ func (tag *Tag) SerializeWithHeader() ([]byte, error) { if err != nil { return nil, err } + header, ok := objectheader.Encode(objecttype.TypeTag, int64(len(body))) if !ok { return nil, errors.New("object: tag: failed to encode object header") } + raw := make([]byte, len(header)+len(body)) copy(raw, header) copy(raw[len(header):], body) + return raw, nil } diff --git a/object/tag_serialize_test.go b/object/tag_serialize_test.go index 1b3ea2f8..e1bdbab2 100644 --- a/object/tag_serialize_test.go +++ b/object/tag_serialize_test.go @@ -16,6 +16,7 @@ func TestTagSerialize(t *testing.T) { tagID := testRepo.TagAnnotated(t, "v1", commitID, "tag message") rawBody := testRepo.CatFile(t, "tag", tagID) + tag, err := object.ParseTag(rawBody, algo) if err != nil { t.Fatalf("ParseTag: %v", err) @@ -25,6 +26,7 @@ func TestTagSerialize(t *testing.T) { if err != nil { t.Fatalf("SerializeWithHeader: %v", err) } + gotID := algo.Sum(rawObj) if gotID != tagID { t.Fatalf("tag id mismatch: got %s want %s", gotID, tagID) diff --git a/object/tree.go b/object/tree.go index 4bb459be..ad4b8f34 100644 --- a/object/tree.go +++ b/object/tree.go @@ -35,6 +35,7 @@ type Tree struct { // ObjectType returns TypeTree. func (tree *Tree) ObjectType() objecttype.Type { _ = tree + return objecttype.TypeTree } @@ -43,9 +44,11 @@ func (tree *Tree) Entry(name []byte) *TreeEntry { if len(tree.Entries) == 0 { return nil } + if e := tree.entry(name, true); e != nil { return e } + return tree.entry(name, false) } @@ -54,6 +57,7 @@ func (tree *Tree) InsertEntry(newEntry TreeEntry) error { if tree.entry(newEntry.Name, true) != nil || tree.entry(newEntry.Name, false) != nil { return fmt.Errorf("object: tree: entry %q already exists", newEntry.Name) } + newIsTree := newEntry.Mode == FileModeDir insertAt := sort.Search(len(tree.Entries), func(i int) bool { return TreeEntryNameCompare(tree.Entries[i].Name, tree.Entries[i].Mode, newEntry.Name, newIsTree) >= 0 @@ -61,6 +65,7 @@ func (tree *Tree) InsertEntry(newEntry TreeEntry) error { tree.Entries = append(tree.Entries, TreeEntry{}) copy(tree.Entries[insertAt+1:], tree.Entries[insertAt:]) tree.Entries[insertAt] = newEntry + return nil } @@ -69,13 +74,16 @@ func (tree *Tree) RemoveEntry(name []byte) error { if len(tree.Entries) == 0 { return fmt.Errorf("object: tree: entry %q not found", name) } + for i := range tree.Entries { if bytes.Equal(tree.Entries[i].Name, name) { copy(tree.Entries[i:], tree.Entries[i+1:]) tree.Entries = tree.Entries[:len(tree.Entries)-1] + return nil } } + return fmt.Errorf("object: tree: entry %q not found", name) } @@ -84,19 +92,23 @@ func (tree *Tree) entry(name []byte, searchIsTree bool) *TreeEntry { for low <= high { mid := low + (high-low)/2 entry := &tree.Entries[mid] + cmp := TreeEntryNameCompare(entry.Name, entry.Mode, name, searchIsTree) if cmp == 0 { if bytes.Equal(entry.Name, name) { return entry } + return nil } + if cmp < 0 { low = mid + 1 } else { high = mid - 1 } } + return nil } @@ -108,6 +120,7 @@ func TreeEntryNameCompare(entryName []byte, entryMode FileMode, searchName []byt if isEntryTree { entryLen++ } + searchLen := len(searchName) if searchIsTree { searchLen++ @@ -122,14 +135,17 @@ func TreeEntryNameCompare(entryName []byte, entryMode FileMode, searchName []byt } else { ec = '/' } + if i < len(searchName) { sc = searchName[i] } else { sc = '/' } + if ec < sc { return -1 } + if ec > sc { return 1 } @@ -138,8 +154,10 @@ func TreeEntryNameCompare(entryName []byte, entryMode FileMode, searchName []byt if entryLen < searchLen { return -1 } + if entryLen > searchLen { return 1 } + return 0 } diff --git a/object/tree_helpers_test.go b/object/tree_helpers_test.go index 4727e1c7..2577e0e1 100644 --- a/object/tree_helpers_test.go +++ b/object/tree_helpers_test.go @@ -15,6 +15,7 @@ func buildGitMktreeInput(entries []object.TreeEntry) string { for _, e := range entries { fmt.Fprintf(&b, "%o %s %s\t%s\n", e.Mode, mktreeTypeFromMode(e.Mode), e.ID.String(), e.Name) } + return b.String() } @@ -35,14 +36,17 @@ func gitLsTreeNames(out []byte) [][]byte { if len(out) == 0 { return nil } + parts := bytes.Split(out, []byte{0}) if len(parts) > 0 && len(parts[len(parts)-1]) == 0 { parts = parts[:len(parts)-1] } + names := make([][]byte, 0, len(parts)) for _, name := range parts { names = append(names, append([]byte(nil), name...)) } + return names } diff --git a/object/tree_parse.go b/object/tree_parse.go index 37a2fa4b..dd4faa8b 100644 --- a/object/tree_parse.go +++ b/object/tree_parse.go @@ -11,12 +11,14 @@ import ( // ParseTree decodes a tree object body. func ParseTree(body []byte, algo objectid.Algorithm) (*Tree, error) { var entries []TreeEntry + i := 0 for i < len(body) { space := bytes.IndexByte(body[i:], ' ') if space < 0 { return nil, fmt.Errorf("object: tree: missing mode terminator") } + modeBytes := body[i : i+space] i += space + 1 @@ -24,6 +26,7 @@ func ParseTree(body []byte, algo objectid.Algorithm) (*Tree, error) { if nul < 0 { return nil, fmt.Errorf("object: tree: missing name terminator") } + nameBytes := body[i : i+nul] i += nul + 1 @@ -31,10 +34,12 @@ func ParseTree(body []byte, algo objectid.Algorithm) (*Tree, error) { if idEnd > len(body) { return nil, fmt.Errorf("object: tree: truncated child object id") } + id, err := objectid.FromBytes(algo, body[i:idEnd]) if err != nil { return nil, err } + i = idEnd mode, err := strconv.ParseUint(string(modeBytes), 8, 32) diff --git a/object/tree_parse_test.go b/object/tree_parse_test.go index 989d6ff1..d4b7c1e6 100644 --- a/object/tree_parse_test.go +++ b/object/tree_parse_test.go @@ -14,9 +14,11 @@ func TestTreeParseFromGit(t *testing.T) { testgit.ForEachAlgorithm(t, func(t *testing.T, algo objectid.Algorithm) { //nolint:thelper testRepo := testgit.NewRepo(t, testgit.RepoOptions{ObjectFormat: algo, Bare: true}) entries := adversarialRootEntries(t, testRepo) + inserted := &object.Tree{} for _, entry := range entries { - if err := inserted.InsertEntry(entry); err != nil { + err := inserted.InsertEntry(entry) + if err != nil { t.Fatalf("InsertEntry(%q): %v", entry.Name, err) } } @@ -24,16 +26,19 @@ func TestTreeParseFromGit(t *testing.T) { treeID := testRepo.Mktree(t, buildGitMktreeInput(inserted.Entries)) rawBody := testRepo.CatFile(t, "tree", treeID) + tree, err := object.ParseTree(rawBody, algo) if err != nil { t.Fatalf("ParseTree: %v", err) } + if len(tree.Entries) != len(inserted.Entries) { t.Fatalf("entry count = %d, want %d", len(tree.Entries), len(inserted.Entries)) } for i := range inserted.Entries { got := tree.Entries[i] + want := inserted.Entries[i] if got.Mode != want.Mode || got.ID != want.ID || !bytes.Equal(got.Name, want.Name) { t.Fatalf("entry[%d] mismatch: got (%o,%q,%s) want (%o,%q,%s)", @@ -45,6 +50,7 @@ func TestTreeParseFromGit(t *testing.T) { if len(lsNames) != len(tree.Entries) { t.Fatalf("ls-tree names = %d, want %d", len(lsNames), len(tree.Entries)) } + for i := range lsNames { if !bytes.Equal(lsNames[i], tree.Entries[i].Name) { t.Fatalf("ordering mismatch at %d: git=%q parsed=%q", i, lsNames[i], tree.Entries[i].Name) @@ -62,6 +68,7 @@ func TestTreeParseFromGit(t *testing.T) { t.Fatalf("Entry(%q) mismatch", want.Name) } } + if tree.Entry([]byte("does-not-exist")) != nil { t.Fatalf("Entry on missing name should be nil") } diff --git a/object/tree_serialize.go b/object/tree_serialize.go index 5c10bef6..42f60f72 100644 --- a/object/tree_serialize.go +++ b/object/tree_serialize.go @@ -11,6 +11,7 @@ import ( // SerializeWithoutHeader renders the raw tree body bytes. func (tree *Tree) SerializeWithoutHeader() ([]byte, error) { var bodyLen int + for _, entry := range tree.Entries { mode := strconv.FormatUint(uint64(entry.Mode), 8) bodyLen += len(mode) + 1 + len(entry.Name) + 1 + entry.ID.Size() @@ -18,6 +19,7 @@ func (tree *Tree) SerializeWithoutHeader() ([]byte, error) { body := make([]byte, bodyLen) pos := 0 + for _, entry := range tree.Entries { mode := strconv.FormatUint(uint64(entry.Mode), 8) pos += copy(body[pos:], mode) @@ -39,12 +41,15 @@ func (tree *Tree) SerializeWithHeader() ([]byte, error) { if err != nil { return nil, err } + header, ok := objectheader.Encode(objecttype.TypeTree, int64(len(body))) if !ok { return nil, errors.New("object: tree: failed to encode object header") } + raw := make([]byte, len(header)+len(body)) copy(raw, header) copy(raw[len(header):], body) + return raw, nil } diff --git a/object/tree_serialize_test.go b/object/tree_serialize_test.go index e8ebb140..c038ad58 100644 --- a/object/tree_serialize_test.go +++ b/object/tree_serialize_test.go @@ -16,32 +16,44 @@ func TestTreeSerialize(t *testing.T) { tree := &object.Tree{} for i := len(entries) - 1; i >= 0; i-- { - if err := tree.InsertEntry(entries[i]); err != nil { + err := tree.InsertEntry(entries[i]) + if err != nil { t.Fatalf("InsertEntry(%q): %v", entries[i].Name, err) } } + if len(tree.Entries) < 32 { t.Fatalf("expected at least 32 entries, got %d", len(tree.Entries)) } dup := tree.Entries[0] - if err := tree.InsertEntry(dup); err == nil { + + err := tree.InsertEntry(dup) + if err == nil { t.Fatalf("duplicate InsertEntry should fail") } removed := tree.Entries[len(tree.Entries)/2] - if err := tree.RemoveEntry(removed.Name); err != nil { + + err = tree.RemoveEntry(removed.Name) + if err != nil { t.Fatalf("RemoveEntry(%q): %v", removed.Name, err) } + if tree.Entry(removed.Name) != nil { t.Fatalf("Entry(%q) should be nil after remove", removed.Name) } - if err := tree.RemoveEntry([]byte("no-such-entry")); err == nil { + + err = tree.RemoveEntry([]byte("no-such-entry")) + if err == nil { t.Fatalf("RemoveEntry missing entry should fail") } - if err := tree.InsertEntry(removed); err != nil { + + err = tree.InsertEntry(removed) + if err != nil { t.Fatalf("re-InsertEntry(%q): %v", removed.Name, err) } + if tree.Entry(removed.Name) == nil { t.Fatalf("Entry(%q) should exist after reinsert", removed.Name) } @@ -52,6 +64,7 @@ func TestTreeSerialize(t *testing.T) { if err != nil { t.Fatalf("SerializeWithHeader: %v", err) } + gotTreeID := algo.Sum(rawObj) if gotTreeID != wantTreeID { t.Fatalf("tree id mismatch: got %s want %s", gotTreeID, wantTreeID) diff --git a/objectheader/append.go b/objectheader/append.go index 3965dc21..bfccf388 100644 --- a/objectheader/append.go +++ b/objectheader/append.go @@ -11,6 +11,7 @@ func Append(dst []byte, ty objecttype.Type, size int64) ([]byte, bool) { if size < 0 { return nil, false } + tyName, ok := objecttype.Name(ty) if !ok { return nil, false @@ -23,5 +24,6 @@ func Append(dst []byte, ty objecttype.Type, size int64) ([]byte, bool) { out = append(out, ' ') out = append(out, sizeStr...) out = append(out, 0) + return out, true } diff --git a/objectheader/parse.go b/objectheader/parse.go index 72d91d3e..677dffdb 100644 --- a/objectheader/parse.go +++ b/objectheader/parse.go @@ -21,6 +21,7 @@ func Parse(data []byte) (objecttype.Type, int64, int, bool) { if nulRel < 0 { return objecttype.TypeInvalid, 0, 0, false } + nul := space + 1 + nulRel ty, ok := objecttype.ParseName(string(data[:space])) @@ -32,6 +33,7 @@ func Parse(data []byte) (objecttype.Type, int64, int, bool) { if len(sizeBytes) == 0 { return objecttype.TypeInvalid, 0, 0, false } + size, err := strconv.ParseInt(string(sizeBytes), 10, 64) if err != nil || size < 0 { return objecttype.TypeInvalid, 0, 0, false diff --git a/objectid/objectid.go b/objectid/objectid.go index 7ce011f3..c1ebfb2c 100644 --- a/objectid/objectid.go +++ b/objectid/objectid.go @@ -43,9 +43,11 @@ var algorithmTable = [...]algorithmDetails{ size: sha1.Size, sum: func(data []byte) ObjectID { sum := sha1.Sum(data) //#nosec G401 + var id ObjectID copy(id.data[:], sum[:]) id.algo = AlgorithmSHA1 + return id }, new: sha1.New, @@ -55,9 +57,11 @@ var algorithmTable = [...]algorithmDetails{ size: sha256.Size, sum: func(data []byte) ObjectID { sum := sha256.Sum256(data) + var id ObjectID copy(id.data[:], sum[:]) id.algo = AlgorithmSHA256 + return id }, new: sha256.New, @@ -69,12 +73,13 @@ var ( supportedAlgorithms []Algorithm ) -func init() { +func init() { //nolint:gochecknoinits for algo := Algorithm(0); int(algo) < len(algorithmTable); algo++ { info := algorithmTable[algo] if info.name == "" { continue } + algorithmByName[info.name] = algo supportedAlgorithms = append(supportedAlgorithms, algo) } @@ -89,6 +94,7 @@ func SupportedAlgorithms() []Algorithm { // ParseAlgorithm parses a canonical algorithm name (e.g. "sha1", "sha256"). func ParseAlgorithm(s string) (Algorithm, bool) { algo, ok := algorithmByName[s] + return algo, ok } @@ -103,6 +109,7 @@ func (algo Algorithm) String() string { if inf.name == "" { return "unknown" } + return inf.name } @@ -122,6 +129,7 @@ func (algo Algorithm) New() (hash.Hash, error) { if newFn == nil { return nil, ErrInvalidAlgorithm } + return newFn(), nil } @@ -150,12 +158,14 @@ func (id ObjectID) Size() int { // String returns the canonical hex representation. func (id ObjectID) String() string { size := id.Size() + return hex.EncodeToString(id.data[:size]) } // Bytes returns a copy of the object ID bytes. func (id ObjectID) Bytes() []byte { size := id.Size() + return append([]byte(nil), id.data[:size]...) } @@ -167,6 +177,7 @@ func (id ObjectID) Bytes() []byte { // Use Bytes when an independent copy is required. func (id *ObjectID) RawBytes() []byte { size := id.Size() + return id.data[:size:size] } @@ -176,18 +187,23 @@ func ParseHex(algo Algorithm, s string) (ObjectID, error) { if algo.Size() == 0 { return id, ErrInvalidAlgorithm } + if len(s)%2 != 0 { return id, fmt.Errorf("%w: odd hex length %d", ErrInvalidObjectID, len(s)) } + if len(s) != algo.HexLen() { return id, fmt.Errorf("%w: got %d chars, expected %d", ErrInvalidObjectID, len(s), algo.HexLen()) } + decoded, err := hex.DecodeString(s) if err != nil { return id, fmt.Errorf("%w: decode: %w", ErrInvalidObjectID, err) } + copy(id.data[:], decoded) id.algo = algo + return id, nil } @@ -197,10 +213,13 @@ func FromBytes(algo Algorithm, b []byte) (ObjectID, error) { if algo.Size() == 0 { return id, ErrInvalidAlgorithm } + if len(b) != algo.Size() { return id, fmt.Errorf("%w: got %d bytes, expected %d", ErrInvalidObjectID, len(b), algo.Size()) } + copy(id.data[:], b) id.algo = algo + return id, nil } diff --git a/objectid/objectid_test.go b/objectid/objectid_test.go index ef191d39..1c5f337a 100644 --- a/objectid/objectid_test.go +++ b/objectid/objectid_test.go @@ -48,13 +48,16 @@ func TestParseHexRoundtrip(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() + id, err := objectid.ParseHex(tt.algo, tt.hex) if err != nil { t.Fatalf("ParseHex failed: %v", err) } + if got := id.String(); got != tt.hex { t.Fatalf("String() = %q, want %q", got, tt.hex) } + if got := id.Size(); got != tt.algo.Size() { t.Fatalf("Size() = %d, want %d", got, tt.algo.Size()) } @@ -68,6 +71,7 @@ func TestParseHexRoundtrip(t *testing.T) { if err != nil { t.Fatalf("FromBytes failed: %v", err) } + if id2.String() != tt.hex { t.Fatalf("FromBytes roundtrip = %q, want %q", id2.String(), tt.hex) } @@ -92,7 +96,9 @@ func TestParseHexErrors(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - if _, err := objectid.ParseHex(tt.algo, tt.hex); err == nil { + + _, err := objectid.ParseHex(tt.algo, tt.hex) + if err == nil { t.Fatalf("expected ParseHex error") } }) @@ -102,10 +108,13 @@ func TestParseHexErrors(t *testing.T) { func TestFromBytesErrors(t *testing.T) { t.Parallel() - if _, err := objectid.FromBytes(objectid.AlgorithmUnknown, []byte{1, 2}); err == nil { + _, err := objectid.FromBytes(objectid.AlgorithmUnknown, []byte{1, 2}) + if err == nil { t.Fatalf("expected FromBytes unknown algo error") } - if _, err := objectid.FromBytes(objectid.AlgorithmSHA1, []byte{1, 2}); err == nil { + + _, err = objectid.FromBytes(objectid.AlgorithmSHA1, []byte{1, 2}) + if err == nil { t.Fatalf("expected FromBytes wrong size error") } } @@ -119,10 +128,12 @@ func TestBytesReturnsCopy(t *testing.T) { } b1 := id.Bytes() + b2 := id.Bytes() if !bytes.Equal(b1, b2) { t.Fatalf("Bytes mismatch") } + b1[0] ^= 0xff if bytes.Equal(b1, b2) { t.Fatalf("Bytes should return independent copies") @@ -141,12 +152,14 @@ func TestRawBytesAliasesStorage(t *testing.T) { if len(b) != id.Size() { t.Fatalf("RawBytes len = %d, want %d", len(b), id.Size()) } + if cap(b) != len(b) { t.Fatalf("RawBytes cap = %d, want %d", cap(b), len(b)) } orig := id.String() b[0] ^= 0xff + if id.String() == orig { t.Fatalf("RawBytes should alias object ID storage") } diff --git a/objectstore/chain/chain.go b/objectstore/chain/chain.go index f2992b34..8e10feb6 100644 --- a/objectstore/chain/chain.go +++ b/objectstore/chain/chain.go @@ -25,13 +25,17 @@ type Chain struct { // New creates a Chain from backends. func New(backends ...objectstore.Store) *Chain { nodeByStore := make(map[objectstore.Store]*backendNode, len(backends)) - var head *backendNode - var tail *backendNode + + var ( + head *backendNode + tail *backendNode + ) for _, backend := range backends { if backend == nil { continue } + node := &backendNode{ backend: backend, prev: tail, @@ -39,9 +43,11 @@ func New(backends ...objectstore.Store) *Chain { if tail != nil { tail.next = node } + if head == nil { head = node } + tail = node nodeByStore[backend] = node } @@ -59,13 +65,17 @@ func (chain *Chain) ReadBytesFull(id objectid.ObjectID) ([]byte, error) { full, err := backend.ReadBytesFull(id) if err == nil { chain.touchBackend(backend) + return full, nil } + if errors.Is(err, objectstore.ErrObjectNotFound) { continue } + return nil, fmt.Errorf("objectstore: backend %d read bytes full: %w", i, err) } + return nil, objectstore.ErrObjectNotFound } @@ -76,13 +86,17 @@ func (chain *Chain) ReadBytesContent(id objectid.ObjectID) (objecttype.Type, []b ty, content, err := backend.ReadBytesContent(id) if err == nil { chain.touchBackend(backend) + return ty, content, nil } + if errors.Is(err, objectstore.ErrObjectNotFound) { continue } + return objecttype.TypeInvalid, nil, fmt.Errorf("objectstore: backend %d read bytes content: %w", i, err) } + return objecttype.TypeInvalid, nil, objectstore.ErrObjectNotFound } @@ -93,13 +107,17 @@ func (chain *Chain) ReadReaderFull(id objectid.ObjectID) (io.ReadCloser, error) reader, err := backend.ReadReaderFull(id) if err == nil { chain.touchBackend(backend) + return reader, nil } + if errors.Is(err, objectstore.ErrObjectNotFound) { continue } + return nil, fmt.Errorf("objectstore: backend %d read reader full: %w", i, err) } + return nil, objectstore.ErrObjectNotFound } @@ -110,13 +128,17 @@ func (chain *Chain) ReadReaderContent(id objectid.ObjectID) (objecttype.Type, in ty, size, reader, err := backend.ReadReaderContent(id) if err == nil { chain.touchBackend(backend) + return ty, size, reader, nil } + if errors.Is(err, objectstore.ErrObjectNotFound) { continue } + return objecttype.TypeInvalid, 0, nil, fmt.Errorf("objectstore: backend %d read reader content: %w", i, err) } + return objecttype.TypeInvalid, 0, nil, objectstore.ErrObjectNotFound } @@ -126,13 +148,17 @@ func (chain *Chain) ReadSize(id objectid.ObjectID) (int64, error) { size, err := backend.ReadSize(id) if err == nil { chain.touchBackend(backend) + return size, nil } + if errors.Is(err, objectstore.ErrObjectNotFound) { continue } + return 0, fmt.Errorf("objectstore: backend %d read size: %w", i, err) } + return 0, objectstore.ErrObjectNotFound } @@ -142,31 +168,40 @@ func (chain *Chain) ReadHeader(id objectid.ObjectID) (objecttype.Type, int64, er ty, size, err := backend.ReadHeader(id) if err == nil { chain.touchBackend(backend) + return ty, size, nil } + if errors.Is(err, objectstore.ErrObjectNotFound) { continue } + return objecttype.TypeInvalid, 0, fmt.Errorf("objectstore: backend %d read header: %w", i, err) } + return objecttype.TypeInvalid, 0, objectstore.ErrObjectNotFound } // Close closes all backends and joins close errors. func (chain *Chain) Close() error { chain.mu.RLock() + backends := make([]objectstore.Store, 0, len(chain.backendNodeByStore)) for node := chain.backendHead; node != nil; node = node.next { backends = append(backends, node.backend) } + chain.mu.RUnlock() var errs []error + for _, backend := range backends { - if err := backend.Close(); err != nil { + err := backend.Close() + if err != nil { errs = append(errs, err) } } + return errors.Join(errs...) } @@ -179,19 +214,23 @@ type backendNode struct { func (chain *Chain) firstBackend() objectstore.Store { chain.mu.RLock() defer chain.mu.RUnlock() + if chain.backendHead == nil { return nil } + return chain.backendHead.backend } func (chain *Chain) nextBackend(current objectstore.Store) objectstore.Store { chain.mu.RLock() defer chain.mu.RUnlock() + node := chain.backendNodeByStore[current] if node == nil || node.next == nil { return nil } + return node.next.backend } @@ -199,6 +238,7 @@ func (chain *Chain) touchBackend(backend objectstore.Store) { if backend == nil { return } + if !chain.mu.TryLock() { return } @@ -208,21 +248,26 @@ func (chain *Chain) touchBackend(backend objectstore.Store) { if node == nil || node == chain.backendHead { return } + if node.prev != nil { node.prev.next = node.next } + if node.next != nil { node.next.prev = node.prev } + if chain.backendTail == node { chain.backendTail = node.prev } node.prev = nil + node.next = chain.backendHead if chain.backendHead != nil { chain.backendHead.prev = node } + chain.backendHead = node if chain.backendTail == nil { chain.backendTail = node diff --git a/objectstore/loose/helpers_test.go b/objectstore/loose/helpers_test.go index 972059e0..4b0bb60e 100644 --- a/objectstore/loose/helpers_test.go +++ b/objectstore/loose/helpers_test.go @@ -15,30 +15,39 @@ import ( func openLooseStore(t *testing.T, repoPath string, algo objectid.Algorithm) *loose.Store { t.Helper() + objectsPath := filepath.Join(repoPath, "objects") + root, err := os.OpenRoot(objectsPath) if err != nil { t.Fatalf("OpenRoot(%q): %v", objectsPath, err) } + t.Cleanup(func() { _ = root.Close() }) store, err := loose.New(root, algo) if err != nil { t.Fatalf("loose.New: %v", err) } + return store } func mustReadAllAndClose(t *testing.T, reader io.ReadCloser) []byte { t.Helper() + data, err := io.ReadAll(reader) if err != nil { _ = reader.Close() + t.Fatalf("ReadAll: %v", err) } - if err := reader.Close(); err != nil { + + err = reader.Close() + if err != nil { t.Fatalf("Close: %v", err) } + return data } @@ -46,11 +55,14 @@ func expectedRawObject(t *testing.T, testRepo *testgit.TestRepo, id objectid.Obj t.Helper() typeName := testRepo.Run(t, "cat-file", "-t", id.String()) + ty, ok := objecttype.ParseName(typeName) if !ok { t.Fatalf("ParseName(%q) failed", typeName) } + body := testRepo.CatFile(t, typeName, id) + header, ok := objectheader.Encode(ty, int64(len(body))) if !ok { t.Fatalf("objectheader.Encode failed") @@ -59,5 +71,6 @@ func expectedRawObject(t *testing.T, testRepo *testgit.TestRepo, id objectid.Obj raw := make([]byte, len(header)+len(body)) copy(raw, header) copy(raw[len(header):], body) + return ty, body, raw } diff --git a/objectstore/loose/parse.go b/objectstore/loose/parse.go index 54bb2375..e88d7c6c 100644 --- a/objectstore/loose/parse.go +++ b/objectstore/loose/parse.go @@ -17,7 +17,9 @@ func decodeAll(file *os.File) ([]byte, error) { if err != nil { return nil, err } + defer func() { _ = zr.Close() }() + return io.ReadAll(zr) } @@ -27,10 +29,12 @@ func parseRaw(raw []byte) (objecttype.Type, []byte, error) { if !ok { return objecttype.TypeInvalid, nil, errors.New("objectstore/loose: malformed object header") } + content := raw[headerLen:] if int64(len(content)) != size { return objecttype.TypeInvalid, nil, errors.New("objectstore/loose: object header size/content mismatch") } + return ty, content, nil } @@ -41,9 +45,11 @@ func readHeader(br *bufio.Reader) ([]byte, objecttype.Type, int64, error) { if err != nil { return nil, objecttype.TypeInvalid, 0, err } + ty, size, _, ok := objectheader.Parse(header) if !ok { return nil, objecttype.TypeInvalid, 0, errors.New("objectstore/loose: malformed object header") } + return header, ty, size, nil } diff --git a/objectstore/loose/paths.go b/objectstore/loose/paths.go index 04730bd3..e8020d72 100644 --- a/objectstore/loose/paths.go +++ b/objectstore/loose/paths.go @@ -16,7 +16,9 @@ func (store *Store) objectPath(id objectid.ObjectID) (string, error) { if id.Algorithm() != store.algo { return "", fmt.Errorf("objectstore/loose: object id algorithm mismatch: got %s want %s", id.Algorithm(), store.algo) } + hex := id.String() + return filepath.Join(hex[:2], hex[2:]), nil } @@ -27,12 +29,15 @@ func (store *Store) openObject(id objectid.ObjectID) (*os.File, error) { if err != nil { return nil, err } + file, err := store.root.Open(relPath) if err != nil { if errors.Is(err, fs.ErrNotExist) { return nil, objectstore.ErrObjectNotFound } + return nil, err } + return file, nil } diff --git a/objectstore/loose/read_bytes.go b/objectstore/loose/read_bytes.go index 2f7c24bc..78e1009e 100644 --- a/objectstore/loose/read_bytes.go +++ b/objectstore/loose/read_bytes.go @@ -12,16 +12,19 @@ func (store *Store) readBytesParsed(id objectid.ObjectID) ([]byte, objecttype.Ty if err != nil { return nil, objecttype.TypeInvalid, nil, err } + defer func() { _ = file.Close() }() raw, err := decodeAll(file) if err != nil { return nil, objecttype.TypeInvalid, nil, err } + ty, content, err := parseRaw(raw) if err != nil { return nil, objecttype.TypeInvalid, nil, err } + return raw, ty, content, nil } @@ -31,6 +34,7 @@ func (store *Store) ReadBytesFull(id objectid.ObjectID) ([]byte, error) { if err != nil { return nil, err } + return raw, nil } @@ -40,5 +44,6 @@ func (store *Store) ReadBytesContent(id objectid.ObjectID) (objecttype.Type, []b if err != nil { return objecttype.TypeInvalid, nil, err } + return ty, content, nil } diff --git a/objectstore/loose/read_header.go b/objectstore/loose/read_header.go index ce76600e..abfb1a02 100644 --- a/objectstore/loose/read_header.go +++ b/objectstore/loose/read_header.go @@ -14,17 +14,20 @@ func (store *Store) ReadHeader(id objectid.ObjectID) (objecttype.Type, int64, er if err != nil { return objecttype.TypeInvalid, 0, err } + defer func() { _ = file.Close() }() zr, err := zlib.NewReader(file) if err != nil { return objecttype.TypeInvalid, 0, err } + defer func() { _ = zr.Close() }() _, ty, size, err := readHeader(bufio.NewReader(zr)) if err != nil { return objecttype.TypeInvalid, 0, err } + return ty, size, nil } diff --git a/objectstore/loose/read_reader.go b/objectstore/loose/read_reader.go index 6a377ba3..a0a51cc1 100644 --- a/objectstore/loose/read_reader.go +++ b/objectstore/loose/read_reader.go @@ -29,6 +29,7 @@ func (reader *objectReader) Read(dst []byte) (int, error) { func (reader *objectReader) Close() error { errZlib := reader.zr.Close() errFile := reader.file.Close() + return errors.Join(errZlib, errFile) } @@ -39,11 +40,14 @@ func (store *Store) openInflated(id objectid.ObjectID) (*os.File, io.ReadCloser, if err != nil { return nil, nil, err } + zr, err := zlib.NewReader(file) if err != nil { _ = file.Close() + return nil, nil, err } + return file, zr, nil } @@ -56,10 +60,12 @@ func (store *Store) ReadReaderFull(id objectid.ObjectID) (io.ReadCloser, error) } br := bufio.NewReader(zr) + header, _, size, err := readHeader(br) if err != nil { _ = zr.Close() _ = file.Close() + return nil, err } @@ -82,10 +88,12 @@ func (store *Store) ReadReaderContent(id objectid.ObjectID) (objecttype.Type, in } br := bufio.NewReader(zr) + _, ty, size, err := readHeader(br) if err != nil { _ = zr.Close() _ = file.Close() + return objecttype.TypeInvalid, 0, nil, err } diff --git a/objectstore/loose/read_size.go b/objectstore/loose/read_size.go index 45f1f0fe..2a1eaec9 100644 --- a/objectstore/loose/read_size.go +++ b/objectstore/loose/read_size.go @@ -5,5 +5,6 @@ import "codeberg.org/lindenii/furgit/objectid" // ReadSize reads an object's declared content length. func (store *Store) ReadSize(id objectid.ObjectID) (int64, error) { _, size, err := store.ReadHeader(id) + return size, err } diff --git a/objectstore/loose/read_test.go b/objectstore/loose/read_test.go index d8166c9e..1efc1682 100644 --- a/objectstore/loose/read_test.go +++ b/objectstore/loose/read_test.go @@ -41,6 +41,7 @@ func TestLooseStoreReadAgainstGit(t *testing.T) { if err != nil { t.Fatalf("ReadBytesFull: %v", err) } + if !bytes.Equal(gotRaw, wantRaw) { t.Fatalf("ReadBytesFull mismatch") } @@ -49,9 +50,11 @@ func TestLooseStoreReadAgainstGit(t *testing.T) { if err != nil { t.Fatalf("ReadBytesContent: %v", err) } + if gotType != wantType { t.Fatalf("ReadBytesContent type = %v, want %v", gotType, wantType) } + if !bytes.Equal(gotBody, wantBody) { t.Fatalf("ReadBytesContent body mismatch") } @@ -60,9 +63,11 @@ func TestLooseStoreReadAgainstGit(t *testing.T) { if err != nil { t.Fatalf("ReadHeader: %v", err) } + if headType != wantType { t.Fatalf("ReadHeader type = %v, want %v", headType, wantType) } + if headSize != int64(len(wantBody)) { t.Fatalf("ReadHeader size = %d, want %d", headSize, len(wantBody)) } @@ -71,7 +76,9 @@ func TestLooseStoreReadAgainstGit(t *testing.T) { if err != nil { t.Fatalf("ReadReaderFull: %v", err) } - if got := mustReadAllAndClose(t, fullReader); !bytes.Equal(got, wantRaw) { + + got := mustReadAllAndClose(t, fullReader) + if !bytes.Equal(got, wantRaw) { t.Fatalf("ReadReaderFull stream mismatch") } @@ -79,13 +86,17 @@ func TestLooseStoreReadAgainstGit(t *testing.T) { if err != nil { t.Fatalf("ReadReaderContent: %v", err) } + if contentType != wantType { t.Fatalf("ReadReaderContent type = %v, want %v", contentType, wantType) } + if contentSize != int64(len(wantBody)) { t.Fatalf("ReadReaderContent size = %d, want %d", contentSize, len(wantBody)) } - if got := mustReadAllAndClose(t, contentReader); !bytes.Equal(got, wantBody) { + + got = mustReadAllAndClose(t, contentReader) + if !bytes.Equal(got, wantBody) { t.Fatalf("ReadReaderContent stream mismatch") } }) @@ -104,19 +115,28 @@ func TestLooseStoreErrors(t *testing.T) { t.Fatalf("ParseHex(notFoundID): %v", err) } - if _, err := store.ReadBytesFull(notFoundID); !errors.Is(err, objectstore.ErrObjectNotFound) { + _, err = store.ReadBytesFull(notFoundID) + if !errors.Is(err, objectstore.ErrObjectNotFound) { t.Fatalf("ReadBytesFull not-found error = %v", err) } - if _, _, err := store.ReadBytesContent(notFoundID); !errors.Is(err, objectstore.ErrObjectNotFound) { + + _, _, err = store.ReadBytesContent(notFoundID) + if !errors.Is(err, objectstore.ErrObjectNotFound) { t.Fatalf("ReadBytesContent not-found error = %v", err) } - if _, err := store.ReadReaderFull(notFoundID); !errors.Is(err, objectstore.ErrObjectNotFound) { + + _, err = store.ReadReaderFull(notFoundID) + if !errors.Is(err, objectstore.ErrObjectNotFound) { t.Fatalf("ReadReaderFull not-found error = %v", err) } - if _, _, _, err := store.ReadReaderContent(notFoundID); !errors.Is(err, objectstore.ErrObjectNotFound) { + + _, _, _, err = store.ReadReaderContent(notFoundID) + if !errors.Is(err, objectstore.ErrObjectNotFound) { t.Fatalf("ReadReaderContent not-found error = %v", err) } - if _, _, err := store.ReadHeader(notFoundID); !errors.Is(err, objectstore.ErrObjectNotFound) { + + _, _, err = store.ReadHeader(notFoundID) + if !errors.Is(err, objectstore.ErrObjectNotFound) { t.Fatalf("ReadHeader not-found error = %v", err) } @@ -126,12 +146,14 @@ func TestLooseStoreErrors(t *testing.T) { } else { otherAlgo = objectid.AlgorithmSHA1 } + otherID, err := objectid.ParseHex(otherAlgo, strings.Repeat("1", otherAlgo.HexLen())) if err != nil { t.Fatalf("ParseHex(otherID): %v", err) } - if _, err := store.ReadBytesFull(otherID); err == nil || !strings.Contains(err.Error(), "algorithm mismatch") { + _, err = store.ReadBytesFull(otherID) + if err == nil || !strings.Contains(err.Error(), "algorithm mismatch") { t.Fatalf("ReadBytesFull algorithm-mismatch error = %v", err) } }) @@ -139,13 +161,16 @@ func TestLooseStoreErrors(t *testing.T) { func TestLooseStoreNewValidation(t *testing.T) { t.Parallel() + root, err := os.OpenRoot(t.TempDir()) if err != nil { t.Fatalf("OpenRoot: %v", err) } + defer func() { _ = root.Close() }() - if _, err := loose.New(root, objectid.AlgorithmUnknown); err == nil { + _, err = loose.New(root, objectid.AlgorithmUnknown) + if err == nil { t.Fatalf("loose.New(root, unknown) expected error") } } diff --git a/objectstore/loose/store.go b/objectstore/loose/store.go index 05459a6c..c3ae989c 100644 --- a/objectstore/loose/store.go +++ b/objectstore/loose/store.go @@ -24,6 +24,7 @@ func New(root *os.Root, algo objectid.Algorithm) (*Store, error) { if algo.Size() == 0 { return nil, objectid.ErrInvalidAlgorithm } + return &Store{ root: root, algo: algo, diff --git a/objectstore/loose/write_reader.go b/objectstore/loose/write_reader.go index b2329f02..9dbf3818 100644 --- a/objectstore/loose/write_reader.go +++ b/objectstore/loose/write_reader.go @@ -27,12 +27,15 @@ func (store *Store) WriteReaderContent(ty objecttype.Type, size int64, src io.Re if err != nil { return objectid.ObjectID{}, err } + writer.headerDone = true writer.expectedContentLeft = size - if err := writer.writeRawChunk(header); err != nil { + err = writer.writeRawChunk(header) + if err != nil { _ = writer.Close() _ = store.root.Remove(writer.tmpRelPath) + return objectid.ObjectID{}, err } @@ -46,25 +49,33 @@ func (store *Store) WriteReaderFull(src io.Reader) (objectid.ObjectID, error) { if err != nil { return objectid.ObjectID{}, err } + return writeReaderIntoStreamWriter(writer, src) } // writeReaderIntoStreamWriter copies src into writer and publishes the object. func writeReaderIntoStreamWriter(writer *streamWriter, src io.Reader) (objectid.ObjectID, error) { - if _, err := io.Copy(writer, src); err != nil { + _, err := io.Copy(writer, src) + if err != nil { _ = writer.Close() _ = writer.store.root.Remove(writer.tmpRelPath) + return objectid.ObjectID{}, err } - if err := writer.Close(); err != nil { + + err = writer.Close() + if err != nil { _ = writer.store.root.Remove(writer.tmpRelPath) + return objectid.ObjectID{}, err } id, err := writer.finalize() if err != nil { _ = writer.store.root.Remove(writer.tmpRelPath) + return objectid.ObjectID{}, err } + return id, nil } diff --git a/objectstore/loose/write_test.go b/objectstore/loose/write_test.go index cceabe5a..5604c5b0 100644 --- a/objectstore/loose/write_test.go +++ b/objectstore/loose/write_test.go @@ -18,6 +18,7 @@ func TestLooseStoreWriteReaderContentAgainstGit(t *testing.T) { content := []byte("written-by-content-reader\n") expectedHex := testRepo.RunInput(t, content, "hash-object", "-t", "blob", "--stdin") + expectedID, err := objectid.ParseHex(algo, expectedHex) if err != nil { t.Fatalf("ParseHex(expected): %v", err) @@ -27,6 +28,7 @@ func TestLooseStoreWriteReaderContentAgainstGit(t *testing.T) { if err != nil { t.Fatalf("WriteReaderContent: %v", err) } + if writtenID != expectedID { t.Fatalf("WriteReaderContent id = %s, want %s", writtenID, expectedID) } @@ -41,6 +43,7 @@ func TestLooseStoreWriteReaderContentAgainstGit(t *testing.T) { if err != nil { t.Fatalf("WriteReaderContent second: %v", err) } + if writtenID2 != expectedID { t.Fatalf("WriteReaderContent second id = %s, want %s", writtenID2, expectedID) } @@ -54,19 +57,23 @@ func TestLooseStoreWriteReaderFullAgainstGit(t *testing.T) { store := openLooseStore(t, testRepo.Dir(), algo) body := []byte("full-reader-body\n") + header, ok := objectheader.Encode(objecttype.TypeBlob, int64(len(body))) if !ok { t.Fatalf("objectheader.Encode failed") } + raw := make([]byte, len(header)+len(body)) copy(raw, header) copy(raw[len(header):], body) wantID := algo.Sum(raw) + gotID, err := store.WriteReaderFull(bytes.NewReader(raw)) if err != nil { t.Fatalf("WriteReaderFull: %v", err) } + if gotID != wantID { t.Fatalf("WriteReaderFull id = %s, want %s", gotID, wantID) } @@ -86,7 +93,8 @@ func TestLooseStoreReaderValidationErrors(t *testing.T) { testRepo := testgit.NewRepo(t, testgit.RepoOptions{ObjectFormat: algo, Bare: true}) store := openLooseStore(t, testRepo.Dir(), algo) - if _, err := store.WriteReaderContent(objecttype.TypeBlob, 1, bytes.NewReader([]byte("hello"))); err == nil { + _, err := store.WriteReaderContent(objecttype.TypeBlob, 1, bytes.NewReader([]byte("hello"))) + if err == nil { t.Fatalf("expected error after overflow") } }) @@ -96,7 +104,8 @@ func TestLooseStoreReaderValidationErrors(t *testing.T) { testRepo := testgit.NewRepo(t, testgit.RepoOptions{ObjectFormat: algo, Bare: true}) store := openLooseStore(t, testRepo.Dir(), algo) - if _, err := store.WriteReaderContent(objecttype.TypeBlob, 5, bytes.NewReader([]byte("x"))); err == nil { + _, err := store.WriteReaderContent(objecttype.TypeBlob, 5, bytes.NewReader([]byte("x"))) + if err == nil { t.Fatalf("expected error for short content") } }) @@ -106,7 +115,8 @@ func TestLooseStoreReaderValidationErrors(t *testing.T) { testRepo := testgit.NewRepo(t, testgit.RepoOptions{ObjectFormat: algo, Bare: true}) store := openLooseStore(t, testRepo.Dir(), algo) - if _, err := store.WriteReaderFull(bytes.NewReader([]byte("not-a-header"))); err == nil { + _, err := store.WriteReaderFull(bytes.NewReader([]byte("not-a-header"))) + if err == nil { t.Fatalf("expected error for malformed header") } }) @@ -117,7 +127,9 @@ func TestLooseStoreReaderValidationErrors(t *testing.T) { store := openLooseStore(t, testRepo.Dir(), algo) raw := []byte("blob 1\x00hello") - if _, err := store.WriteReaderFull(bytes.NewReader(raw)); err == nil { + + _, err := store.WriteReaderFull(bytes.NewReader(raw)) + if err == nil { t.Fatalf("expected error after mismatch") } }) diff --git a/objectstore/loose/write_writer.go b/objectstore/loose/write_writer.go index c075f2ba..a0f24f2b 100644 --- a/objectstore/loose/write_writer.go +++ b/objectstore/loose/write_writer.go @@ -76,23 +76,28 @@ func (writer *streamWriter) Write(src []byte) (int, error) { if writer.finalized { return 0, errors.New("objectstore/loose: write after finalize") } + if writer.closed { return 0, errors.New("objectstore/loose: write after close") } if writer.fullMode { - if err := writer.acceptFull(src); err != nil { + err := writer.acceptFull(src) + if err != nil { return 0, err } } else { - if err := writer.acceptContent(int64(len(src))); err != nil { + err := writer.acceptContent(int64(len(src))) + if err != nil { return 0, err } } - if err := writer.writeRawChunk(src); err != nil { + err := writer.writeRawChunk(src) + if err != nil { return 0, err } + return len(src), nil } @@ -102,12 +107,14 @@ func (writer *streamWriter) Close() error { if writer.closed { return nil } + writer.closed = true errZlib := writer.zw.Close() errSync := writer.file.Sync() errFile := writer.file.Close() writer.file = nil + return errors.Join(errZlib, errSync, errFile) } @@ -118,84 +125,107 @@ func (writer *streamWriter) finalize() (objectid.ObjectID, error) { if writer.finalized { return writer.finalID, writer.finalErr } + writer.finalized = true var zero objectid.ObjectID if !writer.closed { - if err := writer.Close(); err != nil { + err := writer.Close() + if err != nil { writer.finalErr = err + return zero, err } } if writer.fullMode && !writer.headerDone { writer.finalErr = errors.New("objectstore/loose: missing full object header") + return zero, writer.finalErr } + if writer.expectedContentLeft != 0 { writer.finalErr = errors.New("objectstore/loose: object content shorter than declared size") + return zero, writer.finalErr } idBytes := writer.hash.Sum(nil) + id, err := objectid.FromBytes(writer.store.algo, idBytes) if err != nil { writer.finalErr = err + return zero, err } relPath, err := writer.store.objectPath(id) if err != nil { writer.finalErr = err + return zero, err } dir := filepath.Dir(relPath) - if err := writer.store.root.MkdirAll(dir, 0o755); err != nil { + + err = writer.store.root.MkdirAll(dir, 0o755) + if err != nil { writer.finalErr = err + return zero, err } cleanup := true + defer func() { if cleanup { _ = writer.store.root.Remove(writer.tmpRelPath) } }() - if err := writer.store.root.Link(writer.tmpRelPath, relPath); err != nil { + err = writer.store.root.Link(writer.tmpRelPath, relPath) + if err != nil { if errors.Is(err, fs.ErrExist) { writer.finalID = id cleanup = false _ = writer.store.root.Remove(writer.tmpRelPath) + return id, nil } + writer.finalErr = err + return zero, err } writer.finalID = id cleanup = false + return id, nil } // acceptFull validates and accounts raw full-object input. func (writer *streamWriter) acceptFull(src []byte) error { if !writer.headerDone { - if nul := bytes.IndexByte(src, 0); nul >= 0 { + nul := bytes.IndexByte(src, 0) + if nul >= 0 { headerChunkLen := nul + 1 writer.headerBuf = append(writer.headerBuf, src[:headerChunkLen]...) + _, size, _, ok := objectheader.Parse(writer.headerBuf) if !ok { return errors.New("objectstore/loose: malformed object header") } + writer.headerDone = true writer.expectedContentLeft = size + return writer.acceptContent(int64(len(src) - headerChunkLen)) } writer.headerBuf = append(writer.headerBuf, src...) + return nil } @@ -207,18 +237,24 @@ func (writer *streamWriter) acceptContent(n int64) error { if n > writer.expectedContentLeft { return errors.New("objectstore/loose: object content exceeds declared size") } + writer.expectedContentLeft -= n + return nil } // writeRawChunk forwards raw bytes to the hash and deflate pipeline. func (writer *streamWriter) writeRawChunk(src []byte) error { - if _, err := writer.hash.Write(src); err != nil { + _, err := writer.hash.Write(src) + if err != nil { return err } - if _, err := writer.zw.Write(src); err != nil { + + _, err = writer.zw.Write(src) + if err != nil { return err } + return nil } @@ -227,13 +263,16 @@ func (writer *streamWriter) writeRawChunk(src []byte) error { func (store *Store) createTempObjectFile(dir string) (string, *os.File, error) { for range 16 { relPath := filepath.Join(dir, tempObjectFilePrefix+rand.Text()) + file, err := store.root.OpenFile(relPath, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0o644) if err == nil { return relPath, file, nil } + if errors.Is(err, fs.ErrExist) { continue } + return "", nil, err } diff --git a/objectstore/packed/delta_apply.go b/objectstore/packed/delta_apply.go index 5245e0ba..71f09ead 100644 --- a/objectstore/packed/delta_apply.go +++ b/objectstore/packed/delta_apply.go @@ -14,10 +14,12 @@ func (store *Store) deltaResolveContent(start location) (objecttype.Type, []byte if err != nil { return objecttype.TypeInvalid, nil, err } + pack, meta, err := store.entryMetaAt(start) if err != nil { return objecttype.TypeInvalid, nil, err } + declaredSize := meta.size if !packfmt.IsBaseObjectType(meta.ty) { declaredSize, err = deltaDeclaredSizeAt(pack, meta.dataOffset) @@ -25,6 +27,7 @@ func (store *Store) deltaResolveContent(start location) (objecttype.Type, []byte return objecttype.TypeInvalid, nil, err } } + return store.deltaResolveChain(chain, declaredSize) } @@ -37,18 +40,22 @@ func (store *Store) deltaResolveChain(chain deltaChain, declaredSize int64) (obj for i := nextDelta; i >= 0; i-- { node := chain.deltas[i] + pack, err := store.openPack(node.loc.packName) if err != nil { return objecttype.TypeInvalid, nil, err } + delta, err := inflateAt(pack, node.dataOffset, -1) if err != nil { return objecttype.TypeInvalid, nil, err } + out, err = deltaapply.Apply(out, delta) if err != nil { return objecttype.TypeInvalid, nil, err } + store.cacheMu.Lock() store.deltaCache.add( deltaBaseKey{packName: node.loc.packName, offset: node.loc.offset}, @@ -65,6 +72,7 @@ func (store *Store) deltaResolveChain(chain deltaChain, declaredSize int64) (obj declaredSize, ) } + if ty != chain.baseType { return objecttype.TypeInvalid, nil, fmt.Errorf( "objectstore/packed: resolved content type mismatch: got %d want %d", @@ -72,6 +80,7 @@ func (store *Store) deltaResolveChain(chain deltaChain, declaredSize int64) (obj chain.baseType, ) } + return ty, out, nil } @@ -85,6 +94,7 @@ func (store *Store) deltaResolveChainStart(chain deltaChain) (objecttype.Type, [ deltaBaseKey{packName: node.loc.packName, offset: node.loc.offset}, ) store.cacheMu.RUnlock() + if ok { return ty, out, i - 1, nil } @@ -95,6 +105,7 @@ func (store *Store) deltaResolveChainStart(chain deltaChain) (objecttype.Type, [ deltaBaseKey{packName: chain.baseLoc.packName, offset: chain.baseLoc.offset}, ) store.cacheMu.RUnlock() + if ok { return ty, out, len(chain.deltas) - 1, nil } @@ -103,9 +114,11 @@ func (store *Store) deltaResolveChainStart(chain deltaChain) (objecttype.Type, [ if err != nil { return objecttype.TypeInvalid, nil, 0, err } + if !packfmt.IsBaseObjectType(meta.ty) { return objecttype.TypeInvalid, nil, 0, fmt.Errorf("objectstore/packed: delta chain base is not a base object") } + base, err := inflateAt(pack, meta.dataOffset, meta.size) if err != nil { return objecttype.TypeInvalid, nil, 0, err diff --git a/objectstore/packed/delta_cache.go b/objectstore/packed/delta_cache.go index add21698..a911b254 100644 --- a/objectstore/packed/delta_cache.go +++ b/objectstore/packed/delta_cache.go @@ -41,6 +41,7 @@ func (cache *deltaCache) get(key deltaBaseKey) (objecttype.Type, []byte, bool) { if !ok { return objecttype.TypeInvalid, nil, false } + return value.ty, append([]byte(nil), value.content...), true } diff --git a/objectstore/packed/delta_plan.go b/objectstore/packed/delta_plan.go index 5f2ae959..b0b0324c 100644 --- a/objectstore/packed/delta_plan.go +++ b/objectstore/packed/delta_plan.go @@ -38,6 +38,7 @@ func (store *Store) deltaBuildChain(start location) (deltaChain, error) { if _, ok := visited[current]; ok { return deltaChain{}, fmt.Errorf("objectstore/packed: delta cycle while resolving object") } + visited[current] = struct{}{} _, meta, err := store.entryMetaAt(current) @@ -48,6 +49,7 @@ func (store *Store) deltaBuildChain(start location) (deltaChain, error) { if packfmt.IsBaseObjectType(meta.ty) { chain.baseLoc = current chain.baseType = meta.ty + return chain, nil } @@ -57,10 +59,12 @@ func (store *Store) deltaBuildChain(start location) (deltaChain, error) { loc: current, dataOffset: meta.dataOffset, }) + next, err := store.lookup(meta.baseRefID) if err != nil { return deltaChain{}, err } + current = next case objecttype.TypeOfsDelta: chain.deltas = append(chain.deltas, deltaNode{ @@ -88,12 +92,15 @@ func deltaDeclaredSizeAt(pack *packFile, dataOffset int) (int64, error) { if err != nil { return 0, err } + defer func() { _ = reader.Close() }() br := bufio.NewReaderSize(reader, 32) + _, size, err := deltaapply.ReadHeaderSizes(br) if err != nil { return 0, err } + return int64(size), nil } diff --git a/objectstore/packed/entry_inflate.go b/objectstore/packed/entry_inflate.go index 4f91710e..cbdb6a89 100644 --- a/objectstore/packed/entry_inflate.go +++ b/objectstore/packed/entry_inflate.go @@ -14,6 +14,7 @@ func zlibReaderAt(pack *packFile, offset int) (io.ReadCloser, error) { if offset < 0 || offset > len(pack.data) { return nil, fmt.Errorf("objectstore/packed: pack %q zlib offset out of bounds", pack.name) } + return zlib.NewReader(bytes.NewReader(pack.data[offset:])) } @@ -23,6 +24,7 @@ func inflateAt(pack *packFile, offset int, expectedSize int64) ([]byte, error) { if err != nil { return nil, err } + defer func() { _ = reader.Close() }() if expectedSize >= 0 { @@ -35,9 +37,12 @@ func inflateAt(pack *packFile, offset int, expectedSize int64) ([]byte, error) { } body := make([]byte, int(expectedSize)) - if _, err := io.ReadFull(reader, body); err != nil { + + _, err := io.ReadFull(reader, body) + if err != nil { return nil, err } + return body, nil } @@ -45,5 +50,6 @@ func inflateAt(pack *packFile, offset int, expectedSize int64) ([]byte, error) { if err != nil { return nil, err } + return body, nil } diff --git a/objectstore/packed/entry_parse.go b/objectstore/packed/entry_parse.go index 56287386..7af20af1 100644 --- a/objectstore/packed/entry_parse.go +++ b/objectstore/packed/entry_parse.go @@ -34,6 +34,7 @@ func parseEntryMeta(pack *packFile, algo objectid.Algorithm, offset uint64) (ent if err != nil { return zero, fmt.Errorf("objectstore/packed: pack %q offset conversion: %w", pack.name, err) } + entry, err := packfmt.ParseEntry(pack.data[pos:], algo.Size()) if err != nil { return zero, fmt.Errorf("objectstore/packed: pack %q: %w", pack.name, err) @@ -50,11 +51,13 @@ func parseEntryMeta(pack *packFile, algo objectid.Algorithm, offset uint64) (ent if err != nil { return zero, fmt.Errorf("objectstore/packed: pack %q invalid ref-delta base id: %w", pack.name, err) } + meta.baseRefID = baseID case objecttype.TypeOfsDelta: if offset <= entry.OfsBaseDistance { return zero, fmt.Errorf("objectstore/packed: pack %q has invalid ofs-delta base", pack.name) } + meta.baseOfs = offset - entry.OfsBaseDistance case objecttype.TypeCommit, objecttype.TypeTree, objecttype.TypeBlob, objecttype.TypeTag: // Base object types do not have delta base metadata. @@ -63,5 +66,6 @@ func parseEntryMeta(pack *packFile, algo objectid.Algorithm, offset uint64) (ent default: return zero, fmt.Errorf("objectstore/packed: pack %q has unsupported entry type %d", pack.name, meta.ty) } + return meta, nil } diff --git a/objectstore/packed/helpers_test.go b/objectstore/packed/helpers_test.go index f8cbd439..1b517294 100644 --- a/objectstore/packed/helpers_test.go +++ b/objectstore/packed/helpers_test.go @@ -18,30 +18,39 @@ import ( func openPackedStore(t *testing.T, repoPath string, algo objectid.Algorithm) *packed.Store { t.Helper() + packPath := filepath.Join(repoPath, "objects", "pack") + root, err := os.OpenRoot(packPath) if err != nil { t.Fatalf("OpenRoot(%q): %v", packPath, err) } + t.Cleanup(func() { _ = root.Close() }) store, err := packed.New(root, algo) if err != nil { t.Fatalf("packed.New: %v", err) } + return store } func mustReadAllAndClose(t *testing.T, reader io.ReadCloser) []byte { t.Helper() + data, err := io.ReadAll(reader) if err != nil { _ = reader.Close() + t.Fatalf("ReadAll: %v", err) } - if err := reader.Close(); err != nil { + + err = reader.Close() + if err != nil { t.Fatalf("Close: %v", err) } + return data } @@ -49,11 +58,14 @@ func expectedRawObject(t *testing.T, testRepo *testgit.TestRepo, id objectid.Obj t.Helper() typeName := testRepo.Run(t, "cat-file", "-t", id.String()) + ty, ok := objecttype.ParseName(typeName) if !ok { t.Fatalf("ParseName(%q) failed", typeName) } + body := testRepo.CatFile(t, typeName, id) + header, ok := objectheader.Encode(ty, int64(len(body))) if !ok { t.Fatalf("objectheader.Encode failed") @@ -62,6 +74,7 @@ func expectedRawObject(t *testing.T, testRepo *testgit.TestRepo, id objectid.Obj raw := make([]byte, len(header)+len(body)) copy(raw, header) copy(raw[len(header):], body) + return ty, body, raw } @@ -74,6 +87,7 @@ func createPackedFixtureRepo(t *testing.T, algo objectid.Algorithm) (*testgit.Te tagID := testRepo.TagAnnotated(t, "v1.0.0", commitID, "packed-store-tag") parent := commitID + for i := range 24 { content := "common-prefix\n" + strings.Repeat("line-"+strconv.Itoa(i%3)+"\n", 256) + fmt.Sprintf("tail-%d\n", i) nextBlob, nextTree := testRepo.MakeSingleFileTree(t, fmt.Sprintf("file-%02d.txt", i), []byte(content)) @@ -86,6 +100,7 @@ func createPackedFixtureRepo(t *testing.T, algo objectid.Algorithm) (*testgit.Te } testRepo.Repack(t, "-a", "-d", "-f", "--window=64", "--depth=64") + return testRepo, []objectid.ObjectID{ blobID, treeID, diff --git a/objectstore/packed/idx_lookup_candidates.go b/objectstore/packed/idx_lookup_candidates.go index 83055aac..72121b25 100644 --- a/objectstore/packed/idx_lookup_candidates.go +++ b/objectstore/packed/idx_lookup_candidates.go @@ -37,8 +37,11 @@ func (store *Store) ensureCandidates() error { candidateByPack := make(map[string]packCandidate, len(candidates)) nodeByPack := make(map[string]*packCandidateNode, len(candidates)) - var head *packCandidateNode - var tail *packCandidateNode + var ( + head *packCandidateNode + tail *packCandidateNode + ) + for _, candidate := range candidates { node := &packCandidateNode{ candidate: candidate, @@ -47,9 +50,11 @@ func (store *Store) ensureCandidates() error { if tail != nil { tail.next = node } + if head == nil { head = node } + tail = node candidateByPack[candidate.packName] = candidate nodeByPack[candidate.packName] = node @@ -67,6 +72,7 @@ func (store *Store) ensureCandidates() error { store.candidatesMu.RLock() err := store.discoverErr store.candidatesMu.RUnlock() + return err } @@ -78,8 +84,10 @@ func (store *Store) discoverCandidates() ([]packCandidate, error) { if os.IsNotExist(err) { return nil, nil } + return nil, err } + defer func() { _ = dir.Close() }() entries, err := dir.ReadDir(-1) @@ -95,11 +103,13 @@ func (store *Store) discoverCandidates() ([]packCandidate, error) { idxName := entry.Name() packName := strings.TrimSuffix(idxName, ".idx") + ".pack" + packInfo, err := store.root.Stat(packName) if err != nil { if os.IsNotExist(err) { return nil, fmt.Errorf("objectstore/packed: missing pack file for index %q", idxName) } + return nil, err } @@ -115,8 +125,10 @@ func (store *Store) discoverCandidates() ([]packCandidate, error) { if a.mtime > b.mtime { return -1 } + return 1 } + return strings.Compare(a.packName, b.packName) }) @@ -139,18 +151,22 @@ func (store *Store) touchCandidate(packName string) { if node.prev != nil { node.prev.next = node.next } + if node.next != nil { node.next.prev = node.prev } + if store.candidateTail == node { store.candidateTail = node.prev } node.prev = nil + node.next = store.candidateHead if store.candidateHead != nil { store.candidateHead.prev = node } + store.candidateHead = node if store.candidateTail == nil { store.candidateTail = node @@ -162,9 +178,11 @@ func (store *Store) touchCandidate(packName string) { func (store *Store) firstCandidatePackName() string { store.candidatesMu.RLock() defer store.candidatesMu.RUnlock() + if store.candidateHead == nil { return "" } + return store.candidateHead.candidate.packName } @@ -173,9 +191,11 @@ func (store *Store) firstCandidatePackName() string { func (store *Store) nextCandidatePackName(currentPack string) string { store.candidatesMu.RLock() defer store.candidatesMu.RUnlock() + node := store.candidateNodeByPack[currentPack] if node == nil || node.next == nil { return "" } + return node.next.candidate.packName } diff --git a/objectstore/packed/idx_open.go b/objectstore/packed/idx_open.go index c00a7bac..c3c97e4d 100644 --- a/objectstore/packed/idx_open.go +++ b/objectstore/packed/idx_open.go @@ -43,16 +43,21 @@ func (store *Store) candidateForPack(packName string) (packCandidate, bool) { store.candidatesMu.RLock() candidate, ok := store.candidateByPack[packName] store.candidatesMu.RUnlock() + return candidate, ok } // openIndex returns one opened and parsed index, caching it by pack basename. func (store *Store) openIndex(candidate packCandidate) (*idxFile, error) { store.idxMu.RLock() - if index, ok := store.idxByPack[candidate.packName]; ok { + + index, ok := store.idxByPack[candidate.packName] + if ok { store.idxMu.RUnlock() + return index, nil } + store.idxMu.RUnlock() index, err := openIdxFile(store.root, candidate.idxName, candidate.packName, store.algo) @@ -61,13 +66,19 @@ func (store *Store) openIndex(candidate packCandidate) (*idxFile, error) { } store.idxMu.Lock() - if existing, ok := store.idxByPack[candidate.packName]; ok { + + existing, ok := store.idxByPack[candidate.packName] + if ok { store.idxMu.Unlock() + _ = index.close() + return existing, nil } + store.idxByPack[candidate.packName] = index store.idxMu.Unlock() + return index, nil } @@ -77,24 +88,32 @@ func openIdxFile(root *os.Root, idxName, packName string, algo objectid.Algorith if err != nil { return nil, err } + info, err := file.Stat() if err != nil { _ = file.Close() + return nil, err } + size := info.Size() if size < 0 || size > int64(int(^uint(0)>>1)) { _ = file.Close() + return nil, fmt.Errorf("objectstore/packed: idx %q has unsupported size", idxName) } + fd, err := intconv.UintptrToInt(file.Fd()) if err != nil { _ = file.Close() + return nil, err } + data, err := syscall.Mmap(fd, 0, int(size), syscall.PROT_READ, syscall.MAP_PRIVATE) if err != nil { _ = file.Close() + return nil, err } @@ -105,27 +124,38 @@ func openIdxFile(root *os.Root, idxName, packName string, algo objectid.Algorith file: file, data: data, } - if err := index.parse(); err != nil { + + err = index.parse() + if err != nil { _ = index.close() + return nil, err } + return index, nil } // close unmaps and closes one idx handle. func (index *idxFile) close() error { var closeErr error + if index.data != nil { - if err := syscall.Munmap(index.data); err != nil && closeErr == nil { + err := syscall.Munmap(index.data) + if err != nil && closeErr == nil { closeErr = err } + index.data = nil } + if index.file != nil { - if err := index.file.Close(); err != nil && closeErr == nil { + err := index.file.Close() + if err != nil && closeErr == nil { closeErr = err } + index.file = nil } + return closeErr } diff --git a/objectstore/packed/idx_parse.go b/objectstore/packed/idx_parse.go index 0af72594..870ffdae 100644 --- a/objectstore/packed/idx_parse.go +++ b/objectstore/packed/idx_parse.go @@ -19,27 +19,34 @@ func (index *idxFile) parse() error { if hashSize <= 0 { return fmt.Errorf("objectstore/packed: idx %q has invalid hash algorithm", index.idxName) } + minLen := 8 + 256*4 + 2*hashSize if len(index.data) < minLen { return fmt.Errorf("objectstore/packed: idx %q too short", index.idxName) } + if binary.BigEndian.Uint32(index.data[:4]) != idxMagicV2 { return fmt.Errorf("objectstore/packed: idx %q invalid magic", index.idxName) } + if binary.BigEndian.Uint32(index.data[4:8]) != idxVersionV2 { return fmt.Errorf("objectstore/packed: idx %q unsupported version", index.idxName) } prev := uint32(0) + for i := range 256 { base := 8 + i*4 + cur := binary.BigEndian.Uint32(index.data[base : base+4]) if cur < prev { return fmt.Errorf("objectstore/packed: idx %q has non-monotonic fanout table", index.idxName) } + index.fanout[i] = cur prev = cur } + index.numObjects = int(index.fanout[255]) if index.numObjects < 0 { return fmt.Errorf("objectstore/packed: idx %q has invalid object count", index.idxName) @@ -48,6 +55,7 @@ func (index *idxFile) parse() error { namesBytes := index.numObjects * hashSize crcBytes := index.numObjects * 4 offset32Bytes := index.numObjects * 4 + minSize := 8 + 256*4 + namesBytes + crcBytes + offset32Bytes + 2*hashSize if minSize < 0 || len(index.data) < minSize { return fmt.Errorf("objectstore/packed: idx %q has truncated tables", index.idxName) @@ -61,11 +69,14 @@ func (index *idxFile) parse() error { if offset64Bytes < 0 || offset64Bytes%8 != 0 { return fmt.Errorf("objectstore/packed: idx %q has malformed 64-bit offset table", index.idxName) } + index.offset64Count = offset64Bytes / 8 + maxOffset64Count := max(index.numObjects-1, 0) if index.offset64Count > maxOffset64Count { return fmt.Errorf("objectstore/packed: idx %q has oversized 64-bit offset table", index.idxName) } + return nil } @@ -74,17 +85,21 @@ func (index *idxFile) lookup(id objectid.ObjectID) (uint64, bool, error) { if id.Algorithm() != index.algo { return 0, false, fmt.Errorf("objectstore/packed: object id algorithm mismatch") } + idBytes := (&id).RawBytes() + hashSize := len(idBytes) if hashSize != index.algo.Size() { return 0, false, fmt.Errorf("objectstore/packed: unexpected object id length") } first := int(idBytes[0]) + lo := 0 if first > 0 { lo = int(index.fanout[first-1]) } + hi := int(index.fanout[first]) if lo < 0 || hi < 0 || lo > hi || hi > index.numObjects { return 0, false, fmt.Errorf("objectstore/packed: idx %q has invalid fanout bounds", index.idxName) @@ -92,24 +107,29 @@ func (index *idxFile) lookup(id objectid.ObjectID) (uint64, bool, error) { for lo < hi { mid := lo + (hi-lo)/2 + nameOffset := index.namesOffset + mid*hashSize if nameOffset < 0 || nameOffset+hashSize > len(index.data) { return 0, false, fmt.Errorf("objectstore/packed: idx %q truncated name table", index.idxName) } + cmp := bytes.Compare(index.data[nameOffset:nameOffset+hashSize], idBytes) if cmp == 0 { offset, err := index.offsetAt(mid) if err != nil { return 0, false, err } + return offset, true, nil } + if cmp < 0 { lo = mid + 1 } else { hi = mid } } + return 0, false, nil } @@ -118,10 +138,12 @@ func (index *idxFile) offsetAt(objectIndex int) (uint64, error) { if objectIndex < 0 || objectIndex >= index.numObjects { return 0, fmt.Errorf("objectstore/packed: idx %q offset index out of bounds", index.idxName) } + wordOffset := index.offset32Offset + objectIndex*4 if wordOffset < 0 || wordOffset+4 > len(index.data) { return 0, fmt.Errorf("objectstore/packed: idx %q truncated 32-bit offset table", index.idxName) } + word := binary.BigEndian.Uint32(index.data[wordOffset : wordOffset+4]) if word&0x80000000 == 0 { return uint64(word), nil @@ -131,9 +153,11 @@ func (index *idxFile) offsetAt(objectIndex int) (uint64, error) { if pos < 0 || pos >= index.offset64Count { return 0, fmt.Errorf("objectstore/packed: idx %q invalid 64-bit offset position", index.idxName) } + offOffset := index.offset64Offset + pos*8 if offOffset < 0 || offOffset+8 > len(index.data)-2*index.algo.Size() { return 0, fmt.Errorf("objectstore/packed: idx %q truncated 64-bit offset table", index.idxName) } + return binary.BigEndian.Uint64(index.data[offOffset : offOffset+8]), nil } diff --git a/objectstore/packed/pack.go b/objectstore/packed/pack.go index 9af4c860..874b2b76 100644 --- a/objectstore/packed/pack.go +++ b/objectstore/packed/pack.go @@ -25,43 +25,58 @@ func openPackFile(name string, file *os.File, size int64) (*packFile, error) { if size < 12 { return nil, fmt.Errorf("objectstore/packed: pack %q too short", name) } + if size > int64(int(^uint(0)>>1)) { return nil, fmt.Errorf("objectstore/packed: pack %q has unsupported size", name) } + fd, err := intconv.UintptrToInt(file.Fd()) if err != nil { return nil, err } + data, err := syscall.Mmap(fd, 0, int(size), syscall.PROT_READ, syscall.MAP_PRIVATE) if err != nil { return nil, err } + if binary.BigEndian.Uint32(data[:4]) != packfmt.Signature { _ = syscall.Munmap(data) + return nil, fmt.Errorf("objectstore/packed: pack %q invalid signature", name) } + version := binary.BigEndian.Uint32(data[4:8]) if !packfmt.VersionSupported(version) { _ = syscall.Munmap(data) + return nil, fmt.Errorf("objectstore/packed: pack %q unsupported version %d", name, version) } + return &packFile{name: name, file: file, data: data}, nil } // close unmaps and closes one pack handle. func (pack *packFile) close() error { var closeErr error + if pack.data != nil { - if err := syscall.Munmap(pack.data); err != nil && closeErr == nil { + err := syscall.Munmap(pack.data) + if err != nil && closeErr == nil { closeErr = err } + pack.data = nil } + if pack.file != nil { - if err := pack.file.Close(); err != nil && closeErr == nil { + err := pack.file.Close() + if err != nil && closeErr == nil { closeErr = err } + pack.file = nil } + return closeErr } diff --git a/objectstore/packed/pack_idx_checksum.go b/objectstore/packed/pack_idx_checksum.go index 2f55a469..25556088 100644 --- a/objectstore/packed/pack_idx_checksum.go +++ b/objectstore/packed/pack_idx_checksum.go @@ -14,17 +14,21 @@ func verifyMappedPackMatchesMappedIdx(packData, idxData []byte, algo objectid.Al if hashSize <= 0 { return objectid.ErrInvalidAlgorithm } + if len(packData) < hashSize { return fmt.Errorf("objectstore/packed: pack too short for trailer hash") } + if len(idxData) < hashSize*2 { return fmt.Errorf("objectstore/packed: idx too short for trailer hashes") } packTrailerHash := packData[len(packData)-hashSize:] + idxPackHash := idxData[len(idxData)-hashSize*2 : len(idxData)-hashSize] if !bytes.Equal(packTrailerHash, idxPackHash) { return fmt.Errorf("objectstore/packed: pack hash does not match idx") } + return nil } diff --git a/objectstore/packed/read_bytes.go b/objectstore/packed/read_bytes.go index b6f42a0d..e272b626 100644 --- a/objectstore/packed/read_bytes.go +++ b/objectstore/packed/read_bytes.go @@ -14,6 +14,7 @@ func (store *Store) ReadBytesContent(id objectid.ObjectID) (objecttype.Type, []b if err != nil { return objecttype.TypeInvalid, nil, err } + return store.deltaResolveContent(loc) } @@ -23,12 +24,15 @@ func (store *Store) ReadBytesFull(id objectid.ObjectID) ([]byte, error) { if err != nil { return nil, err } + header, ok := objectheader.Encode(ty, int64(len(content))) if !ok { return nil, fmt.Errorf("objectstore/packed: failed to encode object header for type %d", ty) } + out := make([]byte, len(header)+len(content)) copy(out, header) copy(out[len(header):], content) + return out, nil } diff --git a/objectstore/packed/read_header.go b/objectstore/packed/read_header.go index 6822975c..5eb37c92 100644 --- a/objectstore/packed/read_header.go +++ b/objectstore/packed/read_header.go @@ -11,5 +11,6 @@ func (store *Store) ReadHeader(id objectid.ObjectID) (objecttype.Type, int64, er if err != nil { return objecttype.TypeInvalid, 0, err } + return store.resolveHeaderAt(loc) } diff --git a/objectstore/packed/read_header_resolve.go b/objectstore/packed/read_header_resolve.go index cf49fe2b..420d9363 100644 --- a/objectstore/packed/read_header_resolve.go +++ b/objectstore/packed/read_header_resolve.go @@ -17,12 +17,14 @@ func (store *Store) resolveHeaderAt(start location) (objecttype.Type, int64, err if _, ok := visited[current]; ok { return objecttype.TypeInvalid, 0, fmt.Errorf("objectstore/packed: delta cycle while resolving object header") } + visited[current] = struct{}{} pack, meta, err := store.entryMetaAt(current) if err != nil { return objecttype.TypeInvalid, 0, err } + if declaredSize < 0 { if packfmt.IsBaseObjectType(meta.ty) { declaredSize = meta.size @@ -31,9 +33,11 @@ func (store *Store) resolveHeaderAt(start location) (objecttype.Type, int64, err if err != nil { return objecttype.TypeInvalid, 0, err } + declaredSize = size } } + if packfmt.IsBaseObjectType(meta.ty) { return meta.ty, declaredSize, nil } @@ -44,6 +48,7 @@ func (store *Store) resolveHeaderAt(start location) (objecttype.Type, int64, err if err != nil { return objecttype.TypeInvalid, 0, err } + current = next case objecttype.TypeOfsDelta: current = location{ diff --git a/objectstore/packed/read_reader.go b/objectstore/packed/read_reader.go index a1f24799..d8dfdca9 100644 --- a/objectstore/packed/read_reader.go +++ b/objectstore/packed/read_reader.go @@ -41,11 +41,13 @@ func (store *Store) ReadReaderContent(id objectid.ObjectID) (objecttype.Type, in if err != nil { return objecttype.TypeInvalid, 0, nil, err } + if packfmt.IsBaseObjectType(meta.ty) { zr, err := zlibReaderAt(pack, meta.dataOffset) if err != nil { return objecttype.TypeInvalid, 0, nil, err } + return meta.ty, meta.size, &readCloser{ reader: iolimit.ExpectLengthReader(zr, meta.size), closer: zr, @@ -56,6 +58,7 @@ func (store *Store) ReadReaderContent(id objectid.ObjectID) (objecttype.Type, in if err != nil { return objecttype.TypeInvalid, 0, nil, err } + return ty, int64(len(content)), io.NopCloser(bytes.NewReader(content)), nil } @@ -72,15 +75,18 @@ func (store *Store) ReadReaderFull(id objectid.ObjectID) (io.ReadCloser, error) if err != nil { return nil, err } + if packfmt.IsBaseObjectType(meta.ty) { header, ok := objectheader.Encode(meta.ty, meta.size) if !ok { return nil, fmt.Errorf("objectstore/packed: failed to encode object header for type %d", meta.ty) } + zr, err := zlibReaderAt(pack, meta.dataOffset) if err != nil { return nil, err } + return &readCloser{ reader: io.MultiReader(bytes.NewReader(header), iolimit.ExpectLengthReader(zr, meta.size)), closer: zr, @@ -91,5 +97,6 @@ func (store *Store) ReadReaderFull(id objectid.ObjectID) (io.ReadCloser, error) if err != nil { return nil, err } + return io.NopCloser(bytes.NewReader(raw)), nil } diff --git a/objectstore/packed/read_size.go b/objectstore/packed/read_size.go index e162586a..a0a75db7 100644 --- a/objectstore/packed/read_size.go +++ b/objectstore/packed/read_size.go @@ -14,6 +14,7 @@ func (store *Store) ReadSize(id objectid.ObjectID) (int64, error) { if err != nil { return 0, err } + return store.resolveSizeAt(loc) } @@ -23,9 +24,11 @@ func (store *Store) resolveSizeAt(start location) (int64, error) { if err != nil { return 0, err } + if packfmt.IsBaseObjectType(meta.ty) { return meta.size, nil } + switch meta.ty { case objecttype.TypeRefDelta, objecttype.TypeOfsDelta: return deltaDeclaredSizeAt(pack, meta.dataOffset) diff --git a/objectstore/packed/read_test.go b/objectstore/packed/read_test.go index 9bfa6610..9ba89fdf 100644 --- a/objectstore/packed/read_test.go +++ b/objectstore/packed/read_test.go @@ -30,16 +30,20 @@ func TestPackedStoreReadAgainstGit(t *testing.T) { if err != nil { t.Fatalf("ReadHeader: %v", err) } + if gotHeaderType != wantType { t.Fatalf("ReadHeader type = %v, want %v", gotHeaderType, wantType) } + if gotHeaderSize != int64(len(wantBody)) { t.Fatalf("ReadHeader size = %d, want %d", gotHeaderSize, len(wantBody)) } + gotSize, err := store.ReadSize(id) if err != nil { t.Fatalf("ReadSize: %v", err) } + if gotSize != int64(len(wantBody)) { t.Fatalf("ReadSize = %d, want %d", gotSize, len(wantBody)) } @@ -48,6 +52,7 @@ func TestPackedStoreReadAgainstGit(t *testing.T) { if err != nil { t.Fatalf("ReadBytesFull: %v", err) } + if !bytes.Equal(gotRaw, wantRaw) { t.Fatalf("ReadBytesFull mismatch") } @@ -56,9 +61,11 @@ func TestPackedStoreReadAgainstGit(t *testing.T) { if err != nil { t.Fatalf("ReadBytesContent: %v", err) } + if gotType != wantType { t.Fatalf("ReadBytesContent type = %v, want %v", gotType, wantType) } + if !bytes.Equal(gotBody, wantBody) { t.Fatalf("ReadBytesContent mismatch") } @@ -67,7 +74,9 @@ func TestPackedStoreReadAgainstGit(t *testing.T) { if err != nil { t.Fatalf("ReadReaderFull: %v", err) } - if got := mustReadAllAndClose(t, fullReader); !bytes.Equal(got, wantRaw) { + + got := mustReadAllAndClose(t, fullReader) + if !bytes.Equal(got, wantRaw) { t.Fatalf("ReadReaderFull mismatch") } @@ -75,13 +84,17 @@ func TestPackedStoreReadAgainstGit(t *testing.T) { if err != nil { t.Fatalf("ReadReaderContent: %v", err) } + if contentType != wantType { t.Fatalf("ReadReaderContent type = %v, want %v", contentType, wantType) } + if contentSize != int64(len(wantBody)) { t.Fatalf("ReadReaderContent size = %d, want %d", contentSize, len(wantBody)) } - if got := mustReadAllAndClose(t, contentReader); !bytes.Equal(got, wantBody) { + + got = mustReadAllAndClose(t, contentReader) + if !bytes.Equal(got, wantBody) { t.Fatalf("ReadReaderContent mismatch") } }) @@ -100,38 +113,54 @@ func TestPackedStoreErrors(t *testing.T) { t.Fatalf("ParseHex(notFound): %v", err) } - if _, err := store.ReadBytesFull(notFoundID); !errors.Is(err, objectstore.ErrObjectNotFound) { + _, err = store.ReadBytesFull(notFoundID) + if !errors.Is(err, objectstore.ErrObjectNotFound) { t.Fatalf("ReadBytesFull not-found error = %v", err) } - if _, _, err := store.ReadBytesContent(notFoundID); !errors.Is(err, objectstore.ErrObjectNotFound) { + + _, _, err = store.ReadBytesContent(notFoundID) + if !errors.Is(err, objectstore.ErrObjectNotFound) { t.Fatalf("ReadBytesContent not-found error = %v", err) } - if _, err := store.ReadReaderFull(notFoundID); !errors.Is(err, objectstore.ErrObjectNotFound) { + + _, err = store.ReadReaderFull(notFoundID) + if !errors.Is(err, objectstore.ErrObjectNotFound) { t.Fatalf("ReadReaderFull not-found error = %v", err) } - if _, _, _, err := store.ReadReaderContent(notFoundID); !errors.Is(err, objectstore.ErrObjectNotFound) { + + _, _, _, err = store.ReadReaderContent(notFoundID) + if !errors.Is(err, objectstore.ErrObjectNotFound) { t.Fatalf("ReadReaderContent not-found error = %v", err) } - if _, _, err := store.ReadHeader(notFoundID); !errors.Is(err, objectstore.ErrObjectNotFound) { + + _, _, err = store.ReadHeader(notFoundID) + if !errors.Is(err, objectstore.ErrObjectNotFound) { t.Fatalf("ReadHeader not-found error = %v", err) } - if _, err := store.ReadSize(notFoundID); !errors.Is(err, objectstore.ErrObjectNotFound) { + + _, err = store.ReadSize(notFoundID) + if !errors.Is(err, objectstore.ErrObjectNotFound) { t.Fatalf("ReadSize not-found error = %v", err) } var otherAlgo objectid.Algorithm + for _, candidate := range objectid.SupportedAlgorithms() { if candidate != algo { otherAlgo = candidate + break } } + if otherAlgo != objectid.AlgorithmUnknown { mismatchID, err := objectid.ParseHex(otherAlgo, strings.Repeat("0", otherAlgo.HexLen())) if err != nil { t.Fatalf("ParseHex(mismatch): %v", err) } - if _, err := store.ReadBytesFull(mismatchID); err == nil || !strings.Contains(err.Error(), "algorithm mismatch") { + + _, err = store.ReadBytesFull(mismatchID) + if err == nil || !strings.Contains(err.Error(), "algorithm mismatch") { t.Fatalf("ReadBytesFull algorithm-mismatch error = %v", err) } } @@ -141,11 +170,16 @@ func TestPackedStoreErrors(t *testing.T) { func TestPackedStoreNewValidation(t *testing.T) { t.Parallel() testRepo, _ := createPackedFixtureRepo(t, objectid.AlgorithmSHA1) + store := openPackedStore(t, testRepo.Dir(), objectid.AlgorithmSHA1) - if err := store.Close(); err != nil { + + err := store.Close() + if err != nil { t.Fatalf("Close: %v", err) } - if err := store.Close(); err != nil { + + err = store.Close() + if err != nil { t.Fatalf("Close second: %v", err) } } @@ -153,13 +187,16 @@ func TestPackedStoreNewValidation(t *testing.T) { func TestPackedStoreInvalidAlgorithm(t *testing.T) { t.Parallel() testRepo := testgit.NewRepo(t, testgit.RepoOptions{ObjectFormat: objectid.AlgorithmSHA1, Bare: true}) + root, err := os.OpenRoot(testRepo.Dir()) if err != nil { t.Fatalf("OpenRoot(%q): %v", testRepo.Dir(), err) } + t.Cleanup(func() { _ = root.Close() }) - if _, err := packed.New(root, objectid.AlgorithmUnknown); !errors.Is(err, objectid.ErrInvalidAlgorithm) { + _, err = packed.New(root, objectid.AlgorithmUnknown) + if !errors.Is(err, objectid.ErrInvalidAlgorithm) { t.Fatalf("packed.New invalid algorithm error = %v", err) } } @@ -170,15 +207,20 @@ func TestPackedStoreReadHeaderUsesResolvedObjectSizeForDelta(t *testing.T) { testRepo := testgit.NewRepo(t, testgit.RepoOptions{ObjectFormat: algo, Bare: true}) var parent objectid.ObjectID + for i := range 96 { content := strings.Repeat("common-line-"+strconv.Itoa(i%7)+"\n", 384) + fmt.Sprintf("tail-%03d\n", i) + _, treeID := testRepo.MakeSingleFileTree(t, "file.txt", []byte(content)) if i == 0 { parent = testRepo.CommitTree(t, treeID, "delta-header-size-0") + continue } + parent = testRepo.CommitTree(t, treeID, fmt.Sprintf("delta-header-size-%03d", i), parent) } + testRepo.UpdateRef(t, "refs/heads/main", parent) testRepo.Repack(t, "-a", "-d", "-f", "--window=128", "--depth=128") @@ -189,13 +231,16 @@ func TestPackedStoreReadHeaderUsesResolvedObjectSizeForDelta(t *testing.T) { if err != nil { t.Fatalf("ReadHeader(%s): %v", deltaID, err) } + if gotSize != wantResolvedSize { t.Fatalf("ReadHeader(%s) size = %d, want resolved size %d", deltaID, gotSize, wantResolvedSize) } + gotReadSize, err := store.ReadSize(deltaID) if err != nil { t.Fatalf("ReadSize(%s): %v", deltaID, err) } + if gotReadSize != wantResolvedSize { t.Fatalf("ReadSize(%s) = %d, want resolved size %d", deltaID, gotReadSize, wantResolvedSize) } @@ -209,6 +254,7 @@ func findDeltaObjectWithResolvedSizeMismatch(t *testing.T, testRepo *testgit.Tes if err != nil { t.Fatalf("Glob idx: %v", err) } + if len(idxFiles) == 0 { t.Fatalf("no idx files found") } @@ -221,16 +267,19 @@ func findDeltaObjectWithResolvedSizeMismatch(t *testing.T, testRepo *testgit.Tes } idHex := fields[0] + deltaStreamSize, err := strconv.ParseInt(fields[2], 10, 64) if err != nil { continue } resolvedSizeStr := testRepo.Run(t, "cat-file", "-s", idHex) + resolvedSize, err := strconv.ParseInt(strings.TrimSpace(resolvedSizeStr), 10, 64) if err != nil { t.Fatalf("parse cat-file size for %s: %v", idHex, err) } + if deltaStreamSize == resolvedSize { continue } @@ -239,9 +288,11 @@ func findDeltaObjectWithResolvedSizeMismatch(t *testing.T, testRepo *testgit.Tes if err != nil { t.Fatalf("ParseHex(%s): %v", idHex, err) } + return id, resolvedSize } t.Fatalf("did not find a delta object with mismatched stream/resolved size") + return objectid.ObjectID{}, 0 } diff --git a/objectstore/packed/store.go b/objectstore/packed/store.go index abd7175f..d28113d1 100644 --- a/objectstore/packed/store.go +++ b/objectstore/packed/store.go @@ -60,6 +60,7 @@ func New(root *os.Root, algo objectid.Algorithm) (*Store, error) { if algo.Size() == 0 { return nil, objectid.ErrInvalidAlgorithm } + return &Store{ root: root, algo: algo, @@ -76,8 +77,10 @@ func (store *Store) Close() error { store.stateMu.Lock() if store.closed { store.stateMu.Unlock() + return nil } + store.closed = true root := store.root packs := store.packs @@ -87,23 +90,30 @@ func (store *Store) Close() error { store.idxMu.RUnlock() var closeErr error + for _, pack := range packs { - if err := pack.close(); err != nil && closeErr == nil { + err := pack.close() + if err != nil && closeErr == nil { closeErr = err } } + for _, index := range indexes { - if err := index.close(); err != nil && closeErr == nil { + err := index.close() + if err != nil && closeErr == nil { closeErr = err } } + store.cacheMu.Lock() store.deltaCache.clear() store.cacheMu.Unlock() - if err := root.Close(); err != nil && closeErr == nil { + err := root.Close() + if err != nil && closeErr == nil { closeErr = err } + return closeErr } @@ -113,7 +123,9 @@ func (store *Store) lookup(id objectid.ObjectID) (location, error) { if id.Algorithm() != store.algo { return zero, errors.New("objectstore/packed: object id algorithm mismatch") } - if err := store.ensureCandidates(); err != nil { + + err := store.ensureCandidates() + if err != nil { return zero, err } @@ -122,81 +134,111 @@ func (store *Store) lookup(id objectid.ObjectID) (location, error) { candidate, ok := store.candidateForPack(nextPackName) if !ok { nextPackName = store.firstCandidatePackName() + continue } + nextPackName = store.nextCandidatePackName(candidate.packName) + index, err := store.openIndex(candidate) if err != nil { return zero, err } + offset, ok, err := index.lookup(id) if err != nil { return zero, err } + if ok { store.touchCandidate(candidate.packName) + return location{packName: index.packName, offset: offset}, nil } } + return zero, objectstore.ErrObjectNotFound } // openPack returns one opened and validated pack handle. func (store *Store) openPack(name string) (*packFile, error) { store.stateMu.RLock() - if pack, ok := store.packs[name]; ok { + + pack, ok := store.packs[name] + if ok { store.stateMu.RUnlock() + return pack, nil } + store.stateMu.RUnlock() file, err := store.root.Open(name) if err != nil { return nil, err } + info, err := file.Stat() if err != nil { _ = file.Close() + return nil, err } - pack, err := openPackFile(name, file, info.Size()) + + pack, err = openPackFile(name, file, info.Size()) if err != nil { _ = file.Close() + return nil, err } - if err := store.verifyPackMatchesIndexes(pack); err != nil { + + err = store.verifyPackMatchesIndexes(pack) + if err != nil { _ = pack.close() + return nil, err } store.stateMu.Lock() - if existing, ok := store.packs[name]; ok { + + existing, ok := store.packs[name] + if ok { store.stateMu.Unlock() + _ = pack.close() + return existing, nil } + store.packs[name] = pack store.stateMu.Unlock() + return pack, nil } // verifyPackMatchesIndexes checks that one opened pack's trailer hash matches // every loaded index that references the same pack name. func (store *Store) verifyPackMatchesIndexes(pack *packFile) error { - if err := store.ensureCandidates(); err != nil { + err := store.ensureCandidates() + if err != nil { return err } + candidate, ok := store.candidateForPack(pack.name) if !ok { return fmt.Errorf("objectstore/packed: missing index for pack %q", pack.name) } + index, err := store.openIndex(candidate) if err != nil { return err } - if err := verifyMappedPackMatchesMappedIdx(pack.data, index.data, store.algo); err != nil { + + err = verifyMappedPackMatchesMappedIdx(pack.data, index.data, store.algo) + if err != nil { return fmt.Errorf("objectstore/packed: pack %q does not match idx %q: %w", pack.name, index.idxName, err) } + return nil } @@ -206,9 +248,11 @@ func (store *Store) entryMetaAt(loc location) (*packFile, entryMeta, error) { if err != nil { return nil, entryMeta{}, err } + meta, err := parseEntryMeta(pack, store.algo, loc.offset) if err != nil { return nil, entryMeta{}, err } + return pack, meta, nil } diff --git a/reachability/errors.go b/reachability/errors.go index e52bf0a4..7d0d9a18 100644 --- a/reachability/errors.go +++ b/reachability/errors.go @@ -29,9 +29,11 @@ func (e *ErrObjectType) Error() string { if !gotOK { gotName = fmt.Sprintf("type(%d)", e.Got) } + wantName, wantOK := objecttype.Name(e.Want) if !wantOK { wantName = fmt.Sprintf("type(%d)", e.Want) } + return fmt.Sprintf("reachability: object %s has type %s, want %s", e.OID, gotName, wantName) } diff --git a/reachability/helpers.go b/reachability/helpers.go index 1368a3f1..41a2f80b 100644 --- a/reachability/helpers.go +++ b/reachability/helpers.go @@ -22,7 +22,9 @@ func containsOID(set map[objectid.ObjectID]struct{}, id objectid.ObjectID) bool if len(set) == 0 { return false } + _, ok := set[id] + return ok } @@ -39,8 +41,10 @@ func (r *Reachability) readHeaderType(id objectid.ObjectID) (objecttype.Type, er if errors.Is(err, objectstore.ErrObjectNotFound) { return objecttype.TypeInvalid, &ErrObjectMissing{OID: id} } + return objecttype.TypeInvalid, err } + return ty, nil } @@ -49,6 +53,7 @@ func (walk *Walk) readBytesContent(id objectid.ObjectID) ([]byte, error) { if err != nil { return nil, err } + return content, nil } @@ -58,7 +63,9 @@ func (r *Reachability) readBytesContent(id objectid.ObjectID) ([]byte, error) { if errors.Is(err, objectstore.ErrObjectNotFound) { return nil, &ErrObjectMissing{OID: id} } + return nil, err } + return content, nil } diff --git a/reachability/integration_test.go b/reachability/integration_test.go index 10668006..079ce5fc 100644 --- a/reachability/integration_test.go +++ b/reachability/integration_test.go @@ -47,8 +47,11 @@ func TestWalkCommitsMatchesGitRevList(t *testing.T) { nil, map[objectid.ObjectID]struct{}{merge: {}}, ) + got := oidSetFromSeq(walk.Seq()) - if err := walk.Err(); err != nil { + + err := walk.Err() + if err != nil { t.Fatalf("walk.Err(): %v", err) } @@ -62,12 +65,17 @@ func TestWalkCommitsMatchesGitRevList(t *testing.T) { nil, map[objectid.ObjectID]struct{}{tag2: {}}, ) + peelGot := oidSetFromSeq(peelWalk.Seq()) - if err := peelWalk.Err(); err != nil { + + err = peelWalk.Err() + if err != nil { t.Fatalf("peelWalk.Err(): %v", err) } + wantWithTags := maps.Clone(want) wantWithTags[tag1] = struct{}{} + wantWithTags[tag2] = struct{}{} if !maps.Equal(peelGot, wantWithTags) { t.Fatalf("tag-root commit walk mismatch:\n got=%v\nwant=%v", sortedOIDStrings(peelGot), sortedOIDStrings(wantWithTags)) @@ -104,8 +112,11 @@ func TestWalkObjectsMatchesGitRevListObjects(t *testing.T) { nil, map[objectid.ObjectID]struct{}{head: {}}, ) + got := oidSetFromSeq(walk.Seq()) - if err := walk.Err(); err != nil { + + err := walk.Err() + if err != nil { t.Fatalf("walk.Err(): %v", err) } @@ -119,10 +130,14 @@ func TestWalkObjectsMatchesGitRevListObjects(t *testing.T) { nil, map[objectid.ObjectID]struct{}{tag: {}}, ) + peelGot := oidSetFromSeq(peelWalk.Seq()) - if err := peelWalk.Err(); err != nil { + + err = peelWalk.Err() + if err != nil { t.Fatalf("peelWalk.Err(): %v", err) } + wantFromTag := gitRevListSet(t, testRepo, true, []objectid.ObjectID{tag}, nil) if !maps.Equal(peelGot, wantFromTag) { t.Fatalf("tag-root object walk mismatch:\n got=%v\nwant=%v", sortedOIDStrings(peelGot), sortedOIDStrings(wantFromTag)) @@ -133,11 +148,16 @@ func TestWalkObjectsMatchesGitRevListObjects(t *testing.T) { map[objectid.ObjectID]struct{}{base: {}}, map[objectid.ObjectID]struct{}{head: {}}, ) + withHave := oidSetFromSeq(walkWithHave.Seq()) - if err := walkWithHave.Err(); err != nil { + + err = walkWithHave.Err() + if err != nil { t.Fatalf("walkWithHave.Err(): %v", err) } - if _, ok := withHave[base]; ok { + + _, ok := withHave[base] + if ok { t.Fatalf("walk output unexpectedly contains have commit %s", base) } }) @@ -170,7 +190,9 @@ func TestIsAncestorMatchesGitMergeBase(t *testing.T) { if err != nil { t.Fatalf("IsAncestor(c1, tag): %v", err) } - if want := gitMergeBaseIsAncestor(t, testRepo, c1, c2); got != want { + + want := gitMergeBaseIsAncestor(t, testRepo, c1, c2) + if got != want { t.Fatalf("IsAncestor(c1, tag)=%v, want %v", got, want) } @@ -178,7 +200,9 @@ func TestIsAncestorMatchesGitMergeBase(t *testing.T) { if err != nil { t.Fatalf("IsAncestor(c3, c2): %v", err) } - if want := gitMergeBaseIsAncestor(t, testRepo, c3, c2); got != want { + + want = gitMergeBaseIsAncestor(t, testRepo, c3, c2) + if got != want { t.Fatalf("IsAncestor(c3, c2)=%v, want %v", got, want) } }) @@ -195,12 +219,15 @@ func TestCheckConnectedMissingObject(t *testing.T) { }) _, treeID, commitID := testRepo.MakeCommit(t, "missing") - if err := os.Remove(looseObjectPath(testRepo.Dir(), treeID)); err != nil { + + err := os.Remove(looseObjectPath(testRepo.Dir(), treeID)) + if err != nil { t.Fatalf("remove tree object: %v", err) } r := openReachabilityFromTestRepo(t, testRepo) - err := r.CheckConnected( + + err = r.CheckConnected( reachability.DomainObjects, nil, map[objectid.ObjectID]struct{}{commitID: {}}, @@ -208,10 +235,12 @@ func TestCheckConnectedMissingObject(t *testing.T) { if err == nil { t.Fatal("expected error") } + var missing *reachability.ErrObjectMissing if !errors.As(err, &missing) { t.Fatalf("expected ErrObjectMissing, got %T (%v)", err, err) } + if missing.OID != treeID { t.Fatalf("missing oid = %s, want %s", missing.OID, treeID) } @@ -246,14 +275,21 @@ func TestWalkOnPackedOnlyRepo(t *testing.T) { nil, map[objectid.ObjectID]struct{}{c2: {}}, ) + got := oidSetFromSeq(walk.Seq()) - if err := walk.Err(); err != nil { + + err := walk.Err() + if err != nil { t.Fatalf("walk.Err(): %v", err) } - if _, ok := got[c2]; !ok { + + _, ok := got[c2] + if !ok { t.Fatalf("walk output missing HEAD commit %s", c2) } - if _, ok := got[c1]; !ok { + + _, ok = got[c1] + if !ok { t.Fatalf("walk output missing parent commit %s", c1) } }) @@ -261,16 +297,19 @@ func TestWalkOnPackedOnlyRepo(t *testing.T) { func openReachabilityFromTestRepo(t *testing.T, testRepo *testgit.TestRepo) *reachability.Reachability { t.Helper() + root, err := os.OpenRoot(testRepo.Dir()) if err != nil { t.Fatalf("os.OpenRoot: %v", err) } + t.Cleanup(func() { _ = root.Close() }) repo, err := repository.Open(root) if err != nil { t.Fatalf("repository.Open: %v", err) } + t.Cleanup(func() { _ = repo.Close() }) return reachability.New(repo.Objects()) @@ -278,10 +317,13 @@ func openReachabilityFromTestRepo(t *testing.T, testRepo *testgit.TestRepo) *rea func oidSetFromSeq(seq func(func(objectid.ObjectID) bool)) map[objectid.ObjectID]struct{} { out := make(map[objectid.ObjectID]struct{}) + seq(func(id objectid.ObjectID) bool { out[id] = struct{}{} + return true }) + return out } @@ -298,9 +340,11 @@ func gitRevListSet( if includeObjects { args = append(args, "--objects") } + for _, want := range wants { args = append(args, want.String()) } + if len(haves) > 0 { args = append(args, "--not") for _, have := range haves { @@ -310,21 +354,28 @@ func gitRevListSet( out := testRepo.Run(t, args...) set := make(map[objectid.ObjectID]struct{}) + for line := range strings.SplitSeq(strings.TrimSpace(out), "\n") { line = strings.TrimSpace(line) if line == "" { continue } + tok := line - if i := strings.IndexByte(tok, ' '); i >= 0 { + + i := strings.IndexByte(tok, ' ') + if i >= 0 { tok = tok[:i] } + id, err := objectid.ParseHex(testRepo.Algorithm(), tok) if err != nil { t.Fatalf("parse rev-list oid %q: %v", tok, err) } + set[id] = struct{}{} } + return set } @@ -332,6 +383,7 @@ func gitMergeBaseIsAncestor(t *testing.T, testRepo *testgit.TestRepo, a, b objec t.Helper() // testgit.Run fatals on non-zero status, so we compare merge-base output. mb := testRepo.Run(t, "merge-base", a.String(), b.String()) + return mb == a.String() } @@ -340,33 +392,40 @@ func sortedOIDStrings(set map[objectid.ObjectID]struct{}) []string { for id := range set { out = append(out, id.String()) } + slices.Sort(out) + return out } func looseObjectPath(repoDir string, id objectid.ObjectID) string { hex := id.String() + return filepath.Join(repoDir, "objects", hex[:2], hex[2:]) } func assertPackedOnly(t *testing.T, repoDir string) { t.Helper() + objectsDir := filepath.Join(repoDir, "objects") entries, err := os.ReadDir(objectsDir) if err != nil { t.Fatalf("ReadDir(objects): %v", err) } + for _, entry := range entries { name := entry.Name() if name == "pack" || name == "info" { continue } + if len(name) == 2 && isHexDirName(name) { subEntries, err := os.ReadDir(filepath.Join(objectsDir, name)) if err != nil { t.Fatalf("ReadDir(objects/%s): %v", name, err) } + if len(subEntries) != 0 { t.Fatalf("found loose objects in %s", filepath.Join(objectsDir, name)) } @@ -378,11 +437,13 @@ func isHexDirName(name string) bool { if len(name) != 2 { return false } + for i := range 2 { c := name[i] if (c < '0' || c > '9') && (c < 'a' || c > 'f') { return false } } + return true } diff --git a/reachability/peel.go b/reachability/peel.go index 9df9bb4d..7b9e7bf1 100644 --- a/reachability/peel.go +++ b/reachability/peel.go @@ -7,18 +7,22 @@ import ( ) func (r *Reachability) peelRootToDomain(id objectid.ObjectID, domain Domain) (objectid.ObjectID, error) { - if err := validateDomain(domain); err != nil { + err := validateDomain(domain) + if err != nil { return objectid.ObjectID{}, err } + for { ty, err := r.readHeaderType(id) if err != nil { return objectid.ObjectID{}, err } + if ty != objecttype.TypeTag { if domain == DomainCommits && ty != objecttype.TypeCommit { return objectid.ObjectID{}, &ErrObjectType{OID: id, Got: ty, Want: objecttype.TypeCommit} } + return id, nil } @@ -26,10 +30,12 @@ func (r *Reachability) peelRootToDomain(id objectid.ObjectID, domain Domain) (ob if err != nil { return objectid.ObjectID{}, err } + tag, err := object.ParseTag(content, id.Algorithm()) if err != nil { return objectid.ObjectID{}, err } + id = tag.Target } } diff --git a/reachability/reachability.go b/reachability/reachability.go index 0bec055f..93bc840b 100644 --- a/reachability/reachability.go +++ b/reachability/reachability.go @@ -26,10 +26,12 @@ func (r *Reachability) IsAncestor(ancestor, descendant objectid.ObjectID) (bool, if err != nil { return false, err } + descendantCommit, err := r.peelRootToDomain(descendant, DomainCommits) if err != nil { return false, err } + if ancestorCommit == descendantCommit { return true, nil } @@ -40,9 +42,12 @@ func (r *Reachability) IsAncestor(ancestor, descendant objectid.ObjectID) (bool, return true, nil } } - if err := walk.Err(); err != nil { + + err = walk.Err() + if err != nil { return false, err } + return false, nil } @@ -53,6 +58,7 @@ func (r *Reachability) CheckConnected(domain Domain, haves, wants map[objectid.O walk := r.Walk(domain, haves, wants) for range walk.Seq() { } + return walk.Err() } @@ -64,8 +70,11 @@ func (r *Reachability) Walk(domain Domain, haves, wants map[objectid.ObjectID]st haves: haves, wants: wants, } - if err := validateDomain(domain); err != nil { + + err := validateDomain(domain) + if err != nil { walk.err = err } + return walk } diff --git a/reachability/unit_test.go b/reachability/unit_test.go index ec177938..8f19cfcd 100644 --- a/reachability/unit_test.go +++ b/reachability/unit_test.go @@ -42,13 +42,16 @@ func (store *memStore) ReadBytesFull(id objectid.ObjectID) ([]byte, error) { if !ok { return nil, objectstore.ErrObjectNotFound } + header, ok := objectheader.Encode(obj.ty, int64(len(obj.content))) if !ok { panic("failed to encode object header") } + raw := make([]byte, len(header)+len(obj.content)) copy(raw, header) copy(raw[len(header):], obj.content) + return raw, nil } @@ -57,7 +60,9 @@ func (store *memStore) ReadBytesContent(id objectid.ObjectID) (objecttype.Type, if !ok { return objecttype.TypeInvalid, nil, objectstore.ErrObjectNotFound } + store.readBytesByObjectID[id]++ + return obj.ty, append([]byte(nil), obj.content...), nil } @@ -66,6 +71,7 @@ func (store *memStore) ReadReaderFull(id objectid.ObjectID) (io.ReadCloser, erro if err != nil { return nil, err } + return io.NopCloser(bytes.NewReader(raw)), nil } @@ -74,6 +80,7 @@ func (store *memStore) ReadReaderContent(id objectid.ObjectID) (objecttype.Type, if err != nil { return objecttype.TypeInvalid, 0, nil, err } + return ty, int64(len(content)), io.NopCloser(bytes.NewReader(content)), nil } @@ -82,6 +89,7 @@ func (store *memStore) ReadSize(id objectid.ObjectID) (int64, error) { if err != nil { return 0, err } + return size, nil } @@ -90,6 +98,7 @@ func (store *memStore) ReadHeader(id objectid.ObjectID) (objecttype.Type, int64, if !ok { return objecttype.TypeInvalid, 0, objectstore.ErrObjectNotFound } + return obj.ty, int64(len(obj.content)), nil } @@ -102,7 +111,9 @@ func commitBody(tree objectid.ObjectID, parents ...objectid.ObjectID) []byte { for _, parent := range parents { buf = append(buf, fmt.Appendf(nil, "parent %s\n", parent.String())...) } + buf = append(buf, []byte("\nmsg\n")...) + return buf } @@ -111,15 +122,19 @@ func tagBody(target objectid.ObjectID, targetType objecttype.Type) []byte { if !ok { panic("invalid tag target type") } + return fmt.Appendf(nil, "object %s\ntype %s\ntag t\n\nmsg\n", target.String(), targetName) } func collectSeq(seq func(func(objectid.ObjectID) bool)) []objectid.ObjectID { var out []objectid.ObjectID + seq(func(id objectid.ObjectID) bool { out = append(out, id) + return true }) + return out } @@ -128,6 +143,7 @@ func toSet(ids []objectid.ObjectID) map[objectid.ObjectID]struct{} { for _, id := range ids { set[id] = struct{}{} } + return set } @@ -149,12 +165,16 @@ func TestWalkDomainCommitsIncludesTagNodes(t *testing.T) { r := reachability.New(store) walk := r.Walk(reachability.DomainCommits, nil, map[objectid.ObjectID]struct{}{tag2: {}}) + got := collectSeq(walk.Seq()) - if err := walk.Err(); err != nil { + + err := walk.Err() + if err != nil { t.Fatalf("walk.Err(): %v", err) } gotSet := toSet(got) + wantSet := map[objectid.ObjectID]struct{}{tag2: {}, tag1: {}, commit2: {}, commit1: {}} if !maps.Equal(gotSet, wantSet) { t.Fatalf("walk output mismatch: got %v, want %v", slices.Collect(maps.Keys(gotSet)), slices.Collect(maps.Keys(wantSet))) @@ -177,10 +197,14 @@ func TestWalkExcludesHavesCompletely(t *testing.T) { r := reachability.New(store) walk := r.Walk(reachability.DomainCommits, map[objectid.ObjectID]struct{}{commit: {}}, map[objectid.ObjectID]struct{}{commit: {}}) + got := collectSeq(walk.Seq()) - if err := walk.Err(); err != nil { + + err := walk.Err() + if err != nil { t.Fatalf("walk.Err(): %v", err) } + if len(got) != 0 { t.Fatalf("expected empty output, got %v", got) } @@ -203,14 +227,17 @@ func TestWalkDomainCommitsRejectsNonCommitRootAfterPeel(t *testing.T) { r := reachability.New(store) walk := r.Walk(reachability.DomainCommits, nil, map[objectid.ObjectID]struct{}{tag: {}}) _ = collectSeq(walk.Seq()) + err := walk.Err() if err == nil { t.Fatal("expected error") } + var typeErr *reachability.ErrObjectType if !errors.As(err, &typeErr) { t.Fatalf("expected ErrObjectType, got %T (%v)", err, err) } + if typeErr.Got != objecttype.TypeTree || typeErr.Want != objecttype.TypeCommit { t.Fatalf("unexpected type error: %+v", typeErr) } @@ -239,12 +266,16 @@ func TestWalkDomainCommitsHaveTagStopsTraversal(t *testing.T) { map[objectid.ObjectID]struct{}{tag1: {}}, map[objectid.ObjectID]struct{}{tag2: {}}, ) + got := collectSeq(walk.Seq()) - if err := walk.Err(); err != nil { + + err := walk.Err() + if err != nil { t.Fatalf("walk.Err(): %v", err) } gotSet := toSet(got) + wantSet := map[objectid.ObjectID]struct{}{tag2: {}} if !maps.Equal(gotSet, wantSet) { t.Fatalf("walk output mismatch: got %v, want %v", slices.Collect(maps.Keys(gotSet)), slices.Collect(maps.Keys(wantSet))) @@ -276,16 +307,21 @@ func TestWalkDomainObjectsRecursesTreesAndSkipsBlobContentReads(t *testing.T) { r := reachability.New(store) walk := r.Walk(reachability.DomainObjects, nil, map[objectid.ObjectID]struct{}{commit: {}}) + got := collectSeq(walk.Seq()) - if err := walk.Err(); err != nil { + + err := walk.Err() + if err != nil { t.Fatalf("walk.Err(): %v", err) } gotSet := toSet(got) + wantSet := map[objectid.ObjectID]struct{}{commit: {}, rootTree: {}, subtree: {}, blob1: {}, blob2: {}} if !maps.Equal(gotSet, wantSet) { t.Fatalf("walk output mismatch: got %v, want %v", slices.Collect(maps.Keys(gotSet)), slices.Collect(maps.Keys(wantSet))) } + if store.readBytesByObjectID[blob1] != 0 || store.readBytesByObjectID[blob2] != 0 { t.Fatalf("blob contents should not be read; counts: blob1=%d blob2=%d", store.readBytesByObjectID[blob1], store.readBytesByObjectID[blob2]) } @@ -307,14 +343,17 @@ func TestCheckConnectedReturnsConcreteMissingObject(t *testing.T) { commit := store.addObject(objecttype.TypeCommit, commitBody(tree, missingParent)) r := reachability.New(store) + err := r.CheckConnected(reachability.DomainCommits, nil, map[objectid.ObjectID]struct{}{commit: {}}) if err == nil { t.Fatal("expected error") } + var missing *reachability.ErrObjectMissing if !errors.As(err, &missing) { t.Fatalf("expected ErrObjectMissing, got %T (%v)", err, err) } + if missing.OID != missingParent { t.Fatalf("unexpected missing oid: got %s want %s", missing.OID, missingParent) } @@ -327,8 +366,11 @@ func TestWalkInvalidDomainReturnsPlainError(t *testing.T) { testgit.ForEachAlgorithm(t, func(t *testing.T, algo objectid.Algorithm) { //nolint:thelper r := reachability.New(newMemStore(algo)) walk := r.Walk(reachability.Domain(99), nil, nil) + _ = collectSeq(walk.Seq()) - if err := walk.Err(); err == nil { + + err := walk.Err() + if err == nil { t.Fatal("expected error") } }) @@ -357,10 +399,12 @@ func TestIsAncestor(t *testing.T) { tag := store.addObject(objecttype.TypeTag, tagBody(c2, objecttype.TypeCommit)) r := reachability.New(store) + ok, err := r.IsAncestor(c1, tag) if err != nil { t.Fatalf("IsAncestor(c1, tag): %v", err) } + if !ok { t.Fatal("expected c1 to be ancestor of tag->c2") } @@ -369,6 +413,7 @@ func TestIsAncestor(t *testing.T) { if err != nil { t.Fatalf("IsAncestor(c3, c2): %v", err) } + if ok { t.Fatal("did not expect c3 to be ancestor of c2") } @@ -390,10 +435,12 @@ func TestIsAncestorRejectsNonCommitAfterPeel(t *testing.T) { tagToTree := store.addObject(objecttype.TypeTag, tagBody(tree, objecttype.TypeTree)) r := reachability.New(store) + _, err := r.IsAncestor(commit, tagToTree) if err == nil { t.Fatal("expected error") } + var typeErr *reachability.ErrObjectType if !errors.As(err, &typeErr) { t.Fatalf("expected ErrObjectType, got %T (%v)", err, err) @@ -403,10 +450,12 @@ func TestIsAncestorRejectsNonCommitAfterPeel(t *testing.T) { func mustSerializeTree(tb testing.TB, tree *object.Tree) []byte { tb.Helper() + body, err := tree.SerializeWithoutHeader() if err != nil { tb.Fatalf("SerializeWithoutHeader: %v", err) } + return body } @@ -415,8 +464,10 @@ func (store *memStore) addObject(ty objecttype.Type, body []byte) objectid.Objec if !ok { panic("failed to encode object header") } + raw := append(append([]byte(nil), header...), body...) id := store.algo.Sum(raw) store.objects[id] = storeObject{ty: ty, content: append([]byte(nil), body...)} + return id } diff --git a/reachability/walk.go b/reachability/walk.go index 2b592a45..89de19e8 100644 --- a/reachability/walk.go +++ b/reachability/walk.go @@ -26,18 +26,24 @@ func (walk *Walk) Seq() iter.Seq[objectid.ObjectID] { if walk.seqUsed { return func(yield func(objectid.ObjectID) bool) { _ = yield + if walk.err == nil { walk.err = errors.New("reachability: walk sequence already consumed") } } } + walk.seqUsed = true + return func(yield func(objectid.ObjectID) bool) { if walk.err != nil { return } + stack := walk.initialStack() + var err error + visited := make(map[objectid.ObjectID]struct{}, len(stack)) for len(stack) > 0 { item := stack[len(stack)-1] @@ -46,20 +52,26 @@ func (walk *Walk) Seq() iter.Seq[objectid.ObjectID] { if containsOID(walk.haves, item.id) { continue } + if _, ok := visited[item.id]; ok { continue } + visited[item.id] = struct{}{} var next []walkItem + next, err = walk.expand(item) if err != nil { walk.err = err + return } + if !yield(item.id) { return } + stack = append(stack, next...) } } @@ -79,10 +91,12 @@ func (walk *Walk) initialStack() []walkItem { if len(walk.wants) == 0 { return nil } + stack := make([]walkItem, 0, len(walk.wants)) for want := range walk.wants { stack = append(stack, walkItem{id: want, want: objecttype.TypeInvalid}) } + return stack } @@ -90,6 +104,7 @@ func (walk *Walk) expand(item walkItem) ([]walkItem, error) { if walk.domain == DomainCommits { return walk.expandCommits(item) } + return walk.expandObjects(item) } @@ -98,35 +113,42 @@ func (walk *Walk) expandCommits(item walkItem) ([]walkItem, error) { if err != nil { return nil, err } + switch ty { case objecttype.TypeCommit: content, err := walk.readBytesContent(item.id) if err != nil { return nil, err } + commit, err := object.ParseCommit(content, item.id.Algorithm()) if err != nil { return nil, err } + next := make([]walkItem, 0, len(commit.Parents)) for _, parent := range commit.Parents { next = append(next, walkItem{id: parent, want: objecttype.TypeInvalid}) } + return next, nil case objecttype.TypeTag: content, err := walk.readBytesContent(item.id) if err != nil { return nil, err } + tag, err := object.ParseTag(content, item.id.Algorithm()) if err != nil { return nil, err } + return []walkItem{{id: tag.Target, want: objecttype.TypeInvalid}}, nil case objecttype.TypeTree, objecttype.TypeBlob, objecttype.TypeInvalid, objecttype.TypeFuture, objecttype.TypeOfsDelta, objecttype.TypeRefDelta: return nil, &ErrObjectType{OID: item.id, Got: ty, Want: objecttype.TypeCommit} } + return nil, fmt.Errorf("reachability: unreachable object type %d", ty) } @@ -135,6 +157,7 @@ func (walk *Walk) expandObjects(item walkItem) ([]walkItem, error) { if err != nil { return nil, err } + if item.want != objecttype.TypeInvalid && ty != item.want { return nil, &ErrObjectType{OID: item.id, Got: ty, Want: item.want} } @@ -147,25 +170,31 @@ func (walk *Walk) expandObjects(item walkItem) ([]walkItem, error) { if err != nil { return nil, err } + commit, err := object.ParseCommit(content, item.id.Algorithm()) if err != nil { return nil, err } + next := make([]walkItem, 0, len(commit.Parents)+1) + next = append(next, walkItem{id: commit.Tree, want: objecttype.TypeTree}) for _, parent := range commit.Parents { next = append(next, walkItem{id: parent, want: objecttype.TypeCommit}) } + return next, nil case objecttype.TypeTree: content, err := walk.readBytesContent(item.id) if err != nil { return nil, err } + tree, err := object.ParseTree(content, item.id.Algorithm()) if err != nil { return nil, err } + next := make([]walkItem, 0, len(tree.Entries)) for _, entry := range tree.Entries { switch entry.Mode { @@ -177,19 +206,23 @@ func (walk *Walk) expandObjects(item walkItem) ([]walkItem, error) { next = append(next, walkItem{id: entry.ID, want: objecttype.TypeBlob}) } } + return next, nil case objecttype.TypeTag: content, err := walk.readBytesContent(item.id) if err != nil { return nil, err } + tag, err := object.ParseTag(content, item.id.Algorithm()) if err != nil { return nil, err } + return []walkItem{{id: tag.Target, want: tag.TargetType}}, nil case objecttype.TypeInvalid, objecttype.TypeFuture, objecttype.TypeOfsDelta, objecttype.TypeRefDelta: return nil, &ErrObjectType{OID: item.id, Got: ty, Want: item.want} } + return nil, fmt.Errorf("reachability: unreachable object type %d", ty) } diff --git a/refstore/chain/chain.go b/refstore/chain/chain.go index 633bac25..9e04aeec 100644 --- a/refstore/chain/chain.go +++ b/refstore/chain/chain.go @@ -28,15 +28,19 @@ func (chain *Chain) Resolve(name string) (ref.Ref, error) { if backend == nil { continue } + resolved, err := backend.Resolve(name) if err == nil { return resolved, nil } + if errors.Is(err, refstore.ErrReferenceNotFound) { continue } + return nil, fmt.Errorf("refstore: backend %d resolve: %w", i, err) } + return nil, refstore.ErrReferenceNotFound } @@ -46,11 +50,13 @@ func (chain *Chain) Resolve(name string) (ref.Ref, error) { // references to cross backends in the chain. func (chain *Chain) ResolveFully(name string) (ref.Detached, error) { cur := name + seen := map[string]struct{}{} for { if _, ok := seen[cur]; ok { return ref.Detached{}, fmt.Errorf("refstore: symbolic reference cycle at %q", cur) } + seen[cur] = struct{}{} resolved, err := chain.Resolve(cur) @@ -65,6 +71,7 @@ func (chain *Chain) ResolveFully(name string) (ref.Detached, error) { if resolved.Target == "" { return ref.Detached{}, fmt.Errorf("refstore: symbolic reference %q has empty target", resolved.Name()) } + cur = resolved.Target default: return ref.Detached{}, fmt.Errorf("refstore: unsupported reference type %T", resolved) @@ -77,25 +84,31 @@ func (chain *Chain) ResolveFully(name string) (ref.Detached, error) { // First-seen wins, so earlier backends have precedence. func (chain *Chain) List(pattern string) ([]ref.Ref, error) { var refs []ref.Ref + seen := map[string]struct{}{} for i, backend := range chain.backends { if backend == nil { continue } + listed, err := backend.List(pattern) if err != nil { return nil, fmt.Errorf("refstore: backend %d list: %w", i, err) } + for _, entry := range listed { if entry == nil { continue } + name := entry.Name() if _, ok := seen[name]; ok { continue } + seen[name] = struct{}{} + refs = append(refs, entry) } } @@ -109,34 +122,44 @@ func (chain *Chain) Shorten(name string) (string, error) { if err != nil { return "", err } + names := make([]string, 0, len(refs)) found := false + for _, entry := range refs { if entry == nil { continue } + full := entry.Name() + names = append(names, full) if full == name { found = true } } + if !found { return "", refstore.ErrReferenceNotFound } + return refstore.ShortenName(name, names), nil } // Close closes all backends and joins close errors. func (chain *Chain) Close() error { var errs []error + for _, backend := range chain.backends { if backend == nil { continue } - if err := backend.Close(); err != nil { + + err := backend.Close() + if err != nil { errs = append(errs, err) } } + return errors.Join(errs...) } diff --git a/refstore/loose/list.go b/refstore/loose/list.go index d28016da..1fa0adee 100644 --- a/refstore/loose/list.go +++ b/refstore/loose/list.go @@ -17,7 +17,8 @@ import ( func (store *Store) List(pattern string) ([]ref.Ref, error) { matchAll := pattern == "" if !matchAll { - if _, err := path.Match(pattern, "HEAD"); err != nil { + _, err := path.Match(pattern, "HEAD") + if err != nil { return nil, err } } @@ -26,6 +27,7 @@ func (store *Store) List(pattern string) ([]ref.Ref, error) { if err != nil { return nil, err } + slices.Sort(names) refs := make([]ref.Ref, 0, len(names)) @@ -35,19 +37,24 @@ func (store *Store) List(pattern string) ([]ref.Ref, error) { if err != nil { return nil, err } + if !matched { continue } } + resolved, err := store.resolveOne(name) if err != nil { if errors.Is(err, refstore.ErrReferenceNotFound) { continue } + return nil, err } + refs = append(refs, resolved) } + return refs, nil } @@ -55,42 +62,53 @@ func (store *Store) List(pattern string) ([]ref.Ref, error) { func (store *Store) collectLooseRefNames() ([]string, error) { names := make([]string, 0, 16) - if _, err := store.root.Stat("HEAD"); err == nil { + _, err := store.root.Stat("HEAD") + if err == nil { names = append(names, "HEAD") } else if !errors.Is(err, os.ErrNotExist) { return nil, err } var walk func(string) error + walk = func(dir string) error { file, err := store.root.Open(dir) if err != nil { if errors.Is(err, os.ErrNotExist) { return nil } + return err } + defer func() { _ = file.Close() }() entries, err := file.ReadDir(-1) if err != nil { return err } + for _, entry := range entries { name := path.Join(dir, entry.Name()) if entry.IsDir() { - if err := walk(name); err != nil { + err := walk(name) + if err != nil { return err } + continue } + names = append(names, name) } + return nil } - if err := walk("refs"); err != nil { + err = walk("refs") + if err != nil { return nil, err } + return names, nil } diff --git a/refstore/loose/loose_test.go b/refstore/loose/loose_test.go index 8c9d6f98..7b295bbb 100644 --- a/refstore/loose/loose_test.go +++ b/refstore/loose/loose_test.go @@ -16,16 +16,19 @@ import ( func openLooseStore(t *testing.T, repoPath string, algo objectid.Algorithm) *loose.Store { t.Helper() + root, err := os.OpenRoot(repoPath) if err != nil { t.Fatalf("OpenRoot(%q): %v", repoPath, err) } + t.Cleanup(func() { _ = root.Close() }) store, err := loose.New(root, algo) if err != nil { t.Fatalf("loose.New: %v", err) } + return store } @@ -43,10 +46,12 @@ func TestLooseResolveAndResolveFully(t *testing.T) { if err != nil { t.Fatalf("Resolve(HEAD): %v", err) } + headSym, ok := resolvedHead.(ref.Symbolic) if !ok { t.Fatalf("Resolve(HEAD) type = %T, want ref.Symbolic", resolvedHead) } + if headSym.Target != "refs/heads/main" { t.Fatalf("Resolve(HEAD) target = %q, want %q", headSym.Target, "refs/heads/main") } @@ -55,10 +60,12 @@ func TestLooseResolveAndResolveFully(t *testing.T) { if err != nil { t.Fatalf("Resolve(refs/heads/main): %v", err) } + mainDet, ok := resolvedMain.(ref.Detached) if !ok { t.Fatalf("Resolve(main) type = %T, want ref.Detached", resolvedMain) } + if mainDet.ID != commitID { t.Fatalf("Resolve(main) id = %s, want %s", mainDet.ID, commitID) } @@ -67,11 +74,13 @@ func TestLooseResolveAndResolveFully(t *testing.T) { if err != nil { t.Fatalf("ResolveFully(HEAD): %v", err) } + if fullHead.ID != commitID { t.Fatalf("ResolveFully(HEAD) id = %s, want %s", fullHead.ID, commitID) } - if _, err := store.Resolve("refs/heads/does-not-exist"); !errors.Is(err, refstore.ErrReferenceNotFound) { + _, err = store.Resolve("refs/heads/does-not-exist") + if !errors.Is(err, refstore.ErrReferenceNotFound) { t.Fatalf("Resolve(not-found) error = %v", err) } }) @@ -85,7 +94,9 @@ func TestLooseResolveFullyCycle(t *testing.T) { testRepo.SymbolicRef(t, "refs/heads/b", "refs/heads/a") store := openLooseStore(t, testRepo.Dir(), algo) - if _, err := store.ResolveFully("refs/heads/a"); err == nil { + + _, err := store.ResolveFully("refs/heads/a") + if err == nil { t.Fatalf("ResolveFully(cycle) expected error") } }) @@ -107,11 +118,14 @@ func TestLooseListPattern(t *testing.T) { if err != nil { t.Fatalf("List(\"\"): %v", err) } + allNames := make([]string, 0, len(allRefs)) for _, entry := range allRefs { allNames = append(allNames, entry.Name()) } + slices.Sort(allNames) + wantAll := []string{"HEAD", "refs/heads/feature", "refs/heads/main", "refs/tags/v1.0.0"} if !slices.Equal(allNames, wantAll) { t.Fatalf("List(\"\") names = %v, want %v", allNames, wantAll) @@ -121,11 +135,14 @@ func TestLooseListPattern(t *testing.T) { if err != nil { t.Fatalf("List(refs/heads/*): %v", err) } + headNames := make([]string, 0, len(headRefs)) for _, entry := range headRefs { headNames = append(headNames, entry.Name()) } + slices.Sort(headNames) + wantHeads := []string{"refs/heads/feature", "refs/heads/main"} if !slices.Equal(headNames, wantHeads) { t.Fatalf("List(refs/heads/*) names = %v, want %v", headNames, wantHeads) @@ -182,13 +199,17 @@ func TestLooseListPatternMatrix(t *testing.T) { if err != nil { t.Fatalf("List(%q): %v", tt.pattern, err) } + gotNames := make([]string, 0, len(got)) for _, entry := range got { gotNames = append(gotNames, entry.Name()) } + slices.Sort(gotNames) + wantNames := append([]string(nil), tt.want...) slices.Sort(wantNames) + if !slices.Equal(gotNames, wantNames) { t.Fatalf("List(%q) names = %v, want %v", tt.pattern, gotNames, wantNames) } @@ -201,16 +222,23 @@ func TestLooseMalformedDetachedRef(t *testing.T) { t.Parallel() testgit.ForEachAlgorithm(t, func(t *testing.T, algo objectid.Algorithm) { //nolint:thelper testRepo := testgit.NewRepo(t, testgit.RepoOptions{ObjectFormat: algo, Bare: true}) + refPath := filepath.Join(testRepo.Dir(), "refs", "heads", "bad") - if err := os.MkdirAll(filepath.Dir(refPath), 0o755); err != nil { + + err := os.MkdirAll(filepath.Dir(refPath), 0o755) + if err != nil { t.Fatalf("MkdirAll: %v", err) } - if err := os.WriteFile(refPath, []byte("not-a-hash\n"), 0o644); err != nil { + + err = os.WriteFile(refPath, []byte("not-a-hash\n"), 0o644) + if err != nil { t.Fatalf("WriteFile: %v", err) } store := openLooseStore(t, testRepo.Dir(), algo) - if _, err := store.Resolve("refs/heads/bad"); err == nil { + + _, err = store.Resolve("refs/heads/bad") + if err == nil { t.Fatalf("Resolve(malformed) expected error") } }) @@ -231,6 +259,7 @@ func TestLooseShorten(t *testing.T) { if err != nil { t.Fatalf("Shorten(head): %v", err) } + if shortHead != "heads/main" { t.Fatalf("Shorten(refs/heads/main) = %q, want %q", shortHead, "heads/main") } @@ -239,11 +268,13 @@ func TestLooseShorten(t *testing.T) { if err != nil { t.Fatalf("Shorten(remote): %v", err) } + if shortRemote != "origin/main" { t.Fatalf("Shorten(remote) = %q, want %q", shortRemote, "origin/main") } - if _, err := store.Shorten("refs/heads/does-not-exist"); !errors.Is(err, refstore.ErrReferenceNotFound) { + _, err = store.Shorten("refs/heads/does-not-exist") + if !errors.Is(err, refstore.ErrReferenceNotFound) { t.Fatalf("Shorten(not-found) error = %v", err) } }) diff --git a/refstore/loose/resolve.go b/refstore/loose/resolve.go index f54ab5a4..076c4098 100644 --- a/refstore/loose/resolve.go +++ b/refstore/loose/resolve.go @@ -16,10 +16,12 @@ func (store *Store) Resolve(name string) (ref.Ref, error) { if name == "" { return nil, refstore.ErrReferenceNotFound } + resolved, err := store.resolveOne(name) if err != nil { return nil, err } + return resolved, nil } @@ -30,17 +32,20 @@ func (store *Store) ResolveFully(name string) (ref.Detached, error) { } cur := name + seen := make(map[string]struct{}) for { if _, ok := seen[cur]; ok { return ref.Detached{}, fmt.Errorf("refstore/loose: symbolic reference cycle at %q", cur) } + seen[cur] = struct{}{} resolved, err := store.resolveOne(cur) if err != nil { return ref.Detached{}, err } + switch resolved := resolved.(type) { case ref.Detached: return resolved, nil @@ -49,6 +54,7 @@ func (store *Store) ResolveFully(name string) (ref.Detached, error) { if target == "" { return ref.Detached{}, fmt.Errorf("refstore/loose: symbolic reference %q has empty target", resolved.Name()) } + cur = target default: return ref.Detached{}, fmt.Errorf("refstore/loose: unsupported reference type %T", resolved) @@ -63,23 +69,28 @@ func (store *Store) resolveOne(name string) (ref.Ref, error) { if errors.Is(err, os.ErrNotExist) { return nil, refstore.ErrReferenceNotFound } + return nil, err } + line := strings.TrimSpace(string(data)) if strings.HasPrefix(line, "ref: ") { target := strings.TrimSpace(line[len("ref: "):]) if target == "" { return nil, fmt.Errorf("refstore/loose: symbolic reference %q has empty target", name) } + return ref.Symbolic{ RefName: name, Target: target, }, nil } + id, err := objectid.ParseHex(store.algo, line) if err != nil { return nil, fmt.Errorf("refstore/loose: invalid detached reference %q: %w", name, err) } + return ref.Detached{ RefName: name, ID: id, diff --git a/refstore/loose/shorten.go b/refstore/loose/shorten.go index 17a60def..e863d783 100644 --- a/refstore/loose/shorten.go +++ b/refstore/loose/shorten.go @@ -10,20 +10,26 @@ func (store *Store) Shorten(name string) (string, error) { if err != nil { return "", err } + names := make([]string, 0, len(refs)) found := false + for _, entry := range refs { if entry == nil { continue } + full := entry.Name() + names = append(names, full) if full == name { found = true } } + if !found { return "", refstore.ErrReferenceNotFound } + return refstore.ShortenName(name, names), nil } diff --git a/refstore/loose/store.go b/refstore/loose/store.go index e4dc3a34..ec814188 100644 --- a/refstore/loose/store.go +++ b/refstore/loose/store.go @@ -25,6 +25,7 @@ func New(root *os.Root, algo objectid.Algorithm) (*Store, error) { if algo.Size() == 0 { return nil, objectid.ErrInvalidAlgorithm } + return &Store{ root: root, algo: algo, diff --git a/refstore/packed/packed_test.go b/refstore/packed/packed_test.go index dffed2a8..0ddceabf 100644 --- a/refstore/packed/packed_test.go +++ b/refstore/packed/packed_test.go @@ -16,30 +16,39 @@ import ( func openPackedRefStoreFromRepo(t *testing.T, repoPath string, algo objectid.Algorithm) *packed.Store { t.Helper() + root, err := os.OpenRoot(repoPath) if err != nil { t.Fatalf("OpenRoot(repo): %v", err) } + defer func() { _ = root.Close() }() store, err := packed.New(root, algo) if err != nil { t.Fatalf("packed.New: %v", err) } + return store } func openPackedRefStoreFromContent(t *testing.T, content string, algo objectid.Algorithm) (*packed.Store, error) { t.Helper() + dir := t.TempDir() - if err := os.WriteFile(dir+"/packed-refs", []byte(content), 0o644); err != nil { + + err := os.WriteFile(dir+"/packed-refs", []byte(content), 0o644) + if err != nil { t.Fatalf("WriteFile(packed-refs): %v", err) } + root, err := os.OpenRoot(dir) if err != nil { t.Fatalf("OpenRoot(temp): %v", err) } + defer func() { _ = root.Close() }() + return packed.New(root, algo) } @@ -58,10 +67,12 @@ func TestPackedResolveAndPeeled(t *testing.T) { if err != nil { t.Fatalf("Resolve(main): %v", err) } + mainDet, ok := resolvedMain.(ref.Detached) if !ok { t.Fatalf("Resolve(main) type = %T, want ref.Detached", resolvedMain) } + if mainDet.ID != commitID { t.Fatalf("Resolve(main) id = %s, want %s", mainDet.ID, commitID) } @@ -70,16 +81,20 @@ func TestPackedResolveAndPeeled(t *testing.T) { if err != nil { t.Fatalf("Resolve(tag): %v", err) } + tagDet, ok := resolvedTag.(ref.Detached) if !ok { t.Fatalf("Resolve(tag) type = %T, want ref.Detached", resolvedTag) } + if tagDet.ID != tagID { t.Fatalf("Resolve(tag) id = %s, want %s", tagDet.ID, tagID) } + if tagDet.Peeled == nil { t.Fatalf("Resolve(tag) peeled = nil, want commit") } + if *tagDet.Peeled != commitID { t.Fatalf("Resolve(tag) peeled = %s, want %s", *tagDet.Peeled, commitID) } @@ -88,11 +103,13 @@ func TestPackedResolveAndPeeled(t *testing.T) { if err != nil { t.Fatalf("ResolveFully(tag): %v", err) } + if fullTag.ID != tagDet.ID { t.Fatalf("ResolveFully(tag) id = %s, want %s", fullTag.ID, tagDet.ID) } - if _, err := store.Resolve("refs/heads/does-not-exist"); !errors.Is(err, refstore.ErrReferenceNotFound) { + _, err = store.Resolve("refs/heads/does-not-exist") + if !errors.Is(err, refstore.ErrReferenceNotFound) { t.Fatalf("Resolve(not-found) error = %v", err) } }) @@ -114,11 +131,14 @@ func TestPackedListAndShorten(t *testing.T) { if err != nil { t.Fatalf("List(all): %v", err) } + allNames := make([]string, 0, len(all)) for _, entry := range all { allNames = append(allNames, entry.Name()) } + slices.Sort(allNames) + wantAll := []string{"refs/heads/main", "refs/remotes/origin/main", "refs/tags/main"} if !slices.Equal(allNames, wantAll) { t.Fatalf("List(all) names = %v, want %v", allNames, wantAll) @@ -128,6 +148,7 @@ func TestPackedListAndShorten(t *testing.T) { if err != nil { t.Fatalf("List(pattern): %v", err) } + if len(filtered) != 1 || filtered[0].Name() != "refs/heads/main" { t.Fatalf("List(refs/heads/*) = %v, want refs/heads/main only", filtered) } @@ -136,11 +157,13 @@ func TestPackedListAndShorten(t *testing.T) { if err != nil { t.Fatalf("Shorten(main): %v", err) } + if short != "heads/main" { t.Fatalf("Shorten(main) = %q, want %q", short, "heads/main") } - if _, err := store.Shorten("refs/heads/does-not-exist"); !errors.Is(err, refstore.ErrReferenceNotFound) { + _, err = store.Shorten("refs/heads/does-not-exist") + if !errors.Is(err, refstore.ErrReferenceNotFound) { t.Fatalf("Shorten(not-found) error = %v", err) } }) @@ -195,10 +218,13 @@ func TestPackedListPatternMatrix(t *testing.T) { if err != nil { t.Fatalf("List(%q): %v", tt.pattern, err) } + gotNames := refNames(got) slices.Sort(gotNames) + wantNames := append([]string(nil), tt.want...) slices.Sort(wantNames) + if !slices.Equal(gotNames, wantNames) { t.Fatalf("List(%q) names = %v, want %v", tt.pattern, gotNames, wantNames) } @@ -231,7 +257,8 @@ func TestPackedParseErrors(t *testing.T) { for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { - if _, err := openPackedRefStoreFromContent(t, tt.data, algo); err == nil { + _, err := openPackedRefStoreFromContent(t, tt.data, algo) + if err == nil { t.Fatalf("packed.New expected parse error") } }) @@ -242,16 +269,21 @@ func TestPackedParseErrors(t *testing.T) { func TestPackedNewValidation(t *testing.T) { t.Parallel() dir := t.TempDir() + root, err := os.OpenRoot(dir) if err != nil { t.Fatalf("OpenRoot(temp): %v", err) } + defer func() { _ = root.Close() }() - if _, err := packed.New(root, objectid.AlgorithmUnknown); !errors.Is(err, objectid.ErrInvalidAlgorithm) { + _, err = packed.New(root, objectid.AlgorithmUnknown) + if !errors.Is(err, objectid.ErrInvalidAlgorithm) { t.Fatalf("packed.New invalid algorithm error = %v", err) } - if _, err := packed.New(root, objectid.AlgorithmSHA256); !errors.Is(err, os.ErrNotExist) { + + _, err = packed.New(root, objectid.AlgorithmSHA256) + if !errors.Is(err, os.ErrNotExist) { t.Fatalf("packed.New missing packed-refs error = %v", err) } } @@ -261,6 +293,7 @@ func refNames(refs []ref.Ref) []string { for _, entry := range refs { names = append(names, entry.Name()) } + return names } diff --git a/refstore/packed/parse.go b/refstore/packed/parse.go index 6fe88061..4846d258 100644 --- a/refstore/packed/parse.go +++ b/refstore/packed/parse.go @@ -24,24 +24,30 @@ func parsePackedRefs(r io.Reader, algo objectid.Algorithm) (map[string]ref.Detac if err != nil && err != io.EOF { return nil, nil, err } + if line == "" && err == io.EOF { break } + lineNum++ line = strings.TrimSuffix(line, "\n") line = strings.TrimSuffix(line, "\r") + line = strings.TrimSpace(line) if line == "" { if err == io.EOF { break } + continue } + if strings.HasPrefix(line, "#") { if err == io.EOF { break } + continue } @@ -49,19 +55,24 @@ func parsePackedRefs(r io.Reader, algo objectid.Algorithm) (map[string]ref.Detac if prev < 0 { return nil, nil, fmt.Errorf("refstore/packed: line %d: peeled line without preceding ref", lineNum) } + peeledHex := strings.TrimSpace(strings.TrimPrefix(line, "^")) + peeled, parseErr := objectid.ParseHex(algo, peeledHex) if parseErr != nil { return nil, nil, fmt.Errorf("refstore/packed: line %d: invalid peeled oid: %w", lineNum, parseErr) } + peeledCopy := peeled cur := ordered[prev] cur.Peeled = &peeledCopy ordered[prev] = cur byName[cur.Name()] = cur + if err == io.EOF { break } + continue } @@ -79,6 +90,7 @@ func parsePackedRefs(r io.Reader, algo objectid.Algorithm) (map[string]ref.Detac if name == "" { return nil, nil, fmt.Errorf("refstore/packed: line %d: empty ref name", lineNum) } + if _, exists := byName[name]; exists { return nil, nil, fmt.Errorf("refstore/packed: line %d: duplicate ref %q", lineNum, name) } diff --git a/refstore/packed/store.go b/refstore/packed/store.go index 7705dacb..5ab9d602 100644 --- a/refstore/packed/store.go +++ b/refstore/packed/store.go @@ -25,16 +25,19 @@ func New(root *os.Root, algo objectid.Algorithm) (*Store, error) { if algo.Size() == 0 { return nil, objectid.ErrInvalidAlgorithm } + packedRefs, err := root.Open("packed-refs") if err != nil { return nil, fmt.Errorf("refstore/packed: open packed-refs: %w", err) } + defer func() { _ = packedRefs.Close() }() byName, ordered, err := parsePackedRefs(packedRefs, algo) if err != nil { return nil, err } + return &Store{ byName: byName, ordered: ordered, @@ -47,6 +50,7 @@ func (store *Store) Resolve(name string) (ref.Ref, error) { if !ok { return nil, refstore.ErrReferenceNotFound } + return detached, nil } @@ -58,6 +62,7 @@ func (store *Store) ResolveFully(name string) (ref.Detached, error) { if !ok { return ref.Detached{}, refstore.ErrReferenceNotFound } + return detached, nil } @@ -68,7 +73,8 @@ func (store *Store) ResolveFully(name string) (ref.Detached, error) { func (store *Store) List(pattern string) ([]ref.Ref, error) { matchAll := pattern == "" if !matchAll { - if _, err := path.Match(pattern, "refs/heads/main"); err != nil { + _, err := path.Match(pattern, "refs/heads/main") + if err != nil { return nil, err } } @@ -80,12 +86,15 @@ func (store *Store) List(pattern string) ([]ref.Ref, error) { if err != nil { return nil, err } + if !matched { continue } } + refs = append(refs, entry) } + return refs, nil } @@ -100,6 +109,7 @@ func (store *Store) Shorten(name string) (string, error) { for _, entry := range store.ordered { names = append(names, entry.Name()) } + return refstore.ShortenName(name, names), nil } diff --git a/refstore/reftable/lookup.go b/refstore/reftable/lookup.go index 8862f7e6..53483bbf 100644 --- a/refstore/reftable/lookup.go +++ b/refstore/reftable/lookup.go @@ -16,13 +16,16 @@ func (table *tableFile) resolveRecord(name string) (recordValue, bool, error) { if err != nil { return recordValue{}, false, err } + pos, ok, err := table.resolveRefBlockPosFromIndex(name, indexPos) if err != nil { return recordValue{}, false, err } + if !ok { return recordValue{}, false, nil } + return table.lookupInRefBlock(name, pos) } @@ -32,28 +35,36 @@ func (table *tableFile) resolveRecord(name string) (recordValue, bool, error) { for pos < table.refEnd && table.data[pos] == 0 { pos++ } + if pos >= table.refEnd { break } + if table.data[pos] != blockTypeRef { return recordValue{}, false, fmt.Errorf("refstore/reftable: table %q: unexpected block type %q in ref section", table.name, table.data[pos]) } + block, blockEnd, err := table.readBlockAt(pos) if err != nil { return recordValue{}, false, err } + found, done, rec, err := lookupRecordInRefBlock(table, block, name) if err != nil { return recordValue{}, false, err } + if found { return rec, true, nil } + if done { return recordValue{}, false, nil } + pos = table.nextBlockPos(blockEnd) } + return recordValue{}, false, nil } @@ -63,16 +74,20 @@ func (table *tableFile) resolveRefBlockPosFromIndex(name string, indexPos int) ( if err != nil { return 0, false, err } + if block.blockType != blockTypeIndex { return 0, false, fmt.Errorf("refstore/reftable: table %q: ref index root is not index block", table.name) } + childPos, ok, err := lookupChildPosInIndexBlock(block, name) if err != nil { return 0, false, err } + if !ok { return 0, false, nil } + if childPos < 0 || childPos >= len(table.data) { return 0, false, fmt.Errorf("refstore/reftable: table %q: index child position out of range", table.name) } @@ -94,13 +109,16 @@ func (table *tableFile) lookupInRefBlock(name string, pos int) (recordValue, boo if err != nil { return recordValue{}, false, err } + if block.blockType != blockTypeRef { return recordValue{}, false, fmt.Errorf("refstore/reftable: table %q: expected ref block at %d", table.name, pos) } + found, _, rec, err := lookupRecordInRefBlock(table, block, name) if err != nil { return recordValue{}, false, err } + return rec, found, nil } @@ -108,13 +126,16 @@ func (table *tableFile) lookupInRefBlock(name string, pos int) (recordValue, boo func (table *tableFile) forEachRecord(fn func(name string, rec recordValue) error) error { pos := table.headerLen prevLast := "" + for pos < table.refEnd { for pos < table.refEnd && table.data[pos] == 0 { pos++ } + if pos >= table.refEnd { break } + if table.data[pos] != blockTypeRef { return fmt.Errorf("refstore/reftable: table %q: unexpected block type %q in ref section", table.name, table.data[pos]) } @@ -123,25 +144,33 @@ func (table *tableFile) forEachRecord(fn func(name string, rec recordValue) erro if err != nil { return err } + var first, last string + err = forEachRecordInRefBlock(table, block, func(name string, rec recordValue) error { if first == "" { first = name } + last = name + return fn(name, rec) }) if err != nil { return err } + if prevLast != "" && first != "" && strings.Compare(first, prevLast) <= 0 { return fmt.Errorf("refstore/reftable: table %q: ref blocks are not strictly ordered", table.name) } + if last != "" { prevLast = last } + pos = table.nextBlockPos(blockEnd) } + return nil } @@ -159,22 +188,29 @@ func (table *tableFile) readBlockAt(pos int) (blockView, int, error) { if pos < 0 || pos+4 > len(table.data) { return blockView{}, 0, fmt.Errorf("refstore/reftable: table %q: block header out of range", table.name) } + blockLen := int(readUint24(table.data[pos+1 : pos+4])) + effectiveLen := blockLen if pos == table.headerLen { if blockLen < table.headerLen { return blockView{}, 0, fmt.Errorf("refstore/reftable: table %q: invalid first block length", table.name) } + effectiveLen = blockLen - table.headerLen } + if effectiveLen < 4 { return blockView{}, 0, fmt.Errorf("refstore/reftable: table %q: invalid block length", table.name) } + end := pos + effectiveLen if end > len(table.data) { return blockView{}, 0, fmt.Errorf("refstore/reftable: table %q: block out of range", table.name) } + view := blockView{blockType: table.data[pos], start: pos, end: end, first: pos == table.headerLen, payload: table.data[pos:end]} + return view, end, nil } @@ -183,6 +219,7 @@ func (table *tableFile) nextBlockPos(blockEnd int) int { if table.blockSize > 0 { return alignUp(blockEnd, table.blockSize) } + return blockEnd } @@ -192,35 +229,45 @@ func lookupChildPosInIndexBlock(block blockView, key string) (int, bool, error) if err != nil { return 0, false, err } - if err := validateRestarts(block, restarts, off, recordsEnd, true); err != nil { + + err = validateRestarts(block, restarts, off, recordsEnd, true) + if err != nil { return 0, false, err } + prev := "" for off < recordsEnd { name, v, nextOff, err := parseKeyedRecord(block.payload, off, recordsEnd, prev) if err != nil { return 0, false, err } + if (v & 0x7) != 0 { return 0, false, fmt.Errorf("index value_type must be 0") } + childPos, nextOff, err := readVarint(block.payload, nextOff, recordsEnd) if err != nil { return 0, false, err } + if strings.Compare(key, name) <= 0 { childPosInt, err := intconv.Uint64ToInt(childPos) if err != nil { return 0, false, fmt.Errorf("index child position conversion: %w", err) } + return childPosInt, true, nil } + prev = name off = nextOff } + if off != recordsEnd { return 0, false, fmt.Errorf("malformed index block") } + return 0, false, nil } @@ -230,37 +277,48 @@ func lookupRecordInRefBlock(table *tableFile, block blockView, key string) (foun if err != nil { return false, false, recordValue{}, err } - if err := validateRestarts(block, restarts, off, recordsEnd, true); err != nil { + + err = validateRestarts(block, restarts, off, recordsEnd, true) + if err != nil { return false, false, recordValue{}, err } + prev := "" for off < recordsEnd { name, v, nextOff, err := parseKeyedRecord(block.payload, off, recordsEnd, prev) if err != nil { return false, false, recordValue{}, err } + typeBits := byte(v & 0x7) + _, nextOff, err = readVarint(block.payload, nextOff, recordsEnd) if err != nil { return false, false, recordValue{}, err } + recVal, nextOff, err := parseRefValue(block.payload, nextOff, recordsEnd, table.algo, typeBits) if err != nil { return false, false, recordValue{}, err } + cmp := strings.Compare(name, key) if cmp == 0 { return true, true, recVal, nil } + if cmp > 0 { return false, true, recordValue{}, nil } + prev = name off = nextOff } + if off != recordsEnd { return false, false, recordValue{}, fmt.Errorf("malformed ref block") } + return false, false, recordValue{}, nil } @@ -270,33 +328,44 @@ func forEachRecordInRefBlock(table *tableFile, block blockView, fn func(name str if err != nil { return err } - if err := validateRestarts(block, restarts, off, recordsEnd, true); err != nil { + + err = validateRestarts(block, restarts, off, recordsEnd, true) + if err != nil { return err } + prev := "" for off < recordsEnd { name, v, nextOff, err := parseKeyedRecord(block.payload, off, recordsEnd, prev) if err != nil { return err } + typeBits := byte(v & 0x7) + _, nextOff, err = readVarint(block.payload, nextOff, recordsEnd) if err != nil { return err } + recVal, nextOff, err := parseRefValue(block.payload, nextOff, recordsEnd, table.algo, typeBits) if err != nil { return err } - if err := fn(name, recVal); err != nil { + + err = fn(name, recVal) + if err != nil { return err } + prev = name off = nextOff } + if off != recordsEnd { return fmt.Errorf("malformed ref block") } + return nil } @@ -305,51 +374,63 @@ func parseBlockLayout(block blockView) (recordsStart, recordsEnd int, restarts [ if len(block.payload) < 6 { return 0, 0, nil, fmt.Errorf("short block") } + restartCount := int(binary.BigEndian.Uint16(block.payload[len(block.payload)-2:])) if restartCount <= 0 { return 0, 0, nil, fmt.Errorf("invalid restart count") } + restarts = make([]int, restartCount) restartBytes := restartCount * 3 + restartsStart := len(block.payload) - 2 - restartBytes if restartsStart < 4 { return 0, 0, nil, fmt.Errorf("invalid restart table") } + for i := range restartCount { off := restartsStart + i*3 rel := int(readUint24(block.payload[off : off+3])) + base := block.start if block.first { // In the first block, restart offsets are relative to file start. base = 0 } + abs := base + rel restarts[i] = abs - block.start } + return 4, restartsStart, restarts, nil } // validateRestarts validates restart monotonicity, bounds and record-prefix invariants. func validateRestarts(block blockView, restarts []int, recordsStart, recordsEnd int, requirePrefixZero bool) error { prev := -1 + for _, off := range restarts { if off < recordsStart || off >= recordsEnd { return fmt.Errorf("restart offset out of range") } + if off <= prev { return fmt.Errorf("restart offsets not strictly increasing") } + prev = off if requirePrefixZero { prefix, _, err := readVarint(block.payload, off, recordsEnd) if err != nil { return err } + if prefix != 0 { return fmt.Errorf("restart record prefix length must be zero") } } } + return nil } @@ -359,26 +440,33 @@ func parseKeyedRecord(buf []byte, off, end int, prev string) (name string, rawTy if err != nil { return "", 0, 0, err } + suffixAndType, next, err := readVarint(buf, next, end) if err != nil { return "", 0, 0, err } + suffixLen, err := intconv.Uint64ToInt(suffixAndType >> 3) if err != nil || suffixLen < 0 || next+suffixLen > end { return "", 0, 0, fmt.Errorf("invalid suffix length") } + prefixLenInt, err := intconv.Uint64ToInt(prefixLen) if err != nil { return "", 0, 0, fmt.Errorf("invalid prefix length") } + if prefixLenInt > len(prev) { return "", 0, 0, fmt.Errorf("invalid prefix length") } + name = prev[:prefixLenInt] + string(buf[next:next+suffixLen]) next += suffixLen + if prev != "" && strings.Compare(name, prev) <= 0 { return "", 0, 0, fmt.Errorf("keys not strictly increasing") } + return name, suffixAndType, next, nil } @@ -392,40 +480,50 @@ func parseRefValue(buf []byte, off, end int, algo objectid.Algorithm, valueType if err != nil { return recordValue{}, 0, err } + return recordValue{detachedID: id, hasDetached: true}, next, nil case 0x2: id, next, err := readObjectID(buf, off, end, algo) if err != nil { return recordValue{}, 0, err } + peeled, next, err := readObjectID(buf, next, end, algo) if err != nil { return recordValue{}, 0, err } + peeledCopy := peeled + return recordValue{detachedID: id, hasDetached: true, peeled: &peeledCopy}, next, nil case 0x3: targetLen, next, err := readVarint(buf, off, end) if err != nil { return recordValue{}, 0, err } + remaining := end - next if remaining < 0 { return recordValue{}, 0, fmt.Errorf("invalid symref target length") } + remainingU64, err := intconv.IntToUint64(remaining) if err != nil { return recordValue{}, 0, fmt.Errorf("invalid symref target length") } + if targetLen > remainingU64 { return recordValue{}, 0, fmt.Errorf("invalid symref target length") } + targetLenInt, err := intconv.Uint64ToInt(targetLen) if err != nil { return recordValue{}, 0, fmt.Errorf("invalid symref target length") } + target := string(buf[next : next+targetLenInt]) next += targetLenInt + return recordValue{symbolicTarget: target}, next, nil default: return recordValue{}, 0, fmt.Errorf("unsupported ref value type %d", valueType) @@ -438,9 +536,11 @@ func readObjectID(buf []byte, off, end int, algo objectid.Algorithm) (objectid.O if off < 0 || sz < 0 || off+sz > end { return objectid.ObjectID{}, 0, fmt.Errorf("truncated object id") } + id, err := objectid.FromBytes(algo, buf[off:off+sz]) if err != nil { return objectid.ObjectID{}, 0, err } + return id, off + sz, nil } diff --git a/refstore/reftable/parse_helpers.go b/refstore/reftable/parse_helpers.go index b5da555e..5b5fae24 100644 --- a/refstore/reftable/parse_helpers.go +++ b/refstore/reftable/parse_helpers.go @@ -13,6 +13,7 @@ func alignUp(pos, blockSize int) int { if rem == 0 { return pos } + return pos + (blockSize - rem) } @@ -21,16 +22,20 @@ func readVarint(buf []byte, off, end int) (uint64, int, error) { if off >= end { return 0, 0, fmt.Errorf("unexpected EOF") } + b := buf[off] val := uint64(b & 0x7f) + off++ for b&0x80 != 0 { if off >= end { return 0, 0, fmt.Errorf("unexpected EOF") } + b = buf[off] off++ val = ((val + 1) << 7) | uint64(b&0x7f) } + return val, off, nil } diff --git a/refstore/reftable/reftable_test.go b/refstore/reftable/reftable_test.go index 2a6e0738..26aa7584 100644 --- a/refstore/reftable/reftable_test.go +++ b/refstore/reftable/reftable_test.go @@ -17,6 +17,7 @@ import ( // newBareReftableRepo creates a bare repository that uses reftable ref storage. func newBareReftableRepo(tb testing.TB, algo objectid.Algorithm) *testgit.TestRepo { tb.Helper() + return testgit.NewRepo(tb, testgit.RepoOptions{ ObjectFormat: algo, Bare: true, @@ -27,15 +28,19 @@ func newBareReftableRepo(tb testing.TB, algo objectid.Algorithm) *testgit.TestRe // openStore opens a reftable store against repoDir/reftable. func openStore(tb testing.TB, repoDir string, algo objectid.Algorithm) *reftable.Store { tb.Helper() + root, err := os.OpenRoot(filepath.Join(repoDir, "reftable")) if err != nil { tb.Fatalf("OpenRoot(reftable): %v", err) } + tb.Cleanup(func() { _ = root.Close() }) + store, err := reftable.New(root, algo) if err != nil { tb.Fatalf("reftable.New: %v", err) } + return store } @@ -48,14 +53,17 @@ func TestResolveAndResolveFully(t *testing.T) { repo.SymbolicRef(t, "HEAD", "refs/heads/main") store := openStore(t, repo.Dir(), algo) + head, err := store.Resolve("HEAD") if err != nil { t.Fatalf("Resolve(HEAD): %v", err) } + sym, ok := head.(ref.Symbolic) if !ok { t.Fatalf("Resolve(HEAD) type = %T, want ref.Symbolic", head) } + if sym.Target != "refs/heads/main" { t.Fatalf("Resolve(HEAD) target = %q, want refs/heads/main", sym.Target) } @@ -64,11 +72,13 @@ func TestResolveAndResolveFully(t *testing.T) { if err != nil { t.Fatalf("ResolveFully(HEAD): %v", err) } + if main.ID != id { t.Fatalf("ResolveFully(HEAD) id = %s, want %s", main.ID, id) } - if _, err := store.Resolve("refs/heads/missing"); !errors.Is(err, refstore.ErrReferenceNotFound) { + _, err = store.Resolve("refs/heads/missing") + if !errors.Is(err, refstore.ErrReferenceNotFound) { t.Fatalf("Resolve(missing) error = %v", err) } }) @@ -82,7 +92,9 @@ func TestResolveFullyCycle(t *testing.T) { repo.SymbolicRef(t, "refs/heads/b", "refs/heads/a") store := openStore(t, repo.Dir(), algo) - if _, err := store.ResolveFully("refs/heads/a"); err == nil { + + _, err := store.ResolveFully("refs/heads/a") + if err == nil { t.Fatalf("ResolveFully(cycle) expected error") } }) @@ -99,14 +111,17 @@ func TestListAndShorten(t *testing.T) { repo.UpdateRef(t, "refs/remotes/origin/main", id) store := openStore(t, repo.Dir(), algo) + all, err := store.List("") if err != nil { t.Fatalf("List(all): %v", err) } + names := make([]string, 0, len(all)) for _, entry := range all { names = append(names, entry.Name()) } + want := []string{"HEAD", "refs/heads/feature", "refs/heads/main", "refs/remotes/origin/main", "refs/tags/main"} if !slices.Equal(names, want) { t.Fatalf("List(all) = %v, want %v", names, want) @@ -116,10 +131,12 @@ func TestListAndShorten(t *testing.T) { if err != nil { t.Fatalf("List(heads): %v", err) } + headNames := make([]string, 0, len(heads)) for _, entry := range heads { headNames = append(headNames, entry.Name()) } + wantHeads := []string{"refs/heads/feature", "refs/heads/main"} if !slices.Equal(headNames, wantHeads) { t.Fatalf("List(heads) = %v, want %v", headNames, wantHeads) @@ -129,6 +146,7 @@ func TestListAndShorten(t *testing.T) { if err != nil { t.Fatalf("Shorten(remote): %v", err) } + if short != "origin/main" { t.Fatalf("Shorten(remote) = %q, want origin/main", short) } @@ -146,7 +164,9 @@ func TestTombstoneNewestWins(t *testing.T) { repo.DeleteRef(t, "refs/heads/main") store := openStore(t, repo.Dir(), algo) - if _, err := store.Resolve("refs/heads/main"); !errors.Is(err, refstore.ErrReferenceNotFound) { + + _, err := store.Resolve("refs/heads/main") + if !errors.Is(err, refstore.ErrReferenceNotFound) { t.Fatalf("Resolve(main) after delete error = %v", err) } }) @@ -160,20 +180,25 @@ func TestAnnotatedTagPeeled(t *testing.T) { tagID := repo.TagAnnotated(t, "v1.0.0", commitID, "annotated") store := openStore(t, repo.Dir(), algo) + resolved, err := store.Resolve("refs/tags/v1.0.0") if err != nil { t.Fatalf("Resolve(tag): %v", err) } + detached, ok := resolved.(ref.Detached) if !ok { t.Fatalf("Resolve(tag) type = %T, want ref.Detached", resolved) } + if detached.ID != tagID { t.Fatalf("Resolve(tag) id = %s, want %s", detached.ID, tagID) } + if detached.Peeled == nil { t.Fatalf("Resolve(tag) peeled = nil") } + if *detached.Peeled != commitID { t.Fatalf("Resolve(tag) peeled = %s, want %s", *detached.Peeled, commitID) } diff --git a/refstore/reftable/store.go b/refstore/reftable/store.go index 7c02c157..d0d906fc 100644 --- a/refstore/reftable/store.go +++ b/refstore/reftable/store.go @@ -42,6 +42,7 @@ func New(root *os.Root, algo objectid.Algorithm) (*Store, error) { if algo.Size() == 0 { return nil, objectid.ErrInvalidAlgorithm } + return &Store{root: root, algo: algo}, nil } @@ -50,25 +51,33 @@ func (store *Store) Close() error { store.stateMu.Lock() if store.closed { store.stateMu.Unlock() + return nil } + store.closed = true root := store.root tables := store.tables store.stateMu.Unlock() var closeErr error + for _, table := range tables { if table == nil { continue } - if err := table.close(); err != nil && closeErr == nil { + + err := table.close() + if err != nil && closeErr == nil { closeErr = err } } - if err := root.Close(); err != nil && closeErr == nil { + + err := root.Close() + if err != nil && closeErr == nil { closeErr = err } + return closeErr } @@ -78,23 +87,29 @@ func (store *Store) Resolve(name string) (ref.Ref, error) { if err != nil { return nil, err } + for i := len(tables) - 1; i >= 0; i-- { rec, found, err := tables[i].resolveRecord(name) if err != nil { return nil, err } + if !found { continue } + if rec.deleted { return nil, refstore.ErrReferenceNotFound } + resolved, err := rec.toRef(name) if err != nil { return nil, err } + return resolved, nil } + return nil, refstore.ErrReferenceNotFound } @@ -104,16 +119,21 @@ func (store *Store) Resolve(name string) (ref.Ref, error) { // annotated tag objects. func (store *Store) ResolveFully(name string) (ref.Detached, error) { seen := map[string]struct{}{} + cur := name for { - if _, exists := seen[cur]; exists { + _, exists := seen[cur] + if exists { return ref.Detached{}, errors.New("refstore/reftable: symbolic reference cycle") } + seen[cur] = struct{}{} + resolved, err := store.Resolve(cur) if err != nil { return ref.Detached{}, err } + switch resolved := resolved.(type) { case ref.Detached: return resolved, nil @@ -121,6 +141,7 @@ func (store *Store) ResolveFully(name string) (ref.Detached, error) { if resolved.Target == "" { return ref.Detached{}, errors.New("refstore/reftable: symbolic reference has empty target") } + cur = resolved.Target default: return ref.Detached{}, errors.New("refstore/reftable: unsupported reference type") @@ -137,32 +158,41 @@ func (store *Store) List(pattern string) ([]ref.Ref, error) { if err != nil { return nil, err } + visible := make(map[string]ref.Ref) masked := make(map[string]struct{}) for i := len(tables) - 1; i >= 0; i-- { - if err := tables[i].forEachRecord(func(name string, rec recordValue) error { - if _, done := masked[name]; done { + err := tables[i].forEachRecord(func(name string, rec recordValue) error { + _, done := masked[name] + if done { return nil } + masked[name] = struct{}{} + if rec.deleted { return nil } + resolved, err := rec.toRef(name) if err != nil { return err } + visible[name] = resolved + return nil - }); err != nil { + }) + if err != nil { return nil, err } } matchAll := pattern == "" if !matchAll { - if _, err := pathMatch(pattern, "refs/heads/main"); err != nil { + _, err := pathMatch(pattern, "refs/heads/main") + if err != nil { return nil, err } } @@ -171,6 +201,7 @@ func (store *Store) List(pattern string) ([]ref.Ref, error) { for name := range visible { names = append(names, name) } + sort.Strings(names) out := make([]ref.Ref, 0, len(names)) @@ -180,12 +211,15 @@ func (store *Store) List(pattern string) ([]ref.Ref, error) { if err != nil { return nil, err } + if !ok { continue } } + out = append(out, visible[name]) } + return out, nil } @@ -195,21 +229,27 @@ func (store *Store) Shorten(name string) (string, error) { if err != nil { return "", err } + names := make([]string, 0, len(refs)) found := false + for _, entry := range refs { if entry == nil { continue } + full := entry.Name() + names = append(names, full) if full == name { found = true } } + if !found { return "", refstore.ErrReferenceNotFound } + return refstore.ShortenName(name, names), nil } @@ -225,9 +265,11 @@ func (store *Store) ensureTables() ([]*tableFile, error) { store.stateMu.RLock() defer store.stateMu.RUnlock() + if store.closed { return nil, errors.New("refstore/reftable: store is closed") } + return store.tables, store.loadErr } @@ -238,18 +280,23 @@ func (store *Store) loadTables() ([]*tableFile, error) { if errors.Is(err, os.ErrNotExist) { return nil, nil } + return nil, err } + lines := strings.Split(string(listRaw), "\n") + names := make([]string, 0, len(lines)) for _, line := range lines { line = strings.TrimSuffix(line, "\r") if line == "" { continue } + if strings.Contains(line, "/") { return nil, errors.New("refstore/reftable: invalid table name") } + names = append(names, line) } @@ -260,9 +307,12 @@ func (store *Store) loadTables() ([]*tableFile, error) { for _, opened := range out { _ = opened.close() } + return nil, err } + out = append(out, table) } + return out, nil } diff --git a/refstore/reftable/table.go b/refstore/reftable/table.go index 35982bf9..5c05a633 100644 --- a/refstore/reftable/table.go +++ b/refstore/reftable/table.go @@ -71,49 +71,69 @@ func openTableFile(root *os.Root, name string, algo objectid.Algorithm) (*tableF if err != nil { return nil, err } + info, err := file.Stat() if err != nil { _ = file.Close() + return nil, err } + size := info.Size() if size < 0 || size > int64(int(^uint(0)>>1)) { _ = file.Close() + return nil, fmt.Errorf("refstore/reftable: table %q has unsupported size", name) } + fd, err := intconv.UintptrToInt(file.Fd()) if err != nil { _ = file.Close() + return nil, err } + data, err := syscall.Mmap(fd, 0, int(size), syscall.PROT_READ, syscall.MAP_PRIVATE) if err != nil { _ = file.Close() + return nil, err } + out := &tableFile{name: name, algo: algo, file: file, data: data} - if err := out.parseMeta(); err != nil { + + err = out.parseMeta() + if err != nil { _ = out.close() + return nil, err } + return out, nil } // close unmaps and closes one table file. func (table *tableFile) close() error { var closeErr error + if table.data != nil { - if err := syscall.Munmap(table.data); err != nil && closeErr == nil { + err := syscall.Munmap(table.data) + if err != nil && closeErr == nil { closeErr = err } + table.data = nil } + if table.file != nil { - if err := table.file.Close(); err != nil && closeErr == nil { + err := table.file.Close() + if err != nil && closeErr == nil { closeErr = err } + table.file = nil } + return closeErr } @@ -122,9 +142,11 @@ func (table *tableFile) parseMeta() error { if len(table.data) < 24 { return fmt.Errorf("refstore/reftable: table %q: file too short", table.name) } + if string(table.data[:4]) != reftableMagic { return fmt.Errorf("refstore/reftable: table %q: bad magic", table.name) } + version := table.data[4] switch version { case version1: @@ -137,35 +159,47 @@ func (table *tableFile) parseMeta() error { if len(table.data) < table.headerLen { return fmt.Errorf("refstore/reftable: table %q: truncated header", table.name) } + hashID := binary.BigEndian.Uint32(table.data[24:28]) - if err := validateHashID(hashID, table.algo); err != nil { + + err := validateHashID(hashID, table.algo) + if err != nil { return fmt.Errorf("refstore/reftable: table %q: %w", table.name, err) } default: return fmt.Errorf("refstore/reftable: table %q: unsupported version %d", table.name, version) } + table.blockSize = int(readUint24(table.data[5:8])) footerLen := 68 if version == version2 { footerLen = 72 } + if len(table.data) < footerLen { return fmt.Errorf("refstore/reftable: table %q: missing footer", table.name) } + footerStart := len(table.data) - footerLen + footer := table.data[footerStart:] if string(footer[:4]) != reftableMagic || footer[4] != version { return fmt.Errorf("refstore/reftable: table %q: invalid footer header", table.name) } + wantCRC := binary.BigEndian.Uint32(footer[footerLen-4:]) + haveCRC := crc32.ChecksumIEEE(footer[:footerLen-4]) if wantCRC != haveCRC { return fmt.Errorf("refstore/reftable: table %q: footer crc mismatch", table.name) } + if version == version2 { hashID := binary.BigEndian.Uint32(footer[24:28]) - if err := validateHashID(hashID, table.algo); err != nil { + + err := validateHashID(hashID, table.algo) + if err != nil { return fmt.Errorf("refstore/reftable: table %q: %w", table.name, err) } } @@ -188,34 +222,44 @@ func (table *tableFile) parseMeta() error { if err != nil { return fmt.Errorf("refstore/reftable: table %q: invalid footer offset: %w", table.name, err) } + if table.refIndexPos != 0 && table.refIndexPos < refEnd { refEnd = table.refIndexPos } + if objPos != 0 && objPos < refEnd { refEnd = objPos } + if logPos != 0 && logPos < refEnd { refEnd = logPos } + headerLenU64, err := intconv.IntToUint64(table.headerLen) if err != nil { return fmt.Errorf("refstore/reftable: table %q: invalid header length: %w", table.name, err) } + dataLenU64, err := intconv.IntToUint64(len(table.data)) if err != nil { return fmt.Errorf("refstore/reftable: table %q: invalid data length: %w", table.name, err) } + if refEnd < headerLenU64 || refEnd > dataLenU64 { return fmt.Errorf("refstore/reftable: table %q: invalid ref section", table.name) } + if table.refIndexPos > dataLenU64 { return fmt.Errorf("refstore/reftable: table %q: invalid ref index position", table.name) } + refEndInt, err := intconv.Uint64ToInt(refEnd) if err != nil { return fmt.Errorf("refstore/reftable: table %q: invalid ref section end: %w", table.name, err) } + table.refEnd = refEndInt + return nil } @@ -226,11 +270,13 @@ func validateHashID(hashID uint32, algo objectid.Algorithm) error { if algo != objectid.AlgorithmSHA1 { return errors.New("hash id sha1 mismatch") } + return nil case hashIDSHA256: if algo != objectid.AlgorithmSHA256 { return errors.New("hash id s256 mismatch") } + return nil default: return fmt.Errorf("unknown hash id 0x%08x", hashID) @@ -242,11 +288,14 @@ func (record recordValue) toRef(name string) (ref.Ref, error) { if record.deleted { return nil, errors.New("refstore/reftable: cannot materialize deleted record") } + if record.symbolicTarget != "" { return ref.Symbolic{RefName: name, Target: record.symbolicTarget}, nil } + if !record.hasDetached { return nil, errors.New("refstore/reftable: malformed detached record") } + return ref.Detached{RefName: name, ID: record.detachedID, Peeled: record.peeled}, nil } diff --git a/refstore/shorten.go b/refstore/shorten.go index 26fa82c0..250ab01f 100644 --- a/refstore/shorten.go +++ b/refstore/shorten.go @@ -20,17 +20,22 @@ func (rule shortenRule) match(name string) (string, bool) { if !strings.HasPrefix(name, rule.prefix) { return "", false } + if !strings.HasSuffix(name, rule.suffix) { return "", false } + short := strings.TrimPrefix(name, rule.prefix) + short = strings.TrimSuffix(short, rule.suffix) if short == "" { return "", false } + if rule.prefix+short+rule.suffix != name { return "", false } + return short, true } @@ -47,6 +52,7 @@ func ShortenName(name string, all []string) string { if full == "" { continue } + names[full] = struct{}{} } @@ -55,20 +61,26 @@ func ShortenName(name string, all []string) string { if !ok { continue } + ambiguous := false + for j := range shortenRules { if j == i { continue } + full := shortenRules[j].render(short) if _, found := names[full]; found { ambiguous = true + break } } + if !ambiguous { return short } } + return name } diff --git a/refstore/shorten_test.go b/refstore/shorten_test.go index 53e7e003..a4d91453 100644 --- a/refstore/shorten_test.go +++ b/refstore/shorten_test.go @@ -11,6 +11,7 @@ func TestShortenName(t *testing.T) { t.Run("simple", func(t *testing.T) { t.Parallel() + got := refstore.ShortenName("refs/heads/main", []string{"refs/heads/main"}) if got != "main" { t.Fatalf("ShortenName simple = %q, want %q", got, "main") @@ -19,6 +20,7 @@ func TestShortenName(t *testing.T) { t.Run("ambiguous with tags", func(t *testing.T) { t.Parallel() + got := refstore.ShortenName( "refs/heads/main", []string{ @@ -64,7 +66,9 @@ func TestShortenName(t *testing.T) { t.Run("refs-prefix fallback", func(t *testing.T) { t.Parallel() + name := "refs/notes/review/topic" + got := refstore.ShortenName(name, []string{name}) if got != "notes/review/topic" { t.Fatalf("ShortenName refs-prefix fallback = %q, want %q", got, "notes/review/topic") diff --git a/repository/read_stored.go b/repository/read_stored.go index b26421e8..1e92dc40 100644 --- a/repository/read_stored.go +++ b/repository/read_stored.go @@ -15,6 +15,7 @@ func (repo *Repository) ReadStored(id objectid.ObjectID) (objectstored.StoredObj if err != nil { return nil, err } + switch parsed := parsed.(type) { case *object.Blob: return objectstored.NewStoredBlob(id, parsed), nil @@ -35,10 +36,12 @@ func (repo *Repository) ReadStoredBlob(id objectid.ObjectID) (*objectstored.Stor if err != nil { return nil, err } + blob, ok := stored.(*objectstored.StoredBlob) if !ok { return nil, fmt.Errorf("repository: expected blob object %s, got %v", id, stored.Object().ObjectType()) } + return blob, nil } @@ -48,10 +51,12 @@ func (repo *Repository) ReadStoredTree(id objectid.ObjectID) (*objectstored.Stor if err != nil { return nil, err } + tree, ok := stored.(*objectstored.StoredTree) if !ok { return nil, fmt.Errorf("repository: expected tree object %s, got %v", id, stored.Object().ObjectType()) } + return tree, nil } @@ -61,10 +66,12 @@ func (repo *Repository) ReadStoredCommit(id objectid.ObjectID) (*objectstored.St if err != nil { return nil, err } + commit, ok := stored.(*objectstored.StoredCommit) if !ok { return nil, fmt.Errorf("repository: expected commit object %s, got %v", id, stored.Object().ObjectType()) } + return commit, nil } @@ -74,10 +81,12 @@ func (repo *Repository) ReadStoredTag(id objectid.ObjectID) (*objectstored.Store if err != nil { return nil, err } + tag, ok := stored.(*objectstored.StoredTag) if !ok { return nil, fmt.Errorf("repository: expected tag object %s, got %v", id, stored.Object().ObjectType()) } + return tag, nil } @@ -87,13 +96,16 @@ func (repo *Repository) readParsedObject(id objectid.ObjectID) (object.Object, e if err != nil { return nil, err } + parsed, err := object.ParseObjectWithoutHeader(ty, content, repo.algo) if err != nil { tyName, ok := objecttype.Name(ty) if !ok { tyName = fmt.Sprintf("type %d", ty) } + return nil, fmt.Errorf("repository: parse object %s (%s): %w", id, tyName, err) } + return parsed, nil } diff --git a/repository/read_stored_passthrough_test.go b/repository/read_stored_passthrough_test.go index 3adcc103..676dd428 100644 --- a/repository/read_stored_passthrough_test.go +++ b/repository/read_stored_passthrough_test.go @@ -28,21 +28,25 @@ func TestReadStoredPassThroughs(t *testing.T) { if err != nil { t.Fatalf("os.OpenRoot: %v", err) } + defer func() { _ = root.Close() }() repo, err := repository.Open(root) if err != nil { t.Fatalf("repository.Open: %v", err) } + defer func() { _ = repo.Close() }() headerTy, headerSize, err := repo.ReadStoredHeader(commitID) if err != nil { t.Fatalf("ReadStoredHeader: %v", err) } + if headerTy != objecttype.TypeCommit { t.Fatalf("ReadStoredHeader type = %v, want %v", headerTy, objecttype.TypeCommit) } + if headerSize <= 0 { t.Fatalf("ReadStoredHeader size = %d, want > 0", headerSize) } @@ -51,6 +55,7 @@ func TestReadStoredPassThroughs(t *testing.T) { if err != nil { t.Fatalf("ReadStoredBytesFull: %v", err) } + if len(full) == 0 { t.Fatalf("ReadStoredBytesFull returned empty payload") } @@ -59,9 +64,11 @@ func TestReadStoredPassThroughs(t *testing.T) { if err != nil { t.Fatalf("ReadStoredBytesContent: %v", err) } + if contentTy != objecttype.TypeCommit { t.Fatalf("ReadStoredBytesContent type = %v, want %v", contentTy, objecttype.TypeCommit) } + if len(content) == 0 { t.Fatalf("ReadStoredBytesContent returned empty content") } @@ -70,14 +77,18 @@ func TestReadStoredPassThroughs(t *testing.T) { if err != nil { t.Fatalf("ReadStoredReaderFull: %v", err) } + fullReaderBytes, readErr := io.ReadAll(fullReader) closeErr := fullReader.Close() + if readErr != nil { t.Fatalf("ReadStoredReaderFull read: %v", readErr) } + if closeErr != nil { t.Fatalf("ReadStoredReaderFull close: %v", closeErr) } + if !bytes.Equal(fullReaderBytes, full) { t.Fatalf("ReadStoredReaderFull bytes mismatch against ReadStoredBytesFull") } @@ -86,20 +97,26 @@ func TestReadStoredPassThroughs(t *testing.T) { if err != nil { t.Fatalf("ReadStoredReaderContent: %v", err) } + if readerTy != objecttype.TypeCommit { t.Fatalf("ReadStoredReaderContent type = %v, want %v", readerTy, objecttype.TypeCommit) } + if readerSize != int64(len(content)) { t.Fatalf("ReadStoredReaderContent size = %d, want %d", readerSize, len(content)) } + readerContentBytes, readErr := io.ReadAll(contentReader) closeErr = contentReader.Close() + if readErr != nil { t.Fatalf("ReadStoredReaderContent read: %v", readErr) } + if closeErr != nil { t.Fatalf("ReadStoredReaderContent close: %v", closeErr) } + if !bytes.Equal(readerContentBytes, content) { t.Fatalf("ReadStoredReaderContent bytes mismatch against ReadStoredBytesContent") } diff --git a/repository/refs_test.go b/repository/refs_test.go index d0cb216b..68f01898 100644 --- a/repository/refs_test.go +++ b/repository/refs_test.go @@ -30,22 +30,26 @@ func TestRefConvenienceMethods(t *testing.T) { if err != nil { t.Fatalf("os.OpenRoot: %v", err) } + defer func() { _ = root.Close() }() repo, err := repository.Open(root) if err != nil { t.Fatalf("repository.Open: %v", err) } + defer func() { _ = repo.Close() }() resolved, err := repo.ResolveRef("HEAD") if err != nil { t.Fatalf("ResolveRef(HEAD): %v", err) } + sym, ok := resolved.(ref.Symbolic) if !ok { t.Fatalf("ResolveRef(HEAD) type = %T, want ref.Symbolic", resolved) } + if sym.Target != "refs/heads/main" { t.Fatalf("ResolveRef(HEAD) target = %q, want %q", sym.Target, "refs/heads/main") } @@ -54,6 +58,7 @@ func TestRefConvenienceMethods(t *testing.T) { if err != nil { t.Fatalf("ResolveRefFully(HEAD): %v", err) } + if fully.ID != commitID { t.Fatalf("ResolveRefFully(HEAD) id = %s, want %s", fully.ID, commitID) } @@ -62,6 +67,7 @@ func TestRefConvenienceMethods(t *testing.T) { if err != nil { t.Fatalf("ListRefs: %v", err) } + if len(refs) < 2 { t.Fatalf("ListRefs returned %d refs, want >= 2", len(refs)) } @@ -70,6 +76,7 @@ func TestRefConvenienceMethods(t *testing.T) { if err != nil { t.Fatalf("ShortenRef: %v", err) } + if short != "heads/main" && short != "main" { t.Fatalf("ShortenRef = %q, want %q or %q", short, "heads/main", "main") } @@ -90,18 +97,21 @@ func TestResolveRefErrorSurface(t *testing.T) { if err != nil { t.Fatalf("os.OpenRoot: %v", err) } + defer func() { _ = root.Close() }() repo, err := repository.Open(root) if err != nil { t.Fatalf("repository.Open: %v", err) } + defer func() { _ = repo.Close() }() _, err = repo.ResolveRef("refs/heads/does-not-exist") if err == nil { t.Fatalf("ResolveRef missing: expected error") } + if !strings.Contains(err.Error(), "not found") { t.Fatalf("ResolveRef missing error = %v, want not found detail", err) } @@ -131,18 +141,21 @@ func TestListRefsLooseOverridesPacked(t *testing.T) { if err != nil { t.Fatalf("os.OpenRoot: %v", err) } + defer func() { _ = root.Close() }() repo, err := repository.Open(root) if err != nil { t.Fatalf("repository.Open: %v", err) } + defer func() { _ = repo.Close() }() mainRef, err := repo.ResolveRefFully("refs/heads/main") if err != nil { t.Fatalf("ResolveRefFully(main): %v", err) } + if mainRef.ID != commit2 { t.Fatalf("ResolveRefFully(main) id = %s, want %s", mainRef.ID, commit2) } @@ -151,12 +164,14 @@ func TestListRefsLooseOverridesPacked(t *testing.T) { if err != nil { t.Fatalf("ListRefs(refs/heads/*): %v", err) } + byName := make(map[string]ref.Ref, len(refs)) for _, entry := range refs { name := entry.Name() if _, exists := byName[name]; exists { t.Fatalf("duplicate ref %q in ListRefs output", name) } + byName[name] = entry } @@ -164,10 +179,12 @@ func TestListRefsLooseOverridesPacked(t *testing.T) { if !ok { t.Fatalf("missing refs/heads/main in ListRefs output") } + mainDetached, ok := main.(ref.Detached) if !ok { t.Fatalf("refs/heads/main type = %T, want ref.Detached", main) } + if mainDetached.ID != commit2 { t.Fatalf("refs/heads/main id = %s, want %s", mainDetached.ID, commit2) } @@ -176,10 +193,12 @@ func TestListRefsLooseOverridesPacked(t *testing.T) { if !ok { t.Fatalf("missing refs/heads/feature in ListRefs output") } + featureDetached, ok := feature.(ref.Detached) if !ok { t.Fatalf("refs/heads/feature type = %T, want ref.Detached", feature) } + if featureDetached.ID != commit1 { t.Fatalf("refs/heads/feature id = %s, want %s", featureDetached.ID, commit1) } diff --git a/repository/repository.go b/repository/repository.go index 31034d2a..9927264a 100644 --- a/repository/repository.go +++ b/repository/repository.go @@ -37,6 +37,7 @@ type Repository struct { // Open borrows root during construction and does not close it. func Open(root *os.Root) (repo *Repository, err error) { repo = &Repository{} + defer func() { if err != nil { _ = repo.Close() @@ -47,18 +48,21 @@ func Open(root *os.Root) (repo *Repository, err error) { if err != nil { return nil, err } + repo.config = cfg algo, err := detectObjectAlgorithm(cfg) if err != nil { return nil, err } + repo.algo = algo objects, objectsLooseForWritingOnly, err := openObjectStore(root, algo) if err != nil { return nil, err } + repo.objects = objects repo.objectsLooseForWritingOnly = objectsLooseForWritingOnly @@ -66,6 +70,7 @@ func Open(root *os.Root) (repo *Repository, err error) { if err != nil { return nil, err } + repo.refs = refs return repo, nil @@ -100,17 +105,22 @@ func (repo *Repository) Close() error { var errs []error if repo.refs != nil { - if err := repo.refs.Close(); err != nil { + err := repo.refs.Close() + if err != nil { errs = append(errs, err) } } + if repo.objects != nil { - if err := repo.objects.Close(); err != nil { + err := repo.objects.Close() + if err != nil { errs = append(errs, err) } } + if repo.objectsLooseForWritingOnly != nil { - if err := repo.objectsLooseForWritingOnly.Close(); err != nil { + err := repo.objectsLooseForWritingOnly.Close() + if err != nil { errs = append(errs, err) } } @@ -123,12 +133,14 @@ func parseRepositoryConfig(root *os.Root) (*config.Config, error) { if err != nil { return nil, fmt.Errorf("repository: open config: %w", err) } + defer func() { _ = configFile.Close() }() cfg, err := config.ParseConfig(configFile) if err != nil { return nil, fmt.Errorf("repository: parse config: %w", err) } + return cfg, nil } @@ -137,10 +149,12 @@ func detectObjectAlgorithm(cfg *config.Config) (objectid.Algorithm, error) { if algoName == "" { algoName = objectid.AlgorithmSHA1.String() } + algo, ok := objectid.ParseAlgorithm(algoName) if !ok { return objectid.AlgorithmUnknown, fmt.Errorf("repository: unsupported object format %q", algoName) } + return algo, nil } @@ -154,19 +168,24 @@ func openObjectStore(root *os.Root, algo objectid.Algorithm) (objectstore.Store, if err != nil { return nil, nil, err } + backends := []objectstore.Store{looseStore} packRoot, err := objectsRoot.OpenRoot("pack") if err == nil { var packedStore *objectpacked.Store + packedStore, err = objectpacked.New(packRoot, algo) if err != nil { _ = looseStore.Close() + return nil, nil, err } + backends = append(backends, packedStore) } else if !errors.Is(err, os.ErrNotExist) { _ = looseStore.Close() + return nil, nil, fmt.Errorf("repository: open objects/pack: %w", err) } @@ -175,12 +194,15 @@ func openObjectStore(root *os.Root, algo objectid.Algorithm) (objectstore.Store, objectsRootForWriting, err := root.OpenRoot("objects") if err != nil { _ = objectsChain.Close() + return nil, nil, fmt.Errorf("repository: open objects for loose writing: %w", err) } + objectsLooseForWritingOnly, err := objectloose.New(objectsRootForWriting, algo) if err != nil { _ = objectsRootForWriting.Close() _ = objectsChain.Close() + return nil, nil, err } @@ -192,16 +214,20 @@ func openRefStore(root *os.Root, algo objectid.Algorithm) (out refstore.Store, e if err != nil { return nil, err } + if hasReftable { reftableRoot, err := root.OpenRoot("reftable") if err != nil { return nil, fmt.Errorf("repository: open reftable: %w", err) } + reftableStore, err := reftable.New(reftableRoot, algo) if err != nil { _ = reftableRoot.Close() + return nil, err } + return reftableStore, nil } @@ -209,22 +235,29 @@ func openRefStore(root *os.Root, algo objectid.Algorithm) (out refstore.Store, e if err != nil { return nil, fmt.Errorf("repository: open root for loose refs: %w", err) } + looseStore, err := refloose.New(looseRoot, algo) if err != nil { _ = looseRoot.Close() + return nil, err } + backends := []refstore.Store{looseStore} - if _, err := root.Stat("packed-refs"); err == nil { + _, err = root.Stat("packed-refs") + if err == nil { packedStore, packedErr := refpacked.New(root, algo) if packedErr != nil { _ = looseStore.Close() + return nil, packedErr } + backends = append(backends, packedStore) } else if !errors.Is(err, os.ErrNotExist) { _ = looseStore.Close() + return nil, fmt.Errorf("repository: stat packed-refs: %w", err) } @@ -236,8 +269,10 @@ func hasReftableStack(root *os.Root) (bool, error) { if err == nil { return true, nil } + if errors.Is(err, os.ErrNotExist) { return false, nil } + return false, fmt.Errorf("repository: stat reftable/tables.list: %w", err) } diff --git a/repository/repository_test.go b/repository/repository_test.go index f8b33c8a..22ae5a1a 100644 --- a/repository/repository_test.go +++ b/repository/repository_test.go @@ -29,12 +29,14 @@ func TestOpenFilesRefFormat(t *testing.T) { if err != nil { t.Fatalf("os.OpenRoot: %v", err) } + defer func() { _ = root.Close() }() repo, err := repository.Open(root) if err != nil { t.Fatalf("repository.Open: %v", err) } + defer func() { _ = repo.Close() }() if repo.Algorithm() != algo { @@ -45,9 +47,11 @@ func TestOpenFilesRefFormat(t *testing.T) { if err != nil { t.Fatalf("ReadHeader(commit): %v", err) } + if headerType != objecttype.TypeCommit { t.Fatalf("ReadHeader(commit) type = %v, want %v", headerType, objecttype.TypeCommit) } + if headerSize <= 0 { t.Fatalf("ReadHeader(commit) size = %d, want > 0", headerSize) } @@ -56,10 +60,12 @@ func TestOpenFilesRefFormat(t *testing.T) { if err != nil { t.Fatalf("Resolve(refs/heads/main): %v", err) } + detached, ok := resolved.(ref.Detached) if !ok { t.Fatalf("Resolve(refs/heads/main) type = %T, want ref.Detached", resolved) } + if detached.ID != commitID { t.Fatalf("Resolve(refs/heads/main) id = %s, want %s", detached.ID, commitID) } @@ -68,6 +74,7 @@ func TestOpenFilesRefFormat(t *testing.T) { if err != nil { t.Fatalf("ResolveFully(HEAD): %v", err) } + if head.ID != commitID { t.Fatalf("ResolveFully(HEAD) id = %s, want %s", head.ID, commitID) } @@ -97,6 +104,7 @@ func TestOpenReftableRefFormat(t *testing.T) { func newRepoForRefs(t *testing.T, algo objectid.Algorithm, refFormat string) *testgit.TestRepo { t.Helper() + return testgit.NewRepo(t, testgit.RepoOptions{ ObjectFormat: algo, Bare: true, @@ -109,6 +117,7 @@ func writeMainAndHead(t *testing.T, repoHarness *testgit.TestRepo) objectid.Obje _, _, commitID := repoHarness.MakeCommit(t, "refs") repoHarness.UpdateRef(t, "refs/heads/main", commitID) repoHarness.SymbolicRef(t, "HEAD", "refs/heads/main") + return commitID } @@ -119,18 +128,21 @@ func assertResolveFully(t *testing.T, repoHarness *testgit.TestRepo, name string if err != nil { t.Fatalf("os.OpenRoot: %v", err) } + defer func() { _ = root.Close() }() repo, err := repository.Open(root) if err != nil { t.Fatalf("repository.Open: %v", err) } + defer func() { _ = repo.Close() }() resolved, err := repo.Refs().ResolveFully(name) if err != nil { t.Fatalf("ResolveFully(%s): %v", name, err) } + if resolved.ID != want { t.Fatalf("ResolveFully(%s) id = %s, want %s", name, resolved.ID, want) } diff --git a/repository/stored_test.go b/repository/stored_test.go index b53fcde6..6ebd4259 100644 --- a/repository/stored_test.go +++ b/repository/stored_test.go @@ -28,21 +28,25 @@ func TestReadStoredTyped(t *testing.T) { if err != nil { t.Fatalf("os.OpenRoot: %v", err) } + defer func() { _ = root.Close() }() repo, err := repository.Open(root) if err != nil { t.Fatalf("repository.Open: %v", err) } + defer func() { _ = repo.Close() }() blob, err := repo.ReadStoredBlob(blobID) if err != nil { t.Fatalf("ReadStoredBlob: %v", err) } + if blob.ID() != blobID { t.Fatalf("blob ID = %s, want %s", blob.ID(), blobID) } + if string(blob.Blob().Data) != "commit-body\n" { t.Fatalf("blob body = %q, want %q", blob.Blob().Data, "commit-body\n") } @@ -51,9 +55,11 @@ func TestReadStoredTyped(t *testing.T) { if err != nil { t.Fatalf("ReadStoredTree: %v", err) } + if tree.ID() != treeID { t.Fatalf("tree ID = %s, want %s", tree.ID(), treeID) } + if len(tree.Tree().Entries) != 1 { t.Fatalf("tree entries = %d, want 1", len(tree.Tree().Entries)) } @@ -62,9 +68,11 @@ func TestReadStoredTyped(t *testing.T) { if err != nil { t.Fatalf("ReadStoredCommit: %v", err) } + if commit.ID() != commitID { t.Fatalf("commit ID = %s, want %s", commit.ID(), commitID) } + if commit.Commit().Tree != treeID { t.Fatalf("commit tree = %s, want %s", commit.Commit().Tree, treeID) } @@ -89,12 +97,14 @@ func TestResolveTreeEntry(t *testing.T) { if err != nil { t.Fatalf("os.OpenRoot: %v", err) } + defer func() { _ = root.Close() }() repo, err := repository.Open(root) if err != nil { t.Fatalf("repository.Open: %v", err) } + defer func() { _ = repo.Close() }() rootTree, err := repo.ReadStoredTree(rootTreeID) @@ -106,9 +116,11 @@ func TestResolveTreeEntry(t *testing.T) { if err != nil { t.Fatalf("ResolveTreeEntry: %v", err) } + if entry.Mode != object.FileModeRegular { t.Fatalf("ResolveTreeEntry mode = %o, want %o", entry.Mode, object.FileModeRegular) } + if entry.ID != blobID { t.Fatalf("ResolveTreeEntry id = %s, want %s", entry.ID, blobID) } @@ -133,12 +145,14 @@ func TestResolveTreeEntryErrors(t *testing.T) { if err != nil { t.Fatalf("os.OpenRoot: %v", err) } + defer func() { _ = root.Close() }() repo, err := repository.Open(root) if err != nil { t.Fatalf("repository.Open: %v", err) } + defer func() { _ = repo.Close() }() rootTree, err := repo.ReadStoredTree(rootTreeID) @@ -166,12 +180,14 @@ func TestResolveTreeEntryErrors(t *testing.T) { if err != nil { t.Fatalf("os.OpenRoot: %v", err) } + defer func() { _ = root.Close() }() repo, err := repository.Open(root) if err != nil { t.Fatalf("repository.Open: %v", err) } + defer func() { _ = repo.Close() }() rootTree, err := repo.ReadStoredTree(rootTreeID) @@ -208,18 +224,21 @@ func TestResolveTreeEntryDeepPath(t *testing.T) { currentTree = repoHarness.Mktree(t, fmt.Sprintf("040000 tree %s\t%s\n", currentTree, name)) parts = append([][]byte{[]byte(name)}, parts...) } + parts = append(parts, []byte("leaf.txt")) root, err := os.OpenRoot(repoHarness.Dir()) if err != nil { t.Fatalf("os.OpenRoot: %v", err) } + defer func() { _ = root.Close() }() repo, err := repository.Open(root) if err != nil { t.Fatalf("repository.Open: %v", err) } + defer func() { _ = repo.Close() }() rootTree, err := repo.ReadStoredTree(currentTree) @@ -231,9 +250,11 @@ func TestResolveTreeEntryDeepPath(t *testing.T) { if err != nil { t.Fatalf("ResolveTreeEntry(deep): %v", err) } + if entry.Mode != object.FileModeRegular { t.Fatalf("ResolveTreeEntry(deep) mode = %o, want %o", entry.Mode, object.FileModeRegular) } + if entry.ID != leafBlobID { t.Fatalf("ResolveTreeEntry(deep) id = %s, want %s", entry.ID, leafBlobID) } @@ -270,12 +291,14 @@ func TestReadStoredTreeMixedModes(t *testing.T) { if err != nil { t.Fatalf("os.OpenRoot: %v", err) } + defer func() { _ = root.Close() }() repo, err := repository.Open(root) if err != nil { t.Fatalf("repository.Open: %v", err) } + defer func() { _ = repo.Close() }() rootTree, err := repo.ReadStoredTree(rootTreeID) diff --git a/repository/traversal_bench_test.go b/repository/traversal_bench_test.go index 63e131de..3480964e 100644 --- a/repository/traversal_bench_test.go +++ b/repository/traversal_bench_test.go @@ -25,6 +25,7 @@ func BenchmarkTraverseHeadTree(b *testing.B) { if err != nil { b.Fatalf("os.OpenRoot(%q): %v", repoPath, err) } + b.Cleanup(func() { _ = root.Close() }) @@ -33,6 +34,7 @@ func BenchmarkTraverseHeadTree(b *testing.B) { if err != nil { b.Fatalf("repository.Open(root for %q): %v", repoPath, err) } + b.Cleanup(func() { _ = repo.Close() }) @@ -41,10 +43,12 @@ func BenchmarkTraverseHeadTree(b *testing.B) { if err != nil { b.Fatalf("ResolveRefFully(HEAD): %v", err) } + stored, err := repo.ReadStored(head.ID) if err != nil { b.Fatalf("ReadStored(%s): %v", head.ID, err) } + commit, ok := stored.Object().(*object.Commit) if !ok { b.Fatalf("HEAD object type %T, want *object.Commit", stored.Object()) @@ -62,6 +66,7 @@ func BenchmarkTraverseHeadTree(b *testing.B) { } b.StopTimer() + if lastCount <= 0 { b.Fatalf("traverseTreeIter count = %d, want > 0", lastCount) } diff --git a/repository/traversal_helpers_test.go b/repository/traversal_helpers_test.go index b23a81c8..e2d662e6 100644 --- a/repository/traversal_helpers_test.go +++ b/repository/traversal_helpers_test.go @@ -21,10 +21,13 @@ func traverseTreeIter(repo *repository.Repository, root objectid.ObjectID) (int, id := frame.id if !frame.isTree { - if _, err := repo.ReadStoredSize(id); err != nil { + _, err := repo.ReadStoredSize(id) + if err != nil { return 0, err } + total++ + continue } @@ -32,12 +35,15 @@ func traverseTreeIter(repo *repository.Repository, root objectid.ObjectID) (int, if err != nil { return 0, err } + total++ + for i := len(tree.Tree().Entries) - 1; i >= 0; i-- { entry := tree.Tree().Entries[i] if entry.Mode == object.FileModeGitlink { continue } + stack = append(stack, treeWalkFrame{ id: entry.ID, isTree: entry.Mode == object.FileModeDir, @@ -56,15 +62,19 @@ func traverseReachableIter(repo *repository.Repository, root objectid.ObjectID) for len(stack) > 0 { id := stack[len(stack)-1] stack = stack[:len(stack)-1] - if _, ok := visited[id]; ok { + + _, ok := visited[id] + if ok { continue } + visited[id] = struct{}{} stored, err := repo.ReadStored(id) if err != nil { return 0, err } + total++ switch obj := stored.Object().(type) { @@ -77,6 +87,7 @@ func traverseReachableIter(repo *repository.Repository, root objectid.ObjectID) if entry.Mode == object.FileModeGitlink { continue } + stack = append(stack, entry.ID) } case *object.Tag: diff --git a/repository/traversal_test.go b/repository/traversal_test.go index 28c03a2c..34c3a75b 100644 --- a/repository/traversal_test.go +++ b/repository/traversal_test.go @@ -26,6 +26,7 @@ func TestRepositoryDepthFirstEnumerationFromHEAD(t *testing.T) { blob2, tree2 := repoHarness.MakeSingleFileTree(t, "second.txt", []byte("second\n")) commit2 := repoHarness.CommitTree(t, tree2, "walk-two", commit1) _ = blob2 + repoHarness.UpdateRef(t, "refs/heads/main", commit2) repoHarness.SymbolicRef(t, "HEAD", "refs/heads/main") @@ -46,6 +47,7 @@ func TestRepositoryDepthFirstEnumerationCurrentWorktree(t *testing.T) { if info.IsDir() { walkRepositoryFromHead(t, gitPath) + return } @@ -57,7 +59,9 @@ func TestRepositoryDepthFirstEnumerationCurrentWorktree(t *testing.T) { if err != nil { t.Fatalf("read %q: %v", gitPath, err) } + line := strings.TrimSpace(string(content)) + prefix := "gitdir: " if !strings.HasPrefix(line, prefix) { t.Fatalf("%q file does not begin with %q", gitPath, prefix) @@ -67,23 +71,30 @@ func TestRepositoryDepthFirstEnumerationCurrentWorktree(t *testing.T) { if gitdirRel == "" { t.Fatalf("%q contains empty gitdir path", gitPath) } + gitdirPath := gitdirRel if !filepath.IsAbs(gitdirPath) { gitdirPath = filepath.Join(worktreeRoot, gitdirPath) } + commondirPath := filepath.Join(gitdirPath, "commondir") + commondirContent, err := os.ReadFile(commondirPath) //#nosec G304 if err != nil { t.Fatalf("read %q: %v", commondirPath, err) } + repoPath := strings.TrimSpace(string(commondirContent)) if repoPath == "" { t.Fatalf("%q contains empty repo path", commondirPath) } + if filepath.IsAbs(repoPath) { walkRepositoryFromHead(t, repoPath) + return } + repoPath = filepath.Join(gitdirPath, repoPath) walkRepositoryFromHead(t, repoPath) @@ -96,22 +107,26 @@ func walkRepositoryFromHead(t *testing.T, repoPath string) { if err != nil { t.Fatalf("os.OpenRoot(%q): %v", repoPath, err) } + defer func() { _ = root.Close() }() repo, err := repository.Open(root) if err != nil { t.Fatalf("repository.Open(root for %q): %v", repoPath, err) } + defer func() { _ = repo.Close() }() head, err := repo.ResolveRefFully("HEAD") if err != nil { t.Fatalf("ResolveRefFully(HEAD): %v", err) } + objectsRead, err := traverseReachableIter(repo, head.ID) if err != nil { t.Fatalf("traverseReachableIter(%s): %v", head.ID, err) } + if objectsRead <= 0 { t.Fatalf("no objects were enumerated from HEAD (%s)", fmt.Sprintf("%q", repoPath)) } diff --git a/repository/tree_resolve.go b/repository/tree_resolve.go index 6b7023ba..d4ef529e 100644 --- a/repository/tree_resolve.go +++ b/repository/tree_resolve.go @@ -16,11 +16,13 @@ func (repo *Repository) ResolveTreeEntry(tree *objectstored.StoredTree, parts [] if tree == nil { return object.TreeEntry{}, errors.New("repository: nil root tree") } + if len(parts) == 0 { return object.TreeEntry{}, errors.New("repository: empty tree path") } current := tree + for i, part := range parts { if len(part) == 0 { return object.TreeEntry{}, errors.New("repository: empty tree path segment") @@ -30,9 +32,11 @@ func (repo *Repository) ResolveTreeEntry(tree *objectstored.StoredTree, parts [] if entry == nil { return object.TreeEntry{}, fmt.Errorf("repository: tree entry %q not found", part) } + if i == len(parts)-1 { return *entry, nil } + if entry.Mode != object.FileModeDir { return object.TreeEntry{}, fmt.Errorf("repository: path segment %q is not a tree", part) } @@ -41,6 +45,7 @@ func (repo *Repository) ResolveTreeEntry(tree *objectstored.StoredTree, parts [] if err != nil { return object.TreeEntry{}, err } + current = next } diff --git a/repository/write_loose.go b/repository/write_loose.go index adebb2f6..9784ef25 100644 --- a/repository/write_loose.go +++ b/repository/write_loose.go @@ -15,6 +15,7 @@ func (repo *Repository) WriteLooseBytesFull(raw []byte) (objectid.ObjectID, erro if err != nil { return objectid.ObjectID{}, fmt.Errorf("repository: write loose full bytes: %w", err) } + return id, nil } @@ -24,6 +25,7 @@ func (repo *Repository) WriteLooseBytesContent(ty objecttype.Type, content []byt if err != nil { return objectid.ObjectID{}, fmt.Errorf("repository: write loose content bytes: %w", err) } + return id, nil } @@ -34,6 +36,7 @@ func (repo *Repository) WriteLooseReaderFull(src io.Reader) (objectid.ObjectID, if err != nil { return objectid.ObjectID{}, fmt.Errorf("repository: write loose full reader: %w", err) } + return id, nil } @@ -44,5 +47,6 @@ func (repo *Repository) WriteLooseReaderContent(ty objecttype.Type, size int64, if err != nil { return objectid.ObjectID{}, fmt.Errorf("repository: write loose content reader: %w", err) } + return id, nil } diff --git a/repository/write_loose_test.go b/repository/write_loose_test.go index 603b3a88..ab732df4 100644 --- a/repository/write_loose_test.go +++ b/repository/write_loose_test.go @@ -25,15 +25,18 @@ func TestWriteLooseBytesContent(t *testing.T) { if err != nil { t.Fatalf("os.OpenRoot: %v", err) } + defer func() { _ = root.Close() }() repo, err := repository.Open(root) if err != nil { t.Fatalf("repository.Open: %v", err) } + defer func() { _ = repo.Close() }() content := []byte("write-loose-bytes-content\n") + gotID, err := repo.WriteLooseBytesContent(objecttype.TypeBlob, content) if err != nil { t.Fatalf("WriteLooseBytesContent: %v", err) @@ -48,9 +51,11 @@ func TestWriteLooseBytesContent(t *testing.T) { if err != nil { t.Fatalf("ReadStoredBytesContent: %v", err) } + if ty != objecttype.TypeBlob { t.Fatalf("ReadStoredBytesContent type = %v, want %v", ty, objecttype.TypeBlob) } + if !bytes.Equal(gotContent, content) { t.Fatalf("ReadStoredBytesContent content mismatch") } @@ -71,15 +76,18 @@ func TestWriteLooseReaderContent(t *testing.T) { if err != nil { t.Fatalf("os.OpenRoot: %v", err) } + defer func() { _ = root.Close() }() repo, err := repository.Open(root) if err != nil { t.Fatalf("repository.Open: %v", err) } + defer func() { _ = repo.Close() }() content := []byte("write-loose-reader-content\n") + gotID, err := repo.WriteLooseReaderContent(objecttype.TypeBlob, int64(len(content)), bytes.NewReader(content)) if err != nil { t.Fatalf("WriteLooseReaderContent: %v", err) @@ -107,12 +115,14 @@ func TestWriteLooseFull(t *testing.T) { if err != nil { t.Fatalf("os.OpenRoot: %v", err) } + defer func() { _ = root.Close() }() repo, err := repository.Open(root) if err != nil { t.Fatalf("repository.Open: %v", err) } + defer func() { _ = repo.Close() }() raw, err := repo.ReadStoredBytesFull(commitID) @@ -124,6 +134,7 @@ func TestWriteLooseFull(t *testing.T) { if err != nil { t.Fatalf("WriteLooseBytesFull: %v", err) } + if idFromBytes != commitID { t.Fatalf("WriteLooseBytesFull id = %s, want %s", idFromBytes, commitID) } @@ -132,6 +143,7 @@ func TestWriteLooseFull(t *testing.T) { if err != nil { t.Fatalf("WriteLooseReaderFull: %v", err) } + if idFromReader != commitID { t.Fatalf("WriteLooseReaderFull id = %s, want %s", idFromReader, commitID) } |
