github.com/wfusion/gofusion@v1.1.14/common/utils/cmp/base.go (about)

     1  package cmp
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"time"
     7  
     8  	"github.com/shopspring/decimal"
     9  	"gorm.io/gorm"
    10  
    11  	"github.com/wfusion/gofusion/common/utils"
    12  )
    13  
    14  func ComparablePtr[T comparable](a, b *T) bool {
    15  	if a == nil && b == nil {
    16  		return true
    17  	}
    18  	if a == nil || b == nil {
    19  		return false
    20  	}
    21  
    22  	return *a == *b
    23  }
    24  func SliceComparable[T comparable, TS ~[]T](a, b TS) bool {
    25  	if a == nil && b == nil {
    26  		return true
    27  	}
    28  	if a == nil || b == nil || len(a) != len(b) {
    29  		return false
    30  	}
    31  	for i := 0; i < len(a); i++ {
    32  		if a[i] != b[i] {
    33  			return false
    34  		}
    35  	}
    36  	return true
    37  }
    38  func SliceComparablePtr[T comparable, TS ~[]*T](a, b TS) bool {
    39  	if a == nil && b == nil {
    40  		return true
    41  	}
    42  	if a == nil || b == nil || len(a) != len(b) {
    43  		return false
    44  	}
    45  	for i := 0; i < len(a); i++ {
    46  		if deref(a[i]) != deref(b[i]) {
    47  			return false
    48  		}
    49  	}
    50  	return true
    51  }
    52  
    53  func TimePtr(a, b *time.Time) bool {
    54  	if a == nil && b == nil {
    55  		return true
    56  	}
    57  	if a == nil || b == nil {
    58  		return false
    59  	}
    60  
    61  	return a.Equal(*b) && TimeLocationPtr(a.Location(), b.Location())
    62  }
    63  
    64  func TimeLocationPtr(a, b *time.Location) bool {
    65  	if a == nil && b == nil {
    66  		return true
    67  	}
    68  	if a == nil || b == nil {
    69  		return false
    70  	}
    71  
    72  	return a.String() == b.String()
    73  }
    74  
    75  func DecimalPtr(a, b *decimal.Decimal) bool {
    76  	if a == nil && b == nil {
    77  		return true
    78  	}
    79  	if a == nil || b == nil {
    80  		return false
    81  	}
    82  
    83  	return a.Equal(*b)
    84  }
    85  
    86  func GormModelPtr(a, b *gorm.Model) bool {
    87  	if a == nil && b == nil {
    88  		return true
    89  	}
    90  	if a == nil || b == nil {
    91  		return false
    92  	}
    93  
    94  	return a.ID == b.ID &&
    95  		TimePtr(&a.CreatedAt, &b.CreatedAt) &&
    96  		TimePtr(&a.UpdatedAt, &b.UpdatedAt) &&
    97  		a.DeletedAt.Valid == b.DeletedAt.Valid &&
    98  		TimePtr(&a.DeletedAt.Time, &b.DeletedAt.Time)
    99  }
   100  
   101  type _comparable[T any] interface {
   102  	Equals(other T) bool
   103  }
   104  
   105  func Slice[T _comparable[T], TS ~[]T](a, b TS, sortFn func(i, j T) bool) bool {
   106  	if a == nil && b == nil {
   107  		return true
   108  	}
   109  	if a == nil || b == nil || len(a) != len(b) {
   110  		return false
   111  	}
   112  
   113  	if sortFn != nil {
   114  		utils.SortStable(a, sortFn)
   115  		utils.SortStable(b, sortFn)
   116  	}
   117  
   118  	for i := 0; i < len(a); i++ {
   119  		if !a[i].Equals(b[i]) {
   120  			return false
   121  		}
   122  	}
   123  	return true
   124  }
   125  
   126  func SliceAny[T any, TS ~[]T](a, b TS, sortFn func(i, j T) bool) bool {
   127  	if a == nil && b == nil {
   128  		return true
   129  	}
   130  	if a == nil || b == nil || len(a) != len(b) {
   131  		return false
   132  	}
   133  
   134  	if sortFn != nil {
   135  		utils.SortStable(a, sortFn)
   136  		utils.SortStable(b, sortFn)
   137  	}
   138  
   139  	for i := 0; i < len(a); i++ {
   140  		if !anything(a[i], b[i]) {
   141  			return false
   142  		}
   143  	}
   144  	return true
   145  }
   146  
   147  func Map[T _comparable[T], K comparable](a, b map[K]T) bool {
   148  	if a == nil && b == nil {
   149  		return true
   150  	}
   151  	if a == nil || b == nil || len(a) != len(b) {
   152  		return false
   153  	}
   154  
   155  	for ak, av := range a {
   156  		bv, ok := b[ak]
   157  		if !ok || !av.Equals(bv) {
   158  			return false
   159  		}
   160  	}
   161  
   162  	return true
   163  }
   164  
   165  func MapAny[K comparable, T any](a, b map[K]T) bool {
   166  	if a == nil && b == nil {
   167  		return true
   168  	}
   169  	if a == nil || b == nil || len(a) != len(b) {
   170  		return false
   171  	}
   172  
   173  	for ak, av := range a {
   174  		bv, ok := b[ak]
   175  		if !ok || !anything(av, bv) {
   176  			return false
   177  		}
   178  
   179  	}
   180  	return true
   181  }
   182  
   183  func anything(a, b any) bool {
   184  	switch av := a.(type) {
   185  	case
   186  			bool,
   187  			string, uintptr,
   188  			int, int8, int16, int32, int64,
   189  			uint, uint8, uint16, uint32, uint64,
   190  			float32, float64,
   191  			complex64, complex128:
   192  		return a == b
   193  	case
   194  			*bool,
   195  			*string, *uintptr,
   196  			*int, *int8, *int16, *int32, *int64,
   197  			*uint, *uint8, *uint16, *uint32, *uint64,
   198  			*float32, *float64,
   199  			*complex64, *complex128:
   200  		if a == nil && b == nil {
   201  			return true
   202  		}
   203  		if a == nil || b == nil {
   204  			return false
   205  		}
   206  		return anything(deref(a), deref(b))
   207  	case decimal.Decimal:
   208  		bv := b.(decimal.Decimal)
   209  		return DecimalPtr(&av, &bv)
   210  	case *decimal.Decimal:
   211  		bv := b.(*decimal.Decimal)
   212  		return DecimalPtr(av, bv)
   213  	case time.Time:
   214  		bv := b.(time.Time)
   215  		return TimePtr(&av, &bv)
   216  	case *time.Time:
   217  		bv := b.(*time.Time)
   218  		return TimePtr(av, bv)
   219  	case time.Location:
   220  		bv := b.(time.Location)
   221  		return TimeLocationPtr(&av, &bv)
   222  	case *time.Location:
   223  		bv := b.(*time.Location)
   224  		return TimeLocationPtr(av, bv)
   225  	case []bool:
   226  		return SliceComparable(a.([]bool), b.([]bool))
   227  	case []string:
   228  		return SliceComparable(a.([]string), b.([]string))
   229  	case []uintptr:
   230  		return SliceComparable(a.([]uintptr), b.([]uintptr))
   231  	case []int:
   232  		return SliceComparable(a.([]int), b.([]int))
   233  	case []int8:
   234  		return SliceComparable(a.([]int8), b.([]int8))
   235  	case []int16:
   236  		return SliceComparable(a.([]int16), b.([]int16))
   237  	case []int32:
   238  		return SliceComparable(a.([]int32), b.([]int32))
   239  	case []int64:
   240  		return SliceComparable(a.([]int64), b.([]int64))
   241  	case []uint:
   242  		return SliceComparable(a.([]uint), b.([]uint))
   243  	case []uint8:
   244  		return SliceComparable(a.([]uint8), b.([]uint8))
   245  	case []uint16:
   246  		return SliceComparable(a.([]uint16), b.([]uint16))
   247  	case []uint32:
   248  		return SliceComparable(a.([]uint32), b.([]uint32))
   249  	case []uint64:
   250  		return SliceComparable(a.([]uint64), b.([]uint64))
   251  	case []float32:
   252  		return SliceComparable(a.([]float32), b.([]float32))
   253  	case []float64:
   254  		return SliceComparable(a.([]float64), b.([]float64))
   255  	case []complex64:
   256  		return SliceComparable(a.([]complex64), b.([]complex64))
   257  	case []complex128:
   258  		return SliceComparable(a.([]complex128), b.([]complex128))
   259  	case []any:
   260  		return SliceAny(a.([]any), b.([]any), nil)
   261  	case []map[string]any:
   262  		return SliceAny(a.([]map[string]any), b.([]map[string]any), nil)
   263  	case map[string]any:
   264  		return MapAny(av, b.(map[string]any))
   265  	default:
   266  		return reflect.DeepEqual(a, b)
   267  	}
   268  }
   269  
   270  func deref(p any) (v any) {
   271  	switch pp := p.(type) {
   272  	case *bool:
   273  		v = *pp
   274  	case *string:
   275  		v = *pp
   276  	case *int:
   277  		v = *pp
   278  	case *int8:
   279  		v = *pp
   280  	case *int16:
   281  		v = *pp
   282  	case *int32:
   283  		v = *pp
   284  	case *int64:
   285  		v = *pp
   286  	case *uint:
   287  		v = *pp
   288  	case *uint8:
   289  		v = *pp
   290  	case *uint16:
   291  		v = *pp
   292  	case *uint32:
   293  		v = *pp
   294  	case *uint64:
   295  		v = *pp
   296  	case *float32:
   297  		v = *pp
   298  	case *float64:
   299  		v = *pp
   300  	case *complex64:
   301  		v = *pp
   302  	case *complex128:
   303  		v = *pp
   304  	case *uintptr:
   305  		v = *pp
   306  	case *[]byte:
   307  		v = *pp
   308  	case *any:
   309  		v = *pp
   310  	default:
   311  		panic(fmt.Errorf("unsupported type %T", pp))
   312  	}
   313  	return
   314  }