diff options
Diffstat (limited to 'plumbing/transport/internal')
-rw-r--r-- | plumbing/transport/internal/common/common.go | 54 | ||||
-rw-r--r-- | plumbing/transport/internal/common/server.go | 5 |
2 files changed, 48 insertions, 11 deletions
diff --git a/plumbing/transport/internal/common/common.go b/plumbing/transport/internal/common/common.go index 04db770..2db8d54 100644 --- a/plumbing/transport/internal/common/common.go +++ b/plumbing/transport/internal/common/common.go @@ -7,6 +7,7 @@ package common import ( "bufio" + "context" "errors" "fmt" "io" @@ -64,6 +65,13 @@ type Command interface { Close() error } +// CommandKiller expands the Command interface, enableing it for being killed. +type CommandKiller interface { + // Kill and close the session whatever the state it is. It will block until + // the command is terminated. + Kill() error +} + type client struct { cmdr Commander } @@ -212,7 +220,7 @@ func (s *session) handleAdvRefDecodeError(err error) error { // UploadPack performs a request to the server to fetch a packfile. A reader is // returned with the packfile content. The reader must be closed after reading. -func (s *session) UploadPack(req *packp.UploadPackRequest) (*packp.UploadPackResponse, error) { +func (s *session) UploadPack(ctx context.Context, req *packp.UploadPackRequest) (*packp.UploadPackResponse, error) { if req.IsEmpty() { return nil, transport.ErrEmptyUploadPackRequest } @@ -227,11 +235,14 @@ func (s *session) UploadPack(req *packp.UploadPackRequest) (*packp.UploadPackRes s.packRun = true - if err := uploadPack(s.Stdin, s.Stdout, req); err != nil { + in := s.StdinContext(ctx) + out := s.StdoutContext(ctx) + + if err := uploadPack(in, out, req); err != nil { return nil, err } - r, err := ioutil.NonEmptyReader(s.Stdout) + r, err := ioutil.NonEmptyReader(out) if err == ioutil.ErrEmptyReader { if c, ok := s.Stdout.(io.Closer); ok { _ = c.Close() @@ -244,22 +255,45 @@ func (s *session) UploadPack(req *packp.UploadPackRequest) (*packp.UploadPackRes return nil, err } - rc := ioutil.NewReadCloser(r, s.Command) + rc := ioutil.NewReadCloser(r, s) return DecodeUploadPackResponse(rc, req) } -func (s *session) ReceivePack(req *packp.ReferenceUpdateRequest) (*packp.ReportStatus, error) { +func (s *session) StdinContext(ctx context.Context) io.WriteCloser { + return ioutil.NewWriteCloserOnError( + ioutil.NewContextWriteCloser(ctx, s.Stdin), + s.onError, + ) +} + +func (s *session) StdoutContext(ctx context.Context) io.Reader { + return ioutil.NewReaderOnError( + ioutil.NewContextReader(ctx, s.Stdout), + s.onError, + ) +} + +func (s *session) onError(err error) { + if k, ok := s.Command.(CommandKiller); ok { + _ = k.Kill() + } + + _ = s.Close() +} + +func (s *session) ReceivePack(ctx context.Context, req *packp.ReferenceUpdateRequest) (*packp.ReportStatus, error) { if _, err := s.AdvertisedReferences(); err != nil { return nil, err } s.packRun = true - if err := req.Encode(s.Stdin); err != nil { + w := s.StdinContext(ctx) + if err := req.Encode(w); err != nil { return nil, err } - if err := s.Stdin.Close(); err != nil { + if err := w.Close(); err != nil { return nil, err } @@ -270,11 +304,12 @@ func (s *session) ReceivePack(req *packp.ReferenceUpdateRequest) (*packp.ReportS } report := packp.NewReportStatus() - if err := report.Decode(s.Stdout); err != nil { + if err := report.Decode(s.StdoutContext(ctx)); err != nil { return nil, err } if err := report.Error(); err != nil { + defer s.Close() return report, err } @@ -300,8 +335,9 @@ func (s *session) finish() error { } func (s *session) Close() (err error) { - defer ioutil.CheckClose(s.Command, &err) err = s.finish() + + defer ioutil.CheckClose(s.Command, &err) return } diff --git a/plumbing/transport/internal/common/server.go b/plumbing/transport/internal/common/server.go index dd6cfbe..f4ca692 100644 --- a/plumbing/transport/internal/common/server.go +++ b/plumbing/transport/internal/common/server.go @@ -1,6 +1,7 @@ package common import ( + "context" "fmt" "io" @@ -34,7 +35,7 @@ func ServeUploadPack(cmd ServerCommand, s transport.UploadPackSession) (err erro } var resp *packp.UploadPackResponse - resp, err = s.UploadPack(req) + resp, err = s.UploadPack(context.TODO(), req) if err != nil { return err } @@ -57,7 +58,7 @@ func ServeReceivePack(cmd ServerCommand, s transport.ReceivePackSession) error { return fmt.Errorf("error decoding: %s", err) } - rs, err := s.ReceivePack(req) + rs, err := s.ReceivePack(context.TODO(), req) if rs != nil { if err := rs.Encode(cmd.Stdout); err != nil { return fmt.Errorf("error in encoding report status %s", err) |