package transactional
import (
"gopkg.in/src-d/go-git.v4/plumbing"
"gopkg.in/src-d/go-git.v4/plumbing/storer"
"gopkg.in/src-d/go-git.v4/storage"
)
type ReferenceStorage struct {
storer.ReferenceStorer
temporal storer.ReferenceStorer
// deleted, remaining references at this maps are going to be deleted when
// commit is requested, the entries are added when RemoveReference is called
// and deleted if SetReference is called.
deleted map[plumbing.ReferenceName]struct{}
// packRefs if true PackRefs is going to be called in the based storer when
// commit is called.
packRefs bool
}
func NewReferenceStorage(s, temporal storer.ReferenceStorer) *ReferenceStorage {
return &ReferenceStorage{
ReferenceStorer: s,
temporal: temporal,
deleted: make(map[plumbing.ReferenceName]struct{}, 0),
}
}
func (r *ReferenceStorage) SetReference(ref *plumbing.Reference) error {
delete(r.deleted, ref.Name())
return r.temporal.SetReference(ref)
}
func (r *ReferenceStorage) CheckAndSetReference(ref, old *plumbing.Reference) error {
if old == nil {
return r.SetReference(ref)
}
tmp, err := r.temporal.Reference(old.Name())
if err == plumbing.ErrReferenceNotFound {
tmp, err = r.ReferenceStorer.Reference(old.Name())
}
if err != nil {
return err
}
if tmp.Hash() != old.Hash() {
return storage.ErrReferenceHasChanged
}
return r.SetReference(ref)
}
func (r ReferenceStorage) Reference(n plumbing.ReferenceName) (*plumbing.Reference, error) {
if _, deleted := r.deleted[n]; deleted {
return nil, plumbing.ErrReferenceNotFound
}
ref, err := r.temporal.Reference(n)
if err == plumbing.ErrReferenceNotFound {
return r.ReferenceStorer.Reference(n)
}
return ref, err
}
func (r ReferenceStorage) IterReferences() (storer.ReferenceIter, error) {
baseIter, err := r.ReferenceStorer.IterReferences()
if err != nil {
return nil, err
}
temporalIter, err := r.temporal.IterReferences()
if err != nil {
return nil, err
}
return storer.NewMultiReferenceIter([]storer.ReferenceIter{
baseIter,
temporalIter,
}), nil
}
func (r ReferenceStorage) CountLooseRefs() (int, error) {
tc, err := r.temporal.CountLooseRefs()
if err != nil {
return -1, err
}
bc, err := r.ReferenceStorer.CountLooseRefs()
if err != nil {
return -1, err
}
return tc + bc, nil
}
func (r ReferenceStorage) PackRefs() error {
r.packRefs = true
return nil
}
func (r ReferenceStorage) RemoveReference(n plumbing.ReferenceName) error {
r.deleted[n] = struct{}{}
return r.temporal.RemoveReference(n)
}
func (r ReferenceStorage) Commit() error {
for name := range r.deleted {
if err := r.ReferenceStorer.RemoveReference(name); err != nil {
return err
}
}
iter, err := r.temporal.IterReferences()
if err != nil {
return err
}
return iter.ForEach(func(ref *plumbing.Reference) error {
return r.ReferenceStorer.SetReference(ref)
})
}