package ssh
import (
"fmt"
"io"
"io/ioutil"
"log"
"net"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"gopkg.in/src-d/go-git.v4/plumbing/transport"
"gopkg.in/src-d/go-git.v4/plumbing/transport/test"
"github.com/gliderlabs/ssh"
stdssh "golang.org/x/crypto/ssh"
. "gopkg.in/check.v1"
"gopkg.in/src-d/go-git-fixtures.v3"
)
type UploadPackSuite struct {
test.UploadPackSuite
fixtures.Suite
port int
base string
}
var _ = Suite(&UploadPackSuite{})
func (s *UploadPackSuite) SetUpSuite(c *C) {
s.Suite.SetUpSuite(c)
l, err := net.Listen("tcp", "localhost:0")
c.Assert(err, IsNil)
s.port = l.Addr().(*net.TCPAddr).Port
s.base, err = ioutil.TempDir(os.TempDir(), fmt.Sprintf("go-git-ssh-%d", s.port))
c.Assert(err, IsNil)
DefaultAuthBuilder = func(user string) (AuthMethod, error) {
return &Password{User: user}, nil
}
s.UploadPackSuite.Client = NewClient(&stdssh.ClientConfig{
HostKeyCallback: stdssh.InsecureIgnoreHostKey(),
})
s.UploadPackSuite.Endpoint = s.prepareRepository(c, fixtures.Basic().One(), "basic.git")
s.UploadPackSuite.EmptyEndpoint = s.prepareRepository(c, fixtures.ByTag("empty").One(), "empty.git")
s.UploadPackSuite.NonExistentEndpoint = s.newEndpoint(c, "non-existent.git")
server := &ssh.Server{Handler: handlerSSH}
go func() {
log.Fatal(server.Serve(l))
}()
}
func (s *UploadPackSuite) prepareRepository(c *C, f *fixtures.Fixture, name string) *transport.Endpoint {
fs := f.DotGit()
err := fixtures.EnsureIsBare(fs)
c.Assert(err, IsNil)
path := filepath.Join(s.base, name)
err = os.Rename(fs.Root(), path)
c.Assert(err, IsNil)
return s.newEndpoint(c, name)
}
func (s *UploadPackSuite) newEndpoint(c *C, name string) *transport.Endpoint {
ep, err := transport.NewEndpoint(fmt.Sprintf(
"ssh://git@localhost:%d/%s/%s", s.port, filepath.ToSlash(s.base), name,
))
c.Assert(err, IsNil)
return ep
}
func handlerSSH(s ssh.Session) {
cmd, stdin, stderr, stdout, err := buildCommand(s.Command())
if err != nil {
fmt.Println(err)
return
}
if err := cmd.Start(); err != nil {
fmt.Println(err)
return
}
go func() {
defer stdin.Close()
io.Copy(stdin, s)
}()
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
io.Copy(s.Stderr(), stderr)
}()
go func() {
defer wg.Done()
io.Copy(s, stdout)
}()
wg.Wait()
if err := cmd.Wait(); err != nil {
return
}
}
func buildCommand(c []string) (cmd *exec.Cmd, stdin io.WriteCloser, stderr, stdout io.ReadCloser, err error) {
if len(c) != 2 {
err = fmt.Errorf("invalid command")
return
}
// fix for Windows environments
path := strings.Replace(c[1], "/C:/", "C:/", 1)
cmd = exec.Command(c[0], path)
stdout, err = cmd.StdoutPipe()
if err != nil {
return
}
stdin, err = cmd.StdinPipe()
if err != nil {
return
}
stderr, err = cmd.StderrPipe()
if err != nil {
return
}
return
}