aboutsummaryrefslogtreecommitdiffstats
path: root/storage/filesystem/internal/dotgit
diff options
context:
space:
mode:
Diffstat (limited to 'storage/filesystem/internal/dotgit')
-rw-r--r--storage/filesystem/internal/dotgit/dotgit.go154
-rw-r--r--storage/filesystem/internal/dotgit/dotgit_test.go45
2 files changed, 167 insertions, 32 deletions
diff --git a/storage/filesystem/internal/dotgit/dotgit.go b/storage/filesystem/internal/dotgit/dotgit.go
index cacda68..54113d5 100644
--- a/storage/filesystem/internal/dotgit/dotgit.go
+++ b/storage/filesystem/internal/dotgit/dotgit.go
@@ -8,12 +8,12 @@ import (
"io"
"os"
"strings"
+ "sync/atomic"
"time"
"gopkg.in/src-d/go-git.v4/core"
"gopkg.in/src-d/go-git.v4/formats/idxfile"
"gopkg.in/src-d/go-git.v4/formats/packfile"
- "gopkg.in/src-d/go-git.v4/storage/memory"
"gopkg.in/src-d/go-git.v4/utils/fs"
)
@@ -227,35 +227,35 @@ func isHexAlpha(b byte) bool {
}
type PackWriter struct {
- fs fs.Filesystem
- sr io.ReadCloser
- sw io.WriteCloser
- fw fs.File
- mw io.Writer
+ Notify func(h core.Hash, i idxfile.Idxfile)
+ fs fs.Filesystem
+ fr, fw fs.File
+ synced *syncedReader
checksum core.Hash
index idxfile.Idxfile
result chan error
- Notify func(h core.Hash, i idxfile.Idxfile)
}
func newPackWrite(fs fs.Filesystem) (*PackWriter, error) {
- temp := sha1.Sum([]byte(time.Now().String()))
- filename := fmt.Sprintf(".%x", temp)
+ seed := sha1.Sum([]byte(time.Now().String()))
+ tmp := fs.Join(objectsPath, packPath, fmt.Sprintf("tmp_pack_%x", seed))
- fw, err := fs.Create(fs.Join(objectsPath, packPath, filename))
+ fw, err := fs.Create(tmp)
if err != nil {
return nil, err
}
- sr, sw := io.Pipe()
+ fr, err := fs.Open(tmp)
+ if err != nil {
+ return nil, err
+ }
writer := &PackWriter{
fs: fs,
fw: fw,
- sr: sr,
- sw: sw,
- mw: io.MultiWriter(sw, fw),
+ fr: fr,
+ synced: newSyncedReader(fw, fr),
result: make(chan error),
}
@@ -264,10 +264,12 @@ func newPackWrite(fs fs.Filesystem) (*PackWriter, error) {
}
func (w *PackWriter) buildIndex() {
- defer w.sr.Close()
- o := memory.NewStorage().ObjectStorage()
- s := packfile.NewScanner(w.sr)
- d := packfile.NewDecoder(s, o)
+ s := packfile.NewScanner(w.synced)
+ d, err := packfile.NewDecoder(s, nil)
+ if err != nil {
+ w.result <- err
+ return
+ }
checksum, err := d.Decode()
if err != nil {
@@ -287,8 +289,8 @@ func (w *PackWriter) buildIndex() {
w.result <- err
}
-func (w *PackWriter) Write(p []byte) (int, error) {
- return w.mw.Write(p)
+func (w *PackWriter) Write(p []byte) (n int, err error) {
+ return w.synced.Write(p)
}
func (w *PackWriter) Close() error {
@@ -296,20 +298,18 @@ func (w *PackWriter) Close() error {
close(w.result)
}()
- if err := w.fw.Close(); err != nil {
- return err
- }
-
- if err := w.sw.Close(); err != nil {
- return err
- }
-
- if err := <-w.result; err != nil {
- return err
+ pipe := []func() error{
+ func() error { return <-w.result },
+ w.fr.Close,
+ w.fw.Close,
+ w.synced.Close,
+ w.save,
}
- if err := w.save(); err != nil {
- return err
+ for i, f := range pipe {
+ if err := f(); err != nil {
+ return err
+ }
}
if w.Notify != nil {
@@ -342,3 +342,93 @@ func (w *PackWriter) encodeIdx(writer io.Writer) error {
_, err := e.Encode(&w.index)
return err
}
+
+type syncedReader struct {
+ w io.Writer
+ r io.ReadSeeker
+
+ blocked, done uint32
+ written, read uint64
+ news chan bool
+}
+
+func newSyncedReader(w io.Writer, r io.ReadSeeker) *syncedReader {
+ return &syncedReader{
+ w: w,
+ r: r,
+ news: make(chan bool),
+ }
+}
+
+func (s *syncedReader) Write(p []byte) (n int, err error) {
+ defer func() {
+ written := atomic.AddUint64(&s.written, uint64(n))
+ read := atomic.LoadUint64(&s.read)
+ if written > read {
+ s.wake()
+ }
+ }()
+
+ n, err = s.w.Write(p)
+ return
+}
+
+func (s *syncedReader) Read(p []byte) (n int, err error) {
+ defer func() { atomic.AddUint64(&s.read, uint64(n)) }()
+
+ s.sleep()
+ n, err = s.r.Read(p)
+ if err == io.EOF && !s.isDone() {
+ if n == 0 {
+ return s.Read(p)
+ }
+
+ return n, nil
+ }
+
+ return
+}
+
+func (s *syncedReader) isDone() bool {
+ return atomic.LoadUint32(&s.done) == 1
+}
+
+func (s *syncedReader) isBlocked() bool {
+ return atomic.LoadUint32(&s.blocked) == 1
+}
+
+func (s *syncedReader) wake() {
+ if s.isBlocked() {
+ // fmt.Println("wake")
+ atomic.StoreUint32(&s.blocked, 0)
+ s.news <- true
+ }
+}
+
+func (s *syncedReader) sleep() {
+ read := atomic.LoadUint64(&s.read)
+ written := atomic.LoadUint64(&s.written)
+ if read >= written {
+ atomic.StoreUint32(&s.blocked, 1)
+ // fmt.Println("sleep", read, written)
+ <-s.news
+ }
+
+}
+
+func (s *syncedReader) Seek(offset int64, whence int) (int64, error) {
+ if whence == io.SeekCurrent {
+ return s.r.Seek(offset, whence)
+ }
+
+ p, err := s.r.Seek(offset, whence)
+ s.read = uint64(p)
+
+ return p, err
+}
+
+func (s *syncedReader) Close() error {
+ atomic.StoreUint32(&s.done, 1)
+ close(s.news)
+ return nil
+}
diff --git a/storage/filesystem/internal/dotgit/dotgit_test.go b/storage/filesystem/internal/dotgit/dotgit_test.go
index ca2b5b4..f105c58 100644
--- a/storage/filesystem/internal/dotgit/dotgit_test.go
+++ b/storage/filesystem/internal/dotgit/dotgit_test.go
@@ -7,6 +7,7 @@ import (
"log"
"os"
"path/filepath"
+ "strconv"
"strings"
"testing"
@@ -191,3 +192,47 @@ func (s *SuiteDotGit) TestNewObjectPack(c *C) {
c.Assert(err, IsNil)
c.Assert(stat.Size(), Equals, int64(1940))
}
+
+func (s *SuiteDotGit) TestSyncedReader(c *C) {
+ tmpw, err := ioutil.TempFile("", "example")
+ c.Assert(err, IsNil)
+
+ tmpr, err := os.Open(tmpw.Name())
+ c.Assert(err, IsNil)
+
+ defer func() {
+ tmpw.Close()
+ tmpr.Close()
+ os.Remove(tmpw.Name())
+ }()
+
+ synced := newSyncedReader(tmpw, tmpr)
+
+ go func() {
+ for i := 0; i < 281; i++ {
+ _, err := synced.Write([]byte(strconv.Itoa(i) + "\n"))
+ c.Assert(err, IsNil)
+ }
+
+ synced.Close()
+ }()
+
+ o, err := synced.Seek(1002, io.SeekStart)
+ c.Assert(err, IsNil)
+ c.Assert(o, Equals, int64(1002))
+
+ head := make([]byte, 3)
+ n, err := io.ReadFull(synced, head)
+ c.Assert(err, IsNil)
+ c.Assert(n, Equals, 3)
+ c.Assert(string(head), Equals, "278")
+
+ o, err = synced.Seek(1010, io.SeekStart)
+ c.Assert(err, IsNil)
+ c.Assert(o, Equals, int64(1010))
+
+ n, err = io.ReadFull(synced, head)
+ c.Assert(err, IsNil)
+ c.Assert(n, Equals, 3)
+ c.Assert(string(head), Equals, "280")
+}