aboutsummaryrefslogtreecommitdiffstats
path: root/plumbing/format/index
diff options
context:
space:
mode:
Diffstat (limited to 'plumbing/format/index')
-rw-r--r--plumbing/format/index/decoder.go104
-rw-r--r--plumbing/format/index/decoder_test.go102
-rw-r--r--plumbing/format/index/encoder.go34
3 files changed, 200 insertions, 40 deletions
diff --git a/plumbing/format/index/decoder.go b/plumbing/format/index/decoder.go
index 6778cf7..f43b1c5 100644
--- a/plumbing/format/index/decoder.go
+++ b/plumbing/format/index/decoder.go
@@ -24,8 +24,8 @@ var (
// ErrInvalidChecksum is returned by Decode if the SHA1 hash mismatch with
// the read content
ErrInvalidChecksum = errors.New("invalid checksum")
-
- errUnknownExtension = errors.New("unknown extension")
+ // ErrUnknownExtension is returned when an index extension is encountered that is considered mandatory
+ ErrUnknownExtension = errors.New("unknown extension")
)
const (
@@ -39,6 +39,7 @@ const (
// A Decoder reads and decodes index files from an input stream.
type Decoder struct {
+ buf *bufio.Reader
r io.Reader
hash hash.Hash
lastEntry *Entry
@@ -49,8 +50,10 @@ type Decoder struct {
// NewDecoder returns a new decoder that reads from r.
func NewDecoder(r io.Reader) *Decoder {
h := hash.New(hash.CryptoType)
+ buf := bufio.NewReader(r)
return &Decoder{
- r: io.TeeReader(r, h),
+ buf: buf,
+ r: io.TeeReader(buf, h),
hash: h,
extReader: bufio.NewReader(nil),
}
@@ -210,71 +213,76 @@ func (d *Decoder) readExtensions(idx *Index) error {
// count that they are not supported by jgit or libgit
var expected []byte
+ var peeked []byte
var err error
- var header [4]byte
+ // we should always be able to peek for 4 bytes (header) + 4 bytes (extlen) + final hash
+ // if this fails, we know that we're at the end of the index
+ peekLen := 4 + 4 + d.hash.Size()
+
for {
expected = d.hash.Sum(nil)
-
- var n int
- if n, err = io.ReadFull(d.r, header[:]); err != nil {
- if n == 0 {
- err = io.EOF
- }
-
+ peeked, err = d.buf.Peek(peekLen)
+ if len(peeked) < peekLen {
+ // there can't be an extension at this point, so let's bail out
+ err = nil
break
}
+ if err != nil {
+ return err
+ }
- err = d.readExtension(idx, header[:])
+ err = d.readExtension(idx)
if err != nil {
- break
+ return err
}
}
- if err != errUnknownExtension {
+ return d.readChecksum(expected)
+}
+
+func (d *Decoder) readExtension(idx *Index) error {
+ var header [4]byte
+
+ if _, err := io.ReadFull(d.r, header[:]); err != nil {
return err
}
- return d.readChecksum(expected, header)
-}
+ r, err := d.getExtensionReader()
+ if err != nil {
+ return err
+ }
-func (d *Decoder) readExtension(idx *Index, header []byte) error {
switch {
- case bytes.Equal(header, treeExtSignature):
- r, err := d.getExtensionReader()
- if err != nil {
- return err
- }
-
+ case bytes.Equal(header[:], treeExtSignature):
idx.Cache = &Tree{}
d := &treeExtensionDecoder{r}
if err := d.Decode(idx.Cache); err != nil {
return err
}
- case bytes.Equal(header, resolveUndoExtSignature):
- r, err := d.getExtensionReader()
- if err != nil {
- return err
- }
-
+ case bytes.Equal(header[:], resolveUndoExtSignature):
idx.ResolveUndo = &ResolveUndo{}
d := &resolveUndoDecoder{r}
if err := d.Decode(idx.ResolveUndo); err != nil {
return err
}
- case bytes.Equal(header, endOfIndexEntryExtSignature):
- r, err := d.getExtensionReader()
- if err != nil {
- return err
- }
-
+ case bytes.Equal(header[:], endOfIndexEntryExtSignature):
idx.EndOfIndexEntry = &EndOfIndexEntry{}
d := &endOfIndexEntryDecoder{r}
if err := d.Decode(idx.EndOfIndexEntry); err != nil {
return err
}
default:
- return errUnknownExtension
+ // See https://git-scm.com/docs/index-format, which says:
+ // If the first byte is 'A'..'Z' the extension is optional and can be ignored.
+ if header[0] < 'A' || header[0] > 'Z' {
+ return ErrUnknownExtension
+ }
+
+ d := &unknownExtensionDecoder{r}
+ if err := d.Decode(); err != nil {
+ return err
+ }
}
return nil
@@ -290,11 +298,10 @@ func (d *Decoder) getExtensionReader() (*bufio.Reader, error) {
return d.extReader, nil
}
-func (d *Decoder) readChecksum(expected []byte, alreadyRead [4]byte) error {
+func (d *Decoder) readChecksum(expected []byte) error {
var h plumbing.Hash
- copy(h[:4], alreadyRead[:])
- if _, err := io.ReadFull(d.r, h[4:]); err != nil {
+ if _, err := io.ReadFull(d.r, h[:]); err != nil {
return err
}
@@ -476,3 +483,22 @@ func (d *endOfIndexEntryDecoder) Decode(e *EndOfIndexEntry) error {
_, err = io.ReadFull(d.r, e.Hash[:])
return err
}
+
+type unknownExtensionDecoder struct {
+ r *bufio.Reader
+}
+
+func (d *unknownExtensionDecoder) Decode() error {
+ var buf [1024]byte
+
+ for {
+ _, err := d.r.Read(buf[:])
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/plumbing/format/index/decoder_test.go b/plumbing/format/index/decoder_test.go
index 39ab336..4adddda 100644
--- a/plumbing/format/index/decoder_test.go
+++ b/plumbing/format/index/decoder_test.go
@@ -1,6 +1,11 @@
package index
import (
+ "bytes"
+ "crypto"
+ "github.com/go-git/go-git/v5/plumbing/hash"
+ "github.com/go-git/go-git/v5/utils/binary"
+ "io"
"testing"
"github.com/go-git/go-git/v5/plumbing"
@@ -218,3 +223,100 @@ func (s *IndexSuite) TestDecodeEndOfIndexEntry(c *C) {
c.Assert(idx.EndOfIndexEntry.Offset, Equals, uint32(716))
c.Assert(idx.EndOfIndexEntry.Hash.String(), Equals, "922e89d9ffd7cefce93a211615b2053c0f42bd78")
}
+
+func (s *IndexSuite) readSimpleIndex(c *C) *Index {
+ f, err := fixtures.Basic().One().DotGit().Open("index")
+ c.Assert(err, IsNil)
+ defer func() { c.Assert(f.Close(), IsNil) }()
+
+ idx := &Index{}
+ d := NewDecoder(f)
+ err = d.Decode(idx)
+ c.Assert(err, IsNil)
+
+ return idx
+}
+
+func (s *IndexSuite) buildIndexWithExtension(c *C, signature string, data string) []byte {
+ idx := s.readSimpleIndex(c)
+
+ buf := bytes.NewBuffer(nil)
+ e := NewEncoder(buf)
+
+ err := e.encode(idx, false)
+ c.Assert(err, IsNil)
+ err = e.encodeRawExtension(signature, []byte(data))
+ c.Assert(err, IsNil)
+
+ err = e.encodeFooter()
+ c.Assert(err, IsNil)
+
+ return buf.Bytes()
+}
+
+func (s *IndexSuite) TestDecodeUnknownOptionalExt(c *C) {
+ f := bytes.NewReader(s.buildIndexWithExtension(c, "TEST", "testdata"))
+
+ idx := &Index{}
+ d := NewDecoder(f)
+ err := d.Decode(idx)
+ c.Assert(err, IsNil)
+}
+
+func (s *IndexSuite) TestDecodeUnknownMandatoryExt(c *C) {
+ f := bytes.NewReader(s.buildIndexWithExtension(c, "test", "testdata"))
+
+ idx := &Index{}
+ d := NewDecoder(f)
+ err := d.Decode(idx)
+ c.Assert(err, ErrorMatches, ErrUnknownExtension.Error())
+}
+
+func (s *IndexSuite) TestDecodeTruncatedExt(c *C) {
+ idx := s.readSimpleIndex(c)
+
+ buf := bytes.NewBuffer(nil)
+ e := NewEncoder(buf)
+
+ err := e.encode(idx, false)
+ c.Assert(err, IsNil)
+
+ _, err = e.w.Write([]byte("TEST"))
+ c.Assert(err, IsNil)
+
+ err = binary.WriteUint32(e.w, uint32(100))
+ c.Assert(err, IsNil)
+
+ _, err = e.w.Write([]byte("truncated"))
+ c.Assert(err, IsNil)
+
+ err = e.encodeFooter()
+ c.Assert(err, IsNil)
+
+ idx = &Index{}
+ d := NewDecoder(buf)
+ err = d.Decode(idx)
+ c.Assert(err, ErrorMatches, io.EOF.Error())
+}
+
+func (s *IndexSuite) TestDecodeInvalidHash(c *C) {
+ idx := s.readSimpleIndex(c)
+
+ buf := bytes.NewBuffer(nil)
+ e := NewEncoder(buf)
+
+ err := e.encode(idx, false)
+ c.Assert(err, IsNil)
+
+ err = e.encodeRawExtension("TEST", []byte("testdata"))
+ c.Assert(err, IsNil)
+
+ h := hash.New(crypto.SHA1)
+ err = binary.Write(e.w, h.Sum(nil))
+ c.Assert(err, IsNil)
+
+ idx = &Index{}
+ d := NewDecoder(buf)
+ err = d.Decode(idx)
+ c.Assert(err, ErrorMatches, ErrInvalidChecksum.Error())
+}
diff --git a/plumbing/format/index/encoder.go b/plumbing/format/index/encoder.go
index fa2d814..c292c2c 100644
--- a/plumbing/format/index/encoder.go
+++ b/plumbing/format/index/encoder.go
@@ -3,6 +3,7 @@ package index
import (
"bytes"
"errors"
+ "fmt"
"io"
"sort"
"time"
@@ -35,6 +36,11 @@ func NewEncoder(w io.Writer) *Encoder {
// Encode writes the Index to the stream of the encoder.
func (e *Encoder) Encode(idx *Index) error {
+ return e.encode(idx, true)
+}
+
+func (e *Encoder) encode(idx *Index, footer bool) error {
+
// TODO: support v4
// TODO: support extensions
if idx.Version > EncodeVersionSupported {
@@ -49,7 +55,10 @@ func (e *Encoder) Encode(idx *Index) error {
return err
}
- return e.encodeFooter()
+ if footer {
+ return e.encodeFooter()
+ }
+ return nil
}
func (e *Encoder) encodeHeader(idx *Index) error {
@@ -135,6 +144,29 @@ func (e *Encoder) encodeEntry(entry *Entry) error {
return binary.Write(e.w, []byte(entry.Name))
}
+func (e *Encoder) encodeRawExtension(signature string, data []byte) error {
+ if len(signature) != 4 {
+ return fmt.Errorf("invalid signature length")
+ }
+
+ _, err := e.w.Write([]byte(signature))
+ if err != nil {
+ return err
+ }
+
+ err = binary.WriteUint32(e.w, uint32(len(data)))
+ if err != nil {
+ return err
+ }
+
+ _, err = e.w.Write(data)
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
func (e *Encoder) timeToUint32(t *time.Time) (uint32, uint32, error) {
if t.IsZero() {
return 0, 0, nil