github.com/systematiccaos/gorm@v1.22.6/utils/tests/utils.go (about)

     1  package tests
     2  
     3  import (
     4  	"database/sql/driver"
     5  	"fmt"
     6  	"go/ast"
     7  	"reflect"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/systematiccaos/gorm/utils"
    12  )
    13  
    14  func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) {
    15  	for _, name := range names {
    16  		got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface()
    17  		expect := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface()
    18  		t.Run(name, func(t *testing.T) {
    19  			AssertEqual(t, got, expect)
    20  		})
    21  	}
    22  }
    23  
    24  func AssertEqual(t *testing.T, got, expect interface{}) {
    25  	if !reflect.DeepEqual(got, expect) {
    26  		isEqual := func() {
    27  			if curTime, ok := got.(time.Time); ok {
    28  				format := "2006-01-02T15:04:05Z07:00"
    29  
    30  				if curTime.Round(time.Second).UTC().Format(format) != expect.(time.Time).Round(time.Second).UTC().Format(format) && curTime.Truncate(time.Second).UTC().Format(format) != expect.(time.Time).Truncate(time.Second).UTC().Format(format) {
    31  					t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time), curTime)
    32  				}
    33  			} else if fmt.Sprint(got) != fmt.Sprint(expect) {
    34  				t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got)
    35  			}
    36  		}
    37  
    38  		if fmt.Sprint(got) == fmt.Sprint(expect) {
    39  			return
    40  		}
    41  
    42  		if reflect.Indirect(reflect.ValueOf(got)).IsValid() != reflect.Indirect(reflect.ValueOf(expect)).IsValid() {
    43  			t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got)
    44  			return
    45  		}
    46  
    47  		if valuer, ok := got.(driver.Valuer); ok {
    48  			got, _ = valuer.Value()
    49  		}
    50  
    51  		if valuer, ok := expect.(driver.Valuer); ok {
    52  			expect, _ = valuer.Value()
    53  		}
    54  
    55  		if got != nil {
    56  			got = reflect.Indirect(reflect.ValueOf(got)).Interface()
    57  		}
    58  
    59  		if expect != nil {
    60  			expect = reflect.Indirect(reflect.ValueOf(expect)).Interface()
    61  		}
    62  
    63  		if reflect.ValueOf(got).IsValid() != reflect.ValueOf(expect).IsValid() {
    64  			t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got)
    65  			return
    66  		}
    67  
    68  		if reflect.ValueOf(got).Kind() == reflect.Slice {
    69  			if reflect.ValueOf(expect).Kind() == reflect.Slice {
    70  				if reflect.ValueOf(got).Len() == reflect.ValueOf(expect).Len() {
    71  					for i := 0; i < reflect.ValueOf(got).Len(); i++ {
    72  						name := fmt.Sprintf(reflect.ValueOf(got).Type().Name()+" #%v", i)
    73  						t.Run(name, func(t *testing.T) {
    74  							AssertEqual(t, reflect.ValueOf(got).Index(i).Interface(), reflect.ValueOf(expect).Index(i).Interface())
    75  						})
    76  					}
    77  				} else {
    78  					name := reflect.ValueOf(got).Type().Elem().Name()
    79  					t.Errorf("%v expects length: %v, got %v (expects: %+v, got %+v)", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len(), expect, got)
    80  				}
    81  				return
    82  			}
    83  		}
    84  
    85  		if reflect.ValueOf(got).Kind() == reflect.Struct {
    86  			if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() {
    87  				exported := false
    88  				for i := 0; i < reflect.ValueOf(got).NumField(); i++ {
    89  					if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) {
    90  						exported = true
    91  						field := reflect.ValueOf(got).Field(i)
    92  						t.Run(fieldStruct.Name, func(t *testing.T) {
    93  							AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface())
    94  						})
    95  					}
    96  				}
    97  
    98  				if exported {
    99  					return
   100  				}
   101  			}
   102  		}
   103  
   104  		if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) {
   105  			got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface()
   106  			isEqual()
   107  		} else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) {
   108  			expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface()
   109  			isEqual()
   110  		}
   111  	}
   112  }
   113  
   114  func Now() *time.Time {
   115  	now := time.Now()
   116  	return &now
   117  }