From b9c0a09435392913c0054382500c805cd7cb596b Mon Sep 17 00:00:00 2001 From: Máximo Cuadros Date: Sun, 25 Sep 2016 23:58:59 +0200 Subject: formats: objfile idomatic reader/writer --- formats/objfile/common.go | 79 ----- formats/objfile/common_test.go | 31 -- formats/objfile/reader.go | 152 ++++----- formats/objfile/reader_test.go | 64 +--- formats/objfile/writer.go | 145 ++++----- formats/objfile/writer_test.go | 51 +-- storage/filesystem/internal/dotgit/dotgit.go | 348 ++++++++------------- storage/filesystem/internal/dotgit/dotgit_test.go | 138 ++++---- storage/filesystem/internal/dotgit/refs.go | 149 --------- storage/filesystem/internal/dotgit/writers.go | 263 ++++++++++++++++ storage/filesystem/internal/dotgit/writers_test.go | 89 ++++++ storage/filesystem/object.go | 44 +-- utils/fs/os.go | 23 +- 13 files changed, 758 insertions(+), 818 deletions(-) delete mode 100644 formats/objfile/common.go delete mode 100644 storage/filesystem/internal/dotgit/refs.go create mode 100644 storage/filesystem/internal/dotgit/writers.go create mode 100644 storage/filesystem/internal/dotgit/writers_test.go diff --git a/formats/objfile/common.go b/formats/objfile/common.go deleted file mode 100644 index 839f92c..0000000 --- a/formats/objfile/common.go +++ /dev/null @@ -1,79 +0,0 @@ -package objfile - -import ( - "errors" - "io" - "strconv" - - "gopkg.in/src-d/go-git.v4/core" -) - -var ( - // ErrClosed is returned when the objfile Reader or Writer is already closed. - ErrClosed = errors.New("objfile: already closed") - // ErrHeader is returned when the objfile has an invalid header. - ErrHeader = errors.New("objfile: invalid header") - // ErrNegativeSize is returned when a negative object size is declared. - ErrNegativeSize = errors.New("objfile: negative object size") -) - -type header struct { - t core.ObjectType - size int64 -} - -func (h *header) Read(r io.Reader) error { - t, err := h.readSlice(r, ' ') - if err != nil { - return err - } - - h.t, err = core.ParseObjectType(string(t)) - if err != nil { - return err - } - - size, err := h.readSlice(r, 0) - if err != nil { - return err - } - - h.size, err = strconv.ParseInt(string(size), 10, 64) - if err != nil { - return ErrHeader - } - - if h.size < 0 { - return ErrNegativeSize - } - - return nil -} - -func (h *header) Write(w io.Writer) error { - b := h.t.Bytes() - b = append(b, ' ') - b = append(b, []byte(strconv.FormatInt(h.size, 10))...) - b = append(b, 0) - _, err := w.Write(b) - return err -} - -// readSlice reads one byte at a time from r until it encounters delim or an -// error. -func (h *header) readSlice(r io.Reader, delim byte) ([]byte, error) { - var buf [1]byte - value := make([]byte, 0, 16) - for { - if n, err := r.Read(buf[:]); err != nil && (err != io.EOF || n == 0) { - if err == io.EOF { - return nil, ErrHeader - } - return nil, err - } - if buf[0] == delim { - return value, nil - } - value = append(value, buf[0]) - } -} diff --git a/formats/objfile/common_test.go b/formats/objfile/common_test.go index e443e73..682dfbb 100644 --- a/formats/objfile/common_test.go +++ b/formats/objfile/common_test.go @@ -1,7 +1,6 @@ package objfile import ( - "bytes" "encoding/base64" "testing" @@ -68,33 +67,3 @@ var objfileFixtures = []objfileFixture{ } func Test(t *testing.T) { TestingT(t) } - -type SuiteCommon struct{} - -var _ = Suite(&SuiteCommon{}) - -func (s *SuiteCommon) TestHeaderReadEmpty(c *C) { - var h header - c.Assert(h.Read(new(bytes.Buffer)), Equals, ErrHeader) -} - -func (s *SuiteCommon) TestHeaderReadGarbage(c *C) { - var h header - c.Assert(h.Read(bytes.NewBuffer([]byte{1, 2, 3, 4, 5})), Equals, ErrHeader) - c.Assert(h.Read(bytes.NewBuffer([]byte{1, 2, 3, 4, 5, '0'})), Equals, ErrHeader) -} - -func (s *SuiteCommon) TestHeaderReadInvalidType(c *C) { - var h header - c.Assert(h.Read(bytes.NewBuffer([]byte{1, 2, ' ', 4, 5, 0})), Equals, core.ErrInvalidType) -} - -func (s *SuiteCommon) TestHeaderReadInvalidSize(c *C) { - var h header - c.Assert(h.Read(bytes.NewBuffer([]byte{'b', 'l', 'o', 'b', ' ', 'a', 0})), Equals, ErrHeader) -} - -func (s *SuiteCommon) TestHeaderReadNegativeSize(c *C) { - var h header - c.Assert(h.Read(bytes.NewBuffer([]byte{'b', 'l', 'o', 'b', ' ', '-', '1', 0})), Equals, ErrNegativeSize) -} diff --git a/formats/objfile/reader.go b/formats/objfile/reader.go index 99ed754..ab67c6a 100644 --- a/formats/objfile/reader.go +++ b/formats/objfile/reader.go @@ -1,69 +1,96 @@ package objfile import ( + "compress/zlib" "errors" "io" + "strconv" + "gopkg.in/src-d/go-git.v3/formats/packfile" "gopkg.in/src-d/go-git.v4/core" - - "github.com/klauspost/compress/zlib" ) var ( - // ErrZLib is returned when the objfile contains invalid zlib data. - ErrZLib = errors.New("objfile: invalid zlib data") + ErrClosed = errors.New("objfile: already closed") + ErrHeader = errors.New("objfile: invalid header") + ErrNegativeSize = errors.New("objfile: negative object size") ) // Reader reads and decodes compressed objfile data from a provided io.Reader. -// // Reader implements io.ReadCloser. Close should be called when finished with // the Reader. Close will not close the underlying io.Reader. type Reader struct { - header header - hash core.Hash // final computed hash stored after Close - - r io.Reader // provided reader wrapped in decompressor and tee - decompressor io.ReadCloser // provided reader wrapped in decompressor, retained for calling Close - h core.Hasher // streaming SHA1 hash of decoded data + multi io.Reader + zlib io.ReadCloser + hasher core.Hasher } // NewReader returns a new Reader reading from r. -// -// Calling NewReader causes it to immediately read in header data from r -// containing size and type information. Any errors encountered in that -// process will be returned in err. -// -// The returned Reader implements io.ReadCloser. Close should be called when -// finished with the Reader. Close will not close the underlying io.Reader. func NewReader(r io.Reader) (*Reader, error) { - reader := &Reader{} - return reader, reader.init(r) + zlib, err := zlib.NewReader(r) + if err != nil { + return nil, packfile.ErrZLib.AddDetails(err.Error()) + } + + return &Reader{ + zlib: zlib, + }, nil } -// init prepares the zlib decompressor for the given input as well as a hasher -// for computing its hash. -// -// init immediately reads header data from the input and stores it. This leaves -// the Reader in a state that is ready to read content. -func (r *Reader) init(input io.Reader) (err error) { - r.decompressor, err = zlib.NewReader(input) +// Header reads the type and the size of object, and prepares the reader for read +func (r *Reader) Header() (t core.ObjectType, size int64, err error) { + var raw []byte + raw, err = r.readUntil(' ') + if err != nil { + return + } + + t, err = core.ParseObjectType(string(raw)) if err != nil { - // TODO: Make this error match the ZLibErr in formats/packfile/reader.go? - return ErrZLib + return } - err = r.header.Read(r.decompressor) + raw, err = r.readUntil(0) if err != nil { - r.decompressor.Close() return } - r.h = core.NewHasher(r.header.t, r.header.size) - r.r = io.TeeReader(r.decompressor, r.h) // All reads from the decompressor also write to the hash + size, err = strconv.ParseInt(string(raw), 10, 64) + if err != nil { + err = ErrHeader + return + } + defer r.prepareForRead(t, size) return } +// readSlice reads one byte at a time from r until it encounters delim or an +// error. +func (r *Reader) readUntil(delim byte) ([]byte, error) { + var buf [1]byte + value := make([]byte, 0, 16) + for { + if n, err := r.zlib.Read(buf[:]); err != nil && (err != io.EOF || n == 0) { + if err == io.EOF { + return nil, ErrHeader + } + return nil, err + } + + if buf[0] == delim { + return value, nil + } + + value = append(value, buf[0]) + } +} + +func (r *Reader) prepareForRead(t core.ObjectType, size int64) { + r.hasher = core.NewHasher(t, size) + r.multi = io.TeeReader(r.zlib, r.hasher) +} + // Read reads len(p) bytes into p from the object data stream. It returns // the number of bytes read (0 <= n <= len(p)) and any error encountered. Even // if Read returns n < len(p), it may use all of p as scratch space during the @@ -72,65 +99,20 @@ func (r *Reader) init(input io.Reader) (err error) { // If Read encounters the end of the data stream it will return err == io.EOF, // either in the current call if n > 0 or in a subsequent call. func (r *Reader) Read(p []byte) (n int, err error) { - if r.r == nil { - return 0, ErrClosed - } - - return r.r.Read(p) -} - -// Type returns the type of the object. -func (r *Reader) Type() core.ObjectType { - return r.header.t -} - -// Size returns the uncompressed size of the object in bytes. -func (r *Reader) Size() int64 { - return r.header.size + return r.multi.Read(p) } // Hash returns the hash of the object data stream that has been read so far. -// It can be called before or after Close. func (r *Reader) Hash() core.Hash { - if r.r != nil { - return r.h.Sum() // Not yet closed, return hash of data read so far - } - return r.hash + return r.hasher.Sum() } -// Close releases any resources consumed by the Reader. -// -// Calling Close does not close the wrapped io.Reader originally passed to -// NewReader. -func (r *Reader) Close() (err error) { - if r.r == nil { - // TODO: Consider returning ErrClosed here? - return nil // Already closed - } - - // Release the decompressor's resources - err = r.decompressor.Close() - - // Save the hash because we're about to throw away the hasher - r.hash = r.h.Sum() - - // Release references - r.r = nil // Indicates closed state - r.decompressor = nil - r.h.Hash = nil - - return -} - -// FillObject fills the given object from an object entry -func (r *Reader) FillObject(obj core.Object) error { - obj.SetType(r.header.t) - obj.SetSize(r.header.size) - w, err := obj.Writer() - if err != nil { +// Close releases any resources consumed by the Reader. Calling Close does not +// close the wrapped io.Reader originally passed to NewReader. +func (r *Reader) Close() error { + if err := r.zlib.Close(); err != nil { return err } - _, err = io.Copy(w, r.r) - return err + return nil } diff --git a/formats/objfile/reader_test.go b/formats/objfile/reader_test.go index 964a071..a383fd2 100644 --- a/formats/objfile/reader_test.go +++ b/formats/objfile/reader_test.go @@ -9,8 +9,6 @@ import ( . "gopkg.in/check.v1" "gopkg.in/src-d/go-git.v4/core" - - "github.com/klauspost/compress/zlib" ) type SuiteReader struct{} @@ -28,76 +26,42 @@ func (s *SuiteReader) TestReadObjfile(c *C) { } } -func testReader(c *C, source io.Reader, hash core.Hash, typ core.ObjectType, content []byte, com string) { +func testReader(c *C, source io.Reader, hash core.Hash, t core.ObjectType, content []byte, com string) { r, err := NewReader(source) c.Assert(err, IsNil) - c.Assert(r.Type(), Equals, typ) + + typ, size, err := r.Header() + c.Assert(err, IsNil) + c.Assert(typ, Equals, t) + c.Assert(content, HasLen, int(size)) + rc, err := ioutil.ReadAll(r) c.Assert(err, IsNil) c.Assert(rc, DeepEquals, content, Commentf("%scontent=%s, expected=%s", base64.StdEncoding.EncodeToString(rc), base64.StdEncoding.EncodeToString(content))) - c.Assert(r.Size(), Equals, int64(len(content))) + c.Assert(r.Hash(), Equals, hash) // Test Hash() before close c.Assert(r.Close(), IsNil) - c.Assert(r.Hash(), Equals, hash) // Test Hash() after close - _, err = r.Read(make([]byte, 0, 1)) - c.Assert(err, Equals, ErrClosed) + } func (s *SuiteReader) TestReadEmptyObjfile(c *C) { source := bytes.NewReader([]byte{}) _, err := NewReader(source) - c.Assert(err, Equals, ErrZLib) -} - -func (s *SuiteReader) TestReadEmptyContent(c *C) { - b := new(bytes.Buffer) - w := zlib.NewWriter(b) - c.Assert(w.Close(), IsNil) - _, err := NewReader(b) - c.Assert(err, Equals, ErrHeader) + c.Assert(err, NotNil) } func (s *SuiteReader) TestReadGarbage(c *C) { source := bytes.NewReader([]byte("!@#$RO!@NROSADfinq@o#irn@oirfn")) _, err := NewReader(source) - c.Assert(err, Equals, ErrZLib) + c.Assert(err, NotNil) } func (s *SuiteReader) TestReadCorruptZLib(c *C) { data, _ := base64.StdEncoding.DecodeString("eAFLysaalPUjBgAAAJsAHw") source := bytes.NewReader(data) - _, err := NewReader(source) - c.Assert(err, NotNil) -} - -func (s *SuiteReader) TestFillObject(c *C) { - for k, fixture := range objfileFixtures { - com := fmt.Sprintf("test %d: ", k) - hash := core.NewHash(fixture.hash) - content, _ := base64.StdEncoding.DecodeString(fixture.content) - data, _ := base64.StdEncoding.DecodeString(fixture.data) - - testFillObject(c, bytes.NewReader(data), hash, fixture.t, content, com) - } -} - -func testFillObject(c *C, source io.Reader, hash core.Hash, typ core.ObjectType, content []byte, com string) { - var o core.Object = &core.MemoryObject{} r, err := NewReader(source) c.Assert(err, IsNil) - err = r.FillObject(o) - c.Assert(err, IsNil) - c.Assert(o.Type(), Equals, typ) - c.Assert(o.Size(), Equals, int64(len(content))) - c.Assert(o.Hash(), Equals, hash) - or, err := o.Reader() - c.Assert(err, IsNil) - rc, err := ioutil.ReadAll(or) - c.Assert(err, IsNil) - c.Assert(rc, DeepEquals, content, Commentf("%scontent=%s, expected=%s", base64.StdEncoding.EncodeToString(rc), base64.StdEncoding.EncodeToString(content))) - c.Assert(or.Close(), IsNil) - _, err = or.Read(make([]byte, 0, 1)) - c.Assert(err, Equals, nil) - ow, err := o.Writer() - c.Assert(ow, Equals, o) + + _, _, err = r.Header() + c.Assert(err, NotNil) } diff --git a/formats/objfile/writer.go b/formats/objfile/writer.go index 8337a3a..d2f2314 100644 --- a/formats/objfile/writer.go +++ b/formats/objfile/writer.go @@ -1,142 +1,109 @@ package objfile import ( + "compress/zlib" "errors" "io" + "strconv" "gopkg.in/src-d/go-git.v4/core" - - "github.com/klauspost/compress/zlib" ) var ( - // ErrOverflow is returned when an attempt is made to write more data than - // was declared in NewWriter. ErrOverflow = errors.New("objfile: declared data length exceeded (overflow)") ) // Writer writes and encodes data in compressed objfile format to a provided -// io.Writer. -// -// Writer implements io.WriteCloser. Close should be called when finished with -// the Writer. Close will not close the underlying io.Writer. +// io.Writerº. Close should be called when finished with the Writer. Close will +// not close the underlying io.Writer. type Writer struct { - header header - hash core.Hash // final computed hash stored after Close + raw io.Writer + zlib io.WriteCloser + hasher core.Hasher + multi io.Writer - w io.Writer // provided writer wrapped in compressor and tee - compressor io.WriteCloser // provided writer wrapped in compressor, retained for calling Close - h core.Hasher // streaming SHA1 hash of encoded data - written int64 // Number of bytes written + closed bool + pending int64 // number of unwritten bytes } // NewWriter returns a new Writer writing to w. // -// The provided t is the type of object being written. The provided size is the -// number of uncompressed bytes being written. -// -// Calling NewWriter causes it to immediately write header data containing -// size and type information. Any errors encountered in that process will be -// returned in err. -// -// If an invalid t is provided, core.ErrInvalidType is returned. If a negative -// size is provided, ErrNegativeSize is returned. -// // The returned Writer implements io.WriteCloser. Close should be called when // finished with the Writer. Close will not close the underlying io.Writer. -func NewWriter(w io.Writer, t core.ObjectType, size int64) (*Writer, error) { +func NewWriter(w io.Writer) *Writer { + return &Writer{ + raw: w, + zlib: zlib.NewWriter(w), + } +} + +// WriteHeader writes the type and the size and prepares to accept the object's +// contents. If an invalid t is provided, core.ErrInvalidType is returned. If a +// negative size is provided, ErrNegativeSize is returned. +func (w *Writer) WriteHeader(t core.ObjectType, size int64) error { if !t.Valid() { - return nil, core.ErrInvalidType + return core.ErrInvalidType } if size < 0 { - return nil, ErrNegativeSize + return ErrNegativeSize } - writer := &Writer{ - header: header{t: t, size: size}, - } - return writer, writer.init(w) -} -// init prepares the zlib compressor for the given output as well as a hasher -// for computing its hash. -// -// init immediately writes header data to the output. This leaves the writer in -// a state that is ready to write content. -func (w *Writer) init(output io.Writer) (err error) { - w.compressor = zlib.NewWriter(output) - - err = w.header.Write(w.compressor) - if err != nil { - w.compressor.Close() - return - } + b := t.Bytes() + b = append(b, ' ') + b = append(b, []byte(strconv.FormatInt(size, 10))...) + b = append(b, 0) - w.h = core.NewHasher(w.header.t, w.header.size) - w.w = io.MultiWriter(w.compressor, w.h) // All writes to the compressor also write to the hash + defer w.prepareForWrite(t, size) + _, err := w.zlib.Write(b) - return + return err } -// Write reads len(p) from p to the object data stream. It returns the number of -// bytes written from p (0 <= n <= len(p)) and any error encountered that caused -// the write to stop early. The slice data contained in p will not be modified. -// -// If writing len(p) bytes would exceed the size provided in NewWriter, -// ErrOverflow is returned without writing any data. +func (w *Writer) prepareForWrite(t core.ObjectType, size int64) { + w.pending = size + + w.hasher = core.NewHasher(t, size) + w.multi = io.MultiWriter(w.zlib, w.hasher) +} + +// Write writes the object's contents. Write returns the error ErrOverflow if +// more than size bytes are written after WriteHeader. func (w *Writer) Write(p []byte) (n int, err error) { - if w.w == nil { + if w.closed { return 0, ErrClosed } - if w.written+int64(len(p)) > w.header.size { - return 0, ErrOverflow + overwrite := false + if int64(len(p)) > w.pending { + p = p[0:w.pending] + overwrite = true } - n, err = w.w.Write(p) - w.written += int64(n) + n, err = w.multi.Write(p) + w.pending -= int64(n) + if err == nil && overwrite { + err = ErrOverflow + return + } return } -// Type returns the type of the object. -func (w *Writer) Type() core.ObjectType { - return w.header.t -} - -// Size returns the uncompressed size of the object in bytes. -func (w *Writer) Size() int64 { - return w.header.size -} - // Hash returns the hash of the object data stream that has been written so far. // It can be called before or after Close. func (w *Writer) Hash() core.Hash { - if w.w != nil { - return w.h.Sum() // Not yet closed, return hash of data written so far - } - return w.hash + return w.hasher.Sum() // Not yet closed, return hash of data written so far } // Close releases any resources consumed by the Writer. // // Calling Close does not close the wrapped io.Writer originally passed to // NewWriter. -func (w *Writer) Close() (err error) { - if w.w == nil { - // TODO: Consider returning ErrClosed here? - return nil // Already closed +func (w *Writer) Close() error { + if err := w.zlib.Close(); err != nil { + return err } - // Release the compressor's resources - err = w.compressor.Close() - - // Save the hash because we're about to throw away the hasher - w.hash = w.h.Sum() - - // Release references - w.w = nil // Indicates closed state - w.compressor = nil - w.h.Hash = nil - - return + w.closed = true + return nil } diff --git a/formats/objfile/writer_test.go b/formats/objfile/writer_test.go index ab5a5bf..18bba79 100644 --- a/formats/objfile/writer_test.go +++ b/formats/objfile/writer_test.go @@ -16,54 +16,65 @@ var _ = Suite(&SuiteWriter{}) func (s *SuiteWriter) TestWriteObjfile(c *C) { for k, fixture := range objfileFixtures { + buffer := bytes.NewBuffer(nil) + com := fmt.Sprintf("test %d: ", k) hash := core.NewHash(fixture.hash) content, _ := base64.StdEncoding.DecodeString(fixture.content) - buffer := new(bytes.Buffer) // Write the data out to the buffer - testWriter(c, buffer, hash, fixture.t, content, com) + testWriter(c, buffer, hash, fixture.t, content) // Read the data back in from the buffer to be sure it matches testReader(c, buffer, hash, fixture.t, content, com) } } -func testWriter(c *C, dest io.Writer, hash core.Hash, typ core.ObjectType, content []byte, com string) { - length := int64(len(content)) - w, err := NewWriter(dest, typ, length) +func testWriter(c *C, dest io.Writer, hash core.Hash, t core.ObjectType, content []byte) { + size := int64(len(content)) + w := NewWriter(dest) + + err := w.WriteHeader(t, size) c.Assert(err, IsNil) - c.Assert(w.Type(), Equals, typ) - c.Assert(w.Size(), Equals, length) + written, err := io.Copy(w, bytes.NewReader(content)) c.Assert(err, IsNil) - c.Assert(written, Equals, length) - c.Assert(w.Size(), Equals, int64(len(content))) - c.Assert(w.Hash(), Equals, hash) // Test Hash() before close + c.Assert(written, Equals, size) + + c.Assert(w.Hash(), Equals, hash) c.Assert(w.Close(), IsNil) - c.Assert(w.Hash(), Equals, hash) // Test Hash() after close - _, err = w.Write([]byte{1}) - c.Assert(err, Equals, ErrClosed) } func (s *SuiteWriter) TestWriteOverflow(c *C) { - w, err := NewWriter(new(bytes.Buffer), core.BlobObject, 8) + buf := bytes.NewBuffer(nil) + w := NewWriter(buf) + + err := w.WriteHeader(core.BlobObject, 8) c.Assert(err, IsNil) - _, err = w.Write([]byte("1234")) + + n, err := w.Write([]byte("1234")) c.Assert(err, IsNil) - _, err = w.Write([]byte("56789")) + c.Assert(n, Equals, 4) + + n, err = w.Write([]byte("56789")) c.Assert(err, Equals, ErrOverflow) + c.Assert(n, Equals, 4) } func (s *SuiteWriter) TestNewWriterInvalidType(c *C) { - var t core.ObjectType - _, err := NewWriter(new(bytes.Buffer), t, 8) + buf := bytes.NewBuffer(nil) + w := NewWriter(buf) + + err := w.WriteHeader(core.InvalidObject, 8) c.Assert(err, Equals, core.ErrInvalidType) } func (s *SuiteWriter) TestNewWriterInvalidSize(c *C) { - _, err := NewWriter(new(bytes.Buffer), core.BlobObject, -1) + buf := bytes.NewBuffer(nil) + w := NewWriter(buf) + + err := w.WriteHeader(core.BlobObject, -1) c.Assert(err, Equals, ErrNegativeSize) - _, err = NewWriter(new(bytes.Buffer), core.BlobObject, -1651860) + err = w.WriteHeader(core.BlobObject, -1651860) c.Assert(err, Equals, ErrNegativeSize) } diff --git a/storage/filesystem/internal/dotgit/dotgit.go b/storage/filesystem/internal/dotgit/dotgit.go index c4392a2..ba293af 100644 --- a/storage/filesystem/internal/dotgit/dotgit.go +++ b/storage/filesystem/internal/dotgit/dotgit.go @@ -2,18 +2,14 @@ package dotgit import ( - "crypto/sha1" + "bufio" "errors" "fmt" - "io" + "io/ioutil" "os" "strings" - "sync/atomic" - "time" "gopkg.in/src-d/go-git.v4/core" - "gopkg.in/src-d/go-git.v4/formats/idxfile" - "gopkg.in/src-d/go-git.v4/formats/packfile" "gopkg.in/src-d/go-git.v4/utils/fs" ) @@ -24,6 +20,7 @@ const ( objectsPath = "objects" packPath = "pack" + refsPath = "refs" packExt = ".pack" idxExt = ".idx" @@ -38,6 +35,16 @@ var ( ErrPackfileNotFound = errors.New("packfile not found") // ErrConfigNotFound is returned by Config when the config is not found ErrConfigNotFound = errors.New("config file not found") + // ErrPackedRefsDuplicatedRef is returned when a duplicated reference is + // found in the packed-ref file. This is usually the case for corrupted git + // repositories. + ErrPackedRefsDuplicatedRef = errors.New("duplicated ref found in packed-ref file") + // ErrPackedRefsBadFormat is returned when the packed-ref file corrupt. + ErrPackedRefsBadFormat = errors.New("malformed packed-ref") + // ErrSymRefTargetNotFound is returned when a symbolic reference is + // targeting a non-existing object. This usually means the repository + // is corrupt. + ErrSymRefTargetNotFound = errors.New("symbolic reference target not found") ) // The DotGit type represents a local git repository on disk. This @@ -62,45 +69,6 @@ func (d *DotGit) Config() (fs.File, error) { return d.fs.Open(configPath) } -func (d *DotGit) SetRef(r *core.Reference) error { - var content string - switch r.Type() { - case core.SymbolicReference: - content = fmt.Sprintf("ref: %s\n", r.Target()) - case core.HashReference: - content = fmt.Sprintln(r.Hash().String()) - } - - f, err := d.fs.Create(r.Name().String()) - if err != nil { - return err - } - - if _, err := f.Write([]byte(content)); err != nil { - return err - } - return f.Close() -} - -// Refs scans the git directory collecting references, which it returns. -// Symbolic references are resolved and included in the output. -func (d *DotGit) Refs() ([]*core.Reference, error) { - var refs []*core.Reference - if err := d.addRefsFromPackedRefs(&refs); err != nil { - return nil, err - } - - if err := d.addRefsFromRefDir(&refs); err != nil { - return nil, err - } - - if err := d.addRefFromHEAD(&refs); err != nil { - return nil, err - } - - return refs, nil -} - // NewObjectPack return a writer for a new packfile, it saves the packfile to // disk and also generates and save the index for the given packfile. func (d *DotGit) NewObjectPack() (*PackWriter, error) { @@ -165,6 +133,11 @@ func (d *DotGit) ObjectPackIdx(hash core.Hash) (fs.File, error) { return idx, nil } +// NewObject return a writer for a new object file. +func (d *DotGit) NewObject() (*ObjectWriter, error) { + return newObjectWriter(d.fs) +} + // Objects returns a slice with the hashes of objects found under the // .git/objects/ directory. func (d *DotGit) Objects() ([]core.Hash, error) { @@ -203,232 +176,185 @@ func (d *DotGit) Object(h core.Hash) (fs.File, error) { return d.fs.Open(file) } -func isHex(s string) bool { - for _, b := range []byte(s) { - if isNum(b) { - continue - } - if isHexAlpha(b) { - continue - } - - return false +func (d *DotGit) SetRef(r *core.Reference) error { + var content string + switch r.Type() { + case core.SymbolicReference: + content = fmt.Sprintf("ref: %s\n", r.Target()) + case core.HashReference: + content = fmt.Sprintln(r.Hash().String()) } - return true -} - -func isNum(b byte) bool { - return b >= '0' && b <= '9' -} - -func isHexAlpha(b byte) bool { - return b >= 'a' && b <= 'f' || b >= 'A' && b <= 'F' -} - -type PackWriter struct { - Notify func(h core.Hash, i idxfile.Idxfile) + f, err := d.fs.Create(r.Name().String()) + if err != nil { + return err + } - fs fs.Filesystem - fr, fw fs.File - synced *syncedReader - checksum core.Hash - index idxfile.Idxfile - result chan error + if _, err := f.Write([]byte(content)); err != nil { + return err + } + return f.Close() } -func newPackWrite(fs fs.Filesystem) (*PackWriter, error) { - seed := sha1.Sum([]byte(time.Now().String())) - tmp := fs.Join(objectsPath, packPath, fmt.Sprintf("tmp_pack_%x", seed)) - - fw, err := fs.Create(tmp) - if err != nil { +// Refs scans the git directory collecting references, which it returns. +// Symbolic references are resolved and included in the output. +func (d *DotGit) Refs() ([]*core.Reference, error) { + var refs []*core.Reference + if err := d.addRefsFromPackedRefs(&refs); err != nil { return nil, err } - fr, err := fs.Open(tmp) - if err != nil { + if err := d.addRefsFromRefDir(&refs); err != nil { return nil, err } - writer := &PackWriter{ - fs: fs, - fw: fw, - fr: fr, - synced: newSyncedReader(fw, fr), - result: make(chan error), + if err := d.addRefFromHEAD(&refs); err != nil { + return nil, err } - go writer.buildIndex() - return writer, nil + return refs, nil } -func (w *PackWriter) buildIndex() { - s := packfile.NewScanner(w.synced) - d, err := packfile.NewDecoder(s, nil) +func (d *DotGit) addRefsFromPackedRefs(refs *[]*core.Reference) (err error) { + f, err := d.fs.Open(packedRefsPath) if err != nil { - w.result <- err - return + if os.IsNotExist(err) { + return nil + } + return err } - checksum, err := d.Decode() - if err != nil { - w.result <- err - return - } + defer func() { + if errClose := f.Close(); err == nil { + err = errClose + } + }() - w.checksum = checksum - w.index.PackfileChecksum = checksum - w.index.Version = idxfile.VersionSupported + s := bufio.NewScanner(f) + for s.Scan() { + ref, err := d.processLine(s.Text()) + if err != nil { + return err + } - offsets := d.Offsets() - for h, crc := range d.CRCs() { - w.index.Add(h, uint64(offsets[h]), crc) + if ref != nil { + *refs = append(*refs, ref) + } } - w.result <- err + return s.Err() } -func (w *PackWriter) Write(p []byte) (n int, err error) { - return w.synced.Write(p) +// process lines from a packed-refs file +func (d *DotGit) processLine(line string) (*core.Reference, error) { + switch line[0] { + case '#': // comment - ignore + return nil, nil + case '^': // annotated tag commit of the previous line - ignore + return nil, nil + default: + ws := strings.Split(line, " ") // hash then ref + if len(ws) != 2 { + return nil, ErrPackedRefsBadFormat + } + + return core.NewReferenceFromStrings(ws[1], ws[0]), nil + } } -func (w *PackWriter) Close() error { - defer func() { - close(w.result) - }() +func (d *DotGit) addRefsFromRefDir(refs *[]*core.Reference) error { + return d.walkReferencesTree(refs, refsPath) +} + +func (d *DotGit) walkReferencesTree(refs *[]*core.Reference, relPath string) error { + files, err := d.fs.ReadDir(relPath) + if err != nil { + if os.IsNotExist(err) { + return nil + } - pipe := []func() error{ - w.synced.Close, - func() error { return <-w.result }, - w.fr.Close, - w.fw.Close, - w.save, + return err } - for _, f := range pipe { - if err := f(); err != nil { + for _, f := range files { + newRelPath := d.fs.Join(relPath, f.Name()) + if f.IsDir() { + if err = d.walkReferencesTree(refs, newRelPath); err != nil { + return err + } + + continue + } + + ref, err := d.readReferenceFile(".", newRelPath) + if err != nil { return err } - } - if w.Notify != nil { - w.Notify(w.checksum, w.index) + if ref != nil { + *refs = append(*refs, ref) + } } return nil } -func (w *PackWriter) save() error { - base := w.fs.Join(objectsPath, packPath, fmt.Sprintf("pack-%s", w.checksum)) - idx, err := w.fs.Create(fmt.Sprintf("%s.idx", base)) +func (d *DotGit) addRefFromHEAD(refs *[]*core.Reference) error { + ref, err := d.readReferenceFile(".", "HEAD") if err != nil { - return err - } - - if err := w.encodeIdx(idx); err != nil { - return err - } + if os.IsNotExist(err) { + return nil + } - if err := idx.Close(); err != nil { return err } - return w.fs.Rename(w.fw.Filename(), fmt.Sprintf("%s.pack", base)) -} - -func (w *PackWriter) encodeIdx(writer io.Writer) error { - e := idxfile.NewEncoder(writer) - _, err := e.Encode(&w.index) - return err + *refs = append(*refs, ref) + return nil } -type syncedReader struct { - w io.Writer - r io.ReadSeeker +func (d *DotGit) readReferenceFile(refsPath, refFile string) (ref *core.Reference, err error) { + path := d.fs.Join(refsPath, refFile) - blocked, done uint32 - written, read uint64 - news chan bool -} - -func newSyncedReader(w io.Writer, r io.ReadSeeker) *syncedReader { - return &syncedReader{ - w: w, - r: r, - news: make(chan bool), + f, err := d.fs.Open(path) + if err != nil { + return nil, err } -} -func (s *syncedReader) Write(p []byte) (n int, err error) { defer func() { - written := atomic.AddUint64(&s.written, uint64(n)) - read := atomic.LoadUint64(&s.read) - if written > read { - s.wake() + if errClose := f.Close(); err == nil { + err = errClose } }() - n, err = s.w.Write(p) - return -} - -func (s *syncedReader) Read(p []byte) (n int, err error) { - defer func() { atomic.AddUint64(&s.read, uint64(n)) }() - - s.sleep() - n, err = s.r.Read(p) - if err == io.EOF && !s.isDone() { - if n == 0 { - return s.Read(p) - } - - return n, nil + b, err := ioutil.ReadAll(f) + if err != nil { + return nil, err } - return -} - -func (s *syncedReader) isDone() bool { - return atomic.LoadUint32(&s.done) == 1 -} - -func (s *syncedReader) isBlocked() bool { - return atomic.LoadUint32(&s.blocked) == 1 + line := strings.TrimSpace(string(b)) + return core.NewReferenceFromStrings(refFile, line), nil } -func (s *syncedReader) wake() { - if s.isBlocked() { - // fmt.Println("wake") - atomic.StoreUint32(&s.blocked, 0) - s.news <- true - } -} +func isHex(s string) bool { + for _, b := range []byte(s) { + if isNum(b) { + continue + } + if isHexAlpha(b) { + continue + } -func (s *syncedReader) sleep() { - read := atomic.LoadUint64(&s.read) - written := atomic.LoadUint64(&s.written) - if read >= written { - atomic.StoreUint32(&s.blocked, 1) - // fmt.Println("sleep", read, written) - <-s.news + return false } + return true } -func (s *syncedReader) Seek(offset int64, whence int) (int64, error) { - if whence == io.SeekCurrent { - return s.r.Seek(offset, whence) - } - - p, err := s.r.Seek(offset, whence) - s.read = uint64(p) - - return p, err +func isNum(b byte) bool { + return b >= '0' && b <= '9' } -func (s *syncedReader) Close() error { - atomic.StoreUint32(&s.done, 1) - close(s.news) - return nil +func isHexAlpha(b byte) bool { + return b >= 'a' && b <= 'f' || b >= 'A' && b <= 'F' } diff --git a/storage/filesystem/internal/dotgit/dotgit_test.go b/storage/filesystem/internal/dotgit/dotgit_test.go index f105c58..ebd8596 100644 --- a/storage/filesystem/internal/dotgit/dotgit_test.go +++ b/storage/filesystem/internal/dotgit/dotgit_test.go @@ -1,13 +1,9 @@ package dotgit import ( - "fmt" - "io" "io/ioutil" - "log" "os" "path/filepath" - "strconv" "strings" "testing" @@ -26,6 +22,41 @@ type SuiteDotGit struct { var _ = Suite(&SuiteDotGit{}) +func (s *SuiteDotGit) TestSetRefs(c *C) { + tmp, err := ioutil.TempDir("", "dot-git") + c.Assert(err, IsNil) + defer os.RemoveAll(tmp) + + fs := fs.NewOS(tmp) + dir := New(fs) + + err = dir.SetRef(core.NewReferenceFromStrings( + "refs/heads/foo", + "e8d3ffab552895c19b9fcf7aa264d277cde33881", + )) + + c.Assert(err, IsNil) + + err = dir.SetRef(core.NewReferenceFromStrings( + "refs/heads/symbolic", + "ref: refs/heads/foo", + )) + + c.Assert(err, IsNil) + + refs, err := dir.Refs() + c.Assert(err, IsNil) + c.Assert(refs, HasLen, 2) + + ref := findReference(refs, "refs/heads/foo") + c.Assert(ref, NotNil) + c.Assert(ref.Hash().String(), Equals, "e8d3ffab552895c19b9fcf7aa264d277cde33881") + + ref = findReference(refs, "refs/heads/symbolic") + c.Assert(ref, NotNil) + c.Assert(ref.Target().String(), Equals, "refs/heads/foo") +} + func (s *SuiteDotGit) TestRefsFromPackedRefs(c *C) { fs := fixtures.Basic().ByTag(".git").One().DotGit() dir := New(fs) @@ -128,6 +159,31 @@ func (s *SuiteDotGit) TestObjectPackNotFound(c *C) { c.Assert(idx, IsNil) } +func (s *SuiteDotGit) TestNewObject(c *C) { + tmp, err := ioutil.TempDir("", "dot-git") + c.Assert(err, IsNil) + defer os.RemoveAll(tmp) + + fs := fs.NewOS(tmp) + dir := New(fs) + w, err := dir.NewObject() + c.Assert(err, IsNil) + + err = w.WriteHeader(core.BlobObject, 14) + n, err := w.Write([]byte("this is a test")) + c.Assert(err, IsNil) + c.Assert(n, Equals, 14) + + c.Assert(w.Hash().String(), Equals, "a8a940627d132695a9769df883f85992f0ff4a43") + + err = w.Close() + c.Assert(err, IsNil) + + i, err := fs.Stat("objects/a8/a940627d132695a9769df883f85992f0ff4a43") + c.Assert(err, IsNil) + c.Assert(i.Size(), Equals, int64(34)) +} + func (s *SuiteDotGit) TestObjects(c *C) { fs := fixtures.ByTag(".git").ByTag("unpacked").One().DotGit() dir := New(fs) @@ -162,77 +218,3 @@ func (s *SuiteDotGit) TestObjectNotFound(c *C) { c.Assert(err, NotNil) c.Assert(file, IsNil) } - -func (s *SuiteDotGit) TestNewObjectPack(c *C) { - f := fixtures.Basic().One() - - dir, err := ioutil.TempDir("", "example") - if err != nil { - log.Fatal(err) - } - - defer os.RemoveAll(dir) - - fs := fs.NewOS(dir) - dot := New(fs) - - w, err := dot.NewObjectPack() - c.Assert(err, IsNil) - - _, err = io.Copy(w, f.Packfile()) - c.Assert(err, IsNil) - - c.Assert(w.Close(), IsNil) - - stat, err := fs.Stat(fmt.Sprintf("objects/pack/pack-%s.pack", f.PackfileHash)) - c.Assert(err, IsNil) - c.Assert(stat.Size(), Equals, int64(84794)) - - stat, err = fs.Stat(fmt.Sprintf("objects/pack/pack-%s.idx", f.PackfileHash)) - c.Assert(err, IsNil) - c.Assert(stat.Size(), Equals, int64(1940)) -} - -func (s *SuiteDotGit) TestSyncedReader(c *C) { - tmpw, err := ioutil.TempFile("", "example") - c.Assert(err, IsNil) - - tmpr, err := os.Open(tmpw.Name()) - c.Assert(err, IsNil) - - defer func() { - tmpw.Close() - tmpr.Close() - os.Remove(tmpw.Name()) - }() - - synced := newSyncedReader(tmpw, tmpr) - - go func() { - for i := 0; i < 281; i++ { - _, err := synced.Write([]byte(strconv.Itoa(i) + "\n")) - c.Assert(err, IsNil) - } - - synced.Close() - }() - - o, err := synced.Seek(1002, io.SeekStart) - c.Assert(err, IsNil) - c.Assert(o, Equals, int64(1002)) - - head := make([]byte, 3) - n, err := io.ReadFull(synced, head) - c.Assert(err, IsNil) - c.Assert(n, Equals, 3) - c.Assert(string(head), Equals, "278") - - o, err = synced.Seek(1010, io.SeekStart) - c.Assert(err, IsNil) - c.Assert(o, Equals, int64(1010)) - - n, err = io.ReadFull(synced, head) - c.Assert(err, IsNil) - c.Assert(n, Equals, 3) - c.Assert(string(head), Equals, "280") -} diff --git a/storage/filesystem/internal/dotgit/refs.go b/storage/filesystem/internal/dotgit/refs.go deleted file mode 100644 index 8f28332..0000000 --- a/storage/filesystem/internal/dotgit/refs.go +++ /dev/null @@ -1,149 +0,0 @@ -package dotgit - -import ( - "bufio" - "errors" - "io/ioutil" - "os" - "strings" - - "gopkg.in/src-d/go-git.v4/core" -) - -var ( - // ErrPackedRefsDuplicatedRef is returned when a duplicated - // reference is found in the packed-ref file. This is usually the - // case for corrupted git repositories. - ErrPackedRefsDuplicatedRef = errors.New("duplicated ref found in packed-ref file") - // ErrPackedRefsBadFormat is returned when the packed-ref file - // corrupt. - ErrPackedRefsBadFormat = errors.New("malformed packed-ref") - // ErrSymRefTargetNotFound is returned when a symbolic reference is - // targeting a non-existing object. This usually means the - // repository is corrupt. - ErrSymRefTargetNotFound = errors.New("symbolic reference target not found") -) - -const ( - refsPath = "refs" -) - -func (d *DotGit) addRefsFromPackedRefs(refs *[]*core.Reference) (err error) { - f, err := d.fs.Open(packedRefsPath) - if err != nil { - if os.IsNotExist(err) { - return nil - } - return err - } - - defer func() { - if errClose := f.Close(); err == nil { - err = errClose - } - }() - s := bufio.NewScanner(f) - for s.Scan() { - ref, err := d.processLine(s.Text()) - if err != nil { - return err - } - - if ref != nil { - *refs = append(*refs, ref) - } - } - - return s.Err() -} - -// process lines from a packed-refs file -func (d *DotGit) processLine(line string) (*core.Reference, error) { - switch line[0] { - case '#': // comment - ignore - return nil, nil - case '^': // annotated tag commit of the previous line - ignore - return nil, nil - default: - ws := strings.Split(line, " ") // hash then ref - if len(ws) != 2 { - return nil, ErrPackedRefsBadFormat - } - - return core.NewReferenceFromStrings(ws[1], ws[0]), nil - } -} - -func (d *DotGit) addRefsFromRefDir(refs *[]*core.Reference) error { - return d.walkReferencesTree(refs, refsPath) -} - -func (d *DotGit) walkReferencesTree(refs *[]*core.Reference, relPath string) error { - files, err := d.fs.ReadDir(relPath) - if err != nil { - if os.IsNotExist(err) { - return nil - } - - return err - } - - for _, f := range files { - newRelPath := d.fs.Join(relPath, f.Name()) - if f.IsDir() { - if err = d.walkReferencesTree(refs, newRelPath); err != nil { - return err - } - - continue - } - - ref, err := d.readReferenceFile(".", newRelPath) - if err != nil { - return err - } - - if ref != nil { - *refs = append(*refs, ref) - } - } - - return nil -} - -func (d *DotGit) addRefFromHEAD(refs *[]*core.Reference) error { - ref, err := d.readReferenceFile(".", "HEAD") - if err != nil { - if os.IsNotExist(err) { - return nil - } - - return err - } - - *refs = append(*refs, ref) - return nil -} - -func (d *DotGit) readReferenceFile(refsPath, refFile string) (ref *core.Reference, err error) { - path := d.fs.Join(refsPath, refFile) - - f, err := d.fs.Open(path) - if err != nil { - return nil, err - } - - defer func() { - if errClose := f.Close(); err == nil { - err = errClose - } - }() - - b, err := ioutil.ReadAll(f) - if err != nil { - return nil, err - } - - line := strings.TrimSpace(string(b)) - return core.NewReferenceFromStrings(refFile, line), nil -} diff --git a/storage/filesystem/internal/dotgit/writers.go b/storage/filesystem/internal/dotgit/writers.go new file mode 100644 index 0000000..40b004f --- /dev/null +++ b/storage/filesystem/internal/dotgit/writers.go @@ -0,0 +1,263 @@ +package dotgit + +import ( + "crypto/sha1" + "fmt" + "io" + "sync/atomic" + "time" + + "gopkg.in/src-d/go-git.v4/core" + "gopkg.in/src-d/go-git.v4/formats/idxfile" + "gopkg.in/src-d/go-git.v4/formats/objfile" + "gopkg.in/src-d/go-git.v4/formats/packfile" + "gopkg.in/src-d/go-git.v4/utils/fs" +) + +type PackWriter struct { + Notify func(h core.Hash, i idxfile.Idxfile) + + fs fs.Filesystem + fr, fw fs.File + synced *syncedReader + checksum core.Hash + index idxfile.Idxfile + result chan error +} + +func newPackWrite(fs fs.Filesystem) (*PackWriter, error) { + seed := sha1.Sum([]byte(time.Now().String())) + tmp := fs.Join(objectsPath, packPath, fmt.Sprintf("tmp_pack_%x", seed)) + + fw, err := fs.Create(tmp) + if err != nil { + return nil, err + } + + fr, err := fs.Open(tmp) + if err != nil { + return nil, err + } + + writer := &PackWriter{ + fs: fs, + fw: fw, + fr: fr, + synced: newSyncedReader(fw, fr), + result: make(chan error), + } + + go writer.buildIndex() + return writer, nil +} + +func (w *PackWriter) buildIndex() { + s := packfile.NewScanner(w.synced) + d, err := packfile.NewDecoder(s, nil) + if err != nil { + w.result <- err + return + } + + checksum, err := d.Decode() + if err != nil { + w.result <- err + return + } + + w.checksum = checksum + w.index.PackfileChecksum = checksum + w.index.Version = idxfile.VersionSupported + + offsets := d.Offsets() + for h, crc := range d.CRCs() { + w.index.Add(h, uint64(offsets[h]), crc) + } + + w.result <- err +} + +func (w *PackWriter) Write(p []byte) (n int, err error) { + return w.synced.Write(p) +} + +func (w *PackWriter) Close() error { + defer func() { + close(w.result) + }() + + pipe := []func() error{ + w.synced.Close, + func() error { return <-w.result }, + w.fr.Close, + w.fw.Close, + w.save, + } + + for _, f := range pipe { + if err := f(); err != nil { + return err + } + } + + if w.Notify != nil { + w.Notify(w.checksum, w.index) + } + + return nil +} + +func (w *PackWriter) save() error { + base := w.fs.Join(objectsPath, packPath, fmt.Sprintf("pack-%s", w.checksum)) + idx, err := w.fs.Create(fmt.Sprintf("%s.idx", base)) + if err != nil { + return err + } + + if err := w.encodeIdx(idx); err != nil { + return err + } + + if err := idx.Close(); err != nil { + return err + } + + return w.fs.Rename(w.fw.Filename(), fmt.Sprintf("%s.pack", base)) +} + +func (w *PackWriter) encodeIdx(writer io.Writer) error { + e := idxfile.NewEncoder(writer) + _, err := e.Encode(&w.index) + return err +} + +type syncedReader struct { + w io.Writer + r io.ReadSeeker + + blocked, done uint32 + written, read uint64 + news chan bool +} + +func newSyncedReader(w io.Writer, r io.ReadSeeker) *syncedReader { + return &syncedReader{ + w: w, + r: r, + news: make(chan bool), + } +} + +func (s *syncedReader) Write(p []byte) (n int, err error) { + defer func() { + written := atomic.AddUint64(&s.written, uint64(n)) + read := atomic.LoadUint64(&s.read) + if written > read { + s.wake() + } + }() + + n, err = s.w.Write(p) + return +} + +func (s *syncedReader) Read(p []byte) (n int, err error) { + defer func() { atomic.AddUint64(&s.read, uint64(n)) }() + + s.sleep() + n, err = s.r.Read(p) + if err == io.EOF && !s.isDone() { + if n == 0 { + return s.Read(p) + } + + return n, nil + } + + return +} + +func (s *syncedReader) isDone() bool { + return atomic.LoadUint32(&s.done) == 1 +} + +func (s *syncedReader) isBlocked() bool { + return atomic.LoadUint32(&s.blocked) == 1 +} + +func (s *syncedReader) wake() { + if s.isBlocked() { + // fmt.Println("wake") + atomic.StoreUint32(&s.blocked, 0) + s.news <- true + } +} + +func (s *syncedReader) sleep() { + read := atomic.LoadUint64(&s.read) + written := atomic.LoadUint64(&s.written) + if read >= written { + atomic.StoreUint32(&s.blocked, 1) + // fmt.Println("sleep", read, written) + <-s.news + } + +} + +func (s *syncedReader) Seek(offset int64, whence int) (int64, error) { + if whence == io.SeekCurrent { + return s.r.Seek(offset, whence) + } + + p, err := s.r.Seek(offset, whence) + s.read = uint64(p) + + return p, err +} + +func (s *syncedReader) Close() error { + atomic.StoreUint32(&s.done, 1) + close(s.news) + return nil +} + +type ObjectWriter struct { + objfile.Writer + fs fs.Filesystem + f fs.File +} + +func newObjectWriter(fs fs.Filesystem) (*ObjectWriter, error) { + seed := sha1.Sum([]byte(time.Now().String())) + tmp := fs.Join(objectsPath, fmt.Sprintf("tmp_obj_%x", seed)) + + f, err := fs.Create(tmp) + if err != nil { + return nil, err + } + + return &ObjectWriter{ + Writer: (*objfile.NewWriter(f)), + fs: fs, + f: f, + }, nil +} + +func (w *ObjectWriter) Close() error { + if err := w.Writer.Close(); err != nil { + return err + } + + if err := w.f.Close(); err != nil { + return err + } + + return w.save() +} + +func (w *ObjectWriter) save() error { + hash := w.Hash().String() + file := w.fs.Join(objectsPath, hash[0:2], hash[2:40]) + + return w.fs.Rename(w.f.Filename(), file) +} diff --git a/storage/filesystem/internal/dotgit/writers_test.go b/storage/filesystem/internal/dotgit/writers_test.go new file mode 100644 index 0000000..ebecbb4 --- /dev/null +++ b/storage/filesystem/internal/dotgit/writers_test.go @@ -0,0 +1,89 @@ +package dotgit + +import ( + "fmt" + "io" + "io/ioutil" + "log" + "os" + "strconv" + + "gopkg.in/src-d/go-git.v4/fixtures" + "gopkg.in/src-d/go-git.v4/utils/fs" + + . "gopkg.in/check.v1" +) + +func (s *SuiteDotGit) TestNewObjectPack(c *C) { + f := fixtures.Basic().One() + + dir, err := ioutil.TempDir("", "example") + if err != nil { + log.Fatal(err) + } + + defer os.RemoveAll(dir) + + fs := fs.NewOS(dir) + dot := New(fs) + + w, err := dot.NewObjectPack() + c.Assert(err, IsNil) + + _, err = io.Copy(w, f.Packfile()) + c.Assert(err, IsNil) + + c.Assert(w.Close(), IsNil) + + stat, err := fs.Stat(fmt.Sprintf("objects/pack/pack-%s.pack", f.PackfileHash)) + c.Assert(err, IsNil) + c.Assert(stat.Size(), Equals, int64(84794)) + + stat, err = fs.Stat(fmt.Sprintf("objects/pack/pack-%s.idx", f.PackfileHash)) + c.Assert(err, IsNil) + c.Assert(stat.Size(), Equals, int64(1940)) +} + +func (s *SuiteDotGit) TestSyncedReader(c *C) { + tmpw, err := ioutil.TempFile("", "example") + c.Assert(err, IsNil) + + tmpr, err := os.Open(tmpw.Name()) + c.Assert(err, IsNil) + + defer func() { + tmpw.Close() + tmpr.Close() + os.Remove(tmpw.Name()) + }() + + synced := newSyncedReader(tmpw, tmpr) + + go func() { + for i := 0; i < 281; i++ { + _, err := synced.Write([]byte(strconv.Itoa(i) + "\n")) + c.Assert(err, IsNil) + } + + synced.Close() + }() + + o, err := synced.Seek(1002, io.SeekStart) + c.Assert(err, IsNil) + c.Assert(o, Equals, int64(1002)) + + head := make([]byte, 3) + n, err := io.ReadFull(synced, head) + c.Assert(err, IsNil) + c.Assert(n, Equals, 3) + c.Assert(string(head), Equals, "278") + + o, err = synced.Seek(1010, io.SeekStart) + c.Assert(err, IsNil) + c.Assert(o, Equals, int64(1010)) + + n, err = io.ReadFull(synced, head) + c.Assert(err, IsNil) + c.Assert(n, Equals, 3) + c.Assert(string(head), Equals, "280") +} diff --git a/storage/filesystem/object.go b/storage/filesystem/object.go index 03939ce..f2f5351 100644 --- a/storage/filesystem/object.go +++ b/storage/filesystem/object.go @@ -118,31 +118,30 @@ func (s *ObjectStorage) getFromUnpacked(h core.Hash) (obj core.Object, err error return nil, err } - defer func() { - errClose := f.Close() - if err == nil { - err = errClose - } - }() + defer f.Close() obj = s.NewObject() - objReader, err := objfile.NewReader(f) + r, err := objfile.NewReader(f) if err != nil { return nil, err } - defer func() { - errClose := objReader.Close() - if err == nil { - err = errClose - } - }() + defer r.Close() - if err := objReader.FillObject(obj); err != nil { + t, size, err := r.Header() + if err != nil { return nil, err } - return obj, nil + obj.SetType(t) + obj.SetSize(size) + w, err := obj.Writer() + if err != nil { + return nil, err + } + + _, err = io.Copy(w, r) + return obj, err } // Get returns the object with the given hash, by searching for it in @@ -278,11 +277,7 @@ type packfileIter struct { total uint32 } -func newPackfileIter( - f fs.File, - t core.ObjectType, - seen map[core.Hash]bool, -) (core.ObjectIter, error) { +func newPackfileIter(f fs.File, t core.ObjectType, seen map[core.Hash]bool) (core.ObjectIter, error) { s := packfile.NewScanner(f) _, total, err := s.Header() if err != nil { @@ -294,7 +289,14 @@ func newPackfileIter( return nil, err } - return &packfileIter{f: f, d: d, t: t, total: total, seen: seen}, nil + return &packfileIter{ + f: f, + d: d, + t: t, + + total: total, + seen: seen, + }, nil } func (iter *packfileIter) Next() (core.Object, error) { diff --git a/utils/fs/os.go b/utils/fs/os.go index 51af921..a0d197d 100644 --- a/utils/fs/os.go +++ b/utils/fs/os.go @@ -23,11 +23,8 @@ func NewOS(rootDir string) *OS { func (fs *OS) Create(filename string) (File, error) { fullpath := path.Join(fs.RootDir, filename) - dir := filepath.Dir(fullpath) - if dir != "." { - if err := os.MkdirAll(dir, 0755); err != nil { - return nil, err - } + if err := fs.createDir(fullpath); err != nil { + return nil, err } f, err := os.Create(fullpath) @@ -41,6 +38,17 @@ func (fs *OS) Create(filename string) (File, error) { }, nil } +func (fs *OS) createDir(fullpath string) error { + dir := filepath.Dir(fullpath) + if dir != "." { + if err := os.MkdirAll(dir, 0755); err != nil { + return err + } + } + + return nil +} + // ReadDir returns the filesystem info for all the archives under the specified // path. func (fs *OS) ReadDir(path string) ([]FileInfo, error) { @@ -62,6 +70,11 @@ func (fs *OS) ReadDir(path string) ([]FileInfo, error) { func (fs *OS) Rename(from, to string) error { from = fs.Join(fs.RootDir, from) to = fs.Join(fs.RootDir, to) + + if err := fs.createDir(to); err != nil { + return err + } + return os.Rename(from, to) } -- cgit