diff options
Diffstat (limited to 'plumbing/transport/ssh')
-rw-r--r-- | plumbing/transport/ssh/auth_method.go | 82 | ||||
-rw-r--r-- | plumbing/transport/ssh/auth_method_test.go | 8 | ||||
-rw-r--r-- | plumbing/transport/ssh/common.go | 25 | ||||
-rw-r--r-- | plumbing/transport/ssh/common_test.go | 2 | ||||
-rw-r--r-- | plumbing/transport/ssh/upload_pack_test.go | 124 |
5 files changed, 170 insertions, 71 deletions
diff --git a/plumbing/transport/ssh/auth_method.go b/plumbing/transport/ssh/auth_method.go index baae181..a092b29 100644 --- a/plumbing/transport/ssh/auth_method.go +++ b/plumbing/transport/ssh/auth_method.go @@ -25,8 +25,9 @@ const DefaultUsername = "git" // configuration needed to establish an ssh connection. type AuthMethod interface { transport.AuthMethod - clientConfig() *ssh.ClientConfig - hostKeyCallback() (ssh.HostKeyCallback, error) + // ClientConfig should return a valid ssh.ClientConfig to be used to create + // a connection to the SSH server. + ClientConfig() (*ssh.ClientConfig, error) } // The names of the AuthMethod implementations. To be returned by the @@ -45,7 +46,7 @@ const ( type KeyboardInteractive struct { User string Challenge ssh.KeyboardInteractiveChallenge - baseAuthMethod + HostKeyCallbackHelper } func (a *KeyboardInteractive) Name() string { @@ -56,18 +57,20 @@ func (a *KeyboardInteractive) String() string { return fmt.Sprintf("user: %s, name: %s", a.User, a.Name()) } -func (a *KeyboardInteractive) clientConfig() *ssh.ClientConfig { - return &ssh.ClientConfig{ +func (a *KeyboardInteractive) ClientConfig() (*ssh.ClientConfig, error) { + return a.SetHostKeyCallback(&ssh.ClientConfig{ User: a.User, - Auth: []ssh.AuthMethod{ssh.KeyboardInteractiveChallenge(a.Challenge)}, - } + Auth: []ssh.AuthMethod{ + ssh.KeyboardInteractiveChallenge(a.Challenge), + }, + }) } // Password implements AuthMethod by using the given password. type Password struct { - User string - Pass string - baseAuthMethod + User string + Password string + HostKeyCallbackHelper } func (a *Password) Name() string { @@ -78,11 +81,11 @@ func (a *Password) String() string { return fmt.Sprintf("user: %s, name: %s", a.User, a.Name()) } -func (a *Password) clientConfig() *ssh.ClientConfig { - return &ssh.ClientConfig{ +func (a *Password) ClientConfig() (*ssh.ClientConfig, error) { + return a.SetHostKeyCallback(&ssh.ClientConfig{ User: a.User, - Auth: []ssh.AuthMethod{ssh.Password(a.Pass)}, - } + Auth: []ssh.AuthMethod{ssh.Password(a.Password)}, + }) } // PasswordCallback implements AuthMethod by using a callback @@ -90,7 +93,7 @@ func (a *Password) clientConfig() *ssh.ClientConfig { type PasswordCallback struct { User string Callback func() (pass string, err error) - baseAuthMethod + HostKeyCallbackHelper } func (a *PasswordCallback) Name() string { @@ -101,25 +104,25 @@ func (a *PasswordCallback) String() string { return fmt.Sprintf("user: %s, name: %s", a.User, a.Name()) } -func (a *PasswordCallback) clientConfig() *ssh.ClientConfig { - return &ssh.ClientConfig{ +func (a *PasswordCallback) ClientConfig() (*ssh.ClientConfig, error) { + return a.SetHostKeyCallback(&ssh.ClientConfig{ User: a.User, Auth: []ssh.AuthMethod{ssh.PasswordCallback(a.Callback)}, - } + }) } // PublicKeys implements AuthMethod by using the given key pairs. type PublicKeys struct { User string Signer ssh.Signer - baseAuthMethod + HostKeyCallbackHelper } // NewPublicKeys returns a PublicKeys from a PEM encoded private key. An // encryption password should be given if the pemBytes contains a password // encrypted PEM block otherwise password should be empty. It supports RSA // (PKCS#1), DSA (OpenSSL), and ECDSA private keys. -func NewPublicKeys(user string, pemBytes []byte, password string) (AuthMethod, error) { +func NewPublicKeys(user string, pemBytes []byte, password string) (*PublicKeys, error) { block, _ := pem.Decode(pemBytes) if x509.IsEncryptedPEMBlock(block) { key, err := x509.DecryptPEMBlock(block, []byte(password)) @@ -142,7 +145,7 @@ func NewPublicKeys(user string, pemBytes []byte, password string) (AuthMethod, e // NewPublicKeysFromFile returns a PublicKeys from a file containing a PEM // encoded private key. An encryption password should be given if the pemBytes // contains a password encrypted PEM block otherwise password should be empty. -func NewPublicKeysFromFile(user, pemFile, password string) (AuthMethod, error) { +func NewPublicKeysFromFile(user, pemFile, password string) (*PublicKeys, error) { bytes, err := ioutil.ReadFile(pemFile) if err != nil { return nil, err @@ -159,11 +162,11 @@ func (a *PublicKeys) String() string { return fmt.Sprintf("user: %s, name: %s", a.User, a.Name()) } -func (a *PublicKeys) clientConfig() *ssh.ClientConfig { - return &ssh.ClientConfig{ +func (a *PublicKeys) ClientConfig() (*ssh.ClientConfig, error) { + return a.SetHostKeyCallback(&ssh.ClientConfig{ User: a.User, Auth: []ssh.AuthMethod{ssh.PublicKeys(a.Signer)}, - } + }) } func username() (string, error) { @@ -173,9 +176,11 @@ func username() (string, error) { } else { username = os.Getenv("USER") } + if username == "" { return "", errors.New("failed to get username") } + return username, nil } @@ -184,13 +189,13 @@ func username() (string, error) { type PublicKeysCallback struct { User string Callback func() (signers []ssh.Signer, err error) - baseAuthMethod + HostKeyCallbackHelper } // NewSSHAgentAuth returns a PublicKeysCallback based on a SSH agent, it opens // a pipe with the SSH agent and uses the pipe as the implementer of the public // key callback function. -func NewSSHAgentAuth(u string) (AuthMethod, error) { +func NewSSHAgentAuth(u string) (*PublicKeysCallback, error) { var err error if u == "" { u, err = username() @@ -218,11 +223,11 @@ func (a *PublicKeysCallback) String() string { return fmt.Sprintf("user: %s, name: %s", a.User, a.Name()) } -func (a *PublicKeysCallback) clientConfig() *ssh.ClientConfig { - return &ssh.ClientConfig{ +func (a *PublicKeysCallback) ClientConfig() (*ssh.ClientConfig, error) { + return a.SetHostKeyCallback(&ssh.ClientConfig{ User: a.User, Auth: []ssh.AuthMethod{ssh.PublicKeysCallback(a.Callback)}, - } + }) } // NewKnownHostsCallback returns ssh.HostKeyCallback based on a file based on a @@ -287,17 +292,26 @@ func filterKnownHostsFiles(files ...string) ([]string, error) { return out, nil } -type baseAuthMethod struct { +// HostKeyCallbackHelper is a helper that provides common functionality to +// configure HostKeyCallback into a ssh.ClientConfig. +type HostKeyCallbackHelper struct { // HostKeyCallback is the function type used for verifying server keys. - // If nil default callback will be create using NewKnownHostsHostKeyCallback + // If nil default callback will be create using NewKnownHostsCallback // without argument. HostKeyCallback ssh.HostKeyCallback } -func (m *baseAuthMethod) hostKeyCallback() (ssh.HostKeyCallback, error) { +// SetHostKeyCallback sets the field HostKeyCallback in the given cfg. If +// HostKeyCallback is empty a default callback is created using +// NewKnownHostsCallback. +func (m *HostKeyCallbackHelper) SetHostKeyCallback(cfg *ssh.ClientConfig) (*ssh.ClientConfig, error) { + var err error if m.HostKeyCallback == nil { - return NewKnownHostsCallback() + if m.HostKeyCallback, err = NewKnownHostsCallback(); err != nil { + return cfg, err + } } - return m.HostKeyCallback, nil + cfg.HostKeyCallback = m.HostKeyCallback + return cfg, nil } diff --git a/plumbing/transport/ssh/auth_method_test.go b/plumbing/transport/ssh/auth_method_test.go index 2ee5100..1e77ca0 100644 --- a/plumbing/transport/ssh/auth_method_test.go +++ b/plumbing/transport/ssh/auth_method_test.go @@ -32,16 +32,16 @@ func (s *SuiteCommon) TestKeyboardInteractiveString(c *C) { func (s *SuiteCommon) TestPasswordName(c *C) { a := &Password{ - User: "test", - Pass: "", + User: "test", + Password: "", } c.Assert(a.Name(), Equals, PasswordName) } func (s *SuiteCommon) TestPasswordString(c *C) { a := &Password{ - User: "test", - Pass: "", + User: "test", + Password: "", } c.Assert(a.String(), Equals, fmt.Sprintf("user: test, name: %s", PasswordName)) } diff --git a/plumbing/transport/ssh/common.go b/plumbing/transport/ssh/common.go index af79dfb..f5bc9a7 100644 --- a/plumbing/transport/ssh/common.go +++ b/plumbing/transport/ssh/common.go @@ -31,7 +31,7 @@ type runner struct { config *ssh.ClientConfig } -func (r *runner) Command(cmd string, ep transport.Endpoint, auth transport.AuthMethod) (common.Command, error) { +func (r *runner) Command(cmd string, ep *transport.Endpoint, auth transport.AuthMethod) (common.Command, error) { c := &command{command: cmd, endpoint: ep, config: r.config} if auth != nil { c.setAuth(auth) @@ -47,7 +47,7 @@ type command struct { *ssh.Session connected bool command string - endpoint transport.Endpoint + endpoint *transport.Endpoint client *ssh.Client auth AuthMethod config *ssh.ClientConfig @@ -98,8 +98,7 @@ func (c *command) connect() error { } var err error - config := c.auth.clientConfig() - config.HostKeyCallback, err = c.auth.hostKeyCallback() + config, err := c.auth.ClientConfig() if err != nil { return err } @@ -122,8 +121,8 @@ func (c *command) connect() error { } func (c *command) getHostWithPort() string { - host := c.endpoint.Host() - port := c.endpoint.Port() + host := c.endpoint.Host + port := c.endpoint.Port if port <= 0 { port = DefaultPort } @@ -133,12 +132,12 @@ func (c *command) getHostWithPort() string { func (c *command) setAuthFromEndpoint() error { var err error - c.auth, err = DefaultAuthBuilder(c.endpoint.User()) + c.auth, err = DefaultAuthBuilder(c.endpoint.User) return err } -func endpointToCommand(cmd string, ep transport.Endpoint) string { - return fmt.Sprintf("%s '%s'", cmd, ep.Path()) +func endpointToCommand(cmd string, ep *transport.Endpoint) string { + return fmt.Sprintf("%s '%s'", cmd, ep.Path) } func overrideConfig(overrides *ssh.ClientConfig, c *ssh.ClientConfig) { @@ -154,14 +153,8 @@ func overrideConfig(overrides *ssh.ClientConfig, c *ssh.ClientConfig) { f := t.Field(i) vcf := vc.FieldByName(f.Name) vof := vo.FieldByName(f.Name) - if isZeroValue(vcf) { - vcf.Set(vof) - } + vcf.Set(vof) } *c = vc.Interface().(ssh.ClientConfig) } - -func isZeroValue(v reflect.Value) bool { - return reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) -} diff --git a/plumbing/transport/ssh/common_test.go b/plumbing/transport/ssh/common_test.go index 1b07eee..5315e28 100644 --- a/plumbing/transport/ssh/common_test.go +++ b/plumbing/transport/ssh/common_test.go @@ -37,5 +37,5 @@ func (s *SuiteCommon) TestOverrideConfigKeep(c *C) { } overrideConfig(config, target) - c.Assert(target.User, Equals, "bar") + c.Assert(target.User, Equals, "foo") } diff --git a/plumbing/transport/ssh/upload_pack_test.go b/plumbing/transport/ssh/upload_pack_test.go index cb9baa5..56d1601 100644 --- a/plumbing/transport/ssh/upload_pack_test.go +++ b/plumbing/transport/ssh/upload_pack_test.go @@ -1,47 +1,139 @@ package ssh import ( + "fmt" + "io" + "io/ioutil" + "log" + "net" "os" + "os/exec" + "path/filepath" + "strings" "gopkg.in/src-d/go-git.v4/plumbing/transport" "gopkg.in/src-d/go-git.v4/plumbing/transport/test" + "github.com/gliderlabs/ssh" + "gopkg.in/src-d/go-git-fixtures.v3" + stdssh "golang.org/x/crypto/ssh" . "gopkg.in/check.v1" ) type UploadPackSuite struct { test.UploadPackSuite + fixtures.Suite + + port int + base string } var _ = Suite(&UploadPackSuite{}) func (s *UploadPackSuite) SetUpSuite(c *C) { - s.setAuthBuilder(c) - s.UploadPackSuite.Client = DefaultClient + s.Suite.SetUpSuite(c) + + l, err := net.Listen("tcp", "localhost:0") + c.Assert(err, IsNil) - ep, err := transport.NewEndpoint("git@github.com:git-fixtures/basic.git") + s.port = l.Addr().(*net.TCPAddr).Port + s.base, err = ioutil.TempDir(os.TempDir(), fmt.Sprintf("go-git-ssh-%d", s.port)) c.Assert(err, IsNil) - s.UploadPackSuite.Endpoint = ep - ep, err = transport.NewEndpoint("git@github.com:git-fixtures/empty.git") + DefaultAuthBuilder = func(user string) (AuthMethod, error) { + return &Password{User: user}, nil + } + + s.UploadPackSuite.Client = NewClient(&stdssh.ClientConfig{ + HostKeyCallback: stdssh.InsecureIgnoreHostKey(), + }) + + s.UploadPackSuite.Endpoint = s.prepareRepository(c, fixtures.Basic().One(), "basic.git") + s.UploadPackSuite.EmptyEndpoint = s.prepareRepository(c, fixtures.ByTag("empty").One(), "empty.git") + s.UploadPackSuite.NonExistentEndpoint = s.newEndpoint(c, "non-existent.git") + + server := &ssh.Server{Handler: handlerSSH} + go func() { + log.Fatal(server.Serve(l)) + }() +} + +func (s *UploadPackSuite) prepareRepository(c *C, f *fixtures.Fixture, name string) *transport.Endpoint { + fs := f.DotGit() + + err := fixtures.EnsureIsBare(fs) c.Assert(err, IsNil) - s.UploadPackSuite.EmptyEndpoint = ep - ep, err = transport.NewEndpoint("git@github.com:git-fixtures/non-existent.git") + path := filepath.Join(s.base, name) + err = os.Rename(fs.Root(), path) c.Assert(err, IsNil) - s.UploadPackSuite.NonExistentEndpoint = ep + + return s.newEndpoint(c, name) +} + +func (s *UploadPackSuite) newEndpoint(c *C, name string) *transport.Endpoint { + ep, err := transport.NewEndpoint(fmt.Sprintf( + "ssh://git@localhost:%d/%s/%s", s.port, filepath.ToSlash(s.base), name, + )) + + c.Assert(err, IsNil) + return ep } -func (s *UploadPackSuite) setAuthBuilder(c *C) { - privateKey := os.Getenv("SSH_TEST_PRIVATE_KEY") - if privateKey != "" { - DefaultAuthBuilder = func(user string) (AuthMethod, error) { - return NewPublicKeysFromFile(user, privateKey, "") - } +func handlerSSH(s ssh.Session) { + cmd, stdin, stderr, stdout, err := buildCommand(s.Command()) + if err != nil { + fmt.Println(err) + return } - if privateKey == "" && os.Getenv("SSH_AUTH_SOCK") == "" { - c.Skip("SSH_AUTH_SOCK or SSH_TEST_PRIVATE_KEY are required") + if err := cmd.Start(); err != nil { + fmt.Println(err) return } + + go func() { + defer stdin.Close() + io.Copy(stdin, s) + }() + + go func() { + defer stderr.Close() + io.Copy(s.Stderr(), stderr) + }() + + defer stdout.Close() + io.Copy(s, stdout) + + if err := cmd.Wait(); err != nil { + return + } +} + +func buildCommand(c []string) (cmd *exec.Cmd, stdin io.WriteCloser, stderr, stdout io.ReadCloser, err error) { + if len(c) != 2 { + err = fmt.Errorf("invalid command") + return + } + + // fix for Windows environments + path := strings.Replace(c[1], "/C:/", "C:/", 1) + + cmd = exec.Command(c[0], path) + stdout, err = cmd.StdoutPipe() + if err != nil { + return + } + + stdin, err = cmd.StdinPipe() + if err != nil { + return + } + + stderr, err = cmd.StderrPipe() + if err != nil { + return + } + + return } |