diff options
Diffstat (limited to 'plumbing')
34 files changed, 879 insertions, 171 deletions
diff --git a/plumbing/format/gitattributes/attributes.go b/plumbing/format/gitattributes/attributes.go index d36ec1b..026d221 100644 --- a/plumbing/format/gitattributes/attributes.go +++ b/plumbing/format/gitattributes/attributes.go @@ -1,6 +1,7 @@ package gitattributes import ( + "bufio" "errors" "io" "strings" @@ -88,13 +89,10 @@ func (a attribute) String() string { // ReadAttributes reads patterns and attributes from the gitattributes format. func ReadAttributes(r io.Reader, domain []string, allowMacro bool) (attributes []MatchAttribute, err error) { - data, err := io.ReadAll(r) - if err != nil { - return nil, err - } + scanner := bufio.NewScanner(r) - for _, line := range strings.Split(string(data), eol) { - attribute, err := ParseAttributesLine(line, domain, allowMacro) + for scanner.Scan() { + attribute, err := ParseAttributesLine(scanner.Text(), domain, allowMacro) if err != nil { return attributes, err } @@ -105,6 +103,10 @@ func ReadAttributes(r io.Reader, domain []string, allowMacro bool) (attributes [ attributes = append(attributes, attribute) } + if err := scanner.Err(); err != nil { + return attributes, err + } + return attributes, nil } diff --git a/plumbing/format/gitattributes/dir.go b/plumbing/format/gitattributes/dir.go index 123fe25..4238196 100644 --- a/plumbing/format/gitattributes/dir.go +++ b/plumbing/format/gitattributes/dir.go @@ -2,8 +2,11 @@ package gitattributes import ( "os" + "path/filepath" + "strings" "github.com/go-git/go-billy/v5" + "github.com/go-git/go-git/v5/plumbing/format/config" gioutil "github.com/go-git/go-git/v5/utils/ioutil" ) @@ -26,6 +29,8 @@ func ReadAttributesFile(fs billy.Filesystem, path []string, attributesFile strin return nil, err } + defer gioutil.CheckClose(f, &err) + return ReadAttributes(f, path, allowMacro) } @@ -56,7 +61,14 @@ func walkDirectory(fs billy.Filesystem, root []string) (attributes []MatchAttrib continue } - path := append(root, fi.Name()) + p := fi.Name() + + // Handles the case whereby just the volume name ("C:") is appended, + // to root. Change it to "C:\", which is better handled by fs.Join(). + if filepath.VolumeName(p) != "" && !strings.HasSuffix(p, string(filepath.Separator)) { + p = p + string(filepath.Separator) + } + path := append(root, p) dirAttributes, err := ReadAttributesFile(fs, path, gitattributesFile, false) if err != nil { diff --git a/plumbing/format/gitignore/dir.go b/plumbing/format/gitignore/dir.go index d8fb30c..92df5a3 100644 --- a/plumbing/format/gitignore/dir.go +++ b/plumbing/format/gitignore/dir.go @@ -64,6 +64,10 @@ func ReadPatterns(fs billy.Filesystem, path []string) (ps []Pattern, err error) for _, fi := range fis { if fi.IsDir() && fi.Name() != gitDir { + if NewMatcher(ps).Match(append(path, fi.Name()), true) { + continue + } + var subps []Pattern subps, err = ReadPatterns(fs, append(path, fi.Name())) if err != nil { @@ -116,7 +120,7 @@ func loadPatterns(fs billy.Filesystem, path string) (ps []Pattern, err error) { return } -// LoadGlobalPatterns loads gitignore patterns from from the gitignore file +// LoadGlobalPatterns loads gitignore patterns from the gitignore file // declared in a user's ~/.gitconfig file. If the ~/.gitconfig file does not // exist the function will return nil. If the core.excludesfile property // is not declared, the function will return nil. If the file pointed to by @@ -132,7 +136,7 @@ func LoadGlobalPatterns(fs billy.Filesystem) (ps []Pattern, err error) { return loadPatterns(fs, fs.Join(home, gitconfigFile)) } -// LoadSystemPatterns loads gitignore patterns from from the gitignore file +// LoadSystemPatterns loads gitignore patterns from the gitignore file // declared in a system's /etc/gitconfig file. If the /etc/gitconfig file does // not exist the function will return nil. If the core.excludesfile property // is not declared, the function will return nil. If the file pointed to by diff --git a/plumbing/format/gitignore/dir_test.go b/plumbing/format/gitignore/dir_test.go index 465c571..ba8ad80 100644 --- a/plumbing/format/gitignore/dir_test.go +++ b/plumbing/format/gitignore/dir_test.go @@ -44,6 +44,8 @@ func (s *MatcherSuite) SetUpTest(c *C) { c.Assert(err, IsNil) _, err = f.Write([]byte("ignore.crlf\r\n")) c.Assert(err, IsNil) + _, err = f.Write([]byte("ignore_dir\n")) + c.Assert(err, IsNil) err = f.Close() c.Assert(err, IsNil) @@ -56,6 +58,17 @@ func (s *MatcherSuite) SetUpTest(c *C) { err = f.Close() c.Assert(err, IsNil) + err = fs.MkdirAll("ignore_dir", os.ModePerm) + c.Assert(err, IsNil) + f, err = fs.Create("ignore_dir/.gitignore") + c.Assert(err, IsNil) + _, err = f.Write([]byte("!file\n")) + c.Assert(err, IsNil) + _, err = fs.Create("ignore_dir/file") + c.Assert(err, IsNil) + err = f.Close() + c.Assert(err, IsNil) + err = fs.MkdirAll("another", os.ModePerm) c.Assert(err, IsNil) err = fs.MkdirAll("exclude.crlf", os.ModePerm) @@ -267,12 +280,13 @@ func (s *MatcherSuite) SetUpTest(c *C) { func (s *MatcherSuite) TestDir_ReadPatterns(c *C) { checkPatterns := func(ps []Pattern) { - c.Assert(ps, HasLen, 6) + c.Assert(ps, HasLen, 7) m := NewMatcher(ps) c.Assert(m.Match([]string{"exclude.crlf"}, true), Equals, true) c.Assert(m.Match([]string{"ignore.crlf"}, true), Equals, true) c.Assert(m.Match([]string{"vendor", "gopkg.in"}, true), Equals, true) + c.Assert(m.Match([]string{"ignore_dir", "file"}, false), Equals, true) c.Assert(m.Match([]string{"vendor", "github.com"}, true), Equals, false) c.Assert(m.Match([]string{"multiple", "sub", "ignores", "first", "ignore_dir"}, true), Equals, true) c.Assert(m.Match([]string{"multiple", "sub", "ignores", "second", "ignore_dir"}, true), Equals, true) diff --git a/plumbing/format/index/decoder.go b/plumbing/format/index/decoder.go index 6778cf7..f43b1c5 100644 --- a/plumbing/format/index/decoder.go +++ b/plumbing/format/index/decoder.go @@ -24,8 +24,8 @@ var ( // ErrInvalidChecksum is returned by Decode if the SHA1 hash mismatch with // the read content ErrInvalidChecksum = errors.New("invalid checksum") - - errUnknownExtension = errors.New("unknown extension") + // ErrUnknownExtension is returned when an index extension is encountered that is considered mandatory + ErrUnknownExtension = errors.New("unknown extension") ) const ( @@ -39,6 +39,7 @@ const ( // A Decoder reads and decodes index files from an input stream. type Decoder struct { + buf *bufio.Reader r io.Reader hash hash.Hash lastEntry *Entry @@ -49,8 +50,10 @@ type Decoder struct { // NewDecoder returns a new decoder that reads from r. func NewDecoder(r io.Reader) *Decoder { h := hash.New(hash.CryptoType) + buf := bufio.NewReader(r) return &Decoder{ - r: io.TeeReader(r, h), + buf: buf, + r: io.TeeReader(buf, h), hash: h, extReader: bufio.NewReader(nil), } @@ -210,71 +213,76 @@ func (d *Decoder) readExtensions(idx *Index) error { // count that they are not supported by jgit or libgit var expected []byte + var peeked []byte var err error - var header [4]byte + // we should always be able to peek for 4 bytes (header) + 4 bytes (extlen) + final hash + // if this fails, we know that we're at the end of the index + peekLen := 4 + 4 + d.hash.Size() + for { expected = d.hash.Sum(nil) - - var n int - if n, err = io.ReadFull(d.r, header[:]); err != nil { - if n == 0 { - err = io.EOF - } - + peeked, err = d.buf.Peek(peekLen) + if len(peeked) < peekLen { + // there can't be an extension at this point, so let's bail out + err = nil break } + if err != nil { + return err + } - err = d.readExtension(idx, header[:]) + err = d.readExtension(idx) if err != nil { - break + return err } } - if err != errUnknownExtension { + return d.readChecksum(expected) +} + +func (d *Decoder) readExtension(idx *Index) error { + var header [4]byte + + if _, err := io.ReadFull(d.r, header[:]); err != nil { return err } - return d.readChecksum(expected, header) -} + r, err := d.getExtensionReader() + if err != nil { + return err + } -func (d *Decoder) readExtension(idx *Index, header []byte) error { switch { - case bytes.Equal(header, treeExtSignature): - r, err := d.getExtensionReader() - if err != nil { - return err - } - + case bytes.Equal(header[:], treeExtSignature): idx.Cache = &Tree{} d := &treeExtensionDecoder{r} if err := d.Decode(idx.Cache); err != nil { return err } - case bytes.Equal(header, resolveUndoExtSignature): - r, err := d.getExtensionReader() - if err != nil { - return err - } - + case bytes.Equal(header[:], resolveUndoExtSignature): idx.ResolveUndo = &ResolveUndo{} d := &resolveUndoDecoder{r} if err := d.Decode(idx.ResolveUndo); err != nil { return err } - case bytes.Equal(header, endOfIndexEntryExtSignature): - r, err := d.getExtensionReader() - if err != nil { - return err - } - + case bytes.Equal(header[:], endOfIndexEntryExtSignature): idx.EndOfIndexEntry = &EndOfIndexEntry{} d := &endOfIndexEntryDecoder{r} if err := d.Decode(idx.EndOfIndexEntry); err != nil { return err } default: - return errUnknownExtension + // See https://git-scm.com/docs/index-format, which says: + // If the first byte is 'A'..'Z' the extension is optional and can be ignored. + if header[0] < 'A' || header[0] > 'Z' { + return ErrUnknownExtension + } + + d := &unknownExtensionDecoder{r} + if err := d.Decode(); err != nil { + return err + } } return nil @@ -290,11 +298,10 @@ func (d *Decoder) getExtensionReader() (*bufio.Reader, error) { return d.extReader, nil } -func (d *Decoder) readChecksum(expected []byte, alreadyRead [4]byte) error { +func (d *Decoder) readChecksum(expected []byte) error { var h plumbing.Hash - copy(h[:4], alreadyRead[:]) - if _, err := io.ReadFull(d.r, h[4:]); err != nil { + if _, err := io.ReadFull(d.r, h[:]); err != nil { return err } @@ -476,3 +483,22 @@ func (d *endOfIndexEntryDecoder) Decode(e *EndOfIndexEntry) error { _, err = io.ReadFull(d.r, e.Hash[:]) return err } + +type unknownExtensionDecoder struct { + r *bufio.Reader +} + +func (d *unknownExtensionDecoder) Decode() error { + var buf [1024]byte + + for { + _, err := d.r.Read(buf[:]) + if err == io.EOF { + break + } + if err != nil { + return err + } + } + return nil +} diff --git a/plumbing/format/index/decoder_test.go b/plumbing/format/index/decoder_test.go index 39ab336..4adddda 100644 --- a/plumbing/format/index/decoder_test.go +++ b/plumbing/format/index/decoder_test.go @@ -1,6 +1,11 @@ package index import ( + "bytes" + "crypto" + "github.com/go-git/go-git/v5/plumbing/hash" + "github.com/go-git/go-git/v5/utils/binary" + "io" "testing" "github.com/go-git/go-git/v5/plumbing" @@ -218,3 +223,100 @@ func (s *IndexSuite) TestDecodeEndOfIndexEntry(c *C) { c.Assert(idx.EndOfIndexEntry.Offset, Equals, uint32(716)) c.Assert(idx.EndOfIndexEntry.Hash.String(), Equals, "922e89d9ffd7cefce93a211615b2053c0f42bd78") } + +func (s *IndexSuite) readSimpleIndex(c *C) *Index { + f, err := fixtures.Basic().One().DotGit().Open("index") + c.Assert(err, IsNil) + defer func() { c.Assert(f.Close(), IsNil) }() + + idx := &Index{} + d := NewDecoder(f) + err = d.Decode(idx) + c.Assert(err, IsNil) + + return idx +} + +func (s *IndexSuite) buildIndexWithExtension(c *C, signature string, data string) []byte { + idx := s.readSimpleIndex(c) + + buf := bytes.NewBuffer(nil) + e := NewEncoder(buf) + + err := e.encode(idx, false) + c.Assert(err, IsNil) + err = e.encodeRawExtension(signature, []byte(data)) + c.Assert(err, IsNil) + + err = e.encodeFooter() + c.Assert(err, IsNil) + + return buf.Bytes() +} + +func (s *IndexSuite) TestDecodeUnknownOptionalExt(c *C) { + f := bytes.NewReader(s.buildIndexWithExtension(c, "TEST", "testdata")) + + idx := &Index{} + d := NewDecoder(f) + err := d.Decode(idx) + c.Assert(err, IsNil) +} + +func (s *IndexSuite) TestDecodeUnknownMandatoryExt(c *C) { + f := bytes.NewReader(s.buildIndexWithExtension(c, "test", "testdata")) + + idx := &Index{} + d := NewDecoder(f) + err := d.Decode(idx) + c.Assert(err, ErrorMatches, ErrUnknownExtension.Error()) +} + +func (s *IndexSuite) TestDecodeTruncatedExt(c *C) { + idx := s.readSimpleIndex(c) + + buf := bytes.NewBuffer(nil) + e := NewEncoder(buf) + + err := e.encode(idx, false) + c.Assert(err, IsNil) + + _, err = e.w.Write([]byte("TEST")) + c.Assert(err, IsNil) + + err = binary.WriteUint32(e.w, uint32(100)) + c.Assert(err, IsNil) + + _, err = e.w.Write([]byte("truncated")) + c.Assert(err, IsNil) + + err = e.encodeFooter() + c.Assert(err, IsNil) + + idx = &Index{} + d := NewDecoder(buf) + err = d.Decode(idx) + c.Assert(err, ErrorMatches, io.EOF.Error()) +} + +func (s *IndexSuite) TestDecodeInvalidHash(c *C) { + idx := s.readSimpleIndex(c) + + buf := bytes.NewBuffer(nil) + e := NewEncoder(buf) + + err := e.encode(idx, false) + c.Assert(err, IsNil) + + err = e.encodeRawExtension("TEST", []byte("testdata")) + c.Assert(err, IsNil) + + h := hash.New(crypto.SHA1) + err = binary.Write(e.w, h.Sum(nil)) + c.Assert(err, IsNil) + + idx = &Index{} + d := NewDecoder(buf) + err = d.Decode(idx) + c.Assert(err, ErrorMatches, ErrInvalidChecksum.Error()) +} diff --git a/plumbing/format/index/encoder.go b/plumbing/format/index/encoder.go index fa2d814..c292c2c 100644 --- a/plumbing/format/index/encoder.go +++ b/plumbing/format/index/encoder.go @@ -3,6 +3,7 @@ package index import ( "bytes" "errors" + "fmt" "io" "sort" "time" @@ -35,6 +36,11 @@ func NewEncoder(w io.Writer) *Encoder { // Encode writes the Index to the stream of the encoder. func (e *Encoder) Encode(idx *Index) error { + return e.encode(idx, true) +} + +func (e *Encoder) encode(idx *Index, footer bool) error { + // TODO: support v4 // TODO: support extensions if idx.Version > EncodeVersionSupported { @@ -49,7 +55,10 @@ func (e *Encoder) Encode(idx *Index) error { return err } - return e.encodeFooter() + if footer { + return e.encodeFooter() + } + return nil } func (e *Encoder) encodeHeader(idx *Index) error { @@ -135,6 +144,29 @@ func (e *Encoder) encodeEntry(entry *Entry) error { return binary.Write(e.w, []byte(entry.Name)) } +func (e *Encoder) encodeRawExtension(signature string, data []byte) error { + if len(signature) != 4 { + return fmt.Errorf("invalid signature length") + } + + _, err := e.w.Write([]byte(signature)) + if err != nil { + return err + } + + err = binary.WriteUint32(e.w, uint32(len(data))) + if err != nil { + return err + } + + _, err = e.w.Write(data) + if err != nil { + return err + } + + return nil +} + func (e *Encoder) timeToUint32(t *time.Time) (uint32, uint32, error) { if t.IsZero() { return 0, 0, nil diff --git a/plumbing/format/packfile/delta_index.go b/plumbing/format/packfile/delta_index.go index 07a6112..a60ec0b 100644 --- a/plumbing/format/packfile/delta_index.go +++ b/plumbing/format/packfile/delta_index.go @@ -32,19 +32,17 @@ func (idx *deltaIndex) findMatch(src, tgt []byte, tgtOffset int) (srcOffset, l i return 0, -1 } - if len(tgt) >= tgtOffset+s && len(src) >= blksz { - h := hashBlock(tgt, tgtOffset) - tIdx := h & idx.mask - eIdx := idx.table[tIdx] - if eIdx != 0 { - srcOffset = idx.entries[eIdx] - } else { - return - } - - l = matchLength(src, tgt, tgtOffset, srcOffset) + h := hashBlock(tgt, tgtOffset) + tIdx := h & idx.mask + eIdx := idx.table[tIdx] + if eIdx == 0 { + return } + srcOffset = idx.entries[eIdx] + + l = matchLength(src, tgt, tgtOffset, srcOffset) + return } diff --git a/plumbing/object/commit_test.go b/plumbing/object/commit_test.go index 6651ef8..a048926 100644 --- a/plumbing/object/commit_test.go +++ b/plumbing/object/commit_test.go @@ -455,7 +455,7 @@ func (s *SuiteCommit) TestStat(c *C) { c.Assert(fileStats[1].Name, Equals, "php/crappy.php") c.Assert(fileStats[1].Addition, Equals, 259) c.Assert(fileStats[1].Deletion, Equals, 0) - c.Assert(fileStats[1].String(), Equals, " php/crappy.php | 259 ++++++++++++++++++++++++++++++++++++++++++++++++++++\n") + c.Assert(fileStats[1].String(), Equals, " php/crappy.php | 259 +++++++++++++++++++++++++++++++++++++++++++++++++++++\n") } func (s *SuiteCommit) TestVerify(c *C) { diff --git a/plumbing/object/commit_walker_path.go b/plumbing/object/commit_walker_path.go index aa0ca15..c1ec8ba 100644 --- a/plumbing/object/commit_walker_path.go +++ b/plumbing/object/commit_walker_path.go @@ -57,6 +57,8 @@ func (c *commitPathIter) Next() (*Commit, error) { } func (c *commitPathIter) getNextFileCommit() (*Commit, error) { + var parentTree, currentTree *Tree + for { // Parent-commit can be nil if the current-commit is the initial commit parentCommit, parentCommitErr := c.sourceIter.Next() @@ -68,13 +70,17 @@ func (c *commitPathIter) getNextFileCommit() (*Commit, error) { parentCommit = nil } - // Fetch the trees of the current and parent commits - currentTree, currTreeErr := c.currentCommit.Tree() - if currTreeErr != nil { - return nil, currTreeErr + if parentTree == nil { + var currTreeErr error + currentTree, currTreeErr = c.currentCommit.Tree() + if currTreeErr != nil { + return nil, currTreeErr + } + } else { + currentTree = parentTree + parentTree = nil } - var parentTree *Tree if parentCommit != nil { var parentTreeErr error parentTree, parentTreeErr = parentCommit.Tree() @@ -115,7 +121,8 @@ func (c *commitPathIter) hasFileChange(changes Changes, parent *Commit) bool { // filename matches, now check if source iterator contains all commits (from all refs) if c.checkParent { - if parent != nil && isParentHash(parent.Hash, c.currentCommit) { + // Check if parent is beyond the initial commit + if parent == nil || isParentHash(parent.Hash, c.currentCommit) { return true } continue diff --git a/plumbing/object/commit_walker_test.go b/plumbing/object/commit_walker_test.go index c47d68b..fa0ca7d 100644 --- a/plumbing/object/commit_walker_test.go +++ b/plumbing/object/commit_walker_test.go @@ -228,3 +228,29 @@ func (s *CommitWalkerSuite) TestCommitBSFIteratorWithIgnore(c *C) { c.Assert(commit.Hash.String(), Equals, expected[i]) } } + +func (s *CommitWalkerSuite) TestCommitPathIteratorInitialCommit(c *C) { + commit := s.commit(c, plumbing.NewHash(s.Fixture.Head)) + + fileName := "LICENSE" + + var commits []*Commit + NewCommitPathIterFromIter( + func(path string) bool { return path == fileName }, + NewCommitIterCTime(commit, nil, nil), + true, + ).ForEach(func(c *Commit) error { + commits = append(commits, c) + return nil + }) + + expected := []string{ + "b029517f6300c2da0f4b651b8642506cd6aaf45d", + } + + c.Assert(commits, HasLen, len(expected)) + + for i, commit := range commits { + c.Assert(commit.Hash.String(), Equals, expected[i]) + } +} diff --git a/plumbing/object/patch.go b/plumbing/object/patch.go index dd8fef4..3c61f62 100644 --- a/plumbing/object/patch.go +++ b/plumbing/object/patch.go @@ -6,7 +6,7 @@ import ( "errors" "fmt" "io" - "math" + "strconv" "strings" "github.com/go-git/go-git/v5/plumbing" @@ -234,69 +234,56 @@ func (fileStats FileStats) String() string { return printStat(fileStats) } +// printStat prints the stats of changes in content of files. +// Original implementation: https://github.com/git/git/blob/1a87c842ece327d03d08096395969aca5e0a6996/diff.c#L2615 +// Parts of the output: +// <pad><filename><pad>|<pad><changeNumber><pad><+++/---><newline> +// example: " main.go | 10 +++++++--- " func printStat(fileStats []FileStat) string { - padLength := float64(len(" ")) - newlineLength := float64(len("\n")) - separatorLength := float64(len("|")) - // Soft line length limit. The text length calculation below excludes - // length of the change number. Adding that would take it closer to 80, - // but probably not more than 80, until it's a huge number. - lineLength := 72.0 - - // Get the longest filename and longest total change. - var longestLength float64 - var longestTotalChange float64 - for _, fs := range fileStats { - if int(longestLength) < len(fs.Name) { - longestLength = float64(len(fs.Name)) - } - totalChange := fs.Addition + fs.Deletion - if int(longestTotalChange) < totalChange { - longestTotalChange = float64(totalChange) - } - } - - // Parts of the output: - // <pad><filename><pad>|<pad><changeNumber><pad><+++/---><newline> - // example: " main.go | 10 +++++++--- " - - // <pad><filename><pad> - leftTextLength := padLength + longestLength + padLength - - // <pad><number><pad><+++++/-----><newline> - // Excluding number length here. - rightTextLength := padLength + padLength + newlineLength + maxGraphWidth := uint(53) + maxNameLen := 0 + maxChangeLen := 0 - totalTextArea := leftTextLength + separatorLength + rightTextLength - heightOfHistogram := lineLength - totalTextArea + scaleLinear := func(it, width, max uint) uint { + if it == 0 || max == 0 { + return 0 + } - // Scale the histogram. - var scaleFactor float64 - if longestTotalChange > heightOfHistogram { - // Scale down to heightOfHistogram. - scaleFactor = longestTotalChange / heightOfHistogram - } else { - scaleFactor = 1.0 + return 1 + (it * (width - 1) / max) } - finalOutput := "" for _, fs := range fileStats { - addn := float64(fs.Addition) - deln := float64(fs.Deletion) - addc := int(math.Floor(addn/scaleFactor)) - delc := int(math.Floor(deln/scaleFactor)) - if addc < 0 { - addc = 0 + if len(fs.Name) > maxNameLen { + maxNameLen = len(fs.Name) } - if delc < 0 { - delc = 0 + + changes := strconv.Itoa(fs.Addition + fs.Deletion) + if len(changes) > maxChangeLen { + maxChangeLen = len(changes) } - adds := strings.Repeat("+", addc) - dels := strings.Repeat("-", delc) - finalOutput += fmt.Sprintf(" %s | %d %s%s\n", fs.Name, (fs.Addition + fs.Deletion), adds, dels) } - return finalOutput + result := "" + for _, fs := range fileStats { + add := uint(fs.Addition) + del := uint(fs.Deletion) + np := maxNameLen - len(fs.Name) + cp := maxChangeLen - len(strconv.Itoa(fs.Addition+fs.Deletion)) + + total := add + del + if total > maxGraphWidth { + add = scaleLinear(add, maxGraphWidth, total) + del = scaleLinear(del, maxGraphWidth, total) + } + + adds := strings.Repeat("+", int(add)) + dels := strings.Repeat("-", int(del)) + namePad := strings.Repeat(" ", np) + changePad := strings.Repeat(" ", cp) + + result += fmt.Sprintf(" %s%s | %s%d %s%s\n", fs.Name, namePad, changePad, total, adds, dels) + } + return result } func getFileStatsFromFilePatches(filePatches []fdiff.FilePatch) FileStats { diff --git a/plumbing/object/patch_test.go b/plumbing/object/patch_test.go index 2cff795..e0e63a5 100644 --- a/plumbing/object/patch_test.go +++ b/plumbing/object/patch_test.go @@ -45,3 +45,113 @@ func (s *PatchSuite) TestStatsWithSubmodules(c *C) { c.Assert(err, IsNil) c.Assert(p, NotNil) } + +func (s *PatchSuite) TestFileStatsString(c *C) { + testCases := []struct { + description string + input FileStats + expected string + }{ + + { + description: "no files changed", + input: []FileStat{}, + expected: "", + }, + { + description: "one file touched - no changes", + input: []FileStat{ + { + Name: "file1", + }, + }, + expected: " file1 | 0 \n", + }, + { + description: "one file changed", + input: []FileStat{ + { + Name: "file1", + Addition: 1, + }, + }, + expected: " file1 | 1 +\n", + }, + { + description: "one file changed with one addition and one deletion", + input: []FileStat{ + { + Name: ".github/workflows/git.yml", + Addition: 1, + Deletion: 1, + }, + }, + expected: " .github/workflows/git.yml | 2 +-\n", + }, + { + description: "two files changed", + input: []FileStat{ + { + Name: ".github/workflows/git.yml", + Addition: 1, + Deletion: 1, + }, + { + Name: "cli/go-git/go.mod", + Addition: 4, + Deletion: 4, + }, + }, + expected: " .github/workflows/git.yml | 2 +-\n cli/go-git/go.mod | 8 ++++----\n", + }, + { + description: "three files changed", + input: []FileStat{ + { + Name: ".github/workflows/git.yml", + Addition: 3, + Deletion: 3, + }, + { + Name: "worktree.go", + Addition: 107, + }, + { + Name: "worktree_test.go", + Addition: 75, + }, + }, + expected: " .github/workflows/git.yml | 6 +++---\n" + + " worktree.go | 107 +++++++++++++++++++++++++++++++++++++++++++++++++++++\n" + + " worktree_test.go | 75 +++++++++++++++++++++++++++++++++++++++++++++++++++++\n", + }, + { + description: "three files changed with deletions and additions", + input: []FileStat{ + { + Name: ".github/workflows/git.yml", + Addition: 3, + Deletion: 3, + }, + { + Name: "worktree.go", + Addition: 107, + Deletion: 217, + }, + { + Name: "worktree_test.go", + Addition: 75, + Deletion: 275, + }, + }, + expected: " .github/workflows/git.yml | 6 +++---\n" + + " worktree.go | 324 ++++++++++++++++++-----------------------------------\n" + + " worktree_test.go | 350 ++++++++++++-----------------------------------------\n", + }, + } + + for _, tc := range testCases { + c.Log("Executing test cases:", tc.description) + c.Assert(printStat(tc.input), Equals, tc.expected) + } +} diff --git a/plumbing/object/tree.go b/plumbing/object/tree.go index e9f7666..0fd0e51 100644 --- a/plumbing/object/tree.go +++ b/plumbing/object/tree.go @@ -7,6 +7,7 @@ import ( "io" "path" "path/filepath" + "sort" "strings" "github.com/go-git/go-git/v5/plumbing" @@ -27,6 +28,7 @@ var ( ErrFileNotFound = errors.New("file not found") ErrDirectoryNotFound = errors.New("directory not found") ErrEntryNotFound = errors.New("entry not found") + ErrEntriesNotSorted = errors.New("entries in tree are not sorted") ) // Tree is basically like a directory - it references a bunch of other trees @@ -270,6 +272,28 @@ func (t *Tree) Decode(o plumbing.EncodedObject) (err error) { return nil } +type TreeEntrySorter []TreeEntry + +func (s TreeEntrySorter) Len() int { + return len(s) +} + +func (s TreeEntrySorter) Less(i, j int) bool { + name1 := s[i].Name + name2 := s[j].Name + if s[i].Mode == filemode.Dir { + name1 += "/" + } + if s[j].Mode == filemode.Dir { + name2 += "/" + } + return name1 < name2 +} + +func (s TreeEntrySorter) Swap(i, j int) { + s[i], s[j] = s[j], s[i] +} + // Encode transforms a Tree into a plumbing.EncodedObject. func (t *Tree) Encode(o plumbing.EncodedObject) (err error) { o.SetType(plumbing.TreeObject) @@ -279,7 +303,15 @@ func (t *Tree) Encode(o plumbing.EncodedObject) (err error) { } defer ioutil.CheckClose(w, &err) + + if !sort.IsSorted(TreeEntrySorter(t.Entries)) { + return ErrEntriesNotSorted + } + for _, entry := range t.Entries { + if strings.IndexByte(entry.Name, 0) != -1 { + return fmt.Errorf("malformed filename %q", entry.Name) + } if _, err = fmt.Fprintf(w, "%o %s", entry.Mode, entry.Name); err != nil { return err } diff --git a/plumbing/object/tree_test.go b/plumbing/object/tree_test.go index bb5fc7a..feb058a 100644 --- a/plumbing/object/tree_test.go +++ b/plumbing/object/tree_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "io" + "sort" "testing" fixtures "github.com/go-git/go-git-fixtures/v4" @@ -220,6 +221,30 @@ func (o *SortReadCloser) Read(p []byte) (int, error) { return nw, nil } +func (s *TreeSuite) TestTreeEntriesSorted(c *C) { + tree := &Tree{ + Entries: []TreeEntry{ + {"foo", filemode.Empty, plumbing.NewHash("b029517f6300c2da0f4b651b8642506cd6aaf45d")}, + {"bar", filemode.Empty, plumbing.NewHash("c029517f6300c2da0f4b651b8642506cd6aaf45d")}, + {"baz", filemode.Empty, plumbing.NewHash("d029517f6300c2da0f4b651b8642506cd6aaf45d")}, + }, + } + + { + c.Assert(sort.IsSorted(TreeEntrySorter(tree.Entries)), Equals, false) + obj := &plumbing.MemoryObject{} + err := tree.Encode(obj) + c.Assert(err, Equals, ErrEntriesNotSorted) + } + + { + sort.Sort(TreeEntrySorter(tree.Entries)) + obj := &plumbing.MemoryObject{} + err := tree.Encode(obj) + c.Assert(err, IsNil) + } +} + func (s *TreeSuite) TestTreeDecodeEncodeIdempotent(c *C) { trees := []*Tree{ { @@ -231,6 +256,7 @@ func (s *TreeSuite) TestTreeDecodeEncodeIdempotent(c *C) { }, } for _, tree := range trees { + sort.Sort(TreeEntrySorter(tree.Entries)) obj := &plumbing.MemoryObject{} err := tree.Encode(obj) c.Assert(err, IsNil) diff --git a/plumbing/object/treenoder.go b/plumbing/object/treenoder.go index 6e7b334..2adb645 100644 --- a/plumbing/object/treenoder.go +++ b/plumbing/object/treenoder.go @@ -88,7 +88,9 @@ func (t *treeNoder) Children() ([]noder.Noder, error) { } } - return transformChildren(parent) + var err error + t.children, err = transformChildren(parent) + return t.children, err } // Returns the children of a tree as treenoders. diff --git a/plumbing/protocol/packp/filter.go b/plumbing/protocol/packp/filter.go new file mode 100644 index 0000000..145fc71 --- /dev/null +++ b/plumbing/protocol/packp/filter.go @@ -0,0 +1,76 @@ +package packp + +import ( + "errors" + "fmt" + "github.com/go-git/go-git/v5/plumbing" + "net/url" + "strings" +) + +var ErrUnsupportedObjectFilterType = errors.New("unsupported object filter type") + +// Filter values enable the partial clone capability which causes +// the server to omit objects that match the filter. +// +// See [Git's documentation] for more details. +// +// [Git's documentation]: https://github.com/git/git/blob/e02ecfcc534e2021aae29077a958dd11c3897e4c/Documentation/rev-list-options.txt#L948 +type Filter string + +type BlobLimitPrefix string + +const ( + BlobLimitPrefixNone BlobLimitPrefix = "" + BlobLimitPrefixKibi BlobLimitPrefix = "k" + BlobLimitPrefixMebi BlobLimitPrefix = "m" + BlobLimitPrefixGibi BlobLimitPrefix = "g" +) + +// FilterBlobNone omits all blobs. +func FilterBlobNone() Filter { + return "blob:none" +} + +// FilterBlobLimit omits blobs of size at least n bytes (when prefix is +// BlobLimitPrefixNone), n kibibytes (when prefix is BlobLimitPrefixKibi), +// n mebibytes (when prefix is BlobLimitPrefixMebi) or n gibibytes (when +// prefix is BlobLimitPrefixGibi). n can be zero, in which case all blobs +// will be omitted. +func FilterBlobLimit(n uint64, prefix BlobLimitPrefix) Filter { + return Filter(fmt.Sprintf("blob:limit=%d%s", n, prefix)) +} + +// FilterTreeDepth omits all blobs and trees whose depth from the root tree +// is larger or equal to depth. +func FilterTreeDepth(depth uint64) Filter { + return Filter(fmt.Sprintf("tree:%d", depth)) +} + +// FilterObjectType omits all objects which are not of the requested type t. +// Supported types are TagObject, CommitObject, TreeObject and BlobObject. +func FilterObjectType(t plumbing.ObjectType) (Filter, error) { + switch t { + case plumbing.TagObject: + fallthrough + case plumbing.CommitObject: + fallthrough + case plumbing.TreeObject: + fallthrough + case plumbing.BlobObject: + return Filter(fmt.Sprintf("object:type=%s", t.String())), nil + default: + return "", fmt.Errorf("%w: %s", ErrUnsupportedObjectFilterType, t.String()) + } +} + +// FilterCombine combines multiple Filter values together. +func FilterCombine(filters ...Filter) Filter { + var escapedFilters []string + + for _, filter := range filters { + escapedFilters = append(escapedFilters, url.QueryEscape(string(filter))) + } + + return Filter(fmt.Sprintf("combine:%s", strings.Join(escapedFilters, "+"))) +} diff --git a/plumbing/protocol/packp/filter_test.go b/plumbing/protocol/packp/filter_test.go new file mode 100644 index 0000000..266670f --- /dev/null +++ b/plumbing/protocol/packp/filter_test.go @@ -0,0 +1,58 @@ +package packp + +import ( + "github.com/go-git/go-git/v5/plumbing" + "github.com/stretchr/testify/require" + "testing" +) + +func TestFilterBlobNone(t *testing.T) { + require.EqualValues(t, "blob:none", FilterBlobNone()) +} + +func TestFilterBlobLimit(t *testing.T) { + require.EqualValues(t, "blob:limit=0", FilterBlobLimit(0, BlobLimitPrefixNone)) + require.EqualValues(t, "blob:limit=1000", FilterBlobLimit(1000, BlobLimitPrefixNone)) + require.EqualValues(t, "blob:limit=4k", FilterBlobLimit(4, BlobLimitPrefixKibi)) + require.EqualValues(t, "blob:limit=4m", FilterBlobLimit(4, BlobLimitPrefixMebi)) + require.EqualValues(t, "blob:limit=4g", FilterBlobLimit(4, BlobLimitPrefixGibi)) +} + +func TestFilterTreeDepth(t *testing.T) { + require.EqualValues(t, "tree:0", FilterTreeDepth(0)) + require.EqualValues(t, "tree:1", FilterTreeDepth(1)) + require.EqualValues(t, "tree:2", FilterTreeDepth(2)) +} + +func TestFilterObjectType(t *testing.T) { + filter, err := FilterObjectType(plumbing.TagObject) + require.NoError(t, err) + require.EqualValues(t, "object:type=tag", filter) + + filter, err = FilterObjectType(plumbing.CommitObject) + require.NoError(t, err) + require.EqualValues(t, "object:type=commit", filter) + + filter, err = FilterObjectType(plumbing.TreeObject) + require.NoError(t, err) + require.EqualValues(t, "object:type=tree", filter) + + filter, err = FilterObjectType(plumbing.BlobObject) + require.NoError(t, err) + require.EqualValues(t, "object:type=blob", filter) + + _, err = FilterObjectType(plumbing.InvalidObject) + require.Error(t, err) + + _, err = FilterObjectType(plumbing.OFSDeltaObject) + require.Error(t, err) +} + +func TestFilterCombine(t *testing.T) { + require.EqualValues(t, "combine:tree%3A2+blob%3Anone", + FilterCombine( + FilterTreeDepth(2), + FilterBlobNone(), + ), + ) +} diff --git a/plumbing/protocol/packp/sideband/demux.go b/plumbing/protocol/packp/sideband/demux.go index 0116f96..01d95a3 100644 --- a/plumbing/protocol/packp/sideband/demux.go +++ b/plumbing/protocol/packp/sideband/demux.go @@ -114,7 +114,7 @@ func (d *Demuxer) nextPackData() ([]byte, error) { size := len(content) if size == 0 { - return nil, nil + return nil, io.EOF } else if size > d.max { return nil, ErrMaxPackedExceeded } diff --git a/plumbing/protocol/packp/sideband/demux_test.go b/plumbing/protocol/packp/sideband/demux_test.go index 8f23353..1ba3ad9 100644 --- a/plumbing/protocol/packp/sideband/demux_test.go +++ b/plumbing/protocol/packp/sideband/demux_test.go @@ -105,8 +105,34 @@ func (s *SidebandSuite) TestDecodeWithProgress(c *C) { c.Assert(progress, DeepEquals, []byte{'F', 'O', 'O', '\n'}) } -func (s *SidebandSuite) TestDecodeWithUnknownChannel(c *C) { +func (s *SidebandSuite) TestDecodeFlushEOF(c *C) { + expected := []byte("abcdefghijklmnopqrstuvwxyz") + + input := bytes.NewBuffer(nil) + e := pktline.NewEncoder(input) + e.Encode(PackData.WithPayload(expected[0:8])) + e.Encode(ProgressMessage.WithPayload([]byte{'F', 'O', 'O', '\n'})) + e.Encode(PackData.WithPayload(expected[8:16])) + e.Encode(PackData.WithPayload(expected[16:26])) + e.Flush() + e.Encode(PackData.WithPayload([]byte("bar\n"))) + + output := bytes.NewBuffer(nil) + content := bytes.NewBuffer(nil) + d := NewDemuxer(Sideband64k, input) + d.Progress = output + + n, err := content.ReadFrom(d) + c.Assert(err, IsNil) + c.Assert(n, Equals, int64(26)) + c.Assert(content.Bytes(), DeepEquals, expected) + progress, err := io.ReadAll(output) + c.Assert(err, IsNil) + c.Assert(progress, DeepEquals, []byte{'F', 'O', 'O', '\n'}) +} + +func (s *SidebandSuite) TestDecodeWithUnknownChannel(c *C) { buf := bytes.NewBuffer(nil) e := pktline.NewEncoder(buf) e.Encode([]byte{'4', 'F', 'O', 'O', '\n'}) @@ -150,5 +176,4 @@ func (s *SidebandSuite) TestDecodeErrMaxPacked(c *C) { n, err := io.ReadFull(d, content) c.Assert(err, Equals, ErrMaxPackedExceeded) c.Assert(n, Equals, 0) - } diff --git a/plumbing/protocol/packp/ulreq.go b/plumbing/protocol/packp/ulreq.go index 344f8c7..ef4e08a 100644 --- a/plumbing/protocol/packp/ulreq.go +++ b/plumbing/protocol/packp/ulreq.go @@ -17,6 +17,7 @@ type UploadRequest struct { Wants []plumbing.Hash Shallows []plumbing.Hash Depth Depth + Filter Filter } // Depth values stores the desired depth of the requested packfile: see diff --git a/plumbing/protocol/packp/ulreq_encode.go b/plumbing/protocol/packp/ulreq_encode.go index c451e23..8b19c0f 100644 --- a/plumbing/protocol/packp/ulreq_encode.go +++ b/plumbing/protocol/packp/ulreq_encode.go @@ -132,6 +132,17 @@ func (e *ulReqEncoder) encodeDepth() stateFn { return nil } + return e.encodeFilter +} + +func (e *ulReqEncoder) encodeFilter() stateFn { + if filter := e.data.Filter; filter != "" { + if err := e.pe.Encodef("filter %s\n", filter); err != nil { + e.err = fmt.Errorf("encoding filter %s: %s", filter, err) + return nil + } + } + return e.encodeFlush } diff --git a/plumbing/protocol/packp/ulreq_encode_test.go b/plumbing/protocol/packp/ulreq_encode_test.go index ba6df1a..247de27 100644 --- a/plumbing/protocol/packp/ulreq_encode_test.go +++ b/plumbing/protocol/packp/ulreq_encode_test.go @@ -273,6 +273,20 @@ func (s *UlReqEncodeSuite) TestDepthReference(c *C) { testUlReqEncode(c, ur, expected) } +func (s *UlReqEncodeSuite) TestFilter(c *C) { + ur := NewUploadRequest() + ur.Wants = append(ur.Wants, plumbing.NewHash("1111111111111111111111111111111111111111")) + ur.Filter = FilterTreeDepth(0) + + expected := []string{ + "want 1111111111111111111111111111111111111111\n", + "filter tree:0\n", + pktline.FlushString, + } + + testUlReqEncode(c, ur, expected) +} + func (s *UlReqEncodeSuite) TestAll(c *C) { ur := NewUploadRequest() ur.Wants = append(ur.Wants, diff --git a/plumbing/serverinfo/serverinfo_test.go b/plumbing/serverinfo/serverinfo_test.go index 0a52ea2..251746b 100644 --- a/plumbing/serverinfo/serverinfo_test.go +++ b/plumbing/serverinfo/serverinfo_test.go @@ -179,6 +179,7 @@ func (s *ServerInfoSuite) TestUpdateServerInfoBasicChange(c *C) { c.Assert(err, IsNil) err = UpdateServerInfo(st, fs) + c.Assert(err, IsNil) assertInfoRefs(c, st, fs) assertObjectPacks(c, st, fs) diff --git a/plumbing/transport/common.go b/plumbing/transport/common.go index b05437f..fae1aa9 100644 --- a/plumbing/transport/common.go +++ b/plumbing/transport/common.go @@ -19,6 +19,7 @@ import ( "fmt" "io" "net/url" + "path/filepath" "strconv" "strings" @@ -295,7 +296,11 @@ func parseFile(endpoint string) (*Endpoint, bool) { return nil, false } - path := endpoint + path, err := filepath.Abs(endpoint) + if err != nil { + return nil, false + } + return &Endpoint{ Protocol: "file", Path: path, diff --git a/plumbing/transport/common_test.go b/plumbing/transport/common_test.go index 3efc555..1501f73 100644 --- a/plumbing/transport/common_test.go +++ b/plumbing/transport/common_test.go @@ -3,6 +3,9 @@ package transport import ( "fmt" "net/url" + "os" + "path/filepath" + "runtime" "testing" "github.com/go-git/go-git/v5/plumbing/protocol/packp/capability" @@ -120,6 +123,14 @@ func (s *SuiteCommon) TestNewEndpointSCPLikeWithPort(c *C) { } func (s *SuiteCommon) TestNewEndpointFileAbs(c *C) { + var err error + abs := "/foo.git" + + if runtime.GOOS == "windows" { + abs, err = filepath.Abs(abs) + c.Assert(err, IsNil) + } + e, err := NewEndpoint("/foo.git") c.Assert(err, IsNil) c.Assert(e.Protocol, Equals, "file") @@ -127,11 +138,14 @@ func (s *SuiteCommon) TestNewEndpointFileAbs(c *C) { c.Assert(e.Password, Equals, "") c.Assert(e.Host, Equals, "") c.Assert(e.Port, Equals, 0) - c.Assert(e.Path, Equals, "/foo.git") - c.Assert(e.String(), Equals, "file:///foo.git") + c.Assert(e.Path, Equals, abs) + c.Assert(e.String(), Equals, "file://"+abs) } func (s *SuiteCommon) TestNewEndpointFileRel(c *C) { + abs, err := filepath.Abs("foo.git") + c.Assert(err, IsNil) + e, err := NewEndpoint("foo.git") c.Assert(err, IsNil) c.Assert(e.Protocol, Equals, "file") @@ -139,11 +153,20 @@ func (s *SuiteCommon) TestNewEndpointFileRel(c *C) { c.Assert(e.Password, Equals, "") c.Assert(e.Host, Equals, "") c.Assert(e.Port, Equals, 0) - c.Assert(e.Path, Equals, "foo.git") - c.Assert(e.String(), Equals, "file://foo.git") + c.Assert(e.Path, Equals, abs) + c.Assert(e.String(), Equals, "file://"+abs) } func (s *SuiteCommon) TestNewEndpointFileWindows(c *C) { + abs := "C:\\foo.git" + + if runtime.GOOS != "windows" { + cwd, err := os.Getwd() + c.Assert(err, IsNil) + + abs = filepath.Join(cwd, "C:\\foo.git") + } + e, err := NewEndpoint("C:\\foo.git") c.Assert(err, IsNil) c.Assert(e.Protocol, Equals, "file") @@ -151,8 +174,8 @@ func (s *SuiteCommon) TestNewEndpointFileWindows(c *C) { c.Assert(e.Password, Equals, "") c.Assert(e.Host, Equals, "") c.Assert(e.Port, Equals, 0) - c.Assert(e.Path, Equals, "C:\\foo.git") - c.Assert(e.String(), Equals, "file://C:\\foo.git") + c.Assert(e.Path, Equals, abs) + c.Assert(e.String(), Equals, "file://"+abs) } func (s *SuiteCommon) TestNewEndpointFileURL(c *C) { diff --git a/plumbing/transport/http/common.go b/plumbing/transport/http/common.go index 54126fe..120008d 100644 --- a/plumbing/transport/http/common.go +++ b/plumbing/transport/http/common.go @@ -91,9 +91,9 @@ func advertisedReferences(ctx context.Context, s *session, serviceName string) ( } type client struct { - c *http.Client + client *http.Client transports *lru.Cache - m sync.RWMutex + mutex sync.RWMutex } // ClientOptions holds user configurable options for the client. @@ -147,7 +147,7 @@ func NewClientWithOptions(c *http.Client, opts *ClientOptions) transport.Transpo } } cl := &client{ - c: c, + client: c, } if opts != nil { @@ -234,10 +234,10 @@ func newSession(c *client, ep *transport.Endpoint, auth transport.AuthMethod) (* // if the client wasn't configured to have a cache for transports then just configure // the transport and use it directly, otherwise try to use the cache. if c.transports == nil { - tr, ok := c.c.Transport.(*http.Transport) + tr, ok := c.client.Transport.(*http.Transport) if !ok { return nil, fmt.Errorf("expected underlying client transport to be of type: %s; got: %s", - reflect.TypeOf(transport), reflect.TypeOf(c.c.Transport)) + reflect.TypeOf(transport), reflect.TypeOf(c.client.Transport)) } transport = tr.Clone() @@ -258,7 +258,7 @@ func newSession(c *client, ep *transport.Endpoint, auth transport.AuthMethod) (* transport, found = c.fetchTransport(transportOpts) if !found { - transport = c.c.Transport.(*http.Transport).Clone() + transport = c.client.Transport.(*http.Transport).Clone() configureTransport(transport, ep) c.addTransport(transportOpts, transport) } @@ -266,12 +266,12 @@ func newSession(c *client, ep *transport.Endpoint, auth transport.AuthMethod) (* httpClient = &http.Client{ Transport: transport, - CheckRedirect: c.c.CheckRedirect, - Jar: c.c.Jar, - Timeout: c.c.Timeout, + CheckRedirect: c.client.CheckRedirect, + Jar: c.client.Jar, + Timeout: c.client.Timeout, } } else { - httpClient = c.c + httpClient = c.client } s := &session{ @@ -430,11 +430,11 @@ func NewErr(r *http.Response) error { switch r.StatusCode { case http.StatusUnauthorized: - return transport.ErrAuthenticationRequired + return fmt.Errorf("%w: %s", transport.ErrAuthenticationRequired, reason) case http.StatusForbidden: - return transport.ErrAuthorizationFailed + return fmt.Errorf("%w: %s", transport.ErrAuthorizationFailed, reason) case http.StatusNotFound: - return transport.ErrRepositoryNotFound + return fmt.Errorf("%w: %s", transport.ErrRepositoryNotFound, reason) } return plumbing.NewUnexpectedError(&Err{r, reason}) diff --git a/plumbing/transport/http/common_test.go b/plumbing/transport/http/common_test.go index 6bd018b..f0eb68d 100644 --- a/plumbing/transport/http/common_test.go +++ b/plumbing/transport/http/common_test.go @@ -46,7 +46,7 @@ func (s *UploadPackSuite) TestNewClient(c *C) { cl := &http.Client{Transport: roundTripper} r, ok := NewClient(cl).(*client) c.Assert(ok, Equals, true) - c.Assert(r.c, Equals, cl) + c.Assert(r.client, Equals, cl) } func (s *ClientSuite) TestNewBasicAuth(c *C) { @@ -76,15 +76,15 @@ func (s *ClientSuite) TestNewErrOK(c *C) { } func (s *ClientSuite) TestNewErrUnauthorized(c *C) { - s.testNewHTTPError(c, http.StatusUnauthorized, "authentication required") + s.testNewHTTPError(c, http.StatusUnauthorized, ".*authentication required.*") } func (s *ClientSuite) TestNewErrForbidden(c *C) { - s.testNewHTTPError(c, http.StatusForbidden, "authorization failed") + s.testNewHTTPError(c, http.StatusForbidden, ".*authorization failed.*") } func (s *ClientSuite) TestNewErrNotFound(c *C) { - s.testNewHTTPError(c, http.StatusNotFound, "repository not found") + s.testNewHTTPError(c, http.StatusNotFound, ".*repository not found.*") } func (s *ClientSuite) TestNewHTTPError40x(c *C) { diff --git a/plumbing/transport/http/transport.go b/plumbing/transport/http/transport.go index 052f3c8..c8db389 100644 --- a/plumbing/transport/http/transport.go +++ b/plumbing/transport/http/transport.go @@ -14,21 +14,21 @@ type transportOptions struct { } func (c *client) addTransport(opts transportOptions, transport *http.Transport) { - c.m.Lock() + c.mutex.Lock() c.transports.Add(opts, transport) - c.m.Unlock() + c.mutex.Unlock() } func (c *client) removeTransport(opts transportOptions) { - c.m.Lock() + c.mutex.Lock() c.transports.Remove(opts) - c.m.Unlock() + c.mutex.Unlock() } func (c *client) fetchTransport(opts transportOptions) (*http.Transport, bool) { - c.m.RLock() + c.mutex.RLock() t, ok := c.transports.Get(opts) - c.m.RUnlock() + c.mutex.RUnlock() if !ok { return nil, false } diff --git a/plumbing/transport/http/upload_pack_test.go b/plumbing/transport/http/upload_pack_test.go index abb7adf..3a1610a 100644 --- a/plumbing/transport/http/upload_pack_test.go +++ b/plumbing/transport/http/upload_pack_test.go @@ -8,6 +8,7 @@ import ( "os" "path/filepath" + . "github.com/go-git/go-git/v5/internal/test" "github.com/go-git/go-git/v5/plumbing" "github.com/go-git/go-git/v5/plumbing/protocol/packp" "github.com/go-git/go-git/v5/plumbing/transport" @@ -37,7 +38,7 @@ func (s *UploadPackSuite) TestAdvertisedReferencesNotExists(c *C) { r, err := s.Client.NewUploadPackSession(s.NonExistentEndpoint, s.EmptyAuth) c.Assert(err, IsNil) info, err := r.AdvertisedReferences() - c.Assert(err, Equals, transport.ErrRepositoryNotFound) + c.Assert(err, ErrorIs, transport.ErrRepositoryNotFound) c.Assert(info, IsNil) } diff --git a/plumbing/transport/ssh/auth_method.go b/plumbing/transport/ssh/auth_method.go index ac4e358..f9c598e 100644 --- a/plumbing/transport/ssh/auth_method.go +++ b/plumbing/transport/ssh/auth_method.go @@ -230,11 +230,11 @@ func (a *PublicKeysCallback) ClientConfig() (*ssh.ClientConfig, error) { // ~/.ssh/known_hosts // /etc/ssh/ssh_known_hosts func NewKnownHostsCallback(files ...string) (ssh.HostKeyCallback, error) { - kh, err := newKnownHosts(files...) - return ssh.HostKeyCallback(kh), err + db, err := newKnownHostsDb(files...) + return db.HostKeyCallback(), err } -func newKnownHosts(files ...string) (knownhosts.HostKeyCallback, error) { +func newKnownHostsDb(files ...string) (*knownhosts.HostKeyDB, error) { var err error if len(files) == 0 { @@ -247,7 +247,7 @@ func newKnownHosts(files ...string) (knownhosts.HostKeyCallback, error) { return nil, err } - return knownhosts.New(files...) + return knownhosts.NewDB(files...) } func getDefaultKnownHostsFiles() ([]string, error) { @@ -301,11 +301,12 @@ type HostKeyCallbackHelper struct { // HostKeyCallback is empty a default callback is created using // NewKnownHostsCallback. func (m *HostKeyCallbackHelper) SetHostKeyCallback(cfg *ssh.ClientConfig) (*ssh.ClientConfig, error) { - var err error if m.HostKeyCallback == nil { - if m.HostKeyCallback, err = NewKnownHostsCallback(); err != nil { + db, err := newKnownHostsDb() + if err != nil { return cfg, err } + m.HostKeyCallback = db.HostKeyCallback() } cfg.HostKeyCallback = m.HostKeyCallback diff --git a/plumbing/transport/ssh/auth_method_test.go b/plumbing/transport/ssh/auth_method_test.go index b275018..e3f652e 100644 --- a/plumbing/transport/ssh/auth_method_test.go +++ b/plumbing/transport/ssh/auth_method_test.go @@ -18,7 +18,8 @@ import ( type ( SuiteCommon struct{} - mockKnownHosts struct{} + mockKnownHosts struct{} + mockKnownHostsWithCert struct{} ) func (mockKnownHosts) host() string { return "github.com" } @@ -27,6 +28,19 @@ func (mockKnownHosts) knownHosts() []byte { } func (mockKnownHosts) Network() string { return "tcp" } func (mockKnownHosts) String() string { return "github.com:22" } +func (mockKnownHosts) Algorithms() []string { + return []string{ssh.KeyAlgoRSA, ssh.KeyAlgoRSASHA256, ssh.KeyAlgoRSASHA512} +} + +func (mockKnownHostsWithCert) host() string { return "github.com" } +func (mockKnownHostsWithCert) knownHosts() []byte { + return []byte(`@cert-authority github.com ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAQEAq2A7hRGmdnm9tUDbO9IDSwBK6TbQa+PXYPCPy6rbTrTtw7PHkccKrpp0yVhp5HdEIcKr6pLlVDBfOLX9QUsyCOV0wzfjIJNlGEYsdlLJizHhbn2mUjvSAHQqZETYP81eFzLQNnPHt4EVVUh7VfDESU84KezmD5QlWpXLmvU31/yMf+Se8xhHTvKSCZIFImWwoG6mbUoWf9nzpIoaSjB+weqqUUmpaaasXVal72J+UX2B+2RPW3RcT0eOzQgqlJL3RKrTJvdsjE3JEAvGq3lGHSZXy28G3skua2SmVi/w4yCE6gbODqnTWlg7+wC604ydGXA8VJiS5ap43JXiUFFAaQ==`) +} +func (mockKnownHostsWithCert) Network() string { return "tcp" } +func (mockKnownHostsWithCert) String() string { return "github.com:22" } +func (mockKnownHostsWithCert) Algorithms() []string { + return []string{ssh.CertAlgoRSASHA512v01, ssh.CertAlgoRSASHA256v01, ssh.CertAlgoRSAv01} +} var _ = Suite(&SuiteCommon{}) @@ -230,3 +244,93 @@ func (*SuiteCommon) TestNewKnownHostsCallback(c *C) { err = clb(mock.String(), mock, hostKey) c.Assert(err, IsNil) } + +func (*SuiteCommon) TestNewKnownHostsDbWithoutCert(c *C) { + if runtime.GOOS == "js" { + c.Skip("not available in wasm") + } + + var mock = mockKnownHosts{} + + f, err := util.TempFile(osfs.Default, "", "known-hosts") + c.Assert(err, IsNil) + + _, err = f.Write(mock.knownHosts()) + c.Assert(err, IsNil) + + err = f.Close() + c.Assert(err, IsNil) + + defer util.RemoveAll(osfs.Default, f.Name()) + + f, err = osfs.Default.Open(f.Name()) + c.Assert(err, IsNil) + + defer f.Close() + + db, err := newKnownHostsDb(f.Name()) + c.Assert(err, IsNil) + + algos := db.HostKeyAlgorithms(mock.String()) + c.Assert(algos, HasLen, len(mock.Algorithms())) + + contains := func(container []string, value string) bool { + for _, inner := range container { + if inner == value { + return true + } + } + return false + } + + for _, algorithm := range mock.Algorithms() { + if !contains(algos, algorithm) { + c.Error("algos does not contain ", algorithm) + } + } +} + +func (*SuiteCommon) TestNewKnownHostsDbWithCert(c *C) { + if runtime.GOOS == "js" { + c.Skip("not available in wasm") + } + + var mock = mockKnownHostsWithCert{} + + f, err := util.TempFile(osfs.Default, "", "known-hosts") + c.Assert(err, IsNil) + + _, err = f.Write(mock.knownHosts()) + c.Assert(err, IsNil) + + err = f.Close() + c.Assert(err, IsNil) + + defer util.RemoveAll(osfs.Default, f.Name()) + + f, err = osfs.Default.Open(f.Name()) + c.Assert(err, IsNil) + + defer f.Close() + + db, err := newKnownHostsDb(f.Name()) + c.Assert(err, IsNil) + + algos := db.HostKeyAlgorithms(mock.String()) + c.Assert(algos, HasLen, len(mock.Algorithms())) + + contains := func(container []string, value string) bool { + for _, inner := range container { + if inner == value { + return true + } + } + return false + } + + for _, algorithm := range mock.Algorithms() { + if !contains(algos, algorithm) { + c.Error("algos does not contain ", algorithm) + } + } +} diff --git a/plumbing/transport/ssh/common.go b/plumbing/transport/ssh/common.go index 05dea44..a37024f 100644 --- a/plumbing/transport/ssh/common.go +++ b/plumbing/transport/ssh/common.go @@ -11,7 +11,6 @@ import ( "github.com/go-git/go-git/v5/plumbing/transport" "github.com/go-git/go-git/v5/plumbing/transport/internal/common" - "github.com/skeema/knownhosts" "github.com/kevinburke/ssh_config" "golang.org/x/crypto/ssh" @@ -127,17 +126,25 @@ func (c *command) connect() error { } hostWithPort := c.getHostWithPort() if config.HostKeyCallback == nil { - kh, err := newKnownHosts() + db, err := newKnownHostsDb() if err != nil { return err } - config.HostKeyCallback = kh.HostKeyCallback() - config.HostKeyAlgorithms = kh.HostKeyAlgorithms(hostWithPort) + + config.HostKeyCallback = db.HostKeyCallback() + config.HostKeyAlgorithms = db.HostKeyAlgorithms(hostWithPort) } else if len(config.HostKeyAlgorithms) == 0 { // Set the HostKeyAlgorithms based on HostKeyCallback. // For background see https://github.com/go-git/go-git/issues/411 as well as // https://github.com/golang/go/issues/29286 for root cause. - config.HostKeyAlgorithms = knownhosts.HostKeyAlgorithms(config.HostKeyCallback, hostWithPort) + db, err := newKnownHostsDb() + if err != nil { + return err + } + + // Note that the knownhost database is used, as it provides additional functionality + // to handle ssh cert-authorities. + config.HostKeyAlgorithms = db.HostKeyAlgorithms(hostWithPort) } overrideConfig(c.config, config) diff --git a/plumbing/transport/test/receive_pack.go b/plumbing/transport/test/receive_pack.go index 9414fba..d4d2b10 100644 --- a/plumbing/transport/test/receive_pack.go +++ b/plumbing/transport/test/receive_pack.go @@ -9,6 +9,7 @@ import ( "os" "path/filepath" + . "github.com/go-git/go-git/v5/internal/test" "github.com/go-git/go-git/v5/plumbing" "github.com/go-git/go-git/v5/plumbing/format/packfile" "github.com/go-git/go-git/v5/plumbing/protocol/packp" @@ -42,7 +43,7 @@ func (s *ReceivePackSuite) TestAdvertisedReferencesNotExists(c *C) { r, err := s.Client.NewReceivePackSession(s.NonExistentEndpoint, s.EmptyAuth) c.Assert(err, IsNil) ar, err := r.AdvertisedReferences() - c.Assert(err, Equals, transport.ErrRepositoryNotFound) + c.Assert(err, ErrorIs, transport.ErrRepositoryNotFound) c.Assert(ar, IsNil) c.Assert(r.Close(), IsNil) @@ -54,7 +55,7 @@ func (s *ReceivePackSuite) TestAdvertisedReferencesNotExists(c *C) { } writer, err := r.ReceivePack(context.Background(), req) - c.Assert(err, Equals, transport.ErrRepositoryNotFound) + c.Assert(err, ErrorIs, transport.ErrRepositoryNotFound) c.Assert(writer, IsNil) c.Assert(r.Close(), IsNil) } |