diff options
Diffstat (limited to 'remote.go')
-rw-r--r-- | remote.go | 233 |
1 files changed, 228 insertions, 5 deletions
@@ -8,9 +8,11 @@ import ( "gopkg.in/src-d/go-git.v4/config" "gopkg.in/src-d/go-git.v4/plumbing" "gopkg.in/src-d/go-git.v4/plumbing/format/packfile" + "gopkg.in/src-d/go-git.v4/plumbing/object" "gopkg.in/src-d/go-git.v4/plumbing/protocol/packp" "gopkg.in/src-d/go-git.v4/plumbing/protocol/packp/capability" "gopkg.in/src-d/go-git.v4/plumbing/protocol/packp/sideband" + "gopkg.in/src-d/go-git.v4/plumbing/revlist" "gopkg.in/src-d/go-git.v4/plumbing/storer" "gopkg.in/src-d/go-git.v4/plumbing/transport" "gopkg.in/src-d/go-git.v4/plumbing/transport/client" @@ -49,6 +51,72 @@ func (r *Remote) Fetch(o *FetchOptions) error { return err } +// Push performs a push to the remote. Returns NoErrAlreadyUpToDate if the +// remote was already up-to-date. +// +// TODO: Support deletes. +// TODO: Support pushing tags. +// TODO: Check if force update is given, otherwise reject non-fast forward. +func (r *Remote) Push(o *PushOptions) (err error) { + if o.RemoteName == "" { + o.RemoteName = r.c.Name + } + + if err := o.Validate(); err != nil { + return err + } + + if o.RemoteName != r.c.Name { + return fmt.Errorf("remote names don't match: %s != %s", o.RemoteName, r.c.Name) + } + + s, err := newSendPackSession(r.c.URL) + if err != nil { + return err + } + + ar, err := s.AdvertisedReferences() + if err != nil { + return err + } + + remoteRefs, err := ar.AllReferences() + if err != nil { + return err + } + + req := packp.NewReferenceUpdateRequestFromCapabilities(ar.Capabilities) + if err := r.addReferencesToUpdate(o.RefSpecs, remoteRefs, req); err != nil { + return err + } + + if len(req.Commands) == 0 { + return NoErrAlreadyUpToDate + } + + commits, err := commitsToPush(r.s, req.Commands) + if err != nil { + return err + } + + haves, err := referencesToHashes(remoteRefs) + if err != nil { + return err + } + + hashesToPush, err := revlist.Objects(r.s, commits, haves) + if err != nil { + return err + } + + rs, err := pushHashes(s, r.s, req, hashesToPush) + if err != nil { + return err + } + + return rs.Error() +} + func (r *Remote) fetch(o *FetchOptions) (refs storer.ReferenceStorer, err error) { if o.RemoteName == "" { o.RemoteName = r.c.Name @@ -62,7 +130,7 @@ func (r *Remote) fetch(o *FetchOptions) (refs storer.ReferenceStorer, err error) o.RefSpecs = r.c.Fetch } - s, err := r.newFetchPackSession() + s, err := newFetchPackSession(r.c.URL) if err != nil { return nil, err } @@ -105,18 +173,36 @@ func (r *Remote) fetch(o *FetchOptions) (refs storer.ReferenceStorer, err error) return remoteRefs, err } -func (r *Remote) newFetchPackSession() (transport.FetchPackSession, error) { - ep, err := transport.NewEndpoint(r.c.URL) +func newFetchPackSession(url string) (transport.FetchPackSession, error) { + c, ep, err := newClient(url) if err != nil { return nil, err } - c, err := client.NewClient(ep) + return c.NewFetchPackSession(ep) +} + +func newSendPackSession(url string) (transport.SendPackSession, error) { + c, ep, err := newClient(url) if err != nil { return nil, err } - return c.NewFetchPackSession(ep) + return c.NewSendPackSession(ep) +} + +func newClient(url string) (transport.Client, transport.Endpoint, error) { + ep, err := transport.NewEndpoint(url) + if err != nil { + return nil, transport.Endpoint{}, err + } + + c, err := client.NewClient(ep) + if err != nil { + return nil, transport.Endpoint{}, err + } + + return c, ep, err } func (r *Remote) fetchPack(o *FetchOptions, s transport.FetchPackSession, @@ -142,6 +228,75 @@ func (r *Remote) fetchPack(o *FetchOptions, s transport.FetchPackSession, return err } +func (r *Remote) addReferencesToUpdate(refspecs []config.RefSpec, + remoteRefs storer.ReferenceStorer, + req *packp.ReferenceUpdateRequest) error { + + for _, rs := range refspecs { + iter, err := r.s.IterReferences() + if err != nil { + return err + } + + err = iter.ForEach(func(ref *plumbing.Reference) error { + return r.addReferenceIfRefSpecMatches( + rs, remoteRefs, ref, req, + ) + }) + if err != nil { + return err + } + } + + return nil +} + +func (r *Remote) addReferenceIfRefSpecMatches(rs config.RefSpec, + remoteRefs storer.ReferenceStorer, localRef *plumbing.Reference, + req *packp.ReferenceUpdateRequest) error { + + if localRef.Type() != plumbing.HashReference { + return nil + } + + if !rs.Match(localRef.Name()) { + return nil + } + + dstName := rs.Dst(localRef.Name()) + oldHash := plumbing.ZeroHash + newHash := localRef.Hash() + + iter, err := remoteRefs.IterReferences() + if err != nil { + return err + } + + err = iter.ForEach(func(remoteRef *plumbing.Reference) error { + if remoteRef.Type() != plumbing.HashReference { + return nil + } + + if dstName != remoteRef.Name() { + return nil + } + + oldHash = remoteRef.Hash() + return nil + }) + + if oldHash == newHash { + return nil + } + + req.Commands = append(req.Commands, &packp.Command{ + Name: dstName, + Old: oldHash, + New: newHash, + }) + return nil +} + func getHaves(localRefs storer.ReferenceStorer) ([]plumbing.Hash, error) { iter, err := localRefs.IterReferences() if err != nil { @@ -337,6 +492,74 @@ func (r *Remote) buildFetchedTags(refs storer.ReferenceStorer) error { }) } +func commitsToPush(s storer.EncodedObjectStorer, commands []*packp.Command) ([]*object.Commit, error) { + var commits []*object.Commit + for _, cmd := range commands { + if cmd.New == plumbing.ZeroHash { + continue + } + + c, err := object.GetCommit(s, cmd.New) + if err != nil { + return nil, err + } + + commits = append(commits, c) + } + + return commits, nil +} + +func referencesToHashes(refs storer.ReferenceStorer) ([]plumbing.Hash, error) { + iter, err := refs.IterReferences() + if err != nil { + return nil, err + } + + var hs []plumbing.Hash + err = iter.ForEach(func(ref *plumbing.Reference) error { + if ref.Type() != plumbing.HashReference { + return nil + } + + hs = append(hs, ref.Hash()) + return nil + }) + if err != nil { + return nil, err + } + + return hs, nil +} + +func pushHashes(sess transport.SendPackSession, sto storer.EncodedObjectStorer, + req *packp.ReferenceUpdateRequest, hs []plumbing.Hash) (*packp.ReportStatus, error) { + + rd, wr := io.Pipe() + req.Packfile = rd + done := make(chan error) + go func() { + e := packfile.NewEncoder(wr, sto, false) + if _, err := e.Encode(hs); err != nil { + done <- wr.CloseWithError(err) + return + } + + done <- wr.Close() + }() + + rs, err := sess.SendPack(req) + if err != nil { + return nil, err + } + + if err := <-done; err != nil { + return nil, err + } + + return rs, nil +} + func (r *Remote) updateShallow(o *FetchOptions, resp *packp.UploadPackResponse) error { if o.Depth == 0 { return nil |