diff options
Diffstat (limited to 'plumbing/transport/ssh/common.go')
-rw-r--r-- | plumbing/transport/ssh/common.go | 49 |
1 files changed, 44 insertions, 5 deletions
diff --git a/plumbing/transport/ssh/common.go b/plumbing/transport/ssh/common.go index 46e7913..1531603 100644 --- a/plumbing/transport/ssh/common.go +++ b/plumbing/transport/ssh/common.go @@ -4,12 +4,14 @@ package ssh import ( "context" "fmt" + "net" "reflect" "strconv" "strings" "github.com/go-git/go-git/v5/plumbing/transport" "github.com/go-git/go-git/v5/plumbing/transport/internal/common" + "github.com/skeema/knownhosts" "github.com/kevinburke/ssh_config" "golang.org/x/crypto/ssh" @@ -121,10 +123,24 @@ func (c *command) connect() error { if err != nil { return err } + hostWithPort := c.getHostWithPort() + if config.HostKeyCallback == nil { + kh, err := newKnownHosts() + if err != nil { + return err + } + config.HostKeyCallback = kh.HostKeyCallback() + config.HostKeyAlgorithms = kh.HostKeyAlgorithms(hostWithPort) + } else if len(config.HostKeyAlgorithms) == 0 { + // Set the HostKeyAlgorithms based on HostKeyCallback. + // For background see https://github.com/go-git/go-git/issues/411 as well as + // https://github.com/golang/go/issues/29286 for root cause. + config.HostKeyAlgorithms = knownhosts.HostKeyAlgorithms(config.HostKeyCallback, hostWithPort) + } overrideConfig(c.config, config) - c.client, err = dial("tcp", c.getHostWithPort(), config) + c.client, err = dial("tcp", hostWithPort, c.endpoint.Proxy, config) if err != nil { return err } @@ -139,7 +155,7 @@ func (c *command) connect() error { return nil } -func dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { +func dial(network, addr string, proxyOpts transport.ProxyOptions, config *ssh.ClientConfig) (*ssh.Client, error) { var ( ctx = context.Background() cancel context.CancelFunc @@ -151,10 +167,33 @@ func dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { } defer cancel() - conn, err := proxy.Dial(ctx, network, addr) + var conn net.Conn + var err error + + if proxyOpts.URL != "" { + proxyUrl, err := proxyOpts.FullURL() + if err != nil { + return nil, err + } + dialer, err := proxy.FromURL(proxyUrl, proxy.Direct) + if err != nil { + return nil, err + } + + // Try to use a ContextDialer, but fall back to a Dialer if that goes south. + ctxDialer, ok := dialer.(proxy.ContextDialer) + if !ok { + return nil, fmt.Errorf("expected ssh proxy dialer to be of type %s; got %s", + reflect.TypeOf(ctxDialer), reflect.TypeOf(dialer)) + } + conn, err = ctxDialer.DialContext(ctx, "tcp", addr) + } else { + conn, err = proxy.Dial(ctx, network, addr) + } if err != nil { return nil, err } + c, chans, reqs, err := ssh.NewClientConn(conn, addr, config) if err != nil { return nil, err @@ -173,7 +212,7 @@ func (c *command) getHostWithPort() string { port = DefaultPort } - return fmt.Sprintf("%s:%d", host, port) + return net.JoinHostPort(host, strconv.Itoa(port)) } func (c *command) doGetHostWithPortFromSSHConfig() (addr string, found bool) { @@ -201,7 +240,7 @@ func (c *command) doGetHostWithPortFromSSHConfig() (addr string, found bool) { } } - addr = fmt.Sprintf("%s:%d", host, port) + addr = net.JoinHostPort(host, strconv.Itoa(port)) return } |