From e3cb5921c8f3b730a8bbd21877176197c20b8fc7 Mon Sep 17 00:00:00 2001 From: Joshua Sjoding Date: Sat, 27 Feb 2016 00:30:35 -0800 Subject: Added objfile format used for loose git objects --- formats/objfile/common.go | 73 ++++++++++++++++++++++ formats/objfile/common_test.go | 69 +++++++++++++++++++++ formats/objfile/reader.go | 123 +++++++++++++++++++++++++++++++++++++ formats/objfile/reader_test.go | 68 +++++++++++++++++++++ formats/objfile/writer.go | 133 +++++++++++++++++++++++++++++++++++++++++ formats/objfile/writer_test.go | 53 ++++++++++++++++ 6 files changed, 519 insertions(+) create mode 100644 formats/objfile/common.go create mode 100644 formats/objfile/common_test.go create mode 100644 formats/objfile/reader.go create mode 100644 formats/objfile/reader_test.go create mode 100644 formats/objfile/writer.go create mode 100644 formats/objfile/writer_test.go (limited to 'formats/objfile') diff --git a/formats/objfile/common.go b/formats/objfile/common.go new file mode 100644 index 0000000..7389086 --- /dev/null +++ b/formats/objfile/common.go @@ -0,0 +1,73 @@ +package objfile + +import ( + "errors" + "io" + "strconv" + + "gopkg.in/src-d/go-git.v3/core" +) + +var ( + // ErrClosed is returned when the objfile Reader or Writer is already closed. + ErrClosed = errors.New("objfile: already closed") + // ErrHeader is returned when the objfile has an invalid header. + ErrHeader = errors.New("objfile: invalid header") +) + +type header struct { + t core.ObjectType + size int64 +} + +func (h *header) Read(r io.Reader) error { + t, err := h.readSlice(r, ' ') + if err != nil { + return err + } + + h.t, err = core.ParseObjectType(string(t)) + if err != nil { + return err + } + + size, err := h.readSlice(r, 0) + if err != nil { + return err + } + + h.size, err = strconv.ParseInt(string(size), 10, 64) + if err != nil { + return err + } + + return nil +} + +func (h *header) Write(w io.Writer) error { + b := h.t.Bytes() + b = append(b, ' ') + b = append(b, []byte(strconv.FormatInt(h.size, 10))...) + b = append(b, 0) + _, err := w.Write(b) + return err +} + +// readSlice reads one byte at a time from r until it encounters delim or an +// error. +func (h *header) readSlice(r io.Reader, delim byte) ([]byte, error) { + var buf [1]byte + value := make([]byte, 0, 16) + for { + if n, err := r.Read(buf[:]); err != nil && (err != io.EOF || n == 0) { + if err == io.EOF { + return nil, ErrHeader + } + return nil, err + } + if buf[0] == delim { + return value, nil + } + value = append(value, buf[0]) + } +} diff --git a/formats/objfile/common_test.go b/formats/objfile/common_test.go new file mode 100644 index 0000000..4727685 --- /dev/null +++ b/formats/objfile/common_test.go @@ -0,0 +1,69 @@ +package objfile + +import ( + "encoding/base64" + "testing" + + . "gopkg.in/check.v1" + "gopkg.in/src-d/go-git.v3/core" +) + +type objfileFixture struct { + hash string // hash of data + t core.ObjectType // object type + content string // base64-encoded content + data string // base64-encoded objfile data +} + +var objfileFixtures = []objfileFixture{ + { + "e69de29bb2d1d6434b8b29ae775ad8c2e48c5391", + core.BlobObject, + base64.StdEncoding.EncodeToString([]byte("")), + "eAFLyslPUjBgAAAJsAHw", + }, + { + "a8a940627d132695a9769df883f85992f0ff4a43", + core.BlobObject, + base64.StdEncoding.EncodeToString([]byte("this is a test")), + "eAFLyslPUjA0YSjJyCxWAKJEhZLU4hIAUDYHOg==", + }, + { + "4dc2174801ac4a3d36886210fd086fbe134cf7b2", + core.BlobObject, + base64.StdEncoding.EncodeToString([]byte("this\nis\n\n\na\nmultiline\n\ntest.\n")), + "eAFLyslPUjCyZCjJyCzmAiIurkSu3NKcksyczLxULq6S1OISPS4A1I8LMQ==", + }, + { + "13e6f47dd57798bfdc728d91f5c6d7f40c5bb5fc", + core.BlobObject, + base64.StdEncoding.EncodeToString([]byte("this tests\r\nCRLF\r\nencoded files.\r\n")), + "eAFLyslPUjA2YSjJyCxWKEktLinm5XIO8nHj5UrNS85PSU1RSMvMSS3W4+UCABp3DNE=", + }, + { + "72a7bc4667ab068e954172437b993d9fbaa137cb", + core.BlobObject, + base64.StdEncoding.EncodeToString([]byte("test@example.com")), + "eAFLyslPUjA0YyhJLS5xSK1IzC3ISdVLzs8FAGVtCIA=", + }, + { + "bb2b40e85ec0455d1de72daff71583f0dd72a33f", + core.BlobObject, + base64.StdEncoding.EncodeToString([]byte("package main\r\n\r\nimport (\r\n\t\"fmt\"\r\n\t\"io\"\r\n\t\"os\"\r\n\r\n\t\"gopkg.in/src-d/go-git.v3\"\r\n)\r\n\r\nfunc main() {\r\n\tfmt.Printf(\"Retrieving %q ...\\n\", os.Args[2])\r\n\tr, err := git.NewRepository(os.Args[2], nil)\r\n\tif err != nil {\r\n\t\tpanic(err)\r\n\t}\r\n\r\n\tif err := r.Pull(\"origin\", \"refs/heads/master\"); err != nil {\r\n\t\tpanic(err)\r\n\t}\r\n\r\n\tdumpCommits(r)\r\n}\r\n\r\nfunc dumpCommits(r *git.Repository) {\r\n\titer := r.Commits()\r\n\tdefer iter.Close()\r\n\r\n\tfor {\r\n\t\tcommit, err := iter.Next()\r\n\t\tif err != nil {\r\n\t\t\tif err == io.EOF {\r\n\t\t\t\tbreak\r\n\t\t\t}\r\n\r\n\t\t\tpanic(err)\r\n\t\t}\r\n\r\n\t\tfmt.Println(commit)\r\n\t}\r\n}\r\n")), + "eAGNUU1LAzEU9JpC/0NcEFJps2ARQdmDFD3W0qt6SHez8dHdZH1JqyL+d/Oy/aDgQVh47LzJTGayatyKX99MzzpVrpXRvFVgh4PhANrOYeBiOGBZ3YaMJrg0nI+D/o3r1kaCzT2Wkyo3bmIgyO00rkfEqDe2TIJixL/jgagjFwg21CJb6oCgt2ANv3jnUsoXm4258/IejX++eo0CDMdcI/LbgpPuXH8sdec8BIdf4sgccwsN0aFO9POCgGTIOmWhFFGE9j/p1jtWFEW52DSNyByCAXLPUNc+f9Oq8nmrfNCYje7+o1lt2m7m2haCF2SVnFL6kw2/pBzHEH0rEH0oI8q9BF220nWEaSdnjfNaRDDCtcM+WZnsDgUl4lx/BuKxv6rYY0XBwcmHp8deh7EVarWmQ7uC2Glre/TweI0VvTk5xaTx+wWX66Gs", + }, + { + "e94db0f9ffca44dc7bade6a3591f544183395a7c", + core.TreeObject, + "MTAwNjQ0IFRlc3QgMS50eHQAqKlAYn0TJpWpdp34g/hZkvD/SkMxMDA2NDQgVGVzdCAyLnR4dABNwhdIAaxKPTaIYhD9CG++E0z3sjEwMDY0NCBUZXN0IDMudHh0ABPm9H3Vd5i/3HKNkfXG1/QMW7X8MTAwNjQ0IFRlc3QgNC50eHQAcqe8RmerBo6VQXJDe5k9n7qhN8sxMDA2NDQgVGVzdCA1LnR4dAC7K0DoXsBFXR3nLa/3FYPw3XKjPw==", + "eAErKUpNVTC0NGAwNDAwMzFRCEktLlEw1CupKGFYsdIhqVZYberKsrk/mn9ETvrw38sZWZURWJXvIXEPxjVetmYdSQJ/OfL3Cft834SsyhisSvjZl9qr5TP23ynqnfj12PUvPNFb/yCrMgGrKlq+xy19NVvfVMci5+qZtvN3LTQ/jazKFKxqt7bDi7gDrrGyz3XXfxdt/nC3aLE9AA2STmk=", + }, + { + "9d7f8a56eaf92469dee8a856e716a03387ddb076", + core.CommitObject, + "dHJlZSBlOTRkYjBmOWZmY2E0NGRjN2JhZGU2YTM1OTFmNTQ0MTgzMzk1YTdjCmF1dGhvciBKb3NodWEgU2pvZGluZyA8am9zaHVhLnNqb2RpbmdAc2NqYWxsaWFuY2UuY29tPiAxNDU2NTMxNTgzIC0wODAwCmNvbW1pdHRlciBKb3NodWEgU2pvZGluZyA8am9zaHVhLnNqb2RpbmdAc2NqYWxsaWFuY2UuY29tPiAxNDU2NTMxNTgzIC0wODAwCgpUZXN0IENvbW1pdAo=", + "eAGtjksOgjAUAF33FO8CktZ+aBNjTNy51Qs8Xl8FAjSh5f4SvILLmcVkKM/zUOEi3amuzMDBxE6mkBKhMZHaDiM71DaoZI1RXutgsSWBW+3zCs9c+g3hNeY4LB+4jgc35cf3QiNO04ALcUN5voEy1lmtrNdwll5Ksdt9oPIfUuLNpcLjCIov3ApFmQ==", + }, +} + +func Test(t *testing.T) { TestingT(t) } diff --git a/formats/objfile/reader.go b/formats/objfile/reader.go new file mode 100644 index 0000000..b3c2e5c --- /dev/null +++ b/formats/objfile/reader.go @@ -0,0 +1,123 @@ +package objfile + +import ( + "errors" + "io" + + "gopkg.in/src-d/go-git.v3/core" + + "github.com/klauspost/compress/zlib" +) + +var ( + // ErrZLib is returned when the objfile contains invalid zlib data. + ErrZLib = errors.New("objfile: invalid zlib data") +) + +// Reader reads and decodes compressed objfile data from a provided io.Reader. +// +// Reader implements io.ReadCloser. Close should be called when finished with +// the Reader. Close will not close the underlying io.Reader. +type Reader struct { + header header + hash core.Hash // final computed hash stored after Close + + r io.Reader // provided reader wrapped in decompressor and tee + decompressor io.ReadCloser // provided reader wrapped in decompressor, retained for calling Close + h core.Hasher // streaming SHA1 hash of decoded data +} + +// NewReader returns a new Reader reading from r. +// +// Calling NewReader causes it to immediately read in header data from r +// containing size and type information. Any errors encountered in that +// process will be returned in err. +// +// The returned Reader implements io.ReadCloser. Close should be called when +// finished with the Reader. Close will not close the underlying io.Reader. +func NewReader(r io.Reader) (*Reader, error) { + reader := &Reader{} + return reader, reader.init(r) +} + +// init prepares the zlib decompressor for the given input as well as a hasher +// for computing its hash. +// +// init immediately reads header data from the input and stores it. This leaves +// the Reader in a state that is ready to read content. +func (r *Reader) init(input io.Reader) (err error) { + r.decompressor, err = zlib.NewReader(input) + if err != nil { + // TODO: Make this error match the ZLibErr in formats/packfile/reader.go? + return ErrZLib + } + + err = r.header.Read(r.decompressor) + if err != nil { + r.decompressor.Close() + return + } + + r.h = core.NewHasher(r.header.t, r.header.size) + r.r = io.TeeReader(r.decompressor, r.h) // All reads from the decompressor also write to the hash + + return +} + +// Read reads len(p) bytes into p from the object data stream. It returns +// the number of bytes read (0 <= n <= len(p)) and any error encountered. Even +// if Read returns n < len(p), it may use all of p as scratch space during the +// call. +// +// If Read encounters the end of the data stream it will return err == io.EOF, +// either in the current call if n > 0 or in a subsequent call. +func (r *Reader) Read(p []byte) (n int, err error) { + if r.r == nil { + return 0, ErrClosed + } + + return r.r.Read(p) +} + +// Type returns the type of the object. +func (r *Reader) Type() core.ObjectType { + return r.header.t +} + +// Size returns the uncompressed size of the object in bytes. +func (r *Reader) Size() int64 { + return r.header.size +} + +// Hash returns the hash of the object data stream that has been read so far. +// It can be called before or after Close. +func (r *Reader) Hash() core.Hash { + if r.r != nil { + return r.h.Sum() // Not yet closed, return hash of data read so far + } + return r.hash +} + +// Close releases any resources consumed by the Reader. +// +// Calling Close does not close the wrapped io.Reader originally passed to +// NewReader. +func (r *Reader) Close() (err error) { + if r.r == nil { + // TODO: Consider returning ErrClosed here? + return nil // Already closed + } + + // Release the decompressor's resources + err = r.decompressor.Close() + + // Save the hash because we're about to throw away the hasher + r.hash = r.h.Sum() + + // Release references + r.r = nil // Indicates closed state + r.decompressor = nil + r.h.Hash = nil + + return +} diff --git a/formats/objfile/reader_test.go b/formats/objfile/reader_test.go new file mode 100644 index 0000000..871eefe --- /dev/null +++ b/formats/objfile/reader_test.go @@ -0,0 +1,68 @@ +package objfile + +import ( + "bytes" + "encoding/base64" + "fmt" + "io" + "io/ioutil" + + . "gopkg.in/check.v1" + "gopkg.in/src-d/go-git.v3/core" + + "github.com/klauspost/compress/zlib" +) + +type SuiteReader struct{} + +var _ = Suite(&SuiteReader{}) + +func (s *SuiteReader) TestReadObjfile(c *C) { + for k, fixture := range objfileFixtures { + comment := fmt.Sprintf("test %d: ", k) + hash := core.NewHash(fixture.hash) + content, _ := base64.StdEncoding.DecodeString(fixture.content) + data, _ := base64.StdEncoding.DecodeString(fixture.data) + + testReader(c, bytes.NewReader(data), hash, fixture.t, content, comment) + } +} + +func testReader(c *C, source io.Reader, hash core.Hash, typ core.ObjectType, content []byte, comment string) { + r, err := NewReader(source) + c.Assert(err, IsNil) + c.Assert(r.Type(), Equals, typ) + rc, err := ioutil.ReadAll(r) + c.Assert(err, IsNil) + c.Assert(rc, DeepEquals, content, Commentf("%scontent=%s, expected=%s", base64.StdEncoding.EncodeToString(rc), base64.StdEncoding.EncodeToString(content))) + c.Assert(r.Hash(), Equals, hash) // Test Hash() before close + c.Assert(r.Close(), IsNil) + c.Assert(r.Hash(), Equals, hash) // Test Hash() after close +} + +func (s *SuiteReader) TestReadEmptyObjfile(c *C) { + source := bytes.NewReader([]byte{}) + _, err := NewReader(source) + c.Assert(err, Equals, ErrZLib) +} + +func (s *SuiteReader) TestReadEmptyContent(c *C) { + b := new(bytes.Buffer) + w := zlib.NewWriter(b) + c.Assert(w.Close(), IsNil) + _, err := NewReader(b) + c.Assert(err, Equals, ErrHeader) +} + +func (s *SuiteReader) TestReadGarbage(c *C) { + source := bytes.NewReader([]byte("!@#$RO!@NROSADfinq@o#irn@oirfn")) + _, err := NewReader(source) + c.Assert(err, Equals, ErrZLib) +} + +func (s *SuiteReader) TestReadCorruptZLib(c *C) { + data, _ := base64.StdEncoding.DecodeString("eAFLysaalPUjBgAAAJsAHw") + source := bytes.NewReader(data) + _, err := NewReader(source) + c.Assert(err, NotNil) +} diff --git a/formats/objfile/writer.go b/formats/objfile/writer.go new file mode 100644 index 0000000..d80256c --- /dev/null +++ b/formats/objfile/writer.go @@ -0,0 +1,133 @@ +package objfile + +import ( + "errors" + "io" + + "gopkg.in/src-d/go-git.v3/core" + + "github.com/klauspost/compress/zlib" +) + +var ( + // ErrOverflow is returned when an attempt is made to write more data than + // was declared in NewWriter. + ErrOverflow = errors.New("objfile: declared data length exceeded (overflow)") +) + +// Writer writes and encodes data in compressed objfile format to a provided +// io.Writer. +// +// Writer implements io.WriteCloser. Close should be called when finished with +// the Writer. Close will not close the underlying io.Writer. +type Writer struct { + header header + hash core.Hash // final computed hash stored after Close + + w io.Writer // provided writer wrapped in compressor and tee + compressor io.WriteCloser // provided writer wrapped in compressor, retained for calling Close + h core.Hasher // streaming SHA1 hash of encoded data + written int64 // Number of bytes written +} + +// NewWriter returns a new Writer writing to w. +// +// The provided t is the type of object being written. The provided size is the +// number of uncompressed bytes being written. +// +// Calling NewWriter causes it to immediately write header data containing +// size and type information. Any errors encountered in that process will be +// returned in err. +// +// The returned Writer implements io.WriteCloser. Close should be called when +// finished with the Writer. Close will not close the underlying io.Writer. +func NewWriter(w io.Writer, t core.ObjectType, size int64) (*Writer, error) { + writer := &Writer{ + header: header{t: t, size: size}, + } + return writer, writer.init(w) +} + +// init prepares the zlib compressor for the given output as well as a hasher +// for computing its hash. +// +// init immediately writes header data to the output. This leaves the writer in +// a state that is ready to write content. +func (w *Writer) init(output io.Writer) (err error) { + w.compressor = zlib.NewWriter(output) + + err = w.header.Write(w.compressor) + if err != nil { + defer w.compressor.Close() + return + } + + w.h = core.NewHasher(w.header.t, w.header.size) + w.w = io.MultiWriter(w.compressor, w.h) // All writes to the compressor also write to the hash + + return +} + +// Write reads len(p) from p to the object data stream. It returns the number of +// bytes written from p (0 <= n <= len(p)) and any error encountered that caused +// the write to stop early. The slice data contained in p will not be modified. +// +// If writing len(p) bytes would exceed the size provided in NewWriter, +// ErrOverflow is returned without writing any data. +func (w *Writer) Write(p []byte) (n int, err error) { + if w.w == nil { + return 0, ErrClosed + } + + if w.written+int64(len(p)) > w.header.size { + return 0, ErrOverflow + } + + n, err = w.w.Write(p) + w.written += int64(n) + + return +} + +// Type returns the type of the object. +func (w *Writer) Type() core.ObjectType { + return w.header.t +} + +// Size returns the uncompressed size of the object in bytes. +func (w *Writer) Size() int64 { + return w.header.size +} + +// Hash returns the hash of the object data stream that has been written so far. +// It can be called before or after Close. +func (w *Writer) Hash() core.Hash { + if w.w != nil { + return w.h.Sum() // Not yet closed, return hash of data written so far + } + return w.hash +} + +// Close releases any resources consumed by the Writer. +// +// Calling Close does not close the wrapped io.Writer originally passed to +// NewWriter. +func (w *Writer) Close() (err error) { + if w.w == nil { + // TODO: Consider returning ErrClosed here? + return nil // Already closed + } + + // Release the compressor's resources + err = w.compressor.Close() + + // Save the hash because we're about to throw away the hasher + w.hash = w.h.Sum() + + // Release references + w.w = nil // Indicates closed state + w.compressor = nil + w.h.Hash = nil + + return +} diff --git a/formats/objfile/writer_test.go b/formats/objfile/writer_test.go new file mode 100644 index 0000000..03b8370 --- /dev/null +++ b/formats/objfile/writer_test.go @@ -0,0 +1,53 @@ +package objfile + +import ( + "bytes" + "encoding/base64" + "fmt" + "io" + + . "gopkg.in/check.v1" + "gopkg.in/src-d/go-git.v3/core" +) + +type SuiteWriter struct{} + +var _ = Suite(&SuiteWriter{}) + +func (s *SuiteWriter) TestWriteObjfile(c *C) { + for k, fixture := range objfileFixtures { + comment := fmt.Sprintf("test %d: ", k) + hash := core.NewHash(fixture.hash) + content, _ := base64.StdEncoding.DecodeString(fixture.content) + buffer := new(bytes.Buffer) + + // Write the data out to the buffer + testWriter(c, buffer, hash, fixture.t, content, comment) + + // Read the data back in from the buffer to be sure it matches + testReader(c, buffer, hash, fixture.t, content, comment) + } +} + +func testWriter(c *C, dest io.Writer, hash core.Hash, typ core.ObjectType, content []byte, comment string) { + length := int64(len(content)) + w, err := NewWriter(dest, typ, length) + c.Assert(err, IsNil) + c.Assert(w.Type(), Equals, typ) + c.Assert(w.Size(), Equals, length) + written, err := io.Copy(w, bytes.NewReader(content)) + c.Assert(err, IsNil) + c.Assert(written, Equals, length) + c.Assert(w.Hash(), Equals, hash) // Test Hash() before close + c.Assert(w.Close(), IsNil) + c.Assert(w.Hash(), Equals, hash) // Test Hash() after close +} + +func (s *SuiteWriter) TestWriteOverflow(c *C) { + w, err := NewWriter(new(bytes.Buffer), core.BlobObject, 8) + c.Assert(err, IsNil) + _, err = w.Write([]byte("1234")) + c.Assert(err, IsNil) + _, err = w.Write([]byte("56789")) + c.Assert(err, Equals, ErrOverflow) +} -- cgit From 31f920a06aa5d7e7cf363645dac02f6e798fffb1 Mon Sep 17 00:00:00 2001 From: Joshua Sjoding Date: Sat, 27 Feb 2016 14:07:22 -0800 Subject: Improved objfile error handling and test coverage --- formats/objfile/common.go | 8 +++++++- formats/objfile/common_test.go | 31 +++++++++++++++++++++++++++++++ formats/objfile/reader_test.go | 3 +++ formats/objfile/writer.go | 11 ++++++++++- formats/objfile/writer_test.go | 16 ++++++++++++++++ 5 files changed, 67 insertions(+), 2 deletions(-) (limited to 'formats/objfile') diff --git a/formats/objfile/common.go b/formats/objfile/common.go index 7389086..2f0585f 100644 --- a/formats/objfile/common.go +++ b/formats/objfile/common.go @@ -13,6 +13,8 @@ var ( ErrClosed = errors.New("objfile: already closed") // ErrHeader is returned when the objfile has an invalid header. ErrHeader = errors.New("objfile: invalid header") + // ErrNegativeSize is returned when a negative object size is declared. + ErrNegativeSize = errors.New("objfile: negative object size") ) type header struct { @@ -38,7 +40,11 @@ func (h *header) Read(r io.Reader) error { h.size, err = strconv.ParseInt(string(size), 10, 64) if err != nil { - return err + return ErrHeader + } + + if h.size < 0 { + return ErrNegativeSize } return nil diff --git a/formats/objfile/common_test.go b/formats/objfile/common_test.go index 4727685..0c5a4cd 100644 --- a/formats/objfile/common_test.go +++ b/formats/objfile/common_test.go @@ -1,6 +1,7 @@ package objfile import ( + "bytes" "encoding/base64" "testing" @@ -67,3 +68,33 @@ var objfileFixtures = []objfileFixture{ } func Test(t *testing.T) { TestingT(t) } + +type SuiteCommon struct{} + +var _ = Suite(&SuiteCommon{}) + +func (s *SuiteCommon) TestHeaderReadEmpty(c *C) { + var h header + c.Assert(h.Read(new(bytes.Buffer)), Equals, ErrHeader) +} + +func (s *SuiteCommon) TestHeaderReadGarbage(c *C) { + var h header + c.Assert(h.Read(bytes.NewBuffer([]byte{1, 2, 3, 4, 5})), Equals, ErrHeader) + c.Assert(h.Read(bytes.NewBuffer([]byte{1, 2, 3, 4, 5, '0'})), Equals, ErrHeader) +} + +func (s *SuiteCommon) TestHeaderReadInvalidType(c *C) { + var h header + c.Assert(h.Read(bytes.NewBuffer([]byte{1, 2, ' ', 4, 5, 0})), Equals, core.ErrInvalidType) +} + +func (s *SuiteCommon) TestHeaderReadInvalidSize(c *C) { + var h header + c.Assert(h.Read(bytes.NewBuffer([]byte{'b', 'l', 'o', 'b', ' ', 'a', 0})), Equals, ErrHeader) +} + +func (s *SuiteCommon) TestHeaderReadNegativeSize(c *C) { + var h header + c.Assert(h.Read(bytes.NewBuffer([]byte{'b', 'l', 'o', 'b', ' ', '-', '1', 0})), Equals, ErrNegativeSize) +} diff --git a/formats/objfile/reader_test.go b/formats/objfile/reader_test.go index 871eefe..caebb60 100644 --- a/formats/objfile/reader_test.go +++ b/formats/objfile/reader_test.go @@ -35,9 +35,12 @@ func testReader(c *C, source io.Reader, hash core.Hash, typ core.ObjectType, con rc, err := ioutil.ReadAll(r) c.Assert(err, IsNil) c.Assert(rc, DeepEquals, content, Commentf("%scontent=%s, expected=%s", base64.StdEncoding.EncodeToString(rc), base64.StdEncoding.EncodeToString(content))) + c.Assert(r.Size(), Equals, int64(len(content))) c.Assert(r.Hash(), Equals, hash) // Test Hash() before close c.Assert(r.Close(), IsNil) c.Assert(r.Hash(), Equals, hash) // Test Hash() after close + _, err = r.Read(make([]byte, 0, 1)) + c.Assert(err, Equals, ErrClosed) } func (s *SuiteReader) TestReadEmptyObjfile(c *C) { diff --git a/formats/objfile/writer.go b/formats/objfile/writer.go index d80256c..d9d40f0 100644 --- a/formats/objfile/writer.go +++ b/formats/objfile/writer.go @@ -39,9 +39,18 @@ type Writer struct { // size and type information. Any errors encountered in that process will be // returned in err. // +// If an invalid t is provided, core.ErrInvalidType is returned. If a negative +// size is provided, ErrNegativeSize is returned. +// // The returned Writer implements io.WriteCloser. Close should be called when // finished with the Writer. Close will not close the underlying io.Writer. func NewWriter(w io.Writer, t core.ObjectType, size int64) (*Writer, error) { + if !t.Valid() { + return nil, core.ErrInvalidType + } + if size < 0 { + return nil, ErrNegativeSize + } writer := &Writer{ header: header{t: t, size: size}, } @@ -58,7 +67,7 @@ func (w *Writer) init(output io.Writer) (err error) { err = w.header.Write(w.compressor) if err != nil { - defer w.compressor.Close() + w.compressor.Close() return } diff --git a/formats/objfile/writer_test.go b/formats/objfile/writer_test.go index 03b8370..0061f3f 100644 --- a/formats/objfile/writer_test.go +++ b/formats/objfile/writer_test.go @@ -38,9 +38,12 @@ func testWriter(c *C, dest io.Writer, hash core.Hash, typ core.ObjectType, conte written, err := io.Copy(w, bytes.NewReader(content)) c.Assert(err, IsNil) c.Assert(written, Equals, length) + c.Assert(w.Size(), Equals, int64(len(content))) c.Assert(w.Hash(), Equals, hash) // Test Hash() before close c.Assert(w.Close(), IsNil) c.Assert(w.Hash(), Equals, hash) // Test Hash() after close + _, err = w.Write([]byte{1}) + c.Assert(err, Equals, ErrClosed) } func (s *SuiteWriter) TestWriteOverflow(c *C) { @@ -51,3 +54,16 @@ func (s *SuiteWriter) TestWriteOverflow(c *C) { _, err = w.Write([]byte("56789")) c.Assert(err, Equals, ErrOverflow) } + +func (s *SuiteWriter) TestNewWriterInvalidType(c *C) { + var t core.ObjectType + _, err := NewWriter(new(bytes.Buffer), t, 8) + c.Assert(err, Equals, core.ErrInvalidType) +} + +func (s *SuiteWriter) TestNewWriterInvalidSize(c *C) { + _, err := NewWriter(new(bytes.Buffer), core.BlobObject, -1) + c.Assert(err, Equals, ErrNegativeSize) + _, err = NewWriter(new(bytes.Buffer), core.BlobObject, -1651860) + c.Assert(err, Equals, ErrNegativeSize) +} -- cgit