diff --git a/assert/assertion_compare.go b/assert/assertion_compare.go index 718d58202..96027d1ec 100644 --- a/assert/assertion_compare.go +++ b/assert/assertion_compare.go @@ -3,6 +3,7 @@ package assert import ( "fmt" "reflect" + "time" ) type CompareType int @@ -30,6 +31,8 @@ var ( float64Type = reflect.TypeOf(float64(1)) stringType = reflect.TypeOf("") + + timeType = reflect.TypeOf(time.Time{}) ) func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) { @@ -299,6 +302,27 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) { return compareLess, true } } + // Check for known struct types we can check for compare results. + case reflect.Struct: + { + // All structs enter here. We're not interested in most types. + if !obj1Value.CanConvert(timeType) { + break + } + + // time.Time can compared! + timeObj1, ok := obj1.(time.Time) + if !ok { + timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time) + } + + timeObj2, ok := obj2.(time.Time) + if !ok { + timeObj2 = obj2Value.Convert(timeType).Interface().(time.Time) + } + + return compare(timeObj1.UnixNano(), timeObj2.UnixNano(), reflect.Int64) + } } return compareEqual, false diff --git a/assert/assertion_compare_test.go b/assert/assertion_compare_test.go index 667dbf7c4..1af4b5da2 100644 --- a/assert/assertion_compare_test.go +++ b/assert/assertion_compare_test.go @@ -6,6 +6,7 @@ import ( "reflect" "runtime" "testing" + "time" ) func TestCompare(t *testing.T) { @@ -22,6 +23,7 @@ func TestCompare(t *testing.T) { type customFloat32 float32 type customFloat64 float64 type customString string + type customTime time.Time for _, currCase := range []struct { less interface{} greater interface{} @@ -52,6 +54,8 @@ func TestCompare(t *testing.T) { {less: customFloat32(1.23), greater: customFloat32(2.23), cType: "float32"}, {less: float64(1.23), greater: float64(2.34), cType: "float64"}, {less: customFloat64(1.23), greater: customFloat64(2.34), cType: "float64"}, + {less: time.Now(), greater: time.Now().Add(time.Hour), cType: "time.Time"}, + {less: customTime(time.Now()), greater: customTime(time.Now().Add(time.Hour)), cType: "time.Time"}, } { resLess, isComparable := compare(currCase.less, currCase.greater, reflect.ValueOf(currCase.less).Kind()) if !isComparable { @@ -59,7 +63,8 @@ func TestCompare(t *testing.T) { } if resLess != compareLess { - t.Errorf("object less should be less than greater for type " + currCase.cType) + t.Errorf("object less (%v) should be less than greater (%v) for type "+currCase.cType, + currCase.less, currCase.greater) } resGreater, isComparable := compare(currCase.greater, currCase.less, reflect.ValueOf(currCase.less).Kind())