From b9c0a09435392913c0054382500c805cd7cb596b Mon Sep 17 00:00:00 2001 From: Máximo Cuadros Date: Sun, 25 Sep 2016 23:58:59 +0200 Subject: formats: objfile idomatic reader/writer --- formats/objfile/common.go | 79 --------------------- formats/objfile/common_test.go | 31 --------- formats/objfile/reader.go | 152 ++++++++++++++++++----------------------- formats/objfile/reader_test.go | 64 ++++------------- formats/objfile/writer.go | 145 +++++++++++++++------------------------ formats/objfile/writer_test.go | 51 ++++++++------ 6 files changed, 168 insertions(+), 354 deletions(-) delete mode 100644 formats/objfile/common.go (limited to 'formats') diff --git a/formats/objfile/common.go b/formats/objfile/common.go deleted file mode 100644 index 839f92c..0000000 --- a/formats/objfile/common.go +++ /dev/null @@ -1,79 +0,0 @@ -package objfile - -import ( - "errors" - "io" - "strconv" - - "gopkg.in/src-d/go-git.v4/core" -) - -var ( - // ErrClosed is returned when the objfile Reader or Writer is already closed. - ErrClosed = errors.New("objfile: already closed") - // ErrHeader is returned when the objfile has an invalid header. - ErrHeader = errors.New("objfile: invalid header") - // ErrNegativeSize is returned when a negative object size is declared. - ErrNegativeSize = errors.New("objfile: negative object size") -) - -type header struct { - t core.ObjectType - size int64 -} - -func (h *header) Read(r io.Reader) error { - t, err := h.readSlice(r, ' ') - if err != nil { - return err - } - - h.t, err = core.ParseObjectType(string(t)) - if err != nil { - return err - } - - size, err := h.readSlice(r, 0) - if err != nil { - return err - } - - h.size, err = strconv.ParseInt(string(size), 10, 64) - if err != nil { - return ErrHeader - } - - if h.size < 0 { - return ErrNegativeSize - } - - return nil -} - -func (h *header) Write(w io.Writer) error { - b := h.t.Bytes() - b = append(b, ' ') - b = append(b, []byte(strconv.FormatInt(h.size, 10))...) - b = append(b, 0) - _, err := w.Write(b) - return err -} - -// readSlice reads one byte at a time from r until it encounters delim or an -// error. -func (h *header) readSlice(r io.Reader, delim byte) ([]byte, error) { - var buf [1]byte - value := make([]byte, 0, 16) - for { - if n, err := r.Read(buf[:]); err != nil && (err != io.EOF || n == 0) { - if err == io.EOF { - return nil, ErrHeader - } - return nil, err - } - if buf[0] == delim { - return value, nil - } - value = append(value, buf[0]) - } -} diff --git a/formats/objfile/common_test.go b/formats/objfile/common_test.go index e443e73..682dfbb 100644 --- a/formats/objfile/common_test.go +++ b/formats/objfile/common_test.go @@ -1,7 +1,6 @@ package objfile import ( - "bytes" "encoding/base64" "testing" @@ -68,33 +67,3 @@ var objfileFixtures = []objfileFixture{ } func Test(t *testing.T) { TestingT(t) } - -type SuiteCommon struct{} - -var _ = Suite(&SuiteCommon{}) - -func (s *SuiteCommon) TestHeaderReadEmpty(c *C) { - var h header - c.Assert(h.Read(new(bytes.Buffer)), Equals, ErrHeader) -} - -func (s *SuiteCommon) TestHeaderReadGarbage(c *C) { - var h header - c.Assert(h.Read(bytes.NewBuffer([]byte{1, 2, 3, 4, 5})), Equals, ErrHeader) - c.Assert(h.Read(bytes.NewBuffer([]byte{1, 2, 3, 4, 5, '0'})), Equals, ErrHeader) -} - -func (s *SuiteCommon) TestHeaderReadInvalidType(c *C) { - var h header - c.Assert(h.Read(bytes.NewBuffer([]byte{1, 2, ' ', 4, 5, 0})), Equals, core.ErrInvalidType) -} - -func (s *SuiteCommon) TestHeaderReadInvalidSize(c *C) { - var h header - c.Assert(h.Read(bytes.NewBuffer([]byte{'b', 'l', 'o', 'b', ' ', 'a', 0})), Equals, ErrHeader) -} - -func (s *SuiteCommon) TestHeaderReadNegativeSize(c *C) { - var h header - c.Assert(h.Read(bytes.NewBuffer([]byte{'b', 'l', 'o', 'b', ' ', '-', '1', 0})), Equals, ErrNegativeSize) -} diff --git a/formats/objfile/reader.go b/formats/objfile/reader.go index 99ed754..ab67c6a 100644 --- a/formats/objfile/reader.go +++ b/formats/objfile/reader.go @@ -1,69 +1,96 @@ package objfile import ( + "compress/zlib" "errors" "io" + "strconv" + "gopkg.in/src-d/go-git.v3/formats/packfile" "gopkg.in/src-d/go-git.v4/core" - - "github.com/klauspost/compress/zlib" ) var ( - // ErrZLib is returned when the objfile contains invalid zlib data. - ErrZLib = errors.New("objfile: invalid zlib data") + ErrClosed = errors.New("objfile: already closed") + ErrHeader = errors.New("objfile: invalid header") + ErrNegativeSize = errors.New("objfile: negative object size") ) // Reader reads and decodes compressed objfile data from a provided io.Reader. -// // Reader implements io.ReadCloser. Close should be called when finished with // the Reader. Close will not close the underlying io.Reader. type Reader struct { - header header - hash core.Hash // final computed hash stored after Close - - r io.Reader // provided reader wrapped in decompressor and tee - decompressor io.ReadCloser // provided reader wrapped in decompressor, retained for calling Close - h core.Hasher // streaming SHA1 hash of decoded data + multi io.Reader + zlib io.ReadCloser + hasher core.Hasher } // NewReader returns a new Reader reading from r. -// -// Calling NewReader causes it to immediately read in header data from r -// containing size and type information. Any errors encountered in that -// process will be returned in err. -// -// The returned Reader implements io.ReadCloser. Close should be called when -// finished with the Reader. Close will not close the underlying io.Reader. func NewReader(r io.Reader) (*Reader, error) { - reader := &Reader{} - return reader, reader.init(r) + zlib, err := zlib.NewReader(r) + if err != nil { + return nil, packfile.ErrZLib.AddDetails(err.Error()) + } + + return &Reader{ + zlib: zlib, + }, nil } -// init prepares the zlib decompressor for the given input as well as a hasher -// for computing its hash. -// -// init immediately reads header data from the input and stores it. This leaves -// the Reader in a state that is ready to read content. -func (r *Reader) init(input io.Reader) (err error) { - r.decompressor, err = zlib.NewReader(input) +// Header reads the type and the size of object, and prepares the reader for read +func (r *Reader) Header() (t core.ObjectType, size int64, err error) { + var raw []byte + raw, err = r.readUntil(' ') + if err != nil { + return + } + + t, err = core.ParseObjectType(string(raw)) if err != nil { - // TODO: Make this error match the ZLibErr in formats/packfile/reader.go? - return ErrZLib + return } - err = r.header.Read(r.decompressor) + raw, err = r.readUntil(0) if err != nil { - r.decompressor.Close() return } - r.h = core.NewHasher(r.header.t, r.header.size) - r.r = io.TeeReader(r.decompressor, r.h) // All reads from the decompressor also write to the hash + size, err = strconv.ParseInt(string(raw), 10, 64) + if err != nil { + err = ErrHeader + return + } + defer r.prepareForRead(t, size) return } +// readSlice reads one byte at a time from r until it encounters delim or an +// error. +func (r *Reader) readUntil(delim byte) ([]byte, error) { + var buf [1]byte + value := make([]byte, 0, 16) + for { + if n, err := r.zlib.Read(buf[:]); err != nil && (err != io.EOF || n == 0) { + if err == io.EOF { + return nil, ErrHeader + } + return nil, err + } + + if buf[0] == delim { + return value, nil + } + + value = append(value, buf[0]) + } +} + +func (r *Reader) prepareForRead(t core.ObjectType, size int64) { + r.hasher = core.NewHasher(t, size) + r.multi = io.TeeReader(r.zlib, r.hasher) +} + // Read reads len(p) bytes into p from the object data stream. It returns // the number of bytes read (0 <= n <= len(p)) and any error encountered. Even // if Read returns n < len(p), it may use all of p as scratch space during the @@ -72,65 +99,20 @@ func (r *Reader) init(input io.Reader) (err error) { // If Read encounters the end of the data stream it will return err == io.EOF, // either in the current call if n > 0 or in a subsequent call. func (r *Reader) Read(p []byte) (n int, err error) { - if r.r == nil { - return 0, ErrClosed - } - - return r.r.Read(p) -} - -// Type returns the type of the object. -func (r *Reader) Type() core.ObjectType { - return r.header.t -} - -// Size returns the uncompressed size of the object in bytes. -func (r *Reader) Size() int64 { - return r.header.size + return r.multi.Read(p) } // Hash returns the hash of the object data stream that has been read so far. -// It can be called before or after Close. func (r *Reader) Hash() core.Hash { - if r.r != nil { - return r.h.Sum() // Not yet closed, return hash of data read so far - } - return r.hash + return r.hasher.Sum() } -// Close releases any resources consumed by the Reader. -// -// Calling Close does not close the wrapped io.Reader originally passed to -// NewReader. -func (r *Reader) Close() (err error) { - if r.r == nil { - // TODO: Consider returning ErrClosed here? - return nil // Already closed - } - - // Release the decompressor's resources - err = r.decompressor.Close() - - // Save the hash because we're about to throw away the hasher - r.hash = r.h.Sum() - - // Release references - r.r = nil // Indicates closed state - r.decompressor = nil - r.h.Hash = nil - - return -} - -// FillObject fills the given object from an object entry -func (r *Reader) FillObject(obj core.Object) error { - obj.SetType(r.header.t) - obj.SetSize(r.header.size) - w, err := obj.Writer() - if err != nil { +// Close releases any resources consumed by the Reader. Calling Close does not +// close the wrapped io.Reader originally passed to NewReader. +func (r *Reader) Close() error { + if err := r.zlib.Close(); err != nil { return err } - _, err = io.Copy(w, r.r) - return err + return nil } diff --git a/formats/objfile/reader_test.go b/formats/objfile/reader_test.go index 964a071..a383fd2 100644 --- a/formats/objfile/reader_test.go +++ b/formats/objfile/reader_test.go @@ -9,8 +9,6 @@ import ( . "gopkg.in/check.v1" "gopkg.in/src-d/go-git.v4/core" - - "github.com/klauspost/compress/zlib" ) type SuiteReader struct{} @@ -28,76 +26,42 @@ func (s *SuiteReader) TestReadObjfile(c *C) { } } -func testReader(c *C, source io.Reader, hash core.Hash, typ core.ObjectType, content []byte, com string) { +func testReader(c *C, source io.Reader, hash core.Hash, t core.ObjectType, content []byte, com string) { r, err := NewReader(source) c.Assert(err, IsNil) - c.Assert(r.Type(), Equals, typ) + + typ, size, err := r.Header() + c.Assert(err, IsNil) + c.Assert(typ, Equals, t) + c.Assert(content, HasLen, int(size)) + rc, err := ioutil.ReadAll(r) c.Assert(err, IsNil) c.Assert(rc, DeepEquals, content, Commentf("%scontent=%s, expected=%s", base64.StdEncoding.EncodeToString(rc), base64.StdEncoding.EncodeToString(content))) - c.Assert(r.Size(), Equals, int64(len(content))) + c.Assert(r.Hash(), Equals, hash) // Test Hash() before close c.Assert(r.Close(), IsNil) - c.Assert(r.Hash(), Equals, hash) // Test Hash() after close - _, err = r.Read(make([]byte, 0, 1)) - c.Assert(err, Equals, ErrClosed) + } func (s *SuiteReader) TestReadEmptyObjfile(c *C) { source := bytes.NewReader([]byte{}) _, err := NewReader(source) - c.Assert(err, Equals, ErrZLib) -} - -func (s *SuiteReader) TestReadEmptyContent(c *C) { - b := new(bytes.Buffer) - w := zlib.NewWriter(b) - c.Assert(w.Close(), IsNil) - _, err := NewReader(b) - c.Assert(err, Equals, ErrHeader) + c.Assert(err, NotNil) } func (s *SuiteReader) TestReadGarbage(c *C) { source := bytes.NewReader([]byte("!@#$RO!@NROSADfinq@o#irn@oirfn")) _, err := NewReader(source) - c.Assert(err, Equals, ErrZLib) + c.Assert(err, NotNil) } func (s *SuiteReader) TestReadCorruptZLib(c *C) { data, _ := base64.StdEncoding.DecodeString("eAFLysaalPUjBgAAAJsAHw") source := bytes.NewReader(data) - _, err := NewReader(source) - c.Assert(err, NotNil) -} - -func (s *SuiteReader) TestFillObject(c *C) { - for k, fixture := range objfileFixtures { - com := fmt.Sprintf("test %d: ", k) - hash := core.NewHash(fixture.hash) - content, _ := base64.StdEncoding.DecodeString(fixture.content) - data, _ := base64.StdEncoding.DecodeString(fixture.data) - - testFillObject(c, bytes.NewReader(data), hash, fixture.t, content, com) - } -} - -func testFillObject(c *C, source io.Reader, hash core.Hash, typ core.ObjectType, content []byte, com string) { - var o core.Object = &core.MemoryObject{} r, err := NewReader(source) c.Assert(err, IsNil) - err = r.FillObject(o) - c.Assert(err, IsNil) - c.Assert(o.Type(), Equals, typ) - c.Assert(o.Size(), Equals, int64(len(content))) - c.Assert(o.Hash(), Equals, hash) - or, err := o.Reader() - c.Assert(err, IsNil) - rc, err := ioutil.ReadAll(or) - c.Assert(err, IsNil) - c.Assert(rc, DeepEquals, content, Commentf("%scontent=%s, expected=%s", base64.StdEncoding.EncodeToString(rc), base64.StdEncoding.EncodeToString(content))) - c.Assert(or.Close(), IsNil) - _, err = or.Read(make([]byte, 0, 1)) - c.Assert(err, Equals, nil) - ow, err := o.Writer() - c.Assert(ow, Equals, o) + + _, _, err = r.Header() + c.Assert(err, NotNil) } diff --git a/formats/objfile/writer.go b/formats/objfile/writer.go index 8337a3a..d2f2314 100644 --- a/formats/objfile/writer.go +++ b/formats/objfile/writer.go @@ -1,142 +1,109 @@ package objfile import ( + "compress/zlib" "errors" "io" + "strconv" "gopkg.in/src-d/go-git.v4/core" - - "github.com/klauspost/compress/zlib" ) var ( - // ErrOverflow is returned when an attempt is made to write more data than - // was declared in NewWriter. ErrOverflow = errors.New("objfile: declared data length exceeded (overflow)") ) // Writer writes and encodes data in compressed objfile format to a provided -// io.Writer. -// -// Writer implements io.WriteCloser. Close should be called when finished with -// the Writer. Close will not close the underlying io.Writer. +// io.Writerº. Close should be called when finished with the Writer. Close will +// not close the underlying io.Writer. type Writer struct { - header header - hash core.Hash // final computed hash stored after Close + raw io.Writer + zlib io.WriteCloser + hasher core.Hasher + multi io.Writer - w io.Writer // provided writer wrapped in compressor and tee - compressor io.WriteCloser // provided writer wrapped in compressor, retained for calling Close - h core.Hasher // streaming SHA1 hash of encoded data - written int64 // Number of bytes written + closed bool + pending int64 // number of unwritten bytes } // NewWriter returns a new Writer writing to w. // -// The provided t is the type of object being written. The provided size is the -// number of uncompressed bytes being written. -// -// Calling NewWriter causes it to immediately write header data containing -// size and type information. Any errors encountered in that process will be -// returned in err. -// -// If an invalid t is provided, core.ErrInvalidType is returned. If a negative -// size is provided, ErrNegativeSize is returned. -// // The returned Writer implements io.WriteCloser. Close should be called when // finished with the Writer. Close will not close the underlying io.Writer. -func NewWriter(w io.Writer, t core.ObjectType, size int64) (*Writer, error) { +func NewWriter(w io.Writer) *Writer { + return &Writer{ + raw: w, + zlib: zlib.NewWriter(w), + } +} + +// WriteHeader writes the type and the size and prepares to accept the object's +// contents. If an invalid t is provided, core.ErrInvalidType is returned. If a +// negative size is provided, ErrNegativeSize is returned. +func (w *Writer) WriteHeader(t core.ObjectType, size int64) error { if !t.Valid() { - return nil, core.ErrInvalidType + return core.ErrInvalidType } if size < 0 { - return nil, ErrNegativeSize + return ErrNegativeSize } - writer := &Writer{ - header: header{t: t, size: size}, - } - return writer, writer.init(w) -} -// init prepares the zlib compressor for the given output as well as a hasher -// for computing its hash. -// -// init immediately writes header data to the output. This leaves the writer in -// a state that is ready to write content. -func (w *Writer) init(output io.Writer) (err error) { - w.compressor = zlib.NewWriter(output) - - err = w.header.Write(w.compressor) - if err != nil { - w.compressor.Close() - return - } + b := t.Bytes() + b = append(b, ' ') + b = append(b, []byte(strconv.FormatInt(size, 10))...) + b = append(b, 0) - w.h = core.NewHasher(w.header.t, w.header.size) - w.w = io.MultiWriter(w.compressor, w.h) // All writes to the compressor also write to the hash + defer w.prepareForWrite(t, size) + _, err := w.zlib.Write(b) - return + return err } -// Write reads len(p) from p to the object data stream. It returns the number of -// bytes written from p (0 <= n <= len(p)) and any error encountered that caused -// the write to stop early. The slice data contained in p will not be modified. -// -// If writing len(p) bytes would exceed the size provided in NewWriter, -// ErrOverflow is returned without writing any data. +func (w *Writer) prepareForWrite(t core.ObjectType, size int64) { + w.pending = size + + w.hasher = core.NewHasher(t, size) + w.multi = io.MultiWriter(w.zlib, w.hasher) +} + +// Write writes the object's contents. Write returns the error ErrOverflow if +// more than size bytes are written after WriteHeader. func (w *Writer) Write(p []byte) (n int, err error) { - if w.w == nil { + if w.closed { return 0, ErrClosed } - if w.written+int64(len(p)) > w.header.size { - return 0, ErrOverflow + overwrite := false + if int64(len(p)) > w.pending { + p = p[0:w.pending] + overwrite = true } - n, err = w.w.Write(p) - w.written += int64(n) + n, err = w.multi.Write(p) + w.pending -= int64(n) + if err == nil && overwrite { + err = ErrOverflow + return + } return } -// Type returns the type of the object. -func (w *Writer) Type() core.ObjectType { - return w.header.t -} - -// Size returns the uncompressed size of the object in bytes. -func (w *Writer) Size() int64 { - return w.header.size -} - // Hash returns the hash of the object data stream that has been written so far. // It can be called before or after Close. func (w *Writer) Hash() core.Hash { - if w.w != nil { - return w.h.Sum() // Not yet closed, return hash of data written so far - } - return w.hash + return w.hasher.Sum() // Not yet closed, return hash of data written so far } // Close releases any resources consumed by the Writer. // // Calling Close does not close the wrapped io.Writer originally passed to // NewWriter. -func (w *Writer) Close() (err error) { - if w.w == nil { - // TODO: Consider returning ErrClosed here? - return nil // Already closed +func (w *Writer) Close() error { + if err := w.zlib.Close(); err != nil { + return err } - // Release the compressor's resources - err = w.compressor.Close() - - // Save the hash because we're about to throw away the hasher - w.hash = w.h.Sum() - - // Release references - w.w = nil // Indicates closed state - w.compressor = nil - w.h.Hash = nil - - return + w.closed = true + return nil } diff --git a/formats/objfile/writer_test.go b/formats/objfile/writer_test.go index ab5a5bf..18bba79 100644 --- a/formats/objfile/writer_test.go +++ b/formats/objfile/writer_test.go @@ -16,54 +16,65 @@ var _ = Suite(&SuiteWriter{}) func (s *SuiteWriter) TestWriteObjfile(c *C) { for k, fixture := range objfileFixtures { + buffer := bytes.NewBuffer(nil) + com := fmt.Sprintf("test %d: ", k) hash := core.NewHash(fixture.hash) content, _ := base64.StdEncoding.DecodeString(fixture.content) - buffer := new(bytes.Buffer) // Write the data out to the buffer - testWriter(c, buffer, hash, fixture.t, content, com) + testWriter(c, buffer, hash, fixture.t, content) // Read the data back in from the buffer to be sure it matches testReader(c, buffer, hash, fixture.t, content, com) } } -func testWriter(c *C, dest io.Writer, hash core.Hash, typ core.ObjectType, content []byte, com string) { - length := int64(len(content)) - w, err := NewWriter(dest, typ, length) +func testWriter(c *C, dest io.Writer, hash core.Hash, t core.ObjectType, content []byte) { + size := int64(len(content)) + w := NewWriter(dest) + + err := w.WriteHeader(t, size) c.Assert(err, IsNil) - c.Assert(w.Type(), Equals, typ) - c.Assert(w.Size(), Equals, length) + written, err := io.Copy(w, bytes.NewReader(content)) c.Assert(err, IsNil) - c.Assert(written, Equals, length) - c.Assert(w.Size(), Equals, int64(len(content))) - c.Assert(w.Hash(), Equals, hash) // Test Hash() before close + c.Assert(written, Equals, size) + + c.Assert(w.Hash(), Equals, hash) c.Assert(w.Close(), IsNil) - c.Assert(w.Hash(), Equals, hash) // Test Hash() after close - _, err = w.Write([]byte{1}) - c.Assert(err, Equals, ErrClosed) } func (s *SuiteWriter) TestWriteOverflow(c *C) { - w, err := NewWriter(new(bytes.Buffer), core.BlobObject, 8) + buf := bytes.NewBuffer(nil) + w := NewWriter(buf) + + err := w.WriteHeader(core.BlobObject, 8) c.Assert(err, IsNil) - _, err = w.Write([]byte("1234")) + + n, err := w.Write([]byte("1234")) c.Assert(err, IsNil) - _, err = w.Write([]byte("56789")) + c.Assert(n, Equals, 4) + + n, err = w.Write([]byte("56789")) c.Assert(err, Equals, ErrOverflow) + c.Assert(n, Equals, 4) } func (s *SuiteWriter) TestNewWriterInvalidType(c *C) { - var t core.ObjectType - _, err := NewWriter(new(bytes.Buffer), t, 8) + buf := bytes.NewBuffer(nil) + w := NewWriter(buf) + + err := w.WriteHeader(core.InvalidObject, 8) c.Assert(err, Equals, core.ErrInvalidType) } func (s *SuiteWriter) TestNewWriterInvalidSize(c *C) { - _, err := NewWriter(new(bytes.Buffer), core.BlobObject, -1) + buf := bytes.NewBuffer(nil) + w := NewWriter(buf) + + err := w.WriteHeader(core.BlobObject, -1) c.Assert(err, Equals, ErrNegativeSize) - _, err = NewWriter(new(bytes.Buffer), core.BlobObject, -1651860) + err = w.WriteHeader(core.BlobObject, -1651860) c.Assert(err, Equals, ErrNegativeSize) } -- cgit