From cd435d9a72a866f8dcb5e64d4fd0d821b8569006 Mon Sep 17 00:00:00 2001 From: Máximo Cuadros Date: Tue, 21 Feb 2017 18:17:29 +0100 Subject: plumbing/transport: git, error on empty SSH_AUTH_SOCK --- plumbing/transport/ssh/auth_method.go | 14 +++++++++++--- plumbing/transport/ssh/auth_method_test.go | 15 +++++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) (limited to 'plumbing') diff --git a/plumbing/transport/ssh/auth_method.go b/plumbing/transport/ssh/auth_method.go index 9c3d6f3..f53e510 100644 --- a/plumbing/transport/ssh/auth_method.go +++ b/plumbing/transport/ssh/auth_method.go @@ -1,6 +1,7 @@ package ssh import ( + "errors" "fmt" "net" "os" @@ -9,6 +10,8 @@ import ( "golang.org/x/crypto/ssh/agent" ) +var ErrEmptySSHAgentAddr = errors.New("SSH_AUTH_SOCK env variable is required") + // AuthMethod is the interface all auth methods for the ssh client // must implement. The clientConfig method returns the ssh client // configuration needed to establish an ssh connection. @@ -138,16 +141,21 @@ func (a *PublicKeysCallback) clientConfig() *ssh.ClientConfig { const DefaultSSHUsername = "git" -// Opens a pipe with the ssh agent and uses the pipe +// NewSSHAgentAuth opens a pipe with the SSH agent and uses the pipe // as the implementer of the public key callback function. func NewSSHAgentAuth(user string) (*PublicKeysCallback, error) { if user == "" { user = DefaultSSHUsername } - pipe, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")) + sshAgentAddr := os.Getenv("SSH_AUTH_SOCK") + if sshAgentAddr == "" { + return nil, ErrEmptySSHAgentAddr + } + + pipe, err := net.Dial("unix", sshAgentAddr) if err != nil { - return nil, err + return nil, fmt.Errorf("error connecting to SSH agent: %q", err) } return &PublicKeysCallback{ diff --git a/plumbing/transport/ssh/auth_method_test.go b/plumbing/transport/ssh/auth_method_test.go index f9e7dec..412e721 100644 --- a/plumbing/transport/ssh/auth_method_test.go +++ b/plumbing/transport/ssh/auth_method_test.go @@ -2,6 +2,7 @@ package ssh import ( "fmt" + "os" . "gopkg.in/check.v1" ) @@ -89,3 +90,17 @@ func (s *SuiteCommon) TestPublicKeysCallbackString(c *C) { } c.Assert(a.String(), Equals, fmt.Sprintf("user: test, name: %s", PublicKeysCallbackName)) } +func (s *SuiteCommon) TestNewSSHAgentAuth(c *C) { + addr := os.Getenv("SSH_AUTH_SOCK") + err := os.Unsetenv("SSH_AUTH_SOCK") + c.Assert(err, IsNil) + + defer func() { + err := os.Setenv("SSH_AUTH_SOCK", addr) + c.Assert(err, IsNil) + }() + + k, err := NewSSHAgentAuth("foo") + c.Assert(k, IsNil) + c.Assert(err, Equals, ErrEmptySSHAgentAddr) +} -- cgit