aboutsummaryrefslogtreecommitdiffstats
path: root/plumbing/transport/ssh
diff options
context:
space:
mode:
authorMáximo Cuadros <mcuadros@gmail.com>2017-07-05 03:11:25 -0700
committerGitHub <noreply@github.com>2017-07-05 03:11:25 -0700
commit102d4b5aeb9b3cbd544c59706a1b0dd9300ddcc8 (patch)
tree3df8c173979f3fc945ba336d418c556d84501922 /plumbing/transport/ssh
parent5354ebc084d5300cd8d371a837a0d9dab408c561 (diff)
parent7368129bc7e7394e8af7262a0764292c3fb8d3c5 (diff)
downloadgo-git-102d4b5aeb9b3cbd544c59706a1b0dd9300ddcc8.tar.gz
Merge pull request #423 from smola/ssh-options
transport/ssh: allow passing SSH options
Diffstat (limited to 'plumbing/transport/ssh')
-rw-r--r--plumbing/transport/ssh/common.go39
1 files changed, 36 insertions, 3 deletions
diff --git a/plumbing/transport/ssh/common.go b/plumbing/transport/ssh/common.go
index d53fc12..6d1c51b 100644
--- a/plumbing/transport/ssh/common.go
+++ b/plumbing/transport/ssh/common.go
@@ -3,6 +3,7 @@ package ssh
import (
"fmt"
+ "reflect"
"gopkg.in/src-d/go-git.v4/plumbing/transport"
"gopkg.in/src-d/go-git.v4/plumbing/transport/internal/common"
@@ -11,7 +12,12 @@ import (
)
// DefaultClient is the default SSH client.
-var DefaultClient = common.NewClient(&runner{})
+var DefaultClient = NewClient(nil)
+
+// NewClient creates a new SSH client with an optional *ssh.ClientConfig.
+func NewClient(config *ssh.ClientConfig) transport.Transport {
+ return common.NewClient(&runner{config: config})
+}
// DefaultAuthBuilder is the function used to create a default AuthMethod, when
// the user doesn't provide any.
@@ -21,10 +27,12 @@ var DefaultAuthBuilder = func(user string) (AuthMethod, error) {
const DefaultPort = 22
-type runner struct{}
+type runner struct {
+ config *ssh.ClientConfig
+}
func (r *runner) Command(cmd string, ep transport.Endpoint, auth transport.AuthMethod) (common.Command, error) {
- c := &command{command: cmd, endpoint: ep}
+ c := &command{command: cmd, endpoint: ep, config: r.config}
if auth != nil {
c.setAuth(auth)
}
@@ -42,6 +50,7 @@ type command struct {
endpoint transport.Endpoint
client *ssh.Client
auth AuthMethod
+ config *ssh.ClientConfig
}
func (c *command) setAuth(auth transport.AuthMethod) error {
@@ -95,6 +104,8 @@ func (c *command) connect() error {
return err
}
+ overrideConfig(c.config, config)
+
c.client, err = ssh.Dial("tcp", c.getHostWithPort(), config)
if err != nil {
return err
@@ -129,3 +140,25 @@ func (c *command) setAuthFromEndpoint() error {
func endpointToCommand(cmd string, ep transport.Endpoint) string {
return fmt.Sprintf("%s '%s'", cmd, ep.Path())
}
+
+func overrideConfig(overrides *ssh.ClientConfig, c *ssh.ClientConfig) {
+ if overrides == nil {
+ return
+ }
+
+ vo := reflect.ValueOf(*overrides)
+ vc := reflect.ValueOf(*c)
+ for i := 0; i < vc.Type().NumField(); i++ {
+ vcf := vc.Field(i)
+ vof := vo.Field(i)
+ if isZeroValue(vcf) {
+ vcf.Set(vof)
+ }
+ }
+
+ *c = vc.Interface().(ssh.ClientConfig)
+}
+
+func isZeroValue(v reflect.Value) bool {
+ return reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface())
+}