diff options
Diffstat (limited to 'utils')
-rw-r--r-- | utils/ioutil/common.go | 96 | ||||
-rw-r--r-- | utils/ioutil/common_test.go | 102 |
2 files changed, 197 insertions, 1 deletions
diff --git a/utils/ioutil/common.go b/utils/ioutil/common.go index 73cc9c3..66044e2 100644 --- a/utils/ioutil/common.go +++ b/utils/ioutil/common.go @@ -3,8 +3,11 @@ package ioutil import ( "bufio" + "context" "errors" "io" + + "github.com/jbenet/go-context/io" ) type readPeeker interface { @@ -52,6 +55,21 @@ func NewReadCloser(r io.Reader, c io.Closer) io.ReadCloser { return &readCloser{Reader: r, closer: c} } +type writeCloser struct { + io.Writer + closer io.Closer +} + +func (r *writeCloser) Close() error { + return r.closer.Close() +} + +// NewWriteCloser creates an `io.WriteCloser` with the given `io.Writer` and +// `io.Closer`. +func NewWriteCloser(w io.Writer, c io.Closer) io.WriteCloser { + return &writeCloser{Writer: w, closer: c} +} + type writeNopCloser struct { io.Writer } @@ -72,3 +90,81 @@ func CheckClose(c io.Closer, err *error) { *err = cerr } } + +// NewContextWriter wraps a writer to make it respect given Context. +// If there is a blocking write, the returned Writer will return whenever the +// context is cancelled (the return values are n=0 and err=ctx.Err()). +func NewContextWriter(ctx context.Context, w io.Writer) io.Writer { + return ctxio.NewWriter(ctx, w) +} + +// NewContextReader wraps a reader to make it respect given Context. +// If there is a blocking read, the returned Reader will return whenever the +// context is cancelled (the return values are n=0 and err=ctx.Err()). +func NewContextReader(ctx context.Context, r io.Reader) io.Reader { + return ctxio.NewReader(ctx, r) +} + +// NewContextWriteCloser as NewContextWriter but with io.Closer interface. +func NewContextWriteCloser(ctx context.Context, w io.WriteCloser) io.WriteCloser { + ctxw := ctxio.NewWriter(ctx, w) + return NewWriteCloser(ctxw, w) +} + +// NewContextReadCloser as NewContextReader but with io.Closer interface. +func NewContextReadCloser(ctx context.Context, r io.ReadCloser) io.ReadCloser { + ctxr := ctxio.NewReader(ctx, r) + return NewReadCloser(ctxr, r) +} + +type readerOnError struct { + io.Reader + notify func(error) +} + +// NewReaderOnError returns a io.Reader that call the notify function when an +// unexpected (!io.EOF) error happends, after call Read function. +func NewReaderOnError(r io.Reader, notify func(error)) io.Reader { + return &readerOnError{r, notify} +} + +// NewReadCloserOnError returns a io.ReadCloser that call the notify function +// when an unexpected (!io.EOF) error happends, after call Read function. +func NewReadCloserOnError(r io.ReadCloser, notify func(error)) io.ReadCloser { + return NewReadCloser(NewReaderOnError(r, notify), r) +} + +func (r *readerOnError) Read(buf []byte) (n int, err error) { + n, err = r.Reader.Read(buf) + if err != nil && err != io.EOF { + r.notify(err) + } + + return +} + +type writerOnError struct { + io.Writer + notify func(error) +} + +// NewWriterOnError returns a io.Writer that call the notify function when an +// unexpected (!io.EOF) error happends, after call Write function. +func NewWriterOnError(w io.Writer, notify func(error)) io.Writer { + return &writerOnError{w, notify} +} + +// NewWriteCloserOnError returns a io.WriteCloser that call the notify function +//when an unexpected (!io.EOF) error happends, after call Write function. +func NewWriteCloserOnError(w io.WriteCloser, notify func(error)) io.WriteCloser { + return NewWriteCloser(NewWriterOnError(w, notify), w) +} + +func (r *writerOnError) Write(p []byte) (n int, err error) { + n, err = r.Writer.Write(p) + if err != nil && err != io.EOF { + r.notify(err) + } + + return +} diff --git a/utils/ioutil/common_test.go b/utils/ioutil/common_test.go index 2d6ef80..27bfa62 100644 --- a/utils/ioutil/common_test.go +++ b/utils/ioutil/common_test.go @@ -2,6 +2,7 @@ package ioutil import ( "bytes" + "context" "io/ioutil" "strings" "testing" @@ -55,6 +56,106 @@ func (s *CommonSuite) TestNewReadCloser(c *C) { c.Assert(closer.called, Equals, 1) } +func (s *CommonSuite) TestNewContextReader(c *C) { + buf := bytes.NewBuffer([]byte("12")) + ctx, close := context.WithCancel(context.Background()) + + r := NewContextReader(ctx, buf) + + b := make([]byte, 1) + n, err := r.Read(b) + c.Assert(n, Equals, 1) + c.Assert(err, IsNil) + + close() + n, err = r.Read(b) + c.Assert(n, Equals, 0) + c.Assert(err, NotNil) +} + +func (s *CommonSuite) TestNewContextReadCloser(c *C) { + buf := NewReadCloser(bytes.NewBuffer([]byte("12")), &closer{}) + ctx, close := context.WithCancel(context.Background()) + + r := NewContextReadCloser(ctx, buf) + + b := make([]byte, 1) + n, err := r.Read(b) + c.Assert(n, Equals, 1) + c.Assert(err, IsNil) + + close() + n, err = r.Read(b) + c.Assert(n, Equals, 0) + c.Assert(err, NotNil) + + c.Assert(r.Close(), IsNil) +} + +func (s *CommonSuite) TestNewContextWriter(c *C) { + buf := bytes.NewBuffer(nil) + ctx, close := context.WithCancel(context.Background()) + + r := NewContextWriter(ctx, buf) + + n, err := r.Write([]byte("1")) + c.Assert(n, Equals, 1) + c.Assert(err, IsNil) + + close() + n, err = r.Write([]byte("1")) + c.Assert(n, Equals, 0) + c.Assert(err, NotNil) +} + +func (s *CommonSuite) TestNewContextWriteCloser(c *C) { + buf := NewWriteCloser(bytes.NewBuffer(nil), &closer{}) + ctx, close := context.WithCancel(context.Background()) + + w := NewContextWriteCloser(ctx, buf) + + n, err := w.Write([]byte("1")) + c.Assert(n, Equals, 1) + c.Assert(err, IsNil) + + close() + n, err = w.Write([]byte("1")) + c.Assert(n, Equals, 0) + c.Assert(err, NotNil) + + c.Assert(w.Close(), IsNil) +} + +func (s *CommonSuite) TestNewWriteCloserOnError(c *C) { + buf := NewWriteCloser(bytes.NewBuffer(nil), &closer{}) + + ctx, close := context.WithCancel(context.Background()) + + var called error + w := NewWriteCloserOnError(NewContextWriteCloser(ctx, buf), func(err error) { + called = err + }) + + close() + w.Write(nil) + + c.Assert(called, NotNil) +} + +func (s *CommonSuite) TestNewReadCloserOnError(c *C) { + buf := NewReadCloser(bytes.NewBuffer(nil), &closer{}) + ctx, close := context.WithCancel(context.Background()) + + var called error + w := NewReadCloserOnError(NewContextReadCloser(ctx, buf), func(err error) { + called = err + }) + + close() + w.Read(nil) + + c.Assert(called, NotNil) +} func ExampleCheckClose() { // CheckClose is commonly used with named return values f := func() (err error) { @@ -68,7 +169,6 @@ func ExampleCheckClose() { // if err is not nil, CheckClose will assign any close errors to it return err - } err := f() |