aboutsummaryrefslogtreecommitdiffstats
path: root/plumbing/transport/internal
diff options
context:
space:
mode:
Diffstat (limited to 'plumbing/transport/internal')
-rw-r--r--plumbing/transport/internal/common/common.go54
-rw-r--r--plumbing/transport/internal/common/server.go5
2 files changed, 48 insertions, 11 deletions
diff --git a/plumbing/transport/internal/common/common.go b/plumbing/transport/internal/common/common.go
index 04db770..2db8d54 100644
--- a/plumbing/transport/internal/common/common.go
+++ b/plumbing/transport/internal/common/common.go
@@ -7,6 +7,7 @@ package common
import (
"bufio"
+ "context"
"errors"
"fmt"
"io"
@@ -64,6 +65,13 @@ type Command interface {
Close() error
}
+// CommandKiller expands the Command interface, enableing it for being killed.
+type CommandKiller interface {
+ // Kill and close the session whatever the state it is. It will block until
+ // the command is terminated.
+ Kill() error
+}
+
type client struct {
cmdr Commander
}
@@ -212,7 +220,7 @@ func (s *session) handleAdvRefDecodeError(err error) error {
// UploadPack performs a request to the server to fetch a packfile. A reader is
// returned with the packfile content. The reader must be closed after reading.
-func (s *session) UploadPack(req *packp.UploadPackRequest) (*packp.UploadPackResponse, error) {
+func (s *session) UploadPack(ctx context.Context, req *packp.UploadPackRequest) (*packp.UploadPackResponse, error) {
if req.IsEmpty() {
return nil, transport.ErrEmptyUploadPackRequest
}
@@ -227,11 +235,14 @@ func (s *session) UploadPack(req *packp.UploadPackRequest) (*packp.UploadPackRes
s.packRun = true
- if err := uploadPack(s.Stdin, s.Stdout, req); err != nil {
+ in := s.StdinContext(ctx)
+ out := s.StdoutContext(ctx)
+
+ if err := uploadPack(in, out, req); err != nil {
return nil, err
}
- r, err := ioutil.NonEmptyReader(s.Stdout)
+ r, err := ioutil.NonEmptyReader(out)
if err == ioutil.ErrEmptyReader {
if c, ok := s.Stdout.(io.Closer); ok {
_ = c.Close()
@@ -244,22 +255,45 @@ func (s *session) UploadPack(req *packp.UploadPackRequest) (*packp.UploadPackRes
return nil, err
}
- rc := ioutil.NewReadCloser(r, s.Command)
+ rc := ioutil.NewReadCloser(r, s)
return DecodeUploadPackResponse(rc, req)
}
-func (s *session) ReceivePack(req *packp.ReferenceUpdateRequest) (*packp.ReportStatus, error) {
+func (s *session) StdinContext(ctx context.Context) io.WriteCloser {
+ return ioutil.NewWriteCloserOnError(
+ ioutil.NewContextWriteCloser(ctx, s.Stdin),
+ s.onError,
+ )
+}
+
+func (s *session) StdoutContext(ctx context.Context) io.Reader {
+ return ioutil.NewReaderOnError(
+ ioutil.NewContextReader(ctx, s.Stdout),
+ s.onError,
+ )
+}
+
+func (s *session) onError(err error) {
+ if k, ok := s.Command.(CommandKiller); ok {
+ _ = k.Kill()
+ }
+
+ _ = s.Close()
+}
+
+func (s *session) ReceivePack(ctx context.Context, req *packp.ReferenceUpdateRequest) (*packp.ReportStatus, error) {
if _, err := s.AdvertisedReferences(); err != nil {
return nil, err
}
s.packRun = true
- if err := req.Encode(s.Stdin); err != nil {
+ w := s.StdinContext(ctx)
+ if err := req.Encode(w); err != nil {
return nil, err
}
- if err := s.Stdin.Close(); err != nil {
+ if err := w.Close(); err != nil {
return nil, err
}
@@ -270,11 +304,12 @@ func (s *session) ReceivePack(req *packp.ReferenceUpdateRequest) (*packp.ReportS
}
report := packp.NewReportStatus()
- if err := report.Decode(s.Stdout); err != nil {
+ if err := report.Decode(s.StdoutContext(ctx)); err != nil {
return nil, err
}
if err := report.Error(); err != nil {
+ defer s.Close()
return report, err
}
@@ -300,8 +335,9 @@ func (s *session) finish() error {
}
func (s *session) Close() (err error) {
- defer ioutil.CheckClose(s.Command, &err)
err = s.finish()
+
+ defer ioutil.CheckClose(s.Command, &err)
return
}
diff --git a/plumbing/transport/internal/common/server.go b/plumbing/transport/internal/common/server.go
index dd6cfbe..f4ca692 100644
--- a/plumbing/transport/internal/common/server.go
+++ b/plumbing/transport/internal/common/server.go
@@ -1,6 +1,7 @@
package common
import (
+ "context"
"fmt"
"io"
@@ -34,7 +35,7 @@ func ServeUploadPack(cmd ServerCommand, s transport.UploadPackSession) (err erro
}
var resp *packp.UploadPackResponse
- resp, err = s.UploadPack(req)
+ resp, err = s.UploadPack(context.TODO(), req)
if err != nil {
return err
}
@@ -57,7 +58,7 @@ func ServeReceivePack(cmd ServerCommand, s transport.ReceivePackSession) error {
return fmt.Errorf("error decoding: %s", err)
}
- rs, err := s.ReceivePack(req)
+ rs, err := s.ReceivePack(context.TODO(), req)
if rs != nil {
if err := rs.Encode(cmd.Stdout); err != nil {
return fmt.Errorf("error in encoding report status %s", err)