diff options
author | Máximo Cuadros <mcuadros@gmail.com> | 2018-04-02 10:40:53 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-04-02 10:40:53 +0200 |
commit | 2cbff8d8ffa42dd1b399b3c94d094bdb2d1cdf5c (patch) | |
tree | 97bc184a52e6b0b654416d36ca1f848221ccf414 /plumbing/protocol/packp/advrefs.go | |
parent | 32931400e23550cc840d2c2ae7bb5a20ef1946e7 (diff) | |
parent | 0b523020ef35b56c1f2481ca3773d974a7b80945 (diff) | |
download | go-git-2cbff8d8ffa42dd1b399b3c94d094bdb2d1cdf5c.tar.gz |
Merge pull request #792 from ajnavarro/fix/support-no-symref-capability
Resolve HEAD if symRefs capability is not supported
Diffstat (limited to 'plumbing/protocol/packp/advrefs.go')
-rw-r--r-- | plumbing/protocol/packp/advrefs.go | 108 |
1 files changed, 99 insertions, 9 deletions
diff --git a/plumbing/protocol/packp/advrefs.go b/plumbing/protocol/packp/advrefs.go index 7d644bc..684e76a 100644 --- a/plumbing/protocol/packp/advrefs.go +++ b/plumbing/protocol/packp/advrefs.go @@ -2,6 +2,7 @@ package packp import ( "fmt" + "sort" "strings" "gopkg.in/src-d/go-git.v4/plumbing" @@ -68,30 +69,119 @@ func (a *AdvRefs) AddReference(r *plumbing.Reference) error { func (a *AdvRefs) AllReferences() (memory.ReferenceStorage, error) { s := memory.ReferenceStorage{} - if err := addRefs(s, a); err != nil { + if err := a.addRefs(s); err != nil { return s, plumbing.NewUnexpectedError(err) } return s, nil } -func addRefs(s storer.ReferenceStorer, ar *AdvRefs) error { - for name, hash := range ar.References { +func (a *AdvRefs) addRefs(s storer.ReferenceStorer) error { + for name, hash := range a.References { ref := plumbing.NewReferenceFromStrings(name, hash.String()) if err := s.SetReference(ref); err != nil { return err } } - return addSymbolicRefs(s, ar) + if a.supportSymrefs() { + return a.addSymbolicRefs(s) + } + + return a.resolveHead(s) } -func addSymbolicRefs(s storer.ReferenceStorer, ar *AdvRefs) error { - if !hasSymrefs(ar) { +// If the server does not support symrefs capability, +// we need to guess the reference where HEAD is pointing to. +// +// Git versions prior to 1.8.4.3 has an special procedure to get +// the reference where is pointing to HEAD: +// - Check if a reference called master exists. If exists and it +// has the same hash as HEAD hash, we can say that HEAD is pointing to master +// - If master does not exists or does not have the same hash as HEAD, +// order references and check in that order if that reference has the same +// hash than HEAD. If yes, set HEAD pointing to that branch hash +// - If no reference is found, throw an error +func (a *AdvRefs) resolveHead(s storer.ReferenceStorer) error { + if a.Head == nil { + return nil + } + + ref, err := s.Reference(plumbing.ReferenceName(plumbing.Master)) + + // check first if HEAD is pointing to master + if err == nil { + ok, err := a.createHeadIfCorrectReference(ref, s) + if err != nil { + return err + } + + if ok { + return nil + } + } + + if err != nil && err != plumbing.ErrReferenceNotFound { + return err + } + + // From here we are trying to guess the branch that HEAD is pointing + refIter, err := s.IterReferences() + if err != nil { + return err + } + + var refNames []string + err = refIter.ForEach(func(r *plumbing.Reference) error { + refNames = append(refNames, string(r.Name())) return nil + }) + if err != nil { + return err + } + + sort.Strings(refNames) + + var headSet bool + for _, refName := range refNames { + ref, err := s.Reference(plumbing.ReferenceName(refName)) + if err != nil { + return err + } + ok, err := a.createHeadIfCorrectReference(ref, s) + if err != nil { + return err + } + if ok { + headSet = true + break + } + } + + if !headSet { + return plumbing.ErrReferenceNotFound } - for _, symref := range ar.Capabilities.Get(capability.SymRef) { + return nil +} + +func (a *AdvRefs) createHeadIfCorrectReference( + reference *plumbing.Reference, + s storer.ReferenceStorer) (bool, error) { + if reference.Hash() == *a.Head { + headRef := plumbing.NewSymbolicReference(plumbing.HEAD, reference.Name()) + if err := s.SetReference(headRef); err != nil { + return false, err + } + + return true, nil + } + + return false, nil +} + +func (a *AdvRefs) addSymbolicRefs(s storer.ReferenceStorer) error { + for _, symref := range a.Capabilities.Get(capability.SymRef) { chunks := strings.Split(symref, ":") if len(chunks) != 2 { err := fmt.Errorf("bad number of `:` in symref value (%q)", symref) @@ -108,6 +198,6 @@ func addSymbolicRefs(s storer.ReferenceStorer, ar *AdvRefs) error { return nil } -func hasSymrefs(ar *AdvRefs) bool { - return ar.Capabilities.Supports(capability.SymRef) +func (a *AdvRefs) supportSymrefs() bool { + return a.Capabilities.Supports(capability.SymRef) } |