aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorKuba Podgórski <kuba--@users.noreply.github.com>2018-09-07 10:25:23 +0200
committerMáximo Cuadros <mcuadros@gmail.com>2018-09-07 10:25:23 +0200
commita4b12e4161738af6f724776c0c8c55f90542f06f (patch)
treee1db288d87e0ee1ab2c3cdc02ce6353926aec58d
parentd3cec13ac0b195bfb897ed038a08b5130ab9969e (diff)
downloadgo-git-a4b12e4161738af6f724776c0c8c55f90542f06f.tar.gz
plumbing/transport: ssh check if list of known_hosts files is empty
Signed-off-by: kuba-- <kuba@sourced.tech>
-rw-r--r--plumbing/transport/ssh/auth_method.go14
-rw-r--r--plumbing/transport/ssh/auth_method_test.go62
2 files changed, 69 insertions, 7 deletions
diff --git a/plumbing/transport/ssh/auth_method.go b/plumbing/transport/ssh/auth_method.go
index 84cfab2..dbb47c5 100644
--- a/plumbing/transport/ssh/auth_method.go
+++ b/plumbing/transport/ssh/auth_method.go
@@ -236,7 +236,7 @@ func (a *PublicKeysCallback) ClientConfig() (*ssh.ClientConfig, error) {
// NewKnownHostsCallback returns ssh.HostKeyCallback based on a file based on a
// known_hosts file. http://man.openbsd.org/sshd#SSH_KNOWN_HOSTS_FILE_FORMAT
//
-// If files is empty, the list of files will be read from the SSH_KNOWN_HOSTS
+// If list of files is empty, then it will be read from the SSH_KNOWN_HOSTS
// environment variable, example:
// /home/foo/custom_known_hosts_file:/etc/custom_known/hosts_file
//
@@ -244,13 +244,15 @@ func (a *PublicKeysCallback) ClientConfig() (*ssh.ClientConfig, error) {
// ~/.ssh/known_hosts
// /etc/ssh/ssh_known_hosts
func NewKnownHostsCallback(files ...string) (ssh.HostKeyCallback, error) {
- files, err := getDefaultKnownHostsFiles()
- if err != nil {
- return nil, err
+ var err error
+
+ if len(files) == 0 {
+ if files, err = getDefaultKnownHostsFiles(); err != nil {
+ return nil, err
+ }
}
- files, err = filterKnownHostsFiles(files...)
- if err != nil {
+ if files, err = filterKnownHostsFiles(files...); err != nil {
return nil, err
}
diff --git a/plumbing/transport/ssh/auth_method_test.go b/plumbing/transport/ssh/auth_method_test.go
index 0025669..0cde61e 100644
--- a/plumbing/transport/ssh/auth_method_test.go
+++ b/plumbing/transport/ssh/auth_method_test.go
@@ -1,16 +1,30 @@
package ssh
import (
+ "bufio"
"fmt"
"io/ioutil"
"os"
+ "strings"
+ "golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/testdata"
. "gopkg.in/check.v1"
)
-type SuiteCommon struct{}
+type (
+ SuiteCommon struct{}
+
+ mockKnownHosts struct{}
+)
+
+func (mockKnownHosts) host() string { return "github.com" }
+func (mockKnownHosts) knownHosts() []byte {
+ return []byte(`github.com ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAQEAq2A7hRGmdnm9tUDbO9IDSwBK6TbQa+PXYPCPy6rbTrTtw7PHkccKrpp0yVhp5HdEIcKr6pLlVDBfOLX9QUsyCOV0wzfjIJNlGEYsdlLJizHhbn2mUjvSAHQqZETYP81eFzLQNnPHt4EVVUh7VfDESU84KezmD5QlWpXLmvU31/yMf+Se8xhHTvKSCZIFImWwoG6mbUoWf9nzpIoaSjB+weqqUUmpaaasXVal72J+UX2B+2RPW3RcT0eOzQgqlJL3RKrTJvdsjE3JEAvGq3lGHSZXy28G3skua2SmVi/w4yCE6gbODqnTWlg7+wC604ydGXA8VJiS5ap43JXiUFFAaQ==`)
+}
+func (mockKnownHosts) Network() string { return "tcp" }
+func (mockKnownHosts) String() string { return "github.com:22" }
var _ = Suite(&SuiteCommon{})
@@ -149,3 +163,49 @@ func (*SuiteCommon) TestNewPublicKeysWithInvalidPEM(c *C) {
c.Assert(err, NotNil)
c.Assert(auth, IsNil)
}
+
+func (*SuiteCommon) TestNewKnownHostsCallback(c *C) {
+ var mock = mockKnownHosts{}
+
+ f, err := ioutil.TempFile("", "known-hosts")
+ c.Assert(err, IsNil)
+
+ _, err = f.Write(mock.knownHosts())
+ c.Assert(err, IsNil)
+
+ err = f.Close()
+ c.Assert(err, IsNil)
+
+ defer os.RemoveAll(f.Name())
+
+ f, err = os.Open(f.Name())
+ c.Assert(err, IsNil)
+
+ defer f.Close()
+
+ var hostKey ssh.PublicKey
+ scanner := bufio.NewScanner(f)
+ for scanner.Scan() {
+ fields := strings.Split(scanner.Text(), " ")
+ if len(fields) != 3 {
+ continue
+ }
+ if strings.Contains(fields[0], mock.host()) {
+ var err error
+ hostKey, _, _, _, err = ssh.ParseAuthorizedKey(scanner.Bytes())
+ if err != nil {
+ c.Fatalf("error parsing %q: %v", fields[2], err)
+ }
+ break
+ }
+ }
+ if hostKey == nil {
+ c.Fatalf("no hostkey for %s", mock.host())
+ }
+
+ clb, err := NewKnownHostsCallback(f.Name())
+ c.Assert(err, IsNil)
+
+ err = clb(mock.String(), mock, hostKey)
+ c.Assert(err, IsNil)
+}