aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--plumbing/transport/http/common.go15
-rw-r--r--plumbing/transport/http/common_test.go2
-rw-r--r--plumbing/transport/ssh/auth_method.go82
-rw-r--r--plumbing/transport/ssh/auth_method_test.go8
-rw-r--r--plumbing/transport/ssh/common.go7
5 files changed, 59 insertions, 55 deletions
diff --git a/plumbing/transport/http/common.go b/plumbing/transport/http/common.go
index 95103f7..10a267a 100644
--- a/plumbing/transport/http/common.go
+++ b/plumbing/transport/http/common.go
@@ -151,17 +151,12 @@ func basicAuthFromEndpoint(ep *transport.Endpoint) *BasicAuth {
return nil
}
- return NewBasicAuth(u, ep.Password)
+ return &BasicAuth{u, ep.Password}
}
// BasicAuth represent a HTTP basic auth
type BasicAuth struct {
- username, password string
-}
-
-// NewBasicAuth returns a basicAuth base on the given user and password
-func NewBasicAuth(username, password string) *BasicAuth {
- return &BasicAuth{username, password}
+ Username, Password string
}
func (a *BasicAuth) setAuth(r *http.Request) {
@@ -169,7 +164,7 @@ func (a *BasicAuth) setAuth(r *http.Request) {
return
}
- r.SetBasicAuth(a.username, a.password)
+ r.SetBasicAuth(a.Username, a.Password)
}
// Name is name of the auth
@@ -179,11 +174,11 @@ func (a *BasicAuth) Name() string {
func (a *BasicAuth) String() string {
masked := "*******"
- if a.password == "" {
+ if a.Password == "" {
masked = "<empty>"
}
- return fmt.Sprintf("%s - %s:%s", a.Name(), a.username, masked)
+ return fmt.Sprintf("%s - %s:%s", a.Name(), a.Username, masked)
}
// Err is a dedicated error to return errors based on status code
diff --git a/plumbing/transport/http/common_test.go b/plumbing/transport/http/common_test.go
index bd1bec3..c2c0d3e 100644
--- a/plumbing/transport/http/common_test.go
+++ b/plumbing/transport/http/common_test.go
@@ -49,7 +49,7 @@ func (s *UploadPackSuite) TestNewClient(c *C) {
}
func (s *ClientSuite) TestNewBasicAuth(c *C) {
- a := NewBasicAuth("foo", "qux")
+ a := &BasicAuth{"foo", "qux"}
c.Assert(a.Name(), Equals, "http-basic-auth")
c.Assert(a.String(), Equals, "http-basic-auth - foo:*******")
diff --git a/plumbing/transport/ssh/auth_method.go b/plumbing/transport/ssh/auth_method.go
index baae181..a092b29 100644
--- a/plumbing/transport/ssh/auth_method.go
+++ b/plumbing/transport/ssh/auth_method.go
@@ -25,8 +25,9 @@ const DefaultUsername = "git"
// configuration needed to establish an ssh connection.
type AuthMethod interface {
transport.AuthMethod
- clientConfig() *ssh.ClientConfig
- hostKeyCallback() (ssh.HostKeyCallback, error)
+ // ClientConfig should return a valid ssh.ClientConfig to be used to create
+ // a connection to the SSH server.
+ ClientConfig() (*ssh.ClientConfig, error)
}
// The names of the AuthMethod implementations. To be returned by the
@@ -45,7 +46,7 @@ const (
type KeyboardInteractive struct {
User string
Challenge ssh.KeyboardInteractiveChallenge
- baseAuthMethod
+ HostKeyCallbackHelper
}
func (a *KeyboardInteractive) Name() string {
@@ -56,18 +57,20 @@ func (a *KeyboardInteractive) String() string {
return fmt.Sprintf("user: %s, name: %s", a.User, a.Name())
}
-func (a *KeyboardInteractive) clientConfig() *ssh.ClientConfig {
- return &ssh.ClientConfig{
+func (a *KeyboardInteractive) ClientConfig() (*ssh.ClientConfig, error) {
+ return a.SetHostKeyCallback(&ssh.ClientConfig{
User: a.User,
- Auth: []ssh.AuthMethod{ssh.KeyboardInteractiveChallenge(a.Challenge)},
- }
+ Auth: []ssh.AuthMethod{
+ ssh.KeyboardInteractiveChallenge(a.Challenge),
+ },
+ })
}
// Password implements AuthMethod by using the given password.
type Password struct {
- User string
- Pass string
- baseAuthMethod
+ User string
+ Password string
+ HostKeyCallbackHelper
}
func (a *Password) Name() string {
@@ -78,11 +81,11 @@ func (a *Password) String() string {
return fmt.Sprintf("user: %s, name: %s", a.User, a.Name())
}
-func (a *Password) clientConfig() *ssh.ClientConfig {
- return &ssh.ClientConfig{
+func (a *Password) ClientConfig() (*ssh.ClientConfig, error) {
+ return a.SetHostKeyCallback(&ssh.ClientConfig{
User: a.User,
- Auth: []ssh.AuthMethod{ssh.Password(a.Pass)},
- }
+ Auth: []ssh.AuthMethod{ssh.Password(a.Password)},
+ })
}
// PasswordCallback implements AuthMethod by using a callback
@@ -90,7 +93,7 @@ func (a *Password) clientConfig() *ssh.ClientConfig {
type PasswordCallback struct {
User string
Callback func() (pass string, err error)
- baseAuthMethod
+ HostKeyCallbackHelper
}
func (a *PasswordCallback) Name() string {
@@ -101,25 +104,25 @@ func (a *PasswordCallback) String() string {
return fmt.Sprintf("user: %s, name: %s", a.User, a.Name())
}
-func (a *PasswordCallback) clientConfig() *ssh.ClientConfig {
- return &ssh.ClientConfig{
+func (a *PasswordCallback) ClientConfig() (*ssh.ClientConfig, error) {
+ return a.SetHostKeyCallback(&ssh.ClientConfig{
User: a.User,
Auth: []ssh.AuthMethod{ssh.PasswordCallback(a.Callback)},
- }
+ })
}
// PublicKeys implements AuthMethod by using the given key pairs.
type PublicKeys struct {
User string
Signer ssh.Signer
- baseAuthMethod
+ HostKeyCallbackHelper
}
// NewPublicKeys returns a PublicKeys from a PEM encoded private key. An
// encryption password should be given if the pemBytes contains a password
// encrypted PEM block otherwise password should be empty. It supports RSA
// (PKCS#1), DSA (OpenSSL), and ECDSA private keys.
-func NewPublicKeys(user string, pemBytes []byte, password string) (AuthMethod, error) {
+func NewPublicKeys(user string, pemBytes []byte, password string) (*PublicKeys, error) {
block, _ := pem.Decode(pemBytes)
if x509.IsEncryptedPEMBlock(block) {
key, err := x509.DecryptPEMBlock(block, []byte(password))
@@ -142,7 +145,7 @@ func NewPublicKeys(user string, pemBytes []byte, password string) (AuthMethod, e
// NewPublicKeysFromFile returns a PublicKeys from a file containing a PEM
// encoded private key. An encryption password should be given if the pemBytes
// contains a password encrypted PEM block otherwise password should be empty.
-func NewPublicKeysFromFile(user, pemFile, password string) (AuthMethod, error) {
+func NewPublicKeysFromFile(user, pemFile, password string) (*PublicKeys, error) {
bytes, err := ioutil.ReadFile(pemFile)
if err != nil {
return nil, err
@@ -159,11 +162,11 @@ func (a *PublicKeys) String() string {
return fmt.Sprintf("user: %s, name: %s", a.User, a.Name())
}
-func (a *PublicKeys) clientConfig() *ssh.ClientConfig {
- return &ssh.ClientConfig{
+func (a *PublicKeys) ClientConfig() (*ssh.ClientConfig, error) {
+ return a.SetHostKeyCallback(&ssh.ClientConfig{
User: a.User,
Auth: []ssh.AuthMethod{ssh.PublicKeys(a.Signer)},
- }
+ })
}
func username() (string, error) {
@@ -173,9 +176,11 @@ func username() (string, error) {
} else {
username = os.Getenv("USER")
}
+
if username == "" {
return "", errors.New("failed to get username")
}
+
return username, nil
}
@@ -184,13 +189,13 @@ func username() (string, error) {
type PublicKeysCallback struct {
User string
Callback func() (signers []ssh.Signer, err error)
- baseAuthMethod
+ HostKeyCallbackHelper
}
// NewSSHAgentAuth returns a PublicKeysCallback based on a SSH agent, it opens
// a pipe with the SSH agent and uses the pipe as the implementer of the public
// key callback function.
-func NewSSHAgentAuth(u string) (AuthMethod, error) {
+func NewSSHAgentAuth(u string) (*PublicKeysCallback, error) {
var err error
if u == "" {
u, err = username()
@@ -218,11 +223,11 @@ func (a *PublicKeysCallback) String() string {
return fmt.Sprintf("user: %s, name: %s", a.User, a.Name())
}
-func (a *PublicKeysCallback) clientConfig() *ssh.ClientConfig {
- return &ssh.ClientConfig{
+func (a *PublicKeysCallback) ClientConfig() (*ssh.ClientConfig, error) {
+ return a.SetHostKeyCallback(&ssh.ClientConfig{
User: a.User,
Auth: []ssh.AuthMethod{ssh.PublicKeysCallback(a.Callback)},
- }
+ })
}
// NewKnownHostsCallback returns ssh.HostKeyCallback based on a file based on a
@@ -287,17 +292,26 @@ func filterKnownHostsFiles(files ...string) ([]string, error) {
return out, nil
}
-type baseAuthMethod struct {
+// HostKeyCallbackHelper is a helper that provides common functionality to
+// configure HostKeyCallback into a ssh.ClientConfig.
+type HostKeyCallbackHelper struct {
// HostKeyCallback is the function type used for verifying server keys.
- // If nil default callback will be create using NewKnownHostsHostKeyCallback
+ // If nil default callback will be create using NewKnownHostsCallback
// without argument.
HostKeyCallback ssh.HostKeyCallback
}
-func (m *baseAuthMethod) hostKeyCallback() (ssh.HostKeyCallback, error) {
+// SetHostKeyCallback sets the field HostKeyCallback in the given cfg. If
+// HostKeyCallback is empty a default callback is created using
+// NewKnownHostsCallback.
+func (m *HostKeyCallbackHelper) SetHostKeyCallback(cfg *ssh.ClientConfig) (*ssh.ClientConfig, error) {
+ var err error
if m.HostKeyCallback == nil {
- return NewKnownHostsCallback()
+ if m.HostKeyCallback, err = NewKnownHostsCallback(); err != nil {
+ return cfg, err
+ }
}
- return m.HostKeyCallback, nil
+ cfg.HostKeyCallback = m.HostKeyCallback
+ return cfg, nil
}
diff --git a/plumbing/transport/ssh/auth_method_test.go b/plumbing/transport/ssh/auth_method_test.go
index 2ee5100..1e77ca0 100644
--- a/plumbing/transport/ssh/auth_method_test.go
+++ b/plumbing/transport/ssh/auth_method_test.go
@@ -32,16 +32,16 @@ func (s *SuiteCommon) TestKeyboardInteractiveString(c *C) {
func (s *SuiteCommon) TestPasswordName(c *C) {
a := &Password{
- User: "test",
- Pass: "",
+ User: "test",
+ Password: "",
}
c.Assert(a.Name(), Equals, PasswordName)
}
func (s *SuiteCommon) TestPasswordString(c *C) {
a := &Password{
- User: "test",
- Pass: "",
+ User: "test",
+ Password: "",
}
c.Assert(a.String(), Equals, fmt.Sprintf("user: test, name: %s", PasswordName))
}
diff --git a/plumbing/transport/ssh/common.go b/plumbing/transport/ssh/common.go
index b7722bb..f5bc9a7 100644
--- a/plumbing/transport/ssh/common.go
+++ b/plumbing/transport/ssh/common.go
@@ -98,8 +98,7 @@ func (c *command) connect() error {
}
var err error
- config := c.auth.clientConfig()
- config.HostKeyCallback, err = c.auth.hostKeyCallback()
+ config, err := c.auth.ClientConfig()
if err != nil {
return err
}
@@ -159,7 +158,3 @@ func overrideConfig(overrides *ssh.ClientConfig, c *ssh.ClientConfig) {
*c = vc.Interface().(ssh.ClientConfig)
}
-
-func isZeroValue(v reflect.Value) bool {
- return reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface())
-}