diff options
Diffstat (limited to 'plumbing/transport/ssh')
-rw-r--r-- | plumbing/transport/ssh/auth_method.go | 13 | ||||
-rw-r--r-- | plumbing/transport/ssh/common.go | 21 |
2 files changed, 18 insertions, 16 deletions
diff --git a/plumbing/transport/ssh/auth_method.go b/plumbing/transport/ssh/auth_method.go index a3e1ad1..84dfe14 100644 --- a/plumbing/transport/ssh/auth_method.go +++ b/plumbing/transport/ssh/auth_method.go @@ -179,9 +179,14 @@ type PublicKeysCallback struct { // NewSSHAgentAuth returns a PublicKeysCallback based on a SSH agent, it opens // a pipe with the SSH agent and uses the pipe as the implementer of the public // key callback function. -func NewSSHAgentAuth(user string) (AuthMethod, error) { - if user == "" { - user = DefaultUsername +func NewSSHAgentAuth(u string) (AuthMethod, error) { + if u == "" { + usr, err := user.Current() + if err != nil { + return nil, fmt.Errorf("error getting current user: %q", err) + } + + u = usr.Username } sshAgentAddr := os.Getenv("SSH_AUTH_SOCK") @@ -195,7 +200,7 @@ func NewSSHAgentAuth(user string) (AuthMethod, error) { } return &PublicKeysCallback{ - User: user, + User: u, Callback: agent.NewClient(pipe).Signers, }, nil } diff --git a/plumbing/transport/ssh/common.go b/plumbing/transport/ssh/common.go index 7b44a91..d53fc12 100644 --- a/plumbing/transport/ssh/common.go +++ b/plumbing/transport/ssh/common.go @@ -3,7 +3,6 @@ package ssh import ( "fmt" - "strings" "gopkg.in/src-d/go-git.v4/plumbing/transport" "gopkg.in/src-d/go-git.v4/plumbing/transport/internal/common" @@ -20,6 +19,8 @@ var DefaultAuthBuilder = func(user string) (AuthMethod, error) { return NewSSHAgentAuth(user) } +const DefaultPort = 22 + type runner struct{} func (r *runner) Command(cmd string, ep transport.Endpoint, auth transport.AuthMethod) (common.Command, error) { @@ -110,25 +111,21 @@ func (c *command) connect() error { } func (c *command) getHostWithPort() string { - host := c.endpoint.Host - if strings.Index(c.endpoint.Host, ":") == -1 { - host += ":22" + host := c.endpoint.Host() + port := c.endpoint.Port() + if port <= 0 { + port = DefaultPort } - return host + return fmt.Sprintf("%s:%d", host, port) } func (c *command) setAuthFromEndpoint() error { - var u string - if info := c.endpoint.User; info != nil { - u = info.Username() - } - var err error - c.auth, err = DefaultAuthBuilder(u) + c.auth, err = DefaultAuthBuilder(c.endpoint.User()) return err } func endpointToCommand(cmd string, ep transport.Endpoint) string { - return fmt.Sprintf("%s '%s'", cmd, ep.Path) + return fmt.Sprintf("%s '%s'", cmd, ep.Path()) } |