diff options
Diffstat (limited to 'plumbing/transport')
-rw-r--r-- | plumbing/transport/ssh/common.go | 39 |
1 files changed, 36 insertions, 3 deletions
diff --git a/plumbing/transport/ssh/common.go b/plumbing/transport/ssh/common.go index d53fc12..6d1c51b 100644 --- a/plumbing/transport/ssh/common.go +++ b/plumbing/transport/ssh/common.go @@ -3,6 +3,7 @@ package ssh import ( "fmt" + "reflect" "gopkg.in/src-d/go-git.v4/plumbing/transport" "gopkg.in/src-d/go-git.v4/plumbing/transport/internal/common" @@ -11,7 +12,12 @@ import ( ) // DefaultClient is the default SSH client. -var DefaultClient = common.NewClient(&runner{}) +var DefaultClient = NewClient(nil) + +// NewClient creates a new SSH client with an optional *ssh.ClientConfig. +func NewClient(config *ssh.ClientConfig) transport.Transport { + return common.NewClient(&runner{config: config}) +} // DefaultAuthBuilder is the function used to create a default AuthMethod, when // the user doesn't provide any. @@ -21,10 +27,12 @@ var DefaultAuthBuilder = func(user string) (AuthMethod, error) { const DefaultPort = 22 -type runner struct{} +type runner struct { + config *ssh.ClientConfig +} func (r *runner) Command(cmd string, ep transport.Endpoint, auth transport.AuthMethod) (common.Command, error) { - c := &command{command: cmd, endpoint: ep} + c := &command{command: cmd, endpoint: ep, config: r.config} if auth != nil { c.setAuth(auth) } @@ -42,6 +50,7 @@ type command struct { endpoint transport.Endpoint client *ssh.Client auth AuthMethod + config *ssh.ClientConfig } func (c *command) setAuth(auth transport.AuthMethod) error { @@ -95,6 +104,8 @@ func (c *command) connect() error { return err } + overrideConfig(c.config, config) + c.client, err = ssh.Dial("tcp", c.getHostWithPort(), config) if err != nil { return err @@ -129,3 +140,25 @@ func (c *command) setAuthFromEndpoint() error { func endpointToCommand(cmd string, ep transport.Endpoint) string { return fmt.Sprintf("%s '%s'", cmd, ep.Path()) } + +func overrideConfig(overrides *ssh.ClientConfig, c *ssh.ClientConfig) { + if overrides == nil { + return + } + + vo := reflect.ValueOf(*overrides) + vc := reflect.ValueOf(*c) + for i := 0; i < vc.Type().NumField(); i++ { + vcf := vc.Field(i) + vof := vo.Field(i) + if isZeroValue(vcf) { + vcf.Set(vof) + } + } + + *c = vc.Interface().(ssh.ClientConfig) +} + +func isZeroValue(v reflect.Value) bool { + return reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) +} |