1 // Copyright 2009 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
5 // Package reflect is a fork of go's standard library reflection package, which
6 // allows for deep equal with equality functions defined.
15 // Equalities is a map from type to a function comparing two values of
17 type Equalities map[reflect.Type]reflect.Value
19 // For convenience, panics on errrors
20 func EqualitiesOrDie(funcs ...interface{}) Equalities {
22 if err := e.AddFuncs(funcs...); err != nil {
28 // AddFuncs is a shortcut for multiple calls to AddFunc.
29 func (e Equalities) AddFuncs(funcs ...interface{}) error {
30 for _, f := range funcs {
31 if err := e.AddFunc(f); err != nil {
38 // AddFunc uses func as an equality function: it must take
39 // two parameters of the same type, and return a boolean.
40 func (e Equalities) AddFunc(eqFunc interface{}) error {
41 fv := reflect.ValueOf(eqFunc)
43 if ft.Kind() != reflect.Func {
44 return fmt.Errorf("expected func, got: %v", ft)
47 return fmt.Errorf("expected two 'in' params, got: %v", ft)
50 return fmt.Errorf("expected one 'out' param, got: %v", ft)
52 if ft.In(0) != ft.In(1) {
53 return fmt.Errorf("expected arg 1 and 2 to have same type, but got %v", ft)
55 var forReturnType bool
56 boolType := reflect.TypeOf(forReturnType)
57 if ft.Out(0) != boolType {
58 return fmt.Errorf("expected bool return, got: %v", ft)
64 // Below here is forked from go's reflect/deepequal.go
66 // During deepValueEqual, must keep track of checks that are
67 // in progress. The comparison algorithm assumes that all
68 // checks in progress are true when it reencounters them.
69 // Visited comparisons are stored in a map indexed by visit.
76 // unexportedTypePanic is thrown when you use this DeepEqual on something that has an
77 // unexported type. It indicates a programmer error, so should not occur at runtime,
78 // which is why it's not public and thus impossible to catch.
79 type unexportedTypePanic []reflect.Type
81 func (u unexportedTypePanic) Error() string { return u.String() }
82 func (u unexportedTypePanic) String() string {
83 strs := make([]string, len(u))
85 strs[i] = fmt.Sprintf("%v", t)
87 return "an unexported field was encountered, nested like this: " + strings.Join(strs, " -> ")
90 func makeUsefulPanic(v reflect.Value) {
91 if x := recover(); x != nil {
92 if u, ok := x.(unexportedTypePanic); ok {
93 u = append(unexportedTypePanic{v.Type()}, u...)
100 // Tests for deep equality using reflected types. The map argument tracks
101 // comparisons that have already been seen, which allows short circuiting on
103 func (e Equalities) deepValueEqual(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool {
104 defer makeUsefulPanic(v1)
106 if !v1.IsValid() || !v2.IsValid() {
107 return v1.IsValid() == v2.IsValid()
109 if v1.Type() != v2.Type() {
112 if fv, ok := e[v1.Type()]; ok {
113 return fv.Call([]reflect.Value{v1, v2})[0].Bool()
116 hard := func(k reflect.Kind) bool {
118 case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct:
124 if v1.CanAddr() && v2.CanAddr() && hard(v1.Kind()) {
125 addr1 := v1.UnsafeAddr()
126 addr2 := v2.UnsafeAddr()
128 // Canonicalize order to reduce number of entries in visited.
129 addr1, addr2 = addr2, addr1
132 // Short circuit if references are identical ...
137 // ... or already seen
139 v := visit{addr1, addr2, typ}
144 // Remember for later.
150 // We don't need to check length here because length is part of
151 // an array's type, which has already been filtered for.
152 for i := 0; i < v1.Len(); i++ {
153 if !e.deepValueEqual(v1.Index(i), v2.Index(i), visited, depth+1) {
159 if (v1.IsNil() || v1.Len() == 0) != (v2.IsNil() || v2.Len() == 0) {
162 if v1.IsNil() || v1.Len() == 0 {
165 if v1.Len() != v2.Len() {
168 if v1.Pointer() == v2.Pointer() {
171 for i := 0; i < v1.Len(); i++ {
172 if !e.deepValueEqual(v1.Index(i), v2.Index(i), visited, depth+1) {
177 case reflect.Interface:
178 if v1.IsNil() || v2.IsNil() {
179 return v1.IsNil() == v2.IsNil()
181 return e.deepValueEqual(v1.Elem(), v2.Elem(), visited, depth+1)
183 return e.deepValueEqual(v1.Elem(), v2.Elem(), visited, depth+1)
185 for i, n := 0, v1.NumField(); i < n; i++ {
186 if !e.deepValueEqual(v1.Field(i), v2.Field(i), visited, depth+1) {
192 if (v1.IsNil() || v1.Len() == 0) != (v2.IsNil() || v2.Len() == 0) {
195 if v1.IsNil() || v1.Len() == 0 {
198 if v1.Len() != v2.Len() {
201 if v1.Pointer() == v2.Pointer() {
204 for _, k := range v1.MapKeys() {
205 if !e.deepValueEqual(v1.MapIndex(k), v2.MapIndex(k), visited, depth+1) {
211 if v1.IsNil() && v2.IsNil() {
214 // Can't do better than this:
217 // Normal equality suffices
218 if !v1.CanInterface() || !v2.CanInterface() {
219 panic(unexportedTypePanic{})
221 return v1.Interface() == v2.Interface()
225 // DeepEqual is like reflect.DeepEqual, but focused on semantic equality
226 // instead of memory equality.
228 // It will use e's equality functions if it finds types that match.
230 // An empty slice *is* equal to a nil slice for our purposes; same for maps.
232 // Unexported field members cannot be compared and will cause an imformative panic; you must add an Equality
233 // function for these types.
234 func (e Equalities) DeepEqual(a1, a2 interface{}) bool {
235 if a1 == nil || a2 == nil {
238 v1 := reflect.ValueOf(a1)
239 v2 := reflect.ValueOf(a2)
240 if v1.Type() != v2.Type() {
243 return e.deepValueEqual(v1, v2, make(map[visit]bool), 0)
246 func (e Equalities) deepValueDerive(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool {
247 defer makeUsefulPanic(v1)
249 if !v1.IsValid() || !v2.IsValid() {
250 return v1.IsValid() == v2.IsValid()
252 if v1.Type() != v2.Type() {
255 if fv, ok := e[v1.Type()]; ok {
256 return fv.Call([]reflect.Value{v1, v2})[0].Bool()
259 hard := func(k reflect.Kind) bool {
261 case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct:
267 if v1.CanAddr() && v2.CanAddr() && hard(v1.Kind()) {
268 addr1 := v1.UnsafeAddr()
269 addr2 := v2.UnsafeAddr()
271 // Canonicalize order to reduce number of entries in visited.
272 addr1, addr2 = addr2, addr1
275 // Short circuit if references are identical ...
280 // ... or already seen
282 v := visit{addr1, addr2, typ}
287 // Remember for later.
293 // We don't need to check length here because length is part of
294 // an array's type, which has already been filtered for.
295 for i := 0; i < v1.Len(); i++ {
296 if !e.deepValueDerive(v1.Index(i), v2.Index(i), visited, depth+1) {
302 if v1.IsNil() || v1.Len() == 0 {
305 if v1.Len() > v2.Len() {
308 if v1.Pointer() == v2.Pointer() {
311 for i := 0; i < v1.Len(); i++ {
312 if !e.deepValueDerive(v1.Index(i), v2.Index(i), visited, depth+1) {
321 if v1.Len() > v2.Len() {
324 return v1.String() == v2.String()
325 case reflect.Interface:
329 return e.deepValueDerive(v1.Elem(), v2.Elem(), visited, depth+1)
334 return e.deepValueDerive(v1.Elem(), v2.Elem(), visited, depth+1)
336 for i, n := 0, v1.NumField(); i < n; i++ {
337 if !e.deepValueDerive(v1.Field(i), v2.Field(i), visited, depth+1) {
343 if v1.IsNil() || v1.Len() == 0 {
346 if v1.Len() > v2.Len() {
349 if v1.Pointer() == v2.Pointer() {
352 for _, k := range v1.MapKeys() {
353 if !e.deepValueDerive(v1.MapIndex(k), v2.MapIndex(k), visited, depth+1) {
359 if v1.IsNil() && v2.IsNil() {
362 // Can't do better than this:
365 // Normal equality suffices
366 if !v1.CanInterface() || !v2.CanInterface() {
367 panic(unexportedTypePanic{})
369 return v1.Interface() == v2.Interface()
373 // DeepDerivative is similar to DeepEqual except that unset fields in a1 are
374 // ignored (not compared). This allows us to focus on the fields that matter to
375 // the semantic comparison.
377 // The unset fields include a nil pointer and an empty string.
378 func (e Equalities) DeepDerivative(a1, a2 interface{}) bool {
382 v1 := reflect.ValueOf(a1)
383 v2 := reflect.ValueOf(a2)
384 if v1.Type() != v2.Type() {
387 return e.deepValueDerive(v1, v2, make(map[visit]bool), 0)