diff options
Diffstat (limited to 'clients/ssh')
-rw-r--r-- | clients/ssh/git_upload_pack.go | 139 | ||||
-rw-r--r-- | clients/ssh/git_upload_pack_test.go | 5 |
2 files changed, 110 insertions, 34 deletions
diff --git a/clients/ssh/git_upload_pack.go b/clients/ssh/git_upload_pack.go index e3b59a5..e172f1f 100644 --- a/clients/ssh/git_upload_pack.go +++ b/clients/ssh/git_upload_pack.go @@ -10,6 +10,8 @@ import ( "gopkg.in/src-d/go-git.v4/clients/common" "gopkg.in/src-d/go-git.v4/formats/packp/advrefs" + "gopkg.in/src-d/go-git.v4/formats/packp/pktline" + "gopkg.in/src-d/go-git.v4/formats/packp/ulreq" "golang.org/x/crypto/ssh" ) @@ -24,7 +26,8 @@ var ( ErrUnsupportedVCS = errors.New("only git is supported") ErrUnsupportedRepo = errors.New("only github.com is supported") - nak = []byte("0008NAK\n") + nak = []byte("NAK") + eol = []byte("\n") ) // GitUploadPackService holds the service information. @@ -139,75 +142,145 @@ func (s *GitUploadPackService) Disconnect() (err error) { // SSH session on a connected GitUploadPackService, sends the given // upload request to the server and returns a reader for the received // packfile. Closing the returned reader will close the SSH session. -func (s *GitUploadPackService) Fetch(r *common.GitUploadPackRequest) (rc io.ReadCloser, err error) { +func (s *GitUploadPackService) Fetch(req *common.GitUploadPackRequest) (rc io.ReadCloser, err error) { if !s.connected { return nil, ErrNotConnected } - session, err := s.client.NewSession() + session, i, o, done, err := openSSHSession(s.client, s.getCommand()) if err != nil { return nil, fmt.Errorf("cannot open SSH session: %s", err) } - si, err := session.StdinPipe() + if err := talkPackProtocol(i, o, req); err != nil { + return nil, err + } + + return &fetchSession{ + Reader: o, + session: session, + done: done, + }, nil +} + +func openSSHSession(c *ssh.Client, cmd string) ( + *ssh.Session, io.WriteCloser, io.Reader, <-chan error, error) { + + session, err := c.NewSession() if err != nil { - return nil, fmt.Errorf("cannot pipe remote stdin: %s", err) + return nil, nil, nil, nil, fmt.Errorf("cannot open SSH session: %s", err) } - so, err := session.StdoutPipe() + i, err := session.StdinPipe() if err != nil { - return nil, fmt.Errorf("cannot pipe remote stdout: %s", err) + return nil, nil, nil, nil, fmt.Errorf("cannot pipe remote stdin: %s", err) + } + + o, err := session.StdoutPipe() + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("cannot pipe remote stdout: %s", err) } done := make(chan error) go func() { - done <- session.Run(s.getCommand()) + done <- session.Run(cmd) }() - if err := skipAdvRef(so); err != nil { - return nil, fmt.Errorf("skipping advertised-refs: %s", err) + return session, i, o, done, nil +} + +// TODO support multi_ack mode +// TODO support multi_ack_detailed mode +// TODO support acks for common objects +// TODO build a proper state machine for all these processing options +func talkPackProtocol(w io.WriteCloser, r io.Reader, + req *common.GitUploadPackRequest) error { + + if err := skipAdvRef(r); err != nil { + return fmt.Errorf("skipping advertised-refs: %s", err) } - // send the upload request - _, err = io.Copy(si, r.Reader()) - if err != nil { - return nil, fmt.Errorf("sending upload-req message: %s", err) + if err := sendUlReq(w, req); err != nil { + return fmt.Errorf("sending upload-req message: %s", err) } - if err := si.Close(); err != nil { - return nil, fmt.Errorf("closing input: %s", err) + if err := sendHaves(w, req); err != nil { + return fmt.Errorf("sending haves message: %s", err) } - // TODO support multi_ack mode - // TODO support multi_ack_detailed mode - // TODO support acks for common objects - // TODO build a proper state machine for all these processing options - buf := make([]byte, len(nak)) - if _, err := io.ReadFull(so, buf); err != nil { - return nil, fmt.Errorf("looking for NAK: %s", err) + if err := sendDone(w); err != nil { + return fmt.Errorf("sending done message: %s", err) } - if !bytes.Equal(buf, nak) { - return nil, fmt.Errorf("NAK answer not found") + + if err := w.Close(); err != nil { + return fmt.Errorf("closing input: %s", err) } - return &fetchSession{ - Reader: so, - session: session, - done: done, - }, nil + if err := readNAK(r); err != nil { + return fmt.Errorf("reading NAK: %s", err) + } + + return nil } -func skipAdvRef(so io.Reader) error { - d := advrefs.NewDecoder(so) +func skipAdvRef(r io.Reader) error { + d := advrefs.NewDecoder(r) ar := advrefs.New() return d.Decode(ar) } +func sendUlReq(w io.Writer, req *common.GitUploadPackRequest) error { + ur := ulreq.New() + ur.Wants = req.Wants + ur.Depth = ulreq.DepthCommits(req.Depth) + e := ulreq.NewEncoder(w) + + return e.Encode(ur) +} + +func sendHaves(w io.Writer, req *common.GitUploadPackRequest) error { + e := pktline.NewEncoder(w) + for _, have := range req.Haves { + if err := e.Encodef("have %s\n", have); err != nil { + return fmt.Errorf("sending haves for %q: err ", have, err) + } + } + + if len(req.Haves) != 0 { + if err := e.Flush(); err != nil { + return fmt.Errorf("sending flush-pkt after haves: %s", err) + } + } + + return nil +} + +func sendDone(w io.Writer) error { + e := pktline.NewEncoder(w) + + return e.Encodef("done\n") +} + +func readNAK(r io.Reader) error { + s := pktline.NewScanner(r) + if !s.Scan() { + return s.Err() + } + + b := s.Bytes() + b = bytes.TrimSuffix(b, eol) + if !bytes.Equal(b, nak) { + return fmt.Errorf("expecting NAK, found %q instead", string(b)) + } + + return nil +} + type fetchSession struct { io.Reader session *ssh.Session - done chan error + done <-chan error } // Close closes the session and collects the output state of the remote diff --git a/clients/ssh/git_upload_pack_test.go b/clients/ssh/git_upload_pack_test.go index ff27cc2..d7160dc 100644 --- a/clients/ssh/git_upload_pack_test.go +++ b/clients/ssh/git_upload_pack_test.go @@ -136,6 +136,9 @@ func (s *RemoteSuite) TestFetchError(c *C) { req := &common.GitUploadPackRequest{} req.Want(core.NewHash("1111111111111111111111111111111111111111")) - _, err := r.Fetch(req) + reader, err := r.Fetch(req) + c.Assert(err, IsNil) + + err = reader.Close() c.Assert(err, Not(IsNil)) } |