aboutsummaryrefslogtreecommitdiffstats
path: root/plumbing/transport/ssh/common.go
diff options
context:
space:
mode:
Diffstat (limited to 'plumbing/transport/ssh/common.go')
-rw-r--r--plumbing/transport/ssh/common.go52
1 files changed, 40 insertions, 12 deletions
diff --git a/plumbing/transport/ssh/common.go b/plumbing/transport/ssh/common.go
index 872379a..e4a3d18 100644
--- a/plumbing/transport/ssh/common.go
+++ b/plumbing/transport/ssh/common.go
@@ -9,13 +9,21 @@ import (
"gopkg.in/src-d/go-git.v4/plumbing/transport"
"gopkg.in/src-d/go-git.v4/plumbing/transport/internal/common"
- "golang.org/x/crypto/ssh"
"github.com/kevinburke/ssh_config"
+ "golang.org/x/crypto/ssh"
)
// DefaultClient is the default SSH client.
var DefaultClient = NewClient(nil)
+// DefaultSSHConfig is the reader used to access parameters stored in the
+// system's ssh_config files. If nil all the ssh_config are ignored.
+var DefaultSSHConfig sshConfig = ssh_config.DefaultUserSettings
+
+type sshConfig interface {
+ Get(alias, key string) string
+}
+
// NewClient creates a new SSH client with an optional *ssh.ClientConfig.
func NewClient(config *ssh.ClientConfig) transport.Transport {
return common.NewClient(&runner{config: config})
@@ -123,26 +131,46 @@ func (c *command) connect() error {
}
func (c *command) getHostWithPort() string {
+ if addr, found := c.doGetHostWithPortFromSSHConfig(); found {
+ return addr
+ }
+
host := c.endpoint.Host
+ port := c.endpoint.Port
+ if port <= 0 {
+ port = DefaultPort
+ }
- configHost := ssh_config.Get(host, "Hostname")
- if (configHost != "") {
- host = configHost
+ return fmt.Sprintf("%s:%d", host, port)
+}
+
+func (c *command) doGetHostWithPortFromSSHConfig() (addr string, found bool) {
+ if DefaultSSHConfig == nil {
+ return
}
+ host := c.endpoint.Host
port := c.endpoint.Port
- configPort := ssh_config.Get(host, "Port")
- if (configPort != "") {
- i, err := strconv.Atoi(configPort)
- if err != nil {
+
+ configHost := DefaultSSHConfig.Get(c.endpoint.Host, "Hostname")
+ if configHost != "" {
+ host = configHost
+ found = true
+ }
+
+ if !found {
+ return
+ }
+
+ configPort := DefaultSSHConfig.Get(c.endpoint.Host, "Port")
+ if configPort != "" {
+ if i, err := strconv.Atoi(configPort); err == nil {
port = i
}
}
- if port <= 0 {
- port = DefaultPort
- }
- return fmt.Sprintf("%s:%d", host, port)
+ addr = fmt.Sprintf("%s:%d", host, port)
+ return
}
func (c *command) setAuthFromEndpoint() error {