package git
import (
"errors"
"fmt"
"io"
"os"
"srcd.works/go-git.v4/plumbing"
"srcd.works/go-git.v4/plumbing/format/index"
"srcd.works/go-git.v4/plumbing/object"
"srcd.works/go-billy.v1"
)
var ErrWorktreeNotClean = errors.New("worktree is not clean")
type Worktree struct {
r *Repository
fs billy.Filesystem
}
func (w *Worktree) Checkout(commit plumbing.Hash) error {
s, err := w.Status()
if err != nil {
return err
}
if !s.IsClean() {
return ErrWorktreeNotClean
}
c, err := w.r.Commit(commit)
if err != nil {
return err
}
files, err := c.Files()
if err != nil {
return err
}
idx := &index.Index{Version: 2}
if err := files.ForEach(func(f *object.File) error {
return w.checkoutFile(f, idx)
}); err != nil {
return err
}
return w.r.s.SetIndex(idx)
}
func (w *Worktree) checkoutFile(f *object.File, idx *index.Index) error {
from, err := f.Reader()
if err != nil {
return err
}
defer from.Close()
to, err := w.fs.OpenFile(f.Name, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode.Perm())
if err != nil {
return err
}
if _, err := io.Copy(to, from); err != nil {
return err
}
defer to.Close()
return w.indexFile(f, idx)
}
var fillSystemInfo func(e *index.Entry, sys interface{})
func (w *Worktree) indexFile(f *object.File, idx *index.Index) error {
fi, err := w.fs.Stat(f.Name)
if err != nil {
return err
}
e := index.Entry{
Hash: f.Hash,
Name: f.Name,
Mode: w.getMode(fi),
ModifiedAt: fi.ModTime(),
Size: uint32(fi.Size()),
}
// if the FileInfo.Sys() comes from os the ctime, dev, inode, uid and gid
// can be retrieved, otherwise this doesn't apply
if fillSystemInfo != nil {
fillSystemInfo(&e, fi.Sys())
}
idx.Entries = append(idx.Entries, e)
return nil
}
func (w *Worktree) Status() (Status, error) {
idx, err := w.r.s.Index()
if err != nil {
return nil, err
}
files, err := readDirAll(w.fs)
if err != nil {
return nil, err
}
s := make(Status, 0)
for _, e := range idx.Entries {
fi, ok := files[e.Name]
delete(files, e.Name)
if !ok {
s.File(e.Name).Worktree = Deleted
continue
}
status, err := w.compareFileWithEntry(fi, &e)
if err != nil {
return nil, err
}
s.File(e.Name).Worktree = status
}
for f := range files {
s.File(f).Worktree = Untracked
}
return s, nil
}
func (w *Worktree) compareFileWithEntry(fi billy.FileInfo, e *index.Entry) (StatusCode, error) {
if fi.Size() != int64(e.Size) {
return Modified, nil
}
if w.getMode(fi) != e.Mode {
return Modified, nil
}
h, err := calcSHA1(w.fs, e.Name)
if h != e.Hash || err != nil {
return Modified, err
}
return Unmodified, nil
}
func (w *Worktree) getMode(fi billy.FileInfo) os.FileMode {
if fi.Mode().IsDir() {
return object.TreeMode
}
if fi.Mode()&os.ModeSymlink != 0 {
return object.SymlinkMode
}
const modeExec = 0111
if fi.Mode()&modeExec != 0 {
return object.ExecutableMode
}
return object.FileMode
}
// Status current status of a Worktree
type Status map[string]*FileStatus
func (s Status) File(filename string) *FileStatus {
if _, ok := (s)[filename]; !ok {
s[filename] = &FileStatus{}
}
return s[filename]
}
func (s Status) IsClean() bool {
for _, status := range s {
if status.Worktree != Unmodified || status.Staging != Unmodified {
return false
}
}
return true
}
func (s Status) String() string {
var names []string
for name := range s {
names = append(names, name)
}
var output string
for _, name := range names {
status := s[name]
if status.Staging == 0 && status.Worktree == 0 {
continue
}
if status.Staging == Renamed {
name = fmt.Sprintf("%s -> %s", name, status.Extra)
}
output += fmt.Sprintf("%s%s %s\n", status.Staging, status.Worktree, name)
}
return output
}
// FileStatus status of a file in the Worktree
type FileStatus struct {
Staging StatusCode
Worktree StatusCode
Extra string
}
// StatusCode status code of a file in the Worktree
type StatusCode int8
const (
Unmodified StatusCode = iota
Untracked
Modified
Added
Deleted
Renamed
Copied
UpdatedButUnmerged
)
func (c StatusCode) String() string {
switch c {
case Unmodified:
return " "
case Modified:
return "M"
case Added:
return "A"
case Deleted:
return "D"
case Renamed:
return "R"
case Copied:
return "C"
case UpdatedButUnmerged:
return "U"
case Untracked:
return "?"
default:
return "-"
}
}
func calcSHA1(fs billy.Filesystem, filename string) (plumbing.Hash, error) {
file, err := fs.Open(filename)
if err != nil {
return plumbing.ZeroHash, err
}
stat, err := fs.Stat(filename)
if err != nil {
return plumbing.ZeroHash, err
}
h := plumbing.NewHasher(plumbing.BlobObject, stat.Size())
if _, err := io.Copy(h, file); err != nil {
return plumbing.ZeroHash, err
}
return h.Sum(), nil
}
func readDirAll(filesystem billy.Filesystem) (map[string]billy.FileInfo, error) {
all := make(map[string]billy.FileInfo, 0)
return all, doReadDirAll(filesystem, "", all)
}
func doReadDirAll(fs billy.Filesystem, path string, files map[string]billy.FileInfo) error {
if path == ".git" {
return nil
}
l, err := fs.ReadDir(path)
if err != nil {
return err
}
for _, info := range l {
file := fs.Join(path, info.Name())
if !info.IsDir() {
files[file] = info
continue
}
if err := doReadDirAll(fs, file, files); err != nil {
return err
}
}
return nil
}