diff options
Diffstat (limited to 'plumbing/transport/ssh/auth_method.go')
-rw-r--r-- | plumbing/transport/ssh/auth_method.go | 86 |
1 files changed, 86 insertions, 0 deletions
diff --git a/plumbing/transport/ssh/auth_method.go b/plumbing/transport/ssh/auth_method.go index f53e510..82e3453 100644 --- a/plumbing/transport/ssh/auth_method.go +++ b/plumbing/transport/ssh/auth_method.go @@ -5,7 +5,10 @@ import ( "fmt" "net" "os" + "os/user" + "path/filepath" + "github.com/src-d/crypto/ssh/knownhosts" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" ) @@ -17,6 +20,7 @@ var ErrEmptySSHAgentAddr = errors.New("SSH_AUTH_SOCK env variable is required") // configuration needed to establish an ssh connection. type AuthMethod interface { clientConfig() *ssh.ClientConfig + hostKeyCallback() (ssh.HostKeyCallback, error) } // The names of the AuthMethod implementations. To be returned by the @@ -35,6 +39,7 @@ const ( type KeyboardInteractive struct { User string Challenge ssh.KeyboardInteractiveChallenge + baseAuthMethod } func (a *KeyboardInteractive) Name() string { @@ -56,6 +61,7 @@ func (a *KeyboardInteractive) clientConfig() *ssh.ClientConfig { type Password struct { User string Pass string + baseAuthMethod } func (a *Password) Name() string { @@ -78,6 +84,7 @@ func (a *Password) clientConfig() *ssh.ClientConfig { type PasswordCallback struct { User string Callback func() (pass string, err error) + baseAuthMethod } func (a *PasswordCallback) Name() string { @@ -100,6 +107,7 @@ func (a *PasswordCallback) clientConfig() *ssh.ClientConfig { type PublicKeys struct { User string Signer ssh.Signer + baseAuthMethod } func (a *PublicKeys) Name() string { @@ -122,6 +130,7 @@ func (a *PublicKeys) clientConfig() *ssh.ClientConfig { type PublicKeysCallback struct { User string Callback func() (signers []ssh.Signer, err error) + baseAuthMethod } func (a *PublicKeysCallback) Name() string { @@ -163,3 +172,80 @@ func NewSSHAgentAuth(user string) (*PublicKeysCallback, error) { Callback: agent.NewClient(pipe).Signers, }, nil } + +// NewKnownHostsCallback returns ssh.HostKeyCallback based on a file based on a +// know_hosts file. http://man.openbsd.org/sshd#SSH_KNOWN_HOSTS_FILE_FORMAT +// +// If files is empty, the list of files will be read from the SSH_KNOWN_HOSTS +// environment variable, example: +// /home/foo/custom_known_hosts_file:/etc/custom_known/hosts_file +// +// If SSH_KNOWN_HOSTS is not set the following file locations will be used: +// ~/.ssh/known_hosts +// /etc/ssh/ssh_known_hosts +func NewKnownHostsCallback(files ...string) (ssh.HostKeyCallback, error) { + files, err := getDefaultKnownHostsFiles() + if err != nil { + return nil, err + } + + files, err = filterKnownHostsFiles(files...) + if err != nil { + return nil, err + } + + return knownhosts.New(files...) +} + +func getDefaultKnownHostsFiles() ([]string, error) { + files := filepath.SplitList(os.Getenv("SSH_KNOWN_HOSTS")) + if len(files) != 0 { + return files, nil + } + + user, err := user.Current() + if err != nil { + return nil, err + } + + return []string{ + filepath.Join(user.HomeDir, "/.ssh/known_hosts"), + "/etc/ssh/ssh_known_hosts", + }, nil +} + +func filterKnownHostsFiles(files ...string) ([]string, error) { + var out []string + for _, file := range files { + _, err := os.Stat(file) + if err == nil { + out = append(out, file) + continue + } + + if !os.IsNotExist(err) { + return nil, err + } + } + + if len(out) == 0 { + return nil, fmt.Errorf("unable to find any valid know_hosts file, set SSH_KNOWN_HOSTS env variable") + } + + return out, nil +} + +type baseAuthMethod struct { + // HostKeyCallback is the function type used for verifying server keys. + // If nil default callback will be create using NewKnownHostsHostKeyCallback + // without argument. + HostKeyCallback ssh.HostKeyCallback +} + +func (m *baseAuthMethod) hostKeyCallback() (ssh.HostKeyCallback, error) { + if m.HostKeyCallback == nil { + return NewKnownHostsCallback() + } + + return m.HostKeyCallback, nil +} |