diff options
-rw-r--r-- | repository.go | 20 | ||||
-rw-r--r-- | repository_test.go | 21 |
2 files changed, 34 insertions, 7 deletions
diff --git a/repository.go b/repository.go index 1342e7d..0cda947 100644 --- a/repository.go +++ b/repository.go @@ -219,9 +219,21 @@ func (r *Repository) Object(h core.Hash) (Object, error) { } } -// Head returns the hash of the HEAD of the repository. If there is no -// HEAD, it then returns the hash of the HEAD of the default remote. If -// there is no default remote, it returns an error. -func (r *Repository) Head() (core.Hash, error) { +// Head returns the hash of the HEAD of the repository or the head of a +// remote, if one is passed. +func (r *Repository) Head(remote string) (core.Hash, error) { + if remote == "" { + return r.localHead() + } + + rem, ok := r.Remotes[remote] + if !ok { + return core.ZeroHash, fmt.Errorf("unable to find remote %q", remote) + } + + return rem.Head() +} + +func (r *Repository) localHead() (core.Hash, error) { return core.ZeroHash, nil } diff --git a/repository_test.go b/repository_test.go index 06f8625..da02a4d 100644 --- a/repository_test.go +++ b/repository_test.go @@ -224,7 +224,7 @@ func (s *SuiteRepository) TestHeadFromFs(c *C) { repo, err := NewRepositoryFromFS(fs, gitPath) c.Assert(err, IsNil, com) - head, err := repo.Head() + head, err := repo.Head("") c.Assert(err, IsNil) c.Assert(head, Equals, fix.head) @@ -236,13 +236,28 @@ func (s *SuiteRepository) TestHeadFromRemote(c *C) { c.Assert(err, IsNil) upSrv := &MockGitUploadPackService{} - r.Remotes["origin"].upSrv = upSrv + r.Remotes[DefaultRemoteName].upSrv = upSrv + err = r.Remotes[DefaultRemoteName].Connect() + c.Assert(err, IsNil) + info, err := upSrv.Info() c.Assert(err, IsNil) expected := info.Head - obtained, err := r.Head() + obtained, err := r.Head(DefaultRemoteName) c.Assert(err, IsNil) c.Assert(obtained, Equals, expected) } + +func (s *SuiteRepository) TestHeadFromRemoteError(c *C) { + r, err := NewRepository(RepositoryFixture, nil) + c.Assert(err, IsNil) + + upSrv := &MockGitUploadPackService{} + r.Remotes[DefaultRemoteName].upSrv = upSrv + + remote := "not found" + _, err = r.Head(remote) + c.Assert(err, ErrorMatches, fmt.Sprintf("unable to find remote %q", remote)) +} |