aboutsummaryrefslogblamecommitdiffstats
path: root/plumbing/transport/ssh/common.go
blob: 6f0f3d48479379158d7fc76b900e0755a7dd6b82 (plain) (tree)






















































































































































                                                                                                   
package ssh

import (
	"errors"
	"fmt"
	"io"
	"strings"

	"gopkg.in/src-d/go-git.v4/plumbing/transport"

	"golang.org/x/crypto/ssh"
)

// New errors introduced by this package.
var (
	ErrAdvertistedReferencesAlreadyCalled = errors.New("cannot call AdvertisedReference twice")
	ErrAlreadyConnected                   = errors.New("ssh session already created")
	ErrAuthRequired                       = errors.New("cannot connect: auth required")
	ErrNotConnected                       = errors.New("not connected")
	ErrUploadPackAnswerFormat             = errors.New("git-upload-pack bad answer format")
	ErrUnsupportedVCS                     = errors.New("only git is supported")
	ErrUnsupportedRepo                    = errors.New("only github.com is supported")
)

type Client struct{}

var DefaultClient = NewClient()

func NewClient() transport.Client {
	return &Client{}
}

func (c *Client) NewFetchPackSession(ep transport.Endpoint) (
	transport.FetchPackSession, error) {

	return newFetchPackSession(ep)
}

func (c *Client) NewSendPackSession(ep transport.Endpoint) (
	transport.SendPackSession, error) {

	return newSendPackSession(ep)
}

type session struct {
	connected   bool
	endpoint    transport.Endpoint
	client      *ssh.Client
	session     *ssh.Session
	stdin       io.WriteCloser
	stdout      io.Reader
	sessionDone chan error
	auth        AuthMethod
}

func (s *session) SetAuth(auth transport.AuthMethod) error {
	a, ok := auth.(AuthMethod)
	if !ok {
		return transport.ErrInvalidAuthMethod
	}

	s.auth = a
	return nil
}

// Close closes the SSH session.
func (s *session) Close() error {
	if !s.connected {
		return nil
	}

	s.connected = false
	return s.client.Close()
}

// ensureConnected connects to the SSH server, unless a AuthMethod was set with
// SetAuth method, by default uses an auth method based on PublicKeysCallback,
// it connects to a SSH agent, using the address stored in the SSH_AUTH_SOCK
// environment var.
func (s *session) connect() error {
	if s.connected {
		return ErrAlreadyConnected
	}

	if err := s.setAuthFromEndpoint(); err != nil {
		return err
	}

	var err error
	s.client, err = ssh.Dial("tcp", s.getHostWithPort(), s.auth.clientConfig())
	if err != nil {
		return err
	}

	if err := s.openSSHSession(); err != nil {
		_ = s.client.Close()
		return err
	}

	s.connected = true
	return nil
}

func (s *session) getHostWithPort() string {
	host := s.endpoint.Host
	if strings.Index(s.endpoint.Host, ":") == -1 {
		host += ":22"
	}

	return host
}

func (s *session) setAuthFromEndpoint() error {
	var u string
	if info := s.endpoint.User; info != nil {
		u = info.Username()
	}

	var err error
	s.auth, err = NewSSHAgentAuth(u)
	return err
}

func (s *session) openSSHSession() error {
	var err error
	s.session, err = s.client.NewSession()
	if err != nil {
		return fmt.Errorf("cannot open SSH session: %s", err)
	}

	s.stdin, err = s.session.StdinPipe()
	if err != nil {
		return fmt.Errorf("cannot pipe remote stdin: %s", err)
	}

	s.stdout, err = s.session.StdoutPipe()
	if err != nil {
		return fmt.Errorf("cannot pipe remote stdout: %s", err)
	}

	return nil
}

func (s *session) runCommand(cmd string) chan error {
	done := make(chan error)
	go func() {
		done <- s.session.Run(cmd)
	}()

	return done
}