aboutsummaryrefslogblamecommitdiffstats
path: root/plumbing/transport/ssh/common.go
blob: c327c419271143db966c36cf78149daa1bfb2098 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
12
13












                                                     
     
                                                                       

 
                    
 

                                           
 
                                                             




                                            
                                                            











                                           
                             




















                                                            




                                                                               








                                                                               
                                          

























































                                                                                   




                                                                       










                                                     
















                                                                     
package ssh

import (
	"errors"
	"fmt"
	"io"
	"strings"

	"gopkg.in/src-d/go-git.v4/plumbing/transport"

	"golang.org/x/crypto/ssh"
)

var (
	errAlreadyConnected = errors.New("ssh session already created")
)

type client struct{}

// DefaultClient is the default SSH client.
var DefaultClient = &client{}

func (c *client) NewFetchPackSession(ep transport.Endpoint) (
	transport.FetchPackSession, error) {

	return newFetchPackSession(ep)
}

func (c *client) NewSendPackSession(ep transport.Endpoint) (
	transport.SendPackSession, error) {

	return newSendPackSession(ep)
}

type session struct {
	connected   bool
	endpoint    transport.Endpoint
	client      *ssh.Client
	session     *ssh.Session
	stdin       io.WriteCloser
	stdout      io.Reader
	stderr      io.Reader
	sessionDone chan error
	auth        AuthMethod
}

func (s *session) SetAuth(auth transport.AuthMethod) error {
	a, ok := auth.(AuthMethod)
	if !ok {
		return transport.ErrInvalidAuthMethod
	}

	s.auth = a
	return nil
}

// Close closes the SSH session.
func (s *session) Close() error {
	if !s.connected {
		return nil
	}

	s.connected = false

	//XXX: If did read the full packfile, then the session might be already
	//     closed.
	_ = s.session.Close()

	return s.client.Close()
}

// ensureConnected connects to the SSH server, unless a AuthMethod was set with
// SetAuth method, by default uses an auth method based on PublicKeysCallback,
// it connects to a SSH agent, using the address stored in the SSH_AUTH_SOCK
// environment var.
func (s *session) connect() error {
	if s.connected {
		return errAlreadyConnected
	}

	if err := s.setAuthFromEndpoint(); err != nil {
		return err
	}

	var err error
	s.client, err = ssh.Dial("tcp", s.getHostWithPort(), s.auth.clientConfig())
	if err != nil {
		return err
	}

	if err := s.openSSHSession(); err != nil {
		_ = s.client.Close()
		return err
	}

	s.connected = true
	return nil
}

func (s *session) getHostWithPort() string {
	host := s.endpoint.Host
	if strings.Index(s.endpoint.Host, ":") == -1 {
		host += ":22"
	}

	return host
}

func (s *session) setAuthFromEndpoint() error {
	var u string
	if info := s.endpoint.User; info != nil {
		u = info.Username()
	}

	var err error
	s.auth, err = NewSSHAgentAuth(u)
	return err
}

func (s *session) openSSHSession() error {
	var err error
	s.session, err = s.client.NewSession()
	if err != nil {
		return fmt.Errorf("cannot open SSH session: %s", err)
	}

	s.stdin, err = s.session.StdinPipe()
	if err != nil {
		return fmt.Errorf("cannot pipe remote stdin: %s", err)
	}

	s.stdout, err = s.session.StdoutPipe()
	if err != nil {
		return fmt.Errorf("cannot pipe remote stdout: %s", err)
	}

	s.stderr, err = s.session.StderrPipe()
	if err != nil {
		return fmt.Errorf("cannot pipe remote stderr: %s", err)
	}

	return nil
}

func (s *session) runCommand(cmd string) chan error {
	done := make(chan error)
	go func() {
		done <- s.session.Run(cmd)
	}()

	return done
}

const (
	githubRepoNotFoundErr    = "ERROR: Repository not found."
	bitbucketRepoNotFoundErr = "conq: repository does not exist."
)

func isRepoNotFoundError(s string) bool {
	if strings.HasPrefix(s, githubRepoNotFoundErr) {
		return true
	}

	if strings.HasPrefix(s, bitbucketRepoNotFoundErr) {
		return true
	}

	return false
}