diff options
Diffstat (limited to 'plumbing/transport')
-rw-r--r-- | plumbing/transport/ssh/auth_method.go | 86 | ||||
-rw-r--r-- | plumbing/transport/ssh/common.go | 8 | ||||
-rw-r--r-- | plumbing/transport/ssh/upload_pack_test.go | 1 |
3 files changed, 93 insertions, 2 deletions
diff --git a/plumbing/transport/ssh/auth_method.go b/plumbing/transport/ssh/auth_method.go index f53e510..82e3453 100644 --- a/plumbing/transport/ssh/auth_method.go +++ b/plumbing/transport/ssh/auth_method.go @@ -5,7 +5,10 @@ import ( "fmt" "net" "os" + "os/user" + "path/filepath" + "github.com/src-d/crypto/ssh/knownhosts" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" ) @@ -17,6 +20,7 @@ var ErrEmptySSHAgentAddr = errors.New("SSH_AUTH_SOCK env variable is required") // configuration needed to establish an ssh connection. type AuthMethod interface { clientConfig() *ssh.ClientConfig + hostKeyCallback() (ssh.HostKeyCallback, error) } // The names of the AuthMethod implementations. To be returned by the @@ -35,6 +39,7 @@ const ( type KeyboardInteractive struct { User string Challenge ssh.KeyboardInteractiveChallenge + baseAuthMethod } func (a *KeyboardInteractive) Name() string { @@ -56,6 +61,7 @@ func (a *KeyboardInteractive) clientConfig() *ssh.ClientConfig { type Password struct { User string Pass string + baseAuthMethod } func (a *Password) Name() string { @@ -78,6 +84,7 @@ func (a *Password) clientConfig() *ssh.ClientConfig { type PasswordCallback struct { User string Callback func() (pass string, err error) + baseAuthMethod } func (a *PasswordCallback) Name() string { @@ -100,6 +107,7 @@ func (a *PasswordCallback) clientConfig() *ssh.ClientConfig { type PublicKeys struct { User string Signer ssh.Signer + baseAuthMethod } func (a *PublicKeys) Name() string { @@ -122,6 +130,7 @@ func (a *PublicKeys) clientConfig() *ssh.ClientConfig { type PublicKeysCallback struct { User string Callback func() (signers []ssh.Signer, err error) + baseAuthMethod } func (a *PublicKeysCallback) Name() string { @@ -163,3 +172,80 @@ func NewSSHAgentAuth(user string) (*PublicKeysCallback, error) { Callback: agent.NewClient(pipe).Signers, }, nil } + +// NewKnownHostsCallback returns ssh.HostKeyCallback based on a file based on a +// know_hosts file. http://man.openbsd.org/sshd#SSH_KNOWN_HOSTS_FILE_FORMAT +// +// If files is empty, the list of files will be read from the SSH_KNOWN_HOSTS +// environment variable, example: +// /home/foo/custom_known_hosts_file:/etc/custom_known/hosts_file +// +// If SSH_KNOWN_HOSTS is not set the following file locations will be used: +// ~/.ssh/known_hosts +// /etc/ssh/ssh_known_hosts +func NewKnownHostsCallback(files ...string) (ssh.HostKeyCallback, error) { + files, err := getDefaultKnownHostsFiles() + if err != nil { + return nil, err + } + + files, err = filterKnownHostsFiles(files...) + if err != nil { + return nil, err + } + + return knownhosts.New(files...) +} + +func getDefaultKnownHostsFiles() ([]string, error) { + files := filepath.SplitList(os.Getenv("SSH_KNOWN_HOSTS")) + if len(files) != 0 { + return files, nil + } + + user, err := user.Current() + if err != nil { + return nil, err + } + + return []string{ + filepath.Join(user.HomeDir, "/.ssh/known_hosts"), + "/etc/ssh/ssh_known_hosts", + }, nil +} + +func filterKnownHostsFiles(files ...string) ([]string, error) { + var out []string + for _, file := range files { + _, err := os.Stat(file) + if err == nil { + out = append(out, file) + continue + } + + if !os.IsNotExist(err) { + return nil, err + } + } + + if len(out) == 0 { + return nil, fmt.Errorf("unable to find any valid know_hosts file, set SSH_KNOWN_HOSTS env variable") + } + + return out, nil +} + +type baseAuthMethod struct { + // HostKeyCallback is the function type used for verifying server keys. + // If nil default callback will be create using NewKnownHostsHostKeyCallback + // without argument. + HostKeyCallback ssh.HostKeyCallback +} + +func (m *baseAuthMethod) hostKeyCallback() (ssh.HostKeyCallback, error) { + if m.HostKeyCallback == nil { + return NewKnownHostsCallback() + } + + return m.HostKeyCallback, nil +} diff --git a/plumbing/transport/ssh/common.go b/plumbing/transport/ssh/common.go index 5ed64d4..9b484f9 100644 --- a/plumbing/transport/ssh/common.go +++ b/plumbing/transport/ssh/common.go @@ -82,7 +82,13 @@ func (c *command) connect() error { } var err error - c.client, err = ssh.Dial("tcp", c.getHostWithPort(), c.auth.clientConfig()) + config := c.auth.clientConfig() + config.HostKeyCallback, err = c.auth.hostKeyCallback() + if err != nil { + return err + } + + c.client, err = ssh.Dial("tcp", c.getHostWithPort(), config) if err != nil { return err } diff --git a/plumbing/transport/ssh/upload_pack_test.go b/plumbing/transport/ssh/upload_pack_test.go index 8194770..54d523a 100644 --- a/plumbing/transport/ssh/upload_pack_test.go +++ b/plumbing/transport/ssh/upload_pack_test.go @@ -33,5 +33,4 @@ func (s *UploadPackSuite) SetUpSuite(c *C) { ep, err = transport.NewEndpoint("git@github.com:git-fixtures/non-existent.git") c.Assert(err, IsNil) s.UploadPackSuite.NonExistentEndpoint = ep - } |