aboutsummaryrefslogtreecommitdiffstats
path: root/plumbing/transport/ssh/common.go
diff options
context:
space:
mode:
Diffstat (limited to 'plumbing/transport/ssh/common.go')
-rw-r--r--plumbing/transport/ssh/common.go45
1 files changed, 31 insertions, 14 deletions
diff --git a/plumbing/transport/ssh/common.go b/plumbing/transport/ssh/common.go
index 6ab0835..6d1c51b 100644
--- a/plumbing/transport/ssh/common.go
+++ b/plumbing/transport/ssh/common.go
@@ -3,6 +3,7 @@ package ssh
import (
"fmt"
+ "reflect"
"gopkg.in/src-d/go-git.v4/plumbing/transport"
"gopkg.in/src-d/go-git.v4/plumbing/transport/internal/common"
@@ -11,17 +12,13 @@ import (
)
// DefaultClient is the default SSH client.
-var DefaultClient = NewClient()
+var DefaultClient = NewClient(nil)
-// NewClient creates a new SSH client with the given options.
-func NewClient(opts ...ClientOption) transport.Transport {
- return common.NewClient(&runner{options: opts})
+// NewClient creates a new SSH client with an optional *ssh.ClientConfig.
+func NewClient(config *ssh.ClientConfig) transport.Transport {
+ return common.NewClient(&runner{config: config})
}
-// ClientOption is a function that gets a standard ssh.ClientConfig and modifies
-// it.
-type ClientOption func(config *ssh.ClientConfig)
-
// DefaultAuthBuilder is the function used to create a default AuthMethod, when
// the user doesn't provide any.
var DefaultAuthBuilder = func(user string) (AuthMethod, error) {
@@ -31,11 +28,11 @@ var DefaultAuthBuilder = func(user string) (AuthMethod, error) {
const DefaultPort = 22
type runner struct {
- options []ClientOption
+ config *ssh.ClientConfig
}
func (r *runner) Command(cmd string, ep transport.Endpoint, auth transport.AuthMethod) (common.Command, error) {
- c := &command{command: cmd, endpoint: ep, options: r.options}
+ c := &command{command: cmd, endpoint: ep, config: r.config}
if auth != nil {
c.setAuth(auth)
}
@@ -53,7 +50,7 @@ type command struct {
endpoint transport.Endpoint
client *ssh.Client
auth AuthMethod
- options []ClientOption
+ config *ssh.ClientConfig
}
func (c *command) setAuth(auth transport.AuthMethod) error {
@@ -107,9 +104,7 @@ func (c *command) connect() error {
return err
}
- for _, opt := range c.options {
- opt(config)
- }
+ overrideConfig(c.config, config)
c.client, err = ssh.Dial("tcp", c.getHostWithPort(), config)
if err != nil {
@@ -145,3 +140,25 @@ func (c *command) setAuthFromEndpoint() error {
func endpointToCommand(cmd string, ep transport.Endpoint) string {
return fmt.Sprintf("%s '%s'", cmd, ep.Path())
}
+
+func overrideConfig(overrides *ssh.ClientConfig, c *ssh.ClientConfig) {
+ if overrides == nil {
+ return
+ }
+
+ vo := reflect.ValueOf(*overrides)
+ vc := reflect.ValueOf(*c)
+ for i := 0; i < vc.Type().NumField(); i++ {
+ vcf := vc.Field(i)
+ vof := vo.Field(i)
+ if isZeroValue(vcf) {
+ vcf.Set(vof)
+ }
+ }
+
+ *c = vc.Interface().(ssh.ClientConfig)
+}
+
+func isZeroValue(v reflect.Value) bool {
+ return reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface())
+}