aboutsummaryrefslogtreecommitdiffstats
path: root/plumbing/transport
diff options
context:
space:
mode:
Diffstat (limited to 'plumbing/transport')
-rw-r--r--plumbing/transport/ssh/auth_method.go86
-rw-r--r--plumbing/transport/ssh/common.go8
-rw-r--r--plumbing/transport/ssh/upload_pack_test.go1
3 files changed, 93 insertions, 2 deletions
diff --git a/plumbing/transport/ssh/auth_method.go b/plumbing/transport/ssh/auth_method.go
index f53e510..9965f45 100644
--- a/plumbing/transport/ssh/auth_method.go
+++ b/plumbing/transport/ssh/auth_method.go
@@ -5,9 +5,12 @@ import (
"fmt"
"net"
"os"
+ "os/user"
+ "path/filepath"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
+ "golang.org/x/crypto/ssh/knownhosts"
)
var ErrEmptySSHAgentAddr = errors.New("SSH_AUTH_SOCK env variable is required")
@@ -17,6 +20,7 @@ var ErrEmptySSHAgentAddr = errors.New("SSH_AUTH_SOCK env variable is required")
// configuration needed to establish an ssh connection.
type AuthMethod interface {
clientConfig() *ssh.ClientConfig
+ hostKeyCallback() (ssh.HostKeyCallback, error)
}
// The names of the AuthMethod implementations. To be returned by the
@@ -35,6 +39,7 @@ const (
type KeyboardInteractive struct {
User string
Challenge ssh.KeyboardInteractiveChallenge
+ baseAuthMethod
}
func (a *KeyboardInteractive) Name() string {
@@ -56,6 +61,7 @@ func (a *KeyboardInteractive) clientConfig() *ssh.ClientConfig {
type Password struct {
User string
Pass string
+ baseAuthMethod
}
func (a *Password) Name() string {
@@ -78,6 +84,7 @@ func (a *Password) clientConfig() *ssh.ClientConfig {
type PasswordCallback struct {
User string
Callback func() (pass string, err error)
+ baseAuthMethod
}
func (a *PasswordCallback) Name() string {
@@ -100,6 +107,7 @@ func (a *PasswordCallback) clientConfig() *ssh.ClientConfig {
type PublicKeys struct {
User string
Signer ssh.Signer
+ baseAuthMethod
}
func (a *PublicKeys) Name() string {
@@ -122,6 +130,7 @@ func (a *PublicKeys) clientConfig() *ssh.ClientConfig {
type PublicKeysCallback struct {
User string
Callback func() (signers []ssh.Signer, err error)
+ baseAuthMethod
}
func (a *PublicKeysCallback) Name() string {
@@ -163,3 +172,80 @@ func NewSSHAgentAuth(user string) (*PublicKeysCallback, error) {
Callback: agent.NewClient(pipe).Signers,
}, nil
}
+
+// NewKnownHostsCallback returns ssh.HostKeyCallback based on a file based on a
+// know_hosts file. http://man.openbsd.org/sshd#SSH_KNOWN_HOSTS_FILE_FORMAT
+//
+// If files is empty, the list of files will be read from the SSH_KNOWN_HOSTS
+// environment variable, example:
+// /home/foo/custom_known_hosts_file:/etc/custom_known/hosts_file
+//
+// If SSH_KNOWN_HOSTS is not set the following file locations will be used:
+// ~/.ssh/known_hosts
+// /etc/ssh/ssh_known_hosts
+func NewKnownHostsCallback(files ...string) (ssh.HostKeyCallback, error) {
+ files, err := getDefaultKnownHostsFiles()
+ if err != nil {
+ return nil, err
+ }
+
+ files, err = filterKnownHostsFiles(files...)
+ if err != nil {
+ return nil, err
+ }
+
+ return knownhosts.New(files...)
+}
+
+func getDefaultKnownHostsFiles() ([]string, error) {
+ files := filepath.SplitList(os.Getenv("SSH_KNOWN_HOSTS"))
+ if len(files) != 0 {
+ return files, nil
+ }
+
+ user, err := user.Current()
+ if err != nil {
+ return nil, err
+ }
+
+ return []string{
+ filepath.Join(user.HomeDir, "/.ssh/known_hosts"),
+ "/etc/ssh/ssh_known_hosts",
+ }, nil
+}
+
+func filterKnownHostsFiles(files ...string) ([]string, error) {
+ var out []string
+ for _, file := range files {
+ _, err := os.Stat(file)
+ if err == nil {
+ out = append(out, file)
+ continue
+ }
+
+ if !os.IsNotExist(err) {
+ return nil, err
+ }
+ }
+
+ if len(out) == 0 {
+ return nil, fmt.Errorf("unable to find any valid know_hosts file, set SSH_KNOWN_HOSTS env variable")
+ }
+
+ return out, nil
+}
+
+type baseAuthMethod struct {
+ // HostKeyCallback is the function type used for verifying server keys.
+ // If nil default callback will be create using NewKnownHostsHostKeyCallback
+ // without argument.
+ HostKeyCallback ssh.HostKeyCallback
+}
+
+func (m *baseAuthMethod) hostKeyCallback() (ssh.HostKeyCallback, error) {
+ if m.HostKeyCallback == nil {
+ return NewKnownHostsCallback()
+ }
+
+ return m.HostKeyCallback, nil
+}
diff --git a/plumbing/transport/ssh/common.go b/plumbing/transport/ssh/common.go
index 5ed64d4..9b484f9 100644
--- a/plumbing/transport/ssh/common.go
+++ b/plumbing/transport/ssh/common.go
@@ -82,7 +82,13 @@ func (c *command) connect() error {
}
var err error
- c.client, err = ssh.Dial("tcp", c.getHostWithPort(), c.auth.clientConfig())
+ config := c.auth.clientConfig()
+ config.HostKeyCallback, err = c.auth.hostKeyCallback()
+ if err != nil {
+ return err
+ }
+
+ c.client, err = ssh.Dial("tcp", c.getHostWithPort(), config)
if err != nil {
return err
}
diff --git a/plumbing/transport/ssh/upload_pack_test.go b/plumbing/transport/ssh/upload_pack_test.go
index 8194770..54d523a 100644
--- a/plumbing/transport/ssh/upload_pack_test.go
+++ b/plumbing/transport/ssh/upload_pack_test.go
@@ -33,5 +33,4 @@ func (s *UploadPackSuite) SetUpSuite(c *C) {
ep, err = transport.NewEndpoint("git@github.com:git-fixtures/non-existent.git")
c.Assert(err, IsNil)
s.UploadPackSuite.NonExistentEndpoint = ep
-
}