aboutsummaryrefslogblamecommitdiffstats
path: root/plumbing/protocol/packp/sideband/demux.go
blob: 09fe57db608a20fb570c26225746698c9fd5e452 (plain) (tree)


















































































































































                                                                                
package sideband

import (
	"bytes"
	"errors"
	"fmt"
	"io"

	"gopkg.in/src-d/go-git.v4/plumbing/format/pktline"
)

// ErrMaxPackedExceeded returned by Read, if the maximum packed size is exceeded
var ErrMaxPackedExceeded = errors.New("max. packed size exceeded")

// Progress allows to read the progress information
type Progress interface {
	io.Reader
}

// Demuxer demultiplex the progress reports and error info interleaved with the
// packfile itself.
//
// A sideband has three diferent channels the main one call PackData contains
// the packfile data, the ErrorMessage channel, that contains server errors and
// the last one ProgressMessage channel containing information about the ongoing
// tast happening in the server (optinal, can be suppressed sending NoProgress
// or Quiet capabilities to the server)
//
// In order to demultiplex the data stream, method `Read` should be called to
// retrieve the PackData channel, the incoming data from the ProgressMessage is
// stored and can be read from `Progress` field, if any message is retrieved
// from the ErrorMessage channel an error is returned and we can assume that the
// conection has been closed.
type Demuxer struct {
	t Type
	r io.Reader
	s *pktline.Scanner

	max     int
	pending []byte

	// Progress contains progress information
	Progress Progress
}

// NewDemuxer returns a new Demuxer for the given t and read from r
func NewDemuxer(t Type, r io.Reader) *Demuxer {
	max := MaxPackedSize64k
	if t == Sideband {
		max = MaxPackedSize
	}

	return &Demuxer{
		t:        t,
		r:        r,
		max:      max,
		s:        pktline.NewScanner(r),
		Progress: bytes.NewBuffer(nil),
	}
}

// Read reads up to len(p) bytes from the PackData channel into p, an error can
// be return if an error happends when reading or if a message is sent in the
// ErrorMessage channel.
//
// If a ProgressMessage is read, it won't be copied to b. Instead of this, it is
// stored and can be read through the reader Progress. If the n value returned
// is zero, err will be nil unless an error reading happens.
func (d *Demuxer) Read(b []byte) (n int, err error) {
	var read, req int

	req = len(b)
	for read < req {
		n, err := d.doRead(b[read:req])
		read += n

		if err != nil {
			return read, err
		}
	}

	return read, nil
}

func (d *Demuxer) doRead(b []byte) (int, error) {
	read, err := d.nextPackData()
	size := len(read)
	wanted := len(b)

	if size > wanted {
		d.pending = read[wanted:]
	}

	if wanted > size {
		wanted = size
	}

	size = copy(b, read[:wanted])
	return size, err
}

func (d *Demuxer) nextPackData() ([]byte, error) {
	content := d.getPending()
	if len(content) != 0 {
		return content, nil
	}

	if !d.s.Scan() {
		if err := d.s.Err(); err != nil {
			return nil, err
		}

		return nil, io.EOF
	}

	content = d.s.Bytes()

	size := len(content)
	if size == 0 {
		return nil, nil
	} else if size > d.max {
		return nil, ErrMaxPackedExceeded
	}

	switch Channel(content[0]) {
	case PackData:
		return content[1:], nil
	case ProgressMessage:
		_, err := d.Progress.(io.Writer).Write(content[1:])
		return nil, err
	case ErrorMessage:
		return nil, fmt.Errorf("unexpected error: %s", content[1:])
	default:
		return nil, fmt.Errorf("unknown channel %s", content)
	}
}

func (d *Demuxer) getPending() (b []byte) {
	if len(d.pending) == 0 {
		return nil
	}

	content := d.pending
	d.pending = nil

	return content
}