aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/go-test/deep/deep.go
diff options
context:
space:
mode:
authorMichael Muré <batolettre@gmail.com>2018-09-28 20:39:39 +0200
committerMichael Muré <batolettre@gmail.com>2018-09-29 00:51:54 +0200
commit1bf268cebc84a9de1e538cbb54bcc0f434022192 (patch)
treedaeb92cd6b15d56a7a7102f95b73756e5b9597d0 /vendor/github.com/go-test/deep/deep.go
parent8af6f7d98f2fd98f85d6a17bcda49983c272cf48 (diff)
downloadgit-bug-1bf268cebc84a9de1e538cbb54bcc0f434022192.tar.gz
merge package operations into bug, they are tightly coupled anyway
Diffstat (limited to 'vendor/github.com/go-test/deep/deep.go')
-rw-r--r--vendor/github.com/go-test/deep/deep.go352
1 files changed, 352 insertions, 0 deletions
diff --git a/vendor/github.com/go-test/deep/deep.go b/vendor/github.com/go-test/deep/deep.go
new file mode 100644
index 00000000..4ea14cb0
--- /dev/null
+++ b/vendor/github.com/go-test/deep/deep.go
@@ -0,0 +1,352 @@
+// Package deep provides function deep.Equal which is like reflect.DeepEqual but
+// returns a list of differences. This is helpful when comparing complex types
+// like structures and maps.
+package deep
+
+import (
+ "errors"
+ "fmt"
+ "log"
+ "reflect"
+ "strings"
+)
+
+var (
+ // FloatPrecision is the number of decimal places to round float values
+ // to when comparing.
+ FloatPrecision = 10
+
+ // MaxDiff specifies the maximum number of differences to return.
+ MaxDiff = 10
+
+ // MaxDepth specifies the maximum levels of a struct to recurse into.
+ MaxDepth = 10
+
+ // LogErrors causes errors to be logged to STDERR when true.
+ LogErrors = false
+
+ // CompareUnexportedFields causes unexported struct fields, like s in
+ // T{s int}, to be comparsed when true.
+ CompareUnexportedFields = false
+)
+
+var (
+ // ErrMaxRecursion is logged when MaxDepth is reached.
+ ErrMaxRecursion = errors.New("recursed to MaxDepth")
+
+ // ErrTypeMismatch is logged when Equal passed two different types of values.
+ ErrTypeMismatch = errors.New("variables are different reflect.Type")
+
+ // ErrNotHandled is logged when a primitive Go kind is not handled.
+ ErrNotHandled = errors.New("cannot compare the reflect.Kind")
+)
+
+type cmp struct {
+ diff []string
+ buff []string
+ floatFormat string
+}
+
+var errorType = reflect.TypeOf((*error)(nil)).Elem()
+
+// Equal compares variables a and b, recursing into their structure up to
+// MaxDepth levels deep, and returns a list of differences, or nil if there are
+// none. Some differences may not be found if an error is also returned.
+//
+// If a type has an Equal method, like time.Equal, it is called to check for
+// equality.
+func Equal(a, b interface{}) []string {
+ aVal := reflect.ValueOf(a)
+ bVal := reflect.ValueOf(b)
+ c := &cmp{
+ diff: []string{},
+ buff: []string{},
+ floatFormat: fmt.Sprintf("%%.%df", FloatPrecision),
+ }
+ if a == nil && b == nil {
+ return nil
+ } else if a == nil && b != nil {
+ c.saveDiff(b, "<nil pointer>")
+ } else if a != nil && b == nil {
+ c.saveDiff(a, "<nil pointer>")
+ }
+ if len(c.diff) > 0 {
+ return c.diff
+ }
+
+ c.equals(aVal, bVal, 0)
+ if len(c.diff) > 0 {
+ return c.diff // diffs
+ }
+ return nil // no diffs
+}
+
+func (c *cmp) equals(a, b reflect.Value, level int) {
+ if level > MaxDepth {
+ logError(ErrMaxRecursion)
+ return
+ }
+
+ // Check if one value is nil, e.g. T{x: *X} and T.x is nil
+ if !a.IsValid() || !b.IsValid() {
+ if a.IsValid() && !b.IsValid() {
+ c.saveDiff(a.Type(), "<nil pointer>")
+ } else if !a.IsValid() && b.IsValid() {
+ c.saveDiff("<nil pointer>", b.Type())
+ }
+ return
+ }
+
+ // If differenet types, they can't be equal
+ aType := a.Type()
+ bType := b.Type()
+ if aType != bType {
+ c.saveDiff(aType, bType)
+ logError(ErrTypeMismatch)
+ return
+ }
+
+ // Primitive https://golang.org/pkg/reflect/#Kind
+ aKind := a.Kind()
+ bKind := b.Kind()
+
+ // If both types implement the error interface, compare the error strings.
+ // This must be done before dereferencing because the interface is on a
+ // pointer receiver.
+ if aType.Implements(errorType) && bType.Implements(errorType) {
+ if a.Elem().IsValid() && b.Elem().IsValid() { // both err != nil
+ aString := a.MethodByName("Error").Call(nil)[0].String()
+ bString := b.MethodByName("Error").Call(nil)[0].String()
+ if aString != bString {
+ c.saveDiff(aString, bString)
+ }
+ return
+ }
+ }
+
+ // Dereference pointers and interface{}
+ if aElem, bElem := (aKind == reflect.Ptr || aKind == reflect.Interface),
+ (bKind == reflect.Ptr || bKind == reflect.Interface); aElem || bElem {
+
+ if aElem {
+ a = a.Elem()
+ }
+
+ if bElem {
+ b = b.Elem()
+ }
+
+ c.equals(a, b, level+1)
+ return
+ }
+
+ // Types with an Equal(), like time.Time.
+ eqFunc := a.MethodByName("Equal")
+ if eqFunc.IsValid() {
+ retVals := eqFunc.Call([]reflect.Value{b})
+ if !retVals[0].Bool() {
+ c.saveDiff(a, b)
+ }
+ return
+ }
+
+ switch aKind {
+
+ /////////////////////////////////////////////////////////////////////
+ // Iterable kinds
+ /////////////////////////////////////////////////////////////////////
+
+ case reflect.Struct:
+ /*
+ The variables are structs like:
+ type T struct {
+ FirstName string
+ LastName string
+ }
+ Type = <pkg>.T, Kind = reflect.Struct
+
+ Iterate through the fields (FirstName, LastName), recurse into their values.
+ */
+ for i := 0; i < a.NumField(); i++ {
+ if aType.Field(i).PkgPath != "" && !CompareUnexportedFields {
+ continue // skip unexported field, e.g. s in type T struct {s string}
+ }
+
+ c.push(aType.Field(i).Name) // push field name to buff
+
+ // Get the Value for each field, e.g. FirstName has Type = string,
+ // Kind = reflect.String.
+ af := a.Field(i)
+ bf := b.Field(i)
+
+ // Recurse to compare the field values
+ c.equals(af, bf, level+1)
+
+ c.pop() // pop field name from buff
+
+ if len(c.diff) >= MaxDiff {
+ break
+ }
+ }
+ case reflect.Map:
+ /*
+ The variables are maps like:
+ map[string]int{
+ "foo": 1,
+ "bar": 2,
+ }
+ Type = map[string]int, Kind = reflect.Map
+
+ Or:
+ type T map[string]int{}
+ Type = <pkg>.T, Kind = reflect.Map
+
+ Iterate through the map keys (foo, bar), recurse into their values.
+ */
+
+ if a.IsNil() || b.IsNil() {
+ if a.IsNil() && !b.IsNil() {
+ c.saveDiff("<nil map>", b)
+ } else if !a.IsNil() && b.IsNil() {
+ c.saveDiff(a, "<nil map>")
+ }
+ return
+ }
+
+ if a.Pointer() == b.Pointer() {
+ return
+ }
+
+ for _, key := range a.MapKeys() {
+ c.push(fmt.Sprintf("map[%s]", key))
+
+ aVal := a.MapIndex(key)
+ bVal := b.MapIndex(key)
+ if bVal.IsValid() {
+ c.equals(aVal, bVal, level+1)
+ } else {
+ c.saveDiff(aVal, "<does not have key>")
+ }
+
+ c.pop()
+
+ if len(c.diff) >= MaxDiff {
+ return
+ }
+ }
+
+ for _, key := range b.MapKeys() {
+ if aVal := a.MapIndex(key); aVal.IsValid() {
+ continue
+ }
+
+ c.push(fmt.Sprintf("map[%s]", key))
+ c.saveDiff("<does not have key>", b.MapIndex(key))
+ c.pop()
+ if len(c.diff) >= MaxDiff {
+ return
+ }
+ }
+ case reflect.Array:
+ n := a.Len()
+ for i := 0; i < n; i++ {
+ c.push(fmt.Sprintf("array[%d]", i))
+ c.equals(a.Index(i), b.Index(i), level+1)
+ c.pop()
+ if len(c.diff) >= MaxDiff {
+ break
+ }
+ }
+ case reflect.Slice:
+ if a.IsNil() || b.IsNil() {
+ if a.IsNil() && !b.IsNil() {
+ c.saveDiff("<nil slice>", b)
+ } else if !a.IsNil() && b.IsNil() {
+ c.saveDiff(a, "<nil slice>")
+ }
+ return
+ }
+
+ if a.Pointer() == b.Pointer() {
+ return
+ }
+
+ aLen := a.Len()
+ bLen := b.Len()
+ n := aLen
+ if bLen > aLen {
+ n = bLen
+ }
+ for i := 0; i < n; i++ {
+ c.push(fmt.Sprintf("slice[%d]", i))
+ if i < aLen && i < bLen {
+ c.equals(a.Index(i), b.Index(i), level+1)
+ } else if i < aLen {
+ c.saveDiff(a.Index(i), "<no value>")
+ } else {
+ c.saveDiff("<no value>", b.Index(i))
+ }
+ c.pop()
+ if len(c.diff) >= MaxDiff {
+ break
+ }
+ }
+
+ /////////////////////////////////////////////////////////////////////
+ // Primitive kinds
+ /////////////////////////////////////////////////////////////////////
+
+ case reflect.Float32, reflect.Float64:
+ // Avoid 0.04147685731961082 != 0.041476857319611
+ // 6 decimal places is close enough
+ aval := fmt.Sprintf(c.floatFormat, a.Float())
+ bval := fmt.Sprintf(c.floatFormat, b.Float())
+ if aval != bval {
+ c.saveDiff(a.Float(), b.Float())
+ }
+ case reflect.Bool:
+ if a.Bool() != b.Bool() {
+ c.saveDiff(a.Bool(), b.Bool())
+ }
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ if a.Int() != b.Int() {
+ c.saveDiff(a.Int(), b.Int())
+ }
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ if a.Uint() != b.Uint() {
+ c.saveDiff(a.Uint(), b.Uint())
+ }
+ case reflect.String:
+ if a.String() != b.String() {
+ c.saveDiff(a.String(), b.String())
+ }
+
+ default:
+ logError(ErrNotHandled)
+ }
+}
+
+func (c *cmp) push(name string) {
+ c.buff = append(c.buff, name)
+}
+
+func (c *cmp) pop() {
+ if len(c.buff) > 0 {
+ c.buff = c.buff[0 : len(c.buff)-1]
+ }
+}
+
+func (c *cmp) saveDiff(aval, bval interface{}) {
+ if len(c.buff) > 0 {
+ varName := strings.Join(c.buff, ".")
+ c.diff = append(c.diff, fmt.Sprintf("%s: %v != %v", varName, aval, bval))
+ } else {
+ c.diff = append(c.diff, fmt.Sprintf("%v != %v", aval, bval))
+ }
+}
+
+func logError(err error) {
+ if LogErrors {
+ log.Println(err)
+ }
+}