diff options
Diffstat (limited to 'formats/packp/pktline')
-rw-r--r-- | formats/packp/pktline/pktline.go | 114 | ||||
-rw-r--r-- | formats/packp/pktline/pktline_test.go | 199 | ||||
-rw-r--r-- | formats/packp/pktline/scanner.go | 133 | ||||
-rw-r--r-- | formats/packp/pktline/scanner_test.go | 208 |
4 files changed, 654 insertions, 0 deletions
diff --git a/formats/packp/pktline/pktline.go b/formats/packp/pktline/pktline.go new file mode 100644 index 0000000..58c36fe --- /dev/null +++ b/formats/packp/pktline/pktline.go @@ -0,0 +1,114 @@ +// Package pktline implements reading and creating pkt-lines as per +// https://github.com/git/git/blob/master/Documentation/technical/protocol-common.txt. +package pktline + +import ( + "bytes" + "errors" + "io" + "strings" +) + +const ( + // MaxPayloadSize is the maximum payload size of a pkt-line in bytes. + MaxPayloadSize = 65516 +) + +var ( + flush = []byte{'0', '0', '0', '0'} +) + +// PktLine values represent a succession of pkt-lines. +// Values from this type are not zero-value safe, see the functions New +// and NewFromString below. +type PktLine struct { + io.Reader +} + +// ErrPayloadTooLong is returned by New and NewFromString when any of +// the provided payloads is bigger than MaxPayloadSize. +var ErrPayloadTooLong = errors.New("payload is too long") + +// New returns the concatenation of several pkt-lines, each of them with +// the payload specified by the contents of each input byte slice. An +// empty payload byte slice will produce a flush-pkt. +func New(payloads ...[]byte) (PktLine, error) { + ret := []io.Reader{} + for _, p := range payloads { + if err := add(&ret, p); err != nil { + return PktLine{}, err + } + } + + return PktLine{io.MultiReader(ret...)}, nil +} + +func add(dst *[]io.Reader, e []byte) error { + if len(e) > MaxPayloadSize { + return ErrPayloadTooLong + } + + if len(e) == 0 { + *dst = append(*dst, bytes.NewReader(flush)) + return nil + } + + n := len(e) + 4 + *dst = append(*dst, bytes.NewReader(int16ToHex(n))) + *dst = append(*dst, bytes.NewReader(e)) + + return nil +} + +// susbtitutes fmt.Sprintf("%04x", n) to avoid memory garbage +// generation. +func int16ToHex(n int) []byte { + var ret [4]byte + ret[0] = byteToAsciiHex(byte(n & 0xf000 >> 12)) + ret[1] = byteToAsciiHex(byte(n & 0x0f00 >> 8)) + ret[2] = byteToAsciiHex(byte(n & 0x00f0 >> 4)) + ret[3] = byteToAsciiHex(byte(n & 0x000f)) + + return ret[:] +} + +// turns a byte into its hexadecimal ascii representation. Example: +// from 11 (0xb) into 'b'. +func byteToAsciiHex(n byte) byte { + if n < 10 { + return byte('0' + n) + } + + return byte('a' - 10 + n) +} + +// NewFromStrings returns the concatenation of several pkt-lines, each +// of them with the payload specified by the contents of each input +// string. An empty payload string will produce a flush-pkt. +func NewFromStrings(payloads ...string) (PktLine, error) { + ret := []io.Reader{} + for _, p := range payloads { + if err := addString(&ret, p); err != nil { + return PktLine{}, err + } + } + + return PktLine{io.MultiReader(ret...)}, nil +} + +func addString(dst *[]io.Reader, s string) error { + if len(s) > MaxPayloadSize { + return ErrPayloadTooLong + } + + if len(s) == 0 { + *dst = append(*dst, bytes.NewReader(flush)) + return nil + } + + n := len(s) + 4 + *dst = append(*dst, bytes.NewReader(int16ToHex(n))) + *dst = append(*dst, strings.NewReader(s)) + + return nil +} diff --git a/formats/packp/pktline/pktline_test.go b/formats/packp/pktline/pktline_test.go new file mode 100644 index 0000000..3c18f53 --- /dev/null +++ b/formats/packp/pktline/pktline_test.go @@ -0,0 +1,199 @@ +package pktline_test + +import ( + "io" + "io/ioutil" + "os" + "strings" + "testing" + + "gopkg.in/src-d/go-git.v4/formats/packp/pktline" + + . "gopkg.in/check.v1" +) + +func Test(t *testing.T) { TestingT(t) } + +type SuitePktLine struct { +} + +var _ = Suite(&SuitePktLine{}) + +func (s *SuitePktLine) TestNew(c *C) { + for i, test := range [...]struct { + input [][]byte + expected []byte + }{ + { + input: [][]byte{}, + expected: []byte{}, + }, { + input: [][]byte{ + []byte(nil), + }, + expected: []byte("0000"), + }, { + input: [][]byte{ + []byte{}, + }, + expected: []byte("0000"), + }, { + input: [][]byte{ + []byte(""), + }, + expected: []byte("0000"), + }, { + input: [][]byte{ + []byte("hello\n"), + }, + expected: []byte("000ahello\n"), + }, { + input: [][]byte{ + []byte("hello\n"), + []byte("world!\n"), + []byte(""), + []byte("foo"), + []byte(""), + }, + expected: []byte("000ahello\n000bworld!\n00000007foo0000"), + }, { + input: [][]byte{ + []byte(strings.Repeat("a", pktline.MaxPayloadSize)), + }, + expected: []byte("fff0" + strings.Repeat("a", pktline.MaxPayloadSize)), + }, + } { + r, err := pktline.New(test.input...) + c.Assert(err, IsNil, Commentf("input %d = %v", i, test.input)) + + obtained, err := ioutil.ReadAll(r) + c.Assert(err, IsNil, Commentf("input %d = %v", i, test.input)) + + c.Assert(obtained, DeepEquals, test.expected, + Commentf("input %d = %v", i, test.input)) + } +} + +func (s *SuitePktLine) TestNewErrPayloadTooLong(c *C) { + for _, input := range [...][][]byte{ + [][]byte{ + []byte(strings.Repeat("a", pktline.MaxPayloadSize+1)), + }, + [][]byte{ + []byte("hello world!"), + []byte(""), + []byte(strings.Repeat("a", pktline.MaxPayloadSize+1)), + }, + [][]byte{ + []byte("hello world!"), + []byte(strings.Repeat("a", pktline.MaxPayloadSize+1)), + []byte("foo"), + }, + } { + _, err := pktline.New(input...) + + c.Assert(err, Equals, pktline.ErrPayloadTooLong, + Commentf("%v\n", input)) + } +} + +func (s *SuitePktLine) TestNewFromStrings(c *C) { + for _, test := range [...]struct { + input []string + expected []byte + }{ + { + input: []string(nil), + expected: []byte{}, + }, { + input: []string{}, + expected: []byte{}, + }, { + input: []string{""}, + expected: []byte("0000"), + }, { + input: []string{"hello\n"}, + expected: []byte("000ahello\n"), + }, { + input: []string{"hello\n", "world!\n", "", "foo", ""}, + expected: []byte("000ahello\n000bworld!\n00000007foo0000"), + }, { + input: []string{ + strings.Repeat("a", pktline.MaxPayloadSize), + }, + expected: []byte("fff0" + strings.Repeat("a", pktline.MaxPayloadSize)), + }, + } { + r, err := pktline.NewFromStrings(test.input...) + c.Assert(err, IsNil) + + obtained, err := ioutil.ReadAll(r) + c.Assert(err, IsNil) + + c.Assert(obtained, DeepEquals, test.expected, + Commentf("input = %v\n", test.input)) + } +} + +func (s *SuitePktLine) TestNewFromStringsErrPayloadTooLong(c *C) { + for _, input := range [...][]string{ + []string{ + strings.Repeat("a", pktline.MaxPayloadSize+1), + }, + []string{ + "hello world!", + "", + strings.Repeat("a", pktline.MaxPayloadSize+1), + }, + []string{ + "hello world!", + strings.Repeat("a", pktline.MaxPayloadSize+1), + "foo", + }, + } { + _, err := pktline.NewFromStrings(input...) + + c.Assert(err, Equals, pktline.ErrPayloadTooLong, + Commentf("%v\n", input)) + } +} + +func ExampleNew() { + // These are the payloads we want to turn into pkt-lines, + // the empty slice at the end will generate a flush-pkt. + payloads := [][]byte{ + []byte{'h', 'e', 'l', 'l', 'o', '\n'}, + []byte{'w', 'o', 'r', 'l', 'd', '!', '\n'}, + []byte{}, + } + + // Create the pkt-lines, ignoring errors... + pktlines, _ := pktline.New(payloads...) + + // Send the raw data to stdout, ignoring errors... + _, _ = io.Copy(os.Stdout, pktlines) + + // Output: 000ahello + // 000bworld! + // 0000 +} + +func ExampleNewFromStrings() { + // These are the payloads we want to turn into pkt-lines, + // the empty string at the end will generate a flush-pkt. + payloads := []string{ + "hello\n", + "world!\n", + "", + } + + // Create the pkt-lines, ignoring errors... + pktlines, _ := pktline.NewFromStrings(payloads...) + + // Send the raw data to stdout, ignoring errors... + _, _ = io.Copy(os.Stdout, pktlines) + + // Output: 000ahello + // 000bworld! + // 0000 +} diff --git a/formats/packp/pktline/scanner.go b/formats/packp/pktline/scanner.go new file mode 100644 index 0000000..3ce2adf --- /dev/null +++ b/formats/packp/pktline/scanner.go @@ -0,0 +1,133 @@ +package pktline + +import ( + "errors" + "io" +) + +const ( + lenSize = 4 +) + +// ErrInvalidPktLen is returned by Err() when an invalid pkt-len is found. +var ErrInvalidPktLen = errors.New("invalid pkt-len found") + +// Scanner provides a convenient interface for reading the payloads of a +// series of pkt-lines. It takes an io.Reader providing the source, +// which then can be tokenized through repeated calls to the Scan +// method. +// +// After each Scan call, the Bytes method will return the payload of the +// corresponding pkt-line on a shared buffer, which will be 65516 bytes +// or smaller. Flush pkt-lines are represented by empty byte slices. +// +// Scanning stops at EOF or the first I/O error. +type Scanner struct { + r io.Reader // The reader provided by the client + err error // Sticky error + payload []byte // Last pkt-payload + len [lenSize]byte // Last pkt-len +} + +// NewScanner returns a new Scanner to read from r. +func NewScanner(r io.Reader) *Scanner { + return &Scanner{ + r: r, + } +} + +// Err returns the first error encountered by the Scanner. +func (s *Scanner) Err() error { + return s.err +} + +// Scan advances the Scanner to the next pkt-line, whose payload will +// then be available through the Bytes method. Scanning stops at EOF +// or the first I/O error. After Scan returns false, the Err method +// will return any error that occurred during scanning, except that if +// it was io.EOF, Err will return nil. +func (s *Scanner) Scan() bool { + var l int + l, s.err = s.readPayloadLen() + if s.err == io.EOF { + s.err = nil + return false + } + if s.err != nil { + return false + } + + if cap(s.payload) < l { + s.payload = make([]byte, 0, l) + } + + if _, s.err = io.ReadFull(s.r, s.payload[:l]); s.err != nil { + return false + } + s.payload = s.payload[:l] + + return true +} + +// Bytes returns the most recent payload generated by a call to Scan. +// The underlying array may point to data that will be overwritten by a +// subsequent call to Scan. It does no allocation. +func (s *Scanner) Bytes() []byte { + return s.payload +} + +// Method readPayloadLen returns the payload length by reading the +// pkt-len and substracting the pkt-len size. +func (s *Scanner) readPayloadLen() (int, error) { + if _, err := io.ReadFull(s.r, s.len[:]); err != nil { + if err == io.EOF { + return 0, err + } + return 0, ErrInvalidPktLen + } + + n, err := hexDecode(s.len) + if err != nil { + return 0, err + } + + switch { + case n == 0: + return 0, nil + case n <= lenSize: + return 0, ErrInvalidPktLen + case n > MaxPayloadSize+lenSize: + return 0, ErrInvalidPktLen + default: + return n - lenSize, nil + } +} + +// Turns the hexadecimal representation of a number in a byte slice into +// a number. This function substitute strconv.ParseUint(string(buf), 16, +// 16) and/or hex.Decode, to avoid generating new strings, thus helping the +// GC. +func hexDecode(buf [lenSize]byte) (int, error) { + var ret int + for i := 0; i < lenSize; i++ { + n, err := asciiHexToByte(buf[i]) + if err != nil { + return 0, ErrInvalidPktLen + } + ret = 16*ret + int(n) + } + return ret, nil +} + +// turns the hexadecimal ascii representation of a byte into its +// numerical value. Example: from 'b' to 11 (0xb). +func asciiHexToByte(b byte) (byte, error) { + switch { + case b >= '0' && b <= '9': + return b - '0', nil + case b >= 'a' && b <= 'f': + return b - 'a' + 10, nil + default: + return 0, ErrInvalidPktLen + } +} diff --git a/formats/packp/pktline/scanner_test.go b/formats/packp/pktline/scanner_test.go new file mode 100644 index 0000000..08ca51f --- /dev/null +++ b/formats/packp/pktline/scanner_test.go @@ -0,0 +1,208 @@ +package pktline_test + +import ( + "fmt" + "io" + "strings" + + "gopkg.in/src-d/go-git.v4/formats/packp/pktline" + + . "gopkg.in/check.v1" +) + +type SuiteScanner struct{} + +var _ = Suite(&SuiteScanner{}) + +func (s *SuiteScanner) TestInvalid(c *C) { + for _, test := range [...]string{ + "0001", "0002", "0003", "0004", + "0001asdfsadf", "0004foo", + "fff1", "fff2", + "gorka", + "0", "003", + " 5a", "5 a", "5 \n", + "-001", "-000", + } { + r := strings.NewReader(test) + sc := pktline.NewScanner(r) + _ = sc.Scan() + c.Assert(sc.Err(), ErrorMatches, pktline.ErrInvalidPktLen.Error(), + Commentf("data = %q", test)) + } +} + +func (s *SuiteScanner) TestEmptyReader(c *C) { + r := strings.NewReader("") + sc := pktline.NewScanner(r) + hasPayload := sc.Scan() + c.Assert(hasPayload, Equals, false) + c.Assert(sc.Err(), Equals, nil) +} + +func (s *SuiteScanner) TestFlush(c *C) { + r, err := pktline.NewFromStrings("") + c.Assert(err, IsNil) + sc := pktline.NewScanner(r) + c.Assert(sc.Scan(), Equals, true) + payload := sc.Bytes() + c.Assert(len(payload), Equals, 0) +} + +func (s *SuiteScanner) TestPktLineTooShort(c *C) { + r := strings.NewReader("010cfoobar") + + sc := pktline.NewScanner(r) + + c.Assert(sc.Scan(), Equals, false) + c.Assert(sc.Err(), ErrorMatches, "unexpected EOF") +} + +func (s *SuiteScanner) TestScanAndPayload(c *C) { + for _, test := range [...]string{ + "a", + "a\n", + strings.Repeat("a", 100), + strings.Repeat("a", 100) + "\n", + strings.Repeat("\x00", 100), + strings.Repeat("\x00", 100) + "\n", + strings.Repeat("a", pktline.MaxPayloadSize), + strings.Repeat("a", pktline.MaxPayloadSize-1) + "\n", + } { + r, err := pktline.NewFromStrings(test) + c.Assert(err, IsNil, Commentf("input len=%x, contents=%.10q\n", len(test), test)) + sc := pktline.NewScanner(r) + + c.Assert(sc.Scan(), Equals, true, + Commentf("test = %.20q...", test)) + obtained := sc.Bytes() + c.Assert(obtained, DeepEquals, []byte(test), + Commentf("in = %.20q out = %.20q", test, string(obtained))) + } +} + +func (s *SuiteScanner) TestSkip(c *C) { + for _, test := range [...]struct { + input []string + n int + expected []byte + }{ + { + input: []string{ + "first", + "second", + "third", + ""}, + n: 1, + expected: []byte("second"), + }, + { + input: []string{ + "first", + "second", + "third", + ""}, + n: 2, + expected: []byte("third"), + }, + } { + r, err := pktline.NewFromStrings(test.input...) + c.Assert(err, IsNil) + sc := pktline.NewScanner(r) + for i := 0; i < test.n; i++ { + c.Assert(sc.Scan(), Equals, true, + Commentf("scan error = %s", sc.Err())) + } + c.Assert(sc.Scan(), Equals, true, + Commentf("scan error = %s", sc.Err())) + obtained := sc.Bytes() + c.Assert(obtained, DeepEquals, test.expected, + Commentf("\nin = %.20q\nout = %.20q\nexp = %.20q", + test.input, obtained, test.expected)) + } +} + +func (s *SuiteScanner) TestEOF(c *C) { + r, err := pktline.NewFromStrings("first", "second") + c.Assert(err, IsNil) + sc := pktline.NewScanner(r) + for sc.Scan() { + } + c.Assert(sc.Err(), IsNil) +} + +// A section are several non flush-pkt lines followed by a flush-pkt, which +// how the git protocol sends long messages. +func (s *SuiteScanner) TestReadSomeSections(c *C) { + nSections := 2 + nLines := 4 + data := sectionsExample(c, nSections, nLines) + sc := pktline.NewScanner(data) + + sectionCounter := 0 + lineCounter := 0 + for sc.Scan() { + if len(sc.Bytes()) == 0 { + sectionCounter++ + } + lineCounter++ + } + c.Assert(sc.Err(), IsNil) + c.Assert(sectionCounter, Equals, nSections) + c.Assert(lineCounter, Equals, (1+nLines)*nSections) +} + +// returns nSection sections, each of them with nLines pkt-lines (not +// counting the flush-pkt: +// +// 0009 0.0\n +// 0009 0.1\n +// ... +// 0000 +// and so on +func sectionsExample(c *C, nSections, nLines int) io.Reader { + ss := []string{} + for section := 0; section < nSections; section++ { + for line := 0; line < nLines; line++ { + line := fmt.Sprintf(" %d.%d\n", section, line) + ss = append(ss, line) + } + ss = append(ss, "") + } + + ret, err := pktline.NewFromStrings(ss...) + c.Assert(err, IsNil) + + return ret +} + +func ExampleScanner() { + // A reader is needed as input. + input := strings.NewReader("000ahello\n" + + "000bworld!\n" + + "0000", + ) + + // Create the scanner... + s := pktline.NewScanner(input) + + // and scan every pkt-line found in the input. + for s.Scan() { + payload := s.Bytes() + if len(payload) == 0 { // zero sized payloads correspond to flush-pkts. + fmt.Println("FLUSH-PKT DETECTED\n") + } else { // otherwise, you will be able to access the full payload. + fmt.Printf("PAYLOAD = %q\n", string(payload)) + } + } + + // this will catch any error when reading from the input, if any. + if s.Err() != nil { + fmt.Println(s.Err()) + } + + // Output: + // PAYLOAD = "hello\n" + // PAYLOAD = "world!\n" + // FLUSH-PKT DETECTED +} |