diff options
Diffstat (limited to 'plumbing/transport/ssh/common.go')
-rw-r--r-- | plumbing/transport/ssh/common.go | 45 |
1 files changed, 31 insertions, 14 deletions
diff --git a/plumbing/transport/ssh/common.go b/plumbing/transport/ssh/common.go index 6ab0835..6d1c51b 100644 --- a/plumbing/transport/ssh/common.go +++ b/plumbing/transport/ssh/common.go @@ -3,6 +3,7 @@ package ssh import ( "fmt" + "reflect" "gopkg.in/src-d/go-git.v4/plumbing/transport" "gopkg.in/src-d/go-git.v4/plumbing/transport/internal/common" @@ -11,17 +12,13 @@ import ( ) // DefaultClient is the default SSH client. -var DefaultClient = NewClient() +var DefaultClient = NewClient(nil) -// NewClient creates a new SSH client with the given options. -func NewClient(opts ...ClientOption) transport.Transport { - return common.NewClient(&runner{options: opts}) +// NewClient creates a new SSH client with an optional *ssh.ClientConfig. +func NewClient(config *ssh.ClientConfig) transport.Transport { + return common.NewClient(&runner{config: config}) } -// ClientOption is a function that gets a standard ssh.ClientConfig and modifies -// it. -type ClientOption func(config *ssh.ClientConfig) - // DefaultAuthBuilder is the function used to create a default AuthMethod, when // the user doesn't provide any. var DefaultAuthBuilder = func(user string) (AuthMethod, error) { @@ -31,11 +28,11 @@ var DefaultAuthBuilder = func(user string) (AuthMethod, error) { const DefaultPort = 22 type runner struct { - options []ClientOption + config *ssh.ClientConfig } func (r *runner) Command(cmd string, ep transport.Endpoint, auth transport.AuthMethod) (common.Command, error) { - c := &command{command: cmd, endpoint: ep, options: r.options} + c := &command{command: cmd, endpoint: ep, config: r.config} if auth != nil { c.setAuth(auth) } @@ -53,7 +50,7 @@ type command struct { endpoint transport.Endpoint client *ssh.Client auth AuthMethod - options []ClientOption + config *ssh.ClientConfig } func (c *command) setAuth(auth transport.AuthMethod) error { @@ -107,9 +104,7 @@ func (c *command) connect() error { return err } - for _, opt := range c.options { - opt(config) - } + overrideConfig(c.config, config) c.client, err = ssh.Dial("tcp", c.getHostWithPort(), config) if err != nil { @@ -145,3 +140,25 @@ func (c *command) setAuthFromEndpoint() error { func endpointToCommand(cmd string, ep transport.Endpoint) string { return fmt.Sprintf("%s '%s'", cmd, ep.Path()) } + +func overrideConfig(overrides *ssh.ClientConfig, c *ssh.ClientConfig) { + if overrides == nil { + return + } + + vo := reflect.ValueOf(*overrides) + vc := reflect.ValueOf(*c) + for i := 0; i < vc.Type().NumField(); i++ { + vcf := vc.Field(i) + vof := vo.Field(i) + if isZeroValue(vcf) { + vcf.Set(vof) + } + } + + *c = vc.Interface().(ssh.ClientConfig) +} + +func isZeroValue(v reflect.Value) bool { + return reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) +} |