diff options
author | Santiago M. Mola <santi@mola.io> | 2016-12-16 19:31:01 +0100 |
---|---|---|
committer | Máximo Cuadros <mcuadros@gmail.com> | 2016-12-16 19:31:01 +0100 |
commit | b3adbed0ce15d82bf41d23cc507c5dd47a6c4260 (patch) | |
tree | 156ce904ed4af15d89a85876e3585e4c55204ae8 | |
parent | 950676c36030a8796c0a69a8aae606ff1f448b03 (diff) | |
download | go-git-b3adbed0ce15d82bf41d23cc507c5dd47a6c4260.tar.gz |
remote: make Fetch atomic. (#185)
* Remote now exposes only Fetch. No Connect, Disconnect, etc.
* Repository uses a private fetch method in Remote for Clone/Pull.
* getting capabilities, HEAD or other information from remote
requires using the lower level client.
* add Fetch method to Repository.
-rw-r--r-- | options.go | 20 | ||||
-rw-r--r-- | remote.go | 279 | ||||
-rw-r--r-- | remote_test.go | 94 | ||||
-rw-r--r-- | repository.go | 140 | ||||
-rw-r--r-- | repository_test.go | 23 |
5 files changed, 283 insertions, 273 deletions
@@ -51,19 +51,19 @@ func (o *CloneOptions) Validate() error { return nil } -// PullOptions describe how a pull should be perform +// PullOptions describe how a pull should be perform. type PullOptions struct { - // Name of the remote to be pulled + // Name of the remote to be pulled. If empty, uses the default. RemoteName string - // Remote branch to clone + // Remote branch to clone. If empty, uses HEAD. ReferenceName plumbing.ReferenceName - // Fetch only ReferenceName if true + // Fetch only ReferenceName if true. SingleBranch bool - // Limit fetching to the specified number of commits + // Limit fetching to the specified number of commits. Depth int } -// Validate validate the fields and set the default values +// Validate validate the fields and set the default values. func (o *PullOptions) Validate() error { if o.RemoteName == "" { o.RemoteName = DefaultRemoteName @@ -78,7 +78,9 @@ func (o *PullOptions) Validate() error { // FetchOptions describe how a fetch should be perform type FetchOptions struct { - RefSpecs []config.RefSpec + // Name of the remote to fetch from. Defaults to origin. + RemoteName string + RefSpecs []config.RefSpec // Depth limit fetching to the specified number of commits from the tip of // each remote branch history. Depth int @@ -86,6 +88,10 @@ type FetchOptions struct { // Validate validate the fields and set the default values func (o *FetchOptions) Validate() error { + if o.RemoteName == "" { + o.RemoteName = DefaultRemoteName + } + for _, r := range o.RefSpecs { if !r.IsValid() { return ErrInvalidRefSpec @@ -25,13 +25,6 @@ type Remote struct { c *config.RemoteConfig s Storer p sideband.Progress - - // cache fields, there during the connection is open - endpoint transport.Endpoint - client transport.Client - fetchSession transport.FetchPackSession - advRefs *packp.AdvRefs - refs memory.ReferenceStorage } func newRemote(s Storer, p sideband.Progress, c *config.RemoteConfig) *Remote { @@ -43,86 +36,93 @@ func (r *Remote) Config() *config.RemoteConfig { return r.c } -// Connect with the endpoint -func (r *Remote) Connect() error { - if err := r.initClient(); err != nil { - return err - } +func (r *Remote) String() string { + fetch := r.c.URL + push := r.c.URL - var err error - r.fetchSession, err = r.client.NewFetchPackSession(r.endpoint) - if err != nil { - return err - } + return fmt.Sprintf("%s\t%s (fetch)\n%[1]s\t%s (push)", r.c.Name, fetch, push) +} - return r.retrieveAdvertisedReferences() +// Fetch fetches references from the remote to the local repository. +func (r *Remote) Fetch(o *FetchOptions) error { + _, err := r.fetch(o) + return err } -func (r *Remote) initClient() error { - var err error - r.endpoint, err = transport.NewEndpoint(r.c.URL) - if err != nil { - return err +func (r *Remote) fetch(o *FetchOptions) (refs storer.ReferenceStorer, err error) { + if o.RemoteName == "" { + o.RemoteName = r.c.Name } - if r.client != nil { - return nil + if err := o.Validate(); err != nil { + return nil, err } - r.client, err = client.NewClient(r.endpoint) + if len(o.RefSpecs) == 0 { + o.RefSpecs = r.c.Fetch + } + + s, err := r.newFetchPackSession() if err != nil { - return err + return nil, err } - return nil -} + defer ioutil.CheckClose(s, &err) -func (r *Remote) retrieveAdvertisedReferences() error { - var err error - r.advRefs, err = r.fetchSession.AdvertisedReferences() + ar, err := s.AdvertisedReferences() if err != nil { - return err + return nil, err } - r.refs, err = r.advRefs.AllReferences() - return err -} + req, err := r.newUploadPackRequest(o, ar) + if err != nil { + return nil, err + } -// AdvertisedReferences returns the git-upload-pack advertised references. -func (r *Remote) AdvertisedReferences() *packp.AdvRefs { - return r.advRefs -} + remoteRefs, err := ar.AllReferences() + if err != nil { + return nil, err + } -// Capabilities returns the remote capabilities -func (r *Remote) Capabilities() *capability.List { - return r.advRefs.Capabilities -} + req.Wants, err = getWants(o.RefSpecs, r.s, remoteRefs) + if len(req.Wants) == 0 { + return remoteRefs, NoErrAlreadyUpToDate + } -// Fetch returns a reader using the request -func (r *Remote) Fetch(o *FetchOptions) (err error) { - if err := o.Validate(); err != nil { - return err + req.Haves, err = getHaves(r.s) + if err != nil { + return nil, err } - if len(o.RefSpecs) == 0 { - o.RefSpecs = r.c.Fetch + if err := r.fetchPack(o, s, req); err != nil { + return nil, err } - refs, err := r.getWantedReferences(o.RefSpecs) - if err != nil { - return err + if err := r.updateLocalReferenceStorage(o.RefSpecs, remoteRefs); err != nil { + return nil, err } - if len(refs) == 0 { - return NoErrAlreadyUpToDate + return remoteRefs, err +} + +func (r *Remote) newFetchPackSession() (transport.FetchPackSession, error) { + ep, err := transport.NewEndpoint(r.c.URL) + if err != nil { + return nil, err } - req, err := r.buildRequest(r.s, o, refs) + c, err := client.NewClient(ep) if err != nil { - return err + return nil, err } - reader, err := r.fetchSession.FetchPack(req) + return c.NewFetchPackSession(ep) +} + +func (r *Remote) fetchPack(o *FetchOptions, s transport.FetchPackSession, + req *packp.UploadPackRequest) (err error) { + + reader, err := s.FetchPack(req) if err != nil { return err } @@ -134,21 +134,37 @@ func (r *Remote) Fetch(o *FetchOptions) (err error) { } if err = r.updateObjectStorage( - r.buildSidebandIfSupported(req.Capabilities, reader), + buildSidebandIfSupported(req.Capabilities, reader, r.p), ); err != nil { return err } - return r.updateLocalReferenceStorage(o.RefSpecs, refs) + return err } -func (r *Remote) getWantedReferences(spec []config.RefSpec) ([]*plumbing.Reference, error) { - var refs []*plumbing.Reference - iter, err := r.References() +func getHaves(localRefs storer.ReferenceStorer) ([]plumbing.Hash, error) { + iter, err := localRefs.IterReferences() if err != nil { - return refs, err + return nil, err } + var haves []plumbing.Hash + err = iter.ForEach(func(ref *plumbing.Reference) error { + if ref.Type() != plumbing.HashReference { + return nil + } + + haves = append(haves, ref.Hash()) + return nil + }) + if err != nil { + return nil, err + } + + return haves, nil +} + +func getWants(spec []config.RefSpec, localStorer Storer, remoteRefs storer.ReferenceStorer) ([]plumbing.Hash, error) { wantTags := true for _, s := range spec { if !s.IsWildcard() { @@ -157,60 +173,82 @@ func (r *Remote) getWantedReferences(spec []config.RefSpec) ([]*plumbing.Referen } } - return refs, iter.ForEach(func(ref *plumbing.Reference) error { - if ref.Type() != plumbing.HashReference { - return nil - } + iter, err := remoteRefs.IterReferences() + if err != nil { + return nil, err + } + wants := map[plumbing.Hash]bool{} + err = iter.ForEach(func(ref *plumbing.Reference) error { if !config.MatchAny(spec, ref.Name()) { if !ref.IsTag() || !wantTags { return nil } } - _, err := r.s.EncodedObject(plumbing.CommitObject, ref.Hash()) - if err == plumbing.ErrObjectNotFound { - refs = append(refs, ref) + if ref.Type() == plumbing.SymbolicReference { + ref, err = storer.ResolveReference(remoteRefs, ref.Name()) + if err != nil { + return err + } + } + + if ref.Type() != plumbing.HashReference { return nil } - return err - }) -} + hash := ref.Hash() + exists, err := commitExists(localStorer, hash) + if err != nil { + return err + } -func (r *Remote) buildRequest( - s storer.ReferenceStorer, o *FetchOptions, refs []*plumbing.Reference, -) (*packp.UploadPackRequest, error) { - req := packp.NewUploadPackRequestFromCapabilities(r.advRefs.Capabilities) + if !exists { + wants[hash] = true + } - if o.Depth != 0 { - req.Depth = packp.DepthCommits(o.Depth) - req.Capabilities.Set(capability.Shallow) + return nil + }) + if err != nil { + return nil, err } - if r.p == nil && r.advRefs.Capabilities.Supports(capability.NoProgress) { - req.Capabilities.Set(capability.NoProgress) + var result []plumbing.Hash + for h := range wants { + result = append(result, h) } - for _, ref := range refs { - req.Wants = append(req.Wants, ref.Hash()) - } + return result, nil +} - i, err := s.IterReferences() - if err != nil { - return nil, err +func commitExists(s storer.EncodedObjectStorer, h plumbing.Hash) (bool, error) { + _, err := s.EncodedObject(plumbing.CommitObject, h) + if err == plumbing.ErrObjectNotFound { + return false, nil } - err = i.ForEach(func(ref *plumbing.Reference) error { - if ref.Type() != plumbing.HashReference { - return nil + return true, err +} + +func (r *Remote) newUploadPackRequest(o *FetchOptions, + ar *packp.AdvRefs) (*packp.UploadPackRequest, error) { + + req := packp.NewUploadPackRequestFromCapabilities(ar.Capabilities) + + if o.Depth != 0 { + req.Depth = packp.DepthCommits(o.Depth) + if err := req.Capabilities.Set(capability.Shallow); err != nil { + return nil, err } + } - req.Haves = append(req.Haves, ref.Hash()) - return nil - }) + if r.p == nil && ar.Capabilities.Supports(capability.NoProgress) { + if err := req.Capabilities.Set(capability.NoProgress); err != nil { + return nil, err + } + } - return req, err + return req, nil } func (r *Remote) updateObjectStorage(reader io.Reader) error { @@ -235,7 +273,7 @@ func (r *Remote) updateObjectStorage(reader io.Reader) error { return err } -func (r *Remote) buildSidebandIfSupported(l *capability.List, reader io.Reader) io.Reader { +func buildSidebandIfSupported(l *capability.List, reader io.Reader, p sideband.Progress) io.Reader { var t sideband.Type switch { @@ -248,12 +286,12 @@ func (r *Remote) buildSidebandIfSupported(l *capability.List, reader io.Reader) } d := sideband.NewDemuxer(t, reader) - d.Progress = r.p + d.Progress = p return d } -func (r *Remote) updateLocalReferenceStorage(specs []config.RefSpec, refs []*plumbing.Reference) error { +func (r *Remote) updateLocalReferenceStorage(specs []config.RefSpec, refs memory.ReferenceStorage) error { for _, spec := range specs { for _, ref := range refs { if !spec.Match(ref.Name()) { @@ -272,11 +310,11 @@ func (r *Remote) updateLocalReferenceStorage(specs []config.RefSpec, refs []*plu } } - return r.buildFetchedTags() + return r.buildFetchedTags(refs) } -func (r *Remote) buildFetchedTags() error { - iter, err := r.References() +func (r *Remote) buildFetchedTags(refs storer.ReferenceStorer) error { + iter, err := refs.IterReferences() if err != nil { return err } @@ -306,40 +344,3 @@ func (r *Remote) updateShallow(o *FetchOptions, resp *packp.UploadPackResponse) return r.s.SetShallow(resp.Shallows) } - -// Head returns the Reference of the HEAD -func (r *Remote) Head() *plumbing.Reference { - ref, err := storer.ResolveReference(r.refs, plumbing.HEAD) - if err != nil { - return nil - } - - return ref -} - -// Reference returns a Reference for a ReferenceName. -func (r *Remote) Reference(name plumbing.ReferenceName, resolved bool) (*plumbing.Reference, error) { - if resolved { - return storer.ResolveReference(r.refs, name) - } - - return r.refs.Reference(name) -} - -// References returns an iterator for all references. -func (r *Remote) References() (storer.ReferenceIter, error) { - return r.refs.IterReferences() -} - -// Disconnect from the remote and save the config -func (r *Remote) Disconnect() error { - r.advRefs = nil - return r.fetchSession.Close() -} - -func (r *Remote) String() string { - fetch := r.c.URL - push := r.c.URL - - return fmt.Sprintf("%s\t%s (fetch)\n%[1]s\t%s (push)", r.c.Name, fetch, push) -} diff --git a/remote_test.go b/remote_test.go index 5af2e2c..f75adee 100644 --- a/remote_test.go +++ b/remote_test.go @@ -8,7 +8,6 @@ import ( "gopkg.in/src-d/go-git.v4/config" "gopkg.in/src-d/go-git.v4/plumbing" - "gopkg.in/src-d/go-git.v4/plumbing/protocol/packp/capability" "gopkg.in/src-d/go-git.v4/plumbing/storer" "gopkg.in/src-d/go-git.v4/storage/filesystem" "gopkg.in/src-d/go-git.v4/storage/memory" @@ -23,63 +22,35 @@ type RemoteSuite struct { var _ = Suite(&RemoteSuite{}) -func (s *RemoteSuite) TestConnect(c *C) { - url := s.GetBasicLocalRepositoryURL() - r := newRemote(nil, nil, &config.RemoteConfig{Name: "foo", URL: url}) - - err := r.Connect() - c.Assert(err, IsNil) -} - -func (s *RemoteSuite) TestnewRemoteInvalidEndpoint(c *C) { +func (s *RemoteSuite) TestFetchInvalidEndpoint(c *C) { r := newRemote(nil, nil, &config.RemoteConfig{Name: "foo", URL: "qux"}) - - err := r.Connect() - c.Assert(err, NotNil) + err := r.Fetch(&FetchOptions{}) + c.Assert(err, ErrorMatches, ".*invalid endpoint.*") } -func (s *RemoteSuite) TestnewRemoteNonExistentEndpoint(c *C) { +func (s *RemoteSuite) TestFetchNonExistentEndpoint(c *C) { r := newRemote(nil, nil, &config.RemoteConfig{Name: "foo", URL: "ssh://non-existent/foo.git"}) - - err := r.Connect() + err := r.Fetch(&FetchOptions{}) c.Assert(err, NotNil) } -func (s *RemoteSuite) TestnewRemoteInvalidSchemaEndpoint(c *C) { +func (s *RemoteSuite) TestFetchInvalidSchemaEndpoint(c *C) { r := newRemote(nil, nil, &config.RemoteConfig{Name: "foo", URL: "qux://foo"}) - - err := r.Connect() - c.Assert(err, NotNil) + err := r.Fetch(&FetchOptions{}) + c.Assert(err, ErrorMatches, ".*unsupported scheme.*") } -func (s *RemoteSuite) TestInfo(c *C) { - url := s.GetBasicLocalRepositoryURL() - r := newRemote(nil, nil, &config.RemoteConfig{Name: "foo", URL: url}) - c.Assert(r.AdvertisedReferences(), IsNil) - c.Assert(r.Connect(), IsNil) - c.Assert(r.AdvertisedReferences(), NotNil) - c.Assert(r.AdvertisedReferences().Capabilities.Get(capability.Agent), NotNil) -} - -func (s *RemoteSuite) TestDefaultBranch(c *C) { - url := s.GetBasicLocalRepositoryURL() - r := newRemote(nil, nil, &config.RemoteConfig{Name: "foo", URL: url}) - c.Assert(r.Connect(), IsNil) - c.Assert(r.Head().Name(), Equals, plumbing.ReferenceName("refs/heads/master")) -} - -func (s *RemoteSuite) TestCapabilities(c *C) { - url := s.GetBasicLocalRepositoryURL() - r := newRemote(nil, nil, &config.RemoteConfig{Name: "foo", URL: url}) - c.Assert(r.Connect(), IsNil) - c.Assert(r.Capabilities().Get(capability.Agent), HasLen, 1) +func (s *RemoteSuite) TestFetchInvalidFetchOptions(c *C) { + r := newRemote(nil, nil, &config.RemoteConfig{Name: "foo", URL: "qux://foo"}) + invalid := config.RefSpec("^*$ñ") + err := r.Fetch(&FetchOptions{RefSpecs: []config.RefSpec{invalid}}) + c.Assert(err, Equals, ErrInvalidRefSpec) } func (s *RemoteSuite) TestFetch(c *C) { url := s.GetBasicLocalRepositoryURL() sto := memory.NewStorage() r := newRemote(sto, nil, &config.RemoteConfig{Name: "foo", URL: url}) - c.Assert(r.Connect(), IsNil) refspec := config.RefSpec("+refs/heads/*:refs/remotes/origin/*") err := r.Fetch(&FetchOptions{ @@ -104,7 +75,6 @@ func (s *RemoteSuite) TestFetchDepth(c *C) { url := s.GetBasicLocalRepositoryURL() sto := memory.NewStorage() r := newRemote(sto, nil, &config.RemoteConfig{Name: "foo", URL: url}) - c.Assert(r.Connect(), IsNil) refspec := config.RefSpec("+refs/heads/*:refs/remotes/origin/*") err := r.Fetch(&FetchOptions{ @@ -140,7 +110,6 @@ func (s *RemoteSuite) TestFetchWithProgress(c *C) { buf := bytes.NewBuffer(nil) r := newRemote(sto, buf, &config.RemoteConfig{Name: "foo", URL: url}) - c.Assert(r.Connect(), IsNil) refspec := config.RefSpec("+refs/heads/*:refs/remotes/origin/*") err := r.Fetch(&FetchOptions{ @@ -176,7 +145,6 @@ func (s *RemoteSuite) TestFetchWithPackfileWriter(c *C) { url := s.GetBasicLocalRepositoryURL() r := newRemote(mock, nil, &config.RemoteConfig{Name: "foo", URL: url}) - c.Assert(r.Connect(), IsNil) refspec := config.RefSpec("+refs/heads/*:refs/remotes/origin/*") err = r.Fetch(&FetchOptions{ @@ -202,7 +170,6 @@ func (s *RemoteSuite) TestFetchNoErrAlreadyUpToDate(c *C) { url := s.GetBasicLocalRepositoryURL() sto := memory.NewStorage() r := newRemote(sto, nil, &config.RemoteConfig{Name: "foo", URL: url}) - c.Assert(r.Connect(), IsNil) refspec := config.RefSpec("+refs/heads/*:refs/remotes/origin/*") o := &FetchOptions{ @@ -215,41 +182,6 @@ func (s *RemoteSuite) TestFetchNoErrAlreadyUpToDate(c *C) { c.Assert(err, Equals, NoErrAlreadyUpToDate) } -func (s *RemoteSuite) TestHead(c *C) { - url := s.GetBasicLocalRepositoryURL() - r := newRemote(nil, nil, &config.RemoteConfig{Name: "foo", URL: url}) - - err := r.Connect() - c.Assert(err, IsNil) - c.Assert(r.Head().Hash(), Equals, plumbing.NewHash("6ecf0ef2c2dffb796033e5a02219af86ec6584e5")) -} - -func (s *RemoteSuite) TestRef(c *C) { - url := s.GetBasicLocalRepositoryURL() - r := newRemote(nil, nil, &config.RemoteConfig{Name: "foo", URL: url}) - err := r.Connect() - c.Assert(err, IsNil) - - ref, err := r.Reference(plumbing.HEAD, false) - c.Assert(err, IsNil) - c.Assert(ref.Name(), Equals, plumbing.HEAD) - - ref, err = r.Reference(plumbing.HEAD, true) - c.Assert(err, IsNil) - c.Assert(ref.Name(), Equals, plumbing.ReferenceName("refs/heads/master")) -} - -func (s *RemoteSuite) TestRefs(c *C) { - url := s.GetBasicLocalRepositoryURL() - r := newRemote(nil, nil, &config.RemoteConfig{Name: "foo", URL: url}) - err := r.Connect() - c.Assert(err, IsNil) - - iter, err := r.References() - c.Assert(err, IsNil) - c.Assert(iter, NotNil) -} - func (s *RemoteSuite) TestString(c *C) { r := newRemote(nil, nil, &config.RemoteConfig{ Name: "foo", diff --git a/repository.go b/repository.go index 6a6e01d..7d964a4 100644 --- a/repository.go +++ b/repository.go @@ -164,26 +164,43 @@ func (r *Repository) Clone(o *CloneOptions) error { return err } - if err = remote.Connect(); err != nil { + remoteRefs, err := remote.fetch(&FetchOptions{ + RefSpecs: r.cloneRefSpec(o, c), + Depth: o.Depth, + }) + if err != nil { return err } - defer remote.Disconnect() - - if err := r.updateRemoteConfig(remote, o, c); err != nil { + head, err := storer.ResolveReference(remoteRefs, o.ReferenceName) + if err != nil { return err } - if err = remote.Fetch(&FetchOptions{Depth: o.Depth}); err != nil { + if err := r.createReferences(c.Fetch, o.ReferenceName, head); err != nil { return err } - head, err := remote.Reference(o.ReferenceName, true) - if err != nil { - return err + return r.updateRemoteConfig(remote, o, c, head) +} + +func (r *Repository) cloneRefSpec(o *CloneOptions, + c *config.RemoteConfig) []config.RefSpec { + + if !o.SingleBranch { + return c.Fetch + } + + var rs string + + if o.ReferenceName == plumbing.HEAD { + rs = fmt.Sprintf(refspecSingleBranchHEAD, c.Name) + } else { + rs = fmt.Sprintf(refspecSingleBranch, + o.ReferenceName.Short(), c.Name) } - return r.createReferences(head) + return []config.RefSpec{config.RefSpec(rs)} } func (r *Repository) setIsBare(isBare bool) error { @@ -196,28 +213,21 @@ func (r *Repository) setIsBare(isBare bool) error { return r.s.SetConfig(cfg) } -const refspecSingleBranch = "+refs/heads/%s:refs/remotes/%s/%[1]s" +const ( + refspecSingleBranch = "+refs/heads/%s:refs/remotes/%s/%[1]s" + refspecSingleBranchHEAD = "+HEAD:refs/remotes/%s/HEAD" +) + +func (r *Repository) updateRemoteConfig(remote *Remote, o *CloneOptions, + c *config.RemoteConfig, head *plumbing.Reference) error { -func (r *Repository) updateRemoteConfig( - remote *Remote, o *CloneOptions, c *config.RemoteConfig, -) error { if !o.SingleBranch { return nil } - refs, err := remote.AdvertisedReferences().AllReferences() - if err != nil { - return err - } - - head, err := storer.ResolveReference(refs, o.ReferenceName) - if err != nil { - return err - } - - c.Fetch = []config.RefSpec{ - config.RefSpec(fmt.Sprintf(refspecSingleBranch, head.Name().Short(), c.Name)), - } + c.Fetch = []config.RefSpec{config.RefSpec(fmt.Sprintf( + refspecSingleBranch, head.Name().Short(), c.Name, + ))} cfg, err := r.s.Config() if err != nil { @@ -226,22 +236,59 @@ func (r *Repository) updateRemoteConfig( cfg.Remotes[c.Name] = c return r.s.SetConfig(cfg) - } -func (r *Repository) createReferences(ref *plumbing.Reference) error { - if !ref.IsBranch() { - // detached HEAD mode - head := plumbing.NewHashReference(plumbing.HEAD, ref.Hash()) +func (r *Repository) createReferences(spec []config.RefSpec, + headName plumbing.ReferenceName, resolvedHead *plumbing.Reference) error { + + if !resolvedHead.IsBranch() { + // Detached HEAD mode + head := plumbing.NewHashReference(plumbing.HEAD, resolvedHead.Hash()) return r.s.SetReference(head) } - if err := r.s.SetReference(ref); err != nil { + // Create local reference for the resolved head + if err := r.s.SetReference(resolvedHead); err != nil { return err } - head := plumbing.NewSymbolicReference(plumbing.HEAD, ref.Name()) - return r.s.SetReference(head) + // Create local symbolic HEAD + head := plumbing.NewSymbolicReference(plumbing.HEAD, resolvedHead.Name()) + if err := r.s.SetReference(head); err != nil { + return err + } + + return r.createRemoteHeadReference(spec, resolvedHead) +} + +func (r *Repository) createRemoteHeadReference(spec []config.RefSpec, + resolvedHead *plumbing.Reference) error { + + // Create resolved HEAD reference with remote prefix if it does not + // exist. This is needed when using single branch and HEAD. + for _, rs := range spec { + name := resolvedHead.Name() + if !rs.Match(name) { + continue + } + + name = rs.Dst(name) + _, err := r.s.Reference(name) + if err == plumbing.ErrReferenceNotFound { + ref := plumbing.NewHashReference(name, resolvedHead.Hash()) + if err := r.s.SetReference(ref); err != nil { + return err + } + + continue + } + + if err != nil { + return err + } + } + + return nil } // IsEmpty returns true if the repository is empty @@ -269,32 +316,33 @@ func (r *Repository) Pull(o *PullOptions) error { return err } - if err = remote.Connect(); err != nil { + remoteRefs, err := remote.fetch(&FetchOptions{ + Depth: o.Depth, + }) + if err != nil { return err } - defer remote.Disconnect() - - head, err := remote.Reference(o.ReferenceName, true) + head, err := storer.ResolveReference(remoteRefs, o.ReferenceName) if err != nil { return err } - if err = remote.Connect(); err != nil { + return r.createReferences(remote.c.Fetch, o.ReferenceName, head) +} + +// Fetch fetches changes from a remote repository. +func (r *Repository) Fetch(o *FetchOptions) error { + if err := o.Validate(); err != nil { return err } - defer remote.Disconnect() - - err = remote.Fetch(&FetchOptions{ - Depth: o.Depth, - }) - + remote, err := r.Remote(o.RemoteName) if err != nil { return err } - return r.createReferences(head) + return remote.Fetch(o) } // object.Commit return the commit with the given hash diff --git a/repository_test.go b/repository_test.go index 91a5c71..4d17dce 100644 --- a/repository_test.go +++ b/repository_test.go @@ -84,6 +84,29 @@ func (s *RepositorySuite) TestDeleteRemote(c *C) { c.Assert(alt, IsNil) } +func (s *RepositorySuite) TestFetch(c *C) { + r := NewMemoryRepository() + _, err := r.CreateRemote(&config.RemoteConfig{ + Name: DefaultRemoteName, + URL: s.GetBasicLocalRepositoryURL(), + }) + c.Assert(err, IsNil) + c.Assert(r.Fetch(&FetchOptions{}), IsNil) + + remotes, err := r.Remotes() + c.Assert(err, IsNil) + c.Assert(remotes, HasLen, 1) + + _, err = r.Reference(plumbing.HEAD, false) + c.Assert(err, Equals, plumbing.ErrReferenceNotFound) + + branch, err := r.Reference("refs/remotes/origin/master", false) + c.Assert(err, IsNil) + c.Assert(branch, NotNil) + c.Assert(branch.Type(), Equals, plumbing.HashReference) + c.Assert(branch.Hash().String(), Equals, "6ecf0ef2c2dffb796033e5a02219af86ec6584e5") +} + func (s *RepositorySuite) TestClone(c *C) { r := NewMemoryRepository() |