github.com/jxskiss/gopkg/v2@v2.14.9-0.20240514120614-899f3e7952b4/unsafe/reflectx/equality.go (about)

     1  package reflectx
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"strings"
     7  )
     8  
     9  // IsIdenticalType checks whether the given two object types have same
    10  // struct fields and memory layout (same order, same name and same type).
    11  // It's useful to check generated types are exactly same in different
    12  // packages, e.g. Thrift, Protobuf, Msgpack, etc.
    13  //
    14  // If two types are identical, it is expected that unsafe pointer casting
    15  // between the two types won't crash the program.
    16  // If the given two types are not identical, the returned diff message
    17  // contains the detail difference.
    18  func IsIdenticalType(a, b any) (equal bool, diff string) {
    19  	typ1 := reflect.TypeOf(a)
    20  	typ2 := reflect.TypeOf(b)
    21  	return newStrictTypecmp().isEqualType(typ1, typ2)
    22  }
    23  
    24  // IsIdenticalThriftType checks whether the given two object types have same
    25  // struct fields and memory layout, in case that a field's name does not
    26  // match, but the thrift tag's first two fields match, it also considers
    27  // the field matches.
    28  //
    29  // It is almost same with IsIdenticalType, but helps the situation that
    30  // different Thrift generators which generate different field names.
    31  //
    32  // If two types are identical, it is expected that unsafe pointer casting
    33  // between the two types won't crash the program.
    34  // If the given two types are not identical, the returned diff message
    35  // contains the detail difference.
    36  func IsIdenticalThriftType(a, b any) (equal bool, diff string) {
    37  	typ1 := reflect.TypeOf(a)
    38  	typ2 := reflect.TypeOf(b)
    39  	return newThriftTypecmp().isEqualType(typ1, typ2)
    40  }
    41  
    42  type typ1typ2 struct {
    43  	typ1, typ2 reflect.Type
    44  }
    45  
    46  const (
    47  	notequal = 1
    48  	isequal  = 2
    49  	checking = 3
    50  )
    51  
    52  type cmpresult struct {
    53  	result int
    54  	diff   string
    55  }
    56  
    57  type typecmp struct {
    58  	seen map[typ1typ2]*cmpresult
    59  
    60  	fieldCmp func(typ reflect.Type, f1, f2 reflect.StructField) (bool, string)
    61  }
    62  
    63  func newStrictTypecmp() *typecmp {
    64  	cmp := &typecmp{
    65  		seen: make(map[typ1typ2]*cmpresult),
    66  	}
    67  	cmp.fieldCmp = cmp.isStrictEqualField
    68  	return cmp
    69  }
    70  
    71  func newThriftTypecmp() *typecmp {
    72  	cmp := &typecmp{
    73  		seen: make(map[typ1typ2]*cmpresult),
    74  	}
    75  	cmp.fieldCmp = cmp.isEqualThriftField
    76  	return cmp
    77  }
    78  
    79  func (p *typecmp) isEqualType(typ1, typ2 reflect.Type) (bool, string) {
    80  	if typ1.Kind() != typ2.Kind() {
    81  		return false, fmt.Sprintf("type kind not equal: %s, %s", _t(typ1), _t(typ2))
    82  	}
    83  	if typ1.Kind() == reflect.Ptr {
    84  		if typ1.Elem().Kind() != typ2.Elem().Kind() {
    85  			return false, fmt.Sprintf("pointer type not equal: %s, %s", _t(typ1), _t(typ2))
    86  		}
    87  		typ1 = typ1.Elem()
    88  		typ2 = typ2.Elem()
    89  		return p.isEqualType(typ1, typ2)
    90  	}
    91  	switch typ1.Kind() {
    92  	case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
    93  		reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr,
    94  		reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128, reflect.String:
    95  		return true, ""
    96  	case reflect.Struct:
    97  		return p.isEqualStruct(typ1, typ2)
    98  	case reflect.Slice:
    99  		return p.isEqualSlice(typ1, typ2)
   100  	case reflect.Map:
   101  		return p.isEqualMap(typ1, typ2)
   102  	}
   103  	return false, fmt.Sprintf("unsupported types: %s, %s", _t(typ1), _t(typ2))
   104  }
   105  
   106  func (p *typecmp) isEqualStruct(typ1, typ2 reflect.Type) (bool, string) {
   107  	if typ1.Kind() != reflect.Struct || typ2.Kind() != reflect.Struct {
   108  		return false, fmt.Sprintf("type is not struct: %s, %s", _t(typ1), _t(typ2))
   109  	}
   110  
   111  	typidx := typ1typ2{typ1, typ2}
   112  	if cmpr := p.seen[typidx]; cmpr != nil {
   113  		// In case of recursive type, cmpr.result will be checking here,
   114  		// we treat it as equal, the final result will be updated below.
   115  		return cmpr.result != notequal, cmpr.diff
   116  	}
   117  
   118  	p.seen[typidx] = &cmpresult{checking, ""}
   119  	if typ1.NumField() != typ2.NumField() {
   120  		diff := fmt.Sprintf("struct field num not match: %s, %s", _t(typ1), _t(typ2))
   121  		p.seen[typidx] = &cmpresult{notequal, diff}
   122  		return false, diff
   123  	}
   124  	fnum := typ1.NumField()
   125  	for i := 0; i < fnum; i++ {
   126  		f1 := typ1.Field(i)
   127  		f2 := typ2.Field(i)
   128  		if equal, diff := p.fieldCmp(typ1, f1, f2); !equal {
   129  			diff = fmt.Sprintf("struct field not equal: %s", diff)
   130  			p.seen[typidx] = &cmpresult{notequal, diff}
   131  			return false, diff
   132  		}
   133  	}
   134  	p.seen[typidx] = &cmpresult{isequal, ""}
   135  	return true, ""
   136  }
   137  
   138  func (p *typecmp) isStrictEqualField(typ reflect.Type, field1, field2 reflect.StructField) (bool, string) {
   139  	if field1.Name != field2.Name {
   140  		return false, fmt.Sprintf("field name not equal: %s, %s", _f(typ, field1), _f(typ, field2))
   141  	}
   142  	if field1.Offset != field2.Offset {
   143  		return false, fmt.Sprintf("field offset not equal: %s, %s", _f(typ, field1), _f(typ, field2))
   144  	}
   145  
   146  	typ1 := field1.Type
   147  	typ2 := field2.Type
   148  	if typ1.Kind() != typ2.Kind() {
   149  		return false, fmt.Sprintf("field type not equal: %s, %s", _f(typ, field1), _f(typ, field2))
   150  	}
   151  	if typ1.Kind() == reflect.Ptr {
   152  		typ1 = typ1.Elem()
   153  		typ2 = typ2.Elem()
   154  	}
   155  	equal, diff := p.isEqualType(typ1, typ2)
   156  	if equal {
   157  		return true, ""
   158  	}
   159  	return false, fmt.Sprintf("field type not euqal: %s", diff)
   160  }
   161  
   162  func (p *typecmp) isEqualThriftField(typ reflect.Type, field1, field2 reflect.StructField) (bool, string) {
   163  	if field1.Name != field2.Name {
   164  		if !isEqualThriftTag(field1, field2) {
   165  			return false, fmt.Sprintf("field name and thrift tag both not equal: %s, %s", _f(typ, field1), _f(typ, field2))
   166  		}
   167  	}
   168  	if field1.Offset != field2.Offset {
   169  		return false, fmt.Sprintf("field offset not equal: %s, %s", _f(typ, field1), _f(typ, field2))
   170  	}
   171  
   172  	typ1 := field1.Type
   173  	typ2 := field2.Type
   174  	if typ1.Kind() != typ2.Kind() {
   175  		return false, fmt.Sprintf("field type not equal: %s, %s", _f(typ, field1), _f(typ, field2))
   176  	}
   177  	if typ1.Kind() == reflect.Ptr {
   178  		typ1 = typ1.Elem()
   179  		typ2 = typ2.Elem()
   180  	}
   181  	equal, diff := p.isEqualType(typ1, typ2)
   182  	if equal {
   183  		return true, ""
   184  	}
   185  	return false, fmt.Sprintf("field type not euqal: %s", diff)
   186  }
   187  
   188  func isEqualThriftTag(f1, f2 reflect.StructField) bool {
   189  	// Be compatible with standard thrift tag.
   190  	tag1 := f1.Tag.Get("thrift")
   191  	tag2 := f2.Tag.Get("thrift")
   192  	if tag1 != "" && tag2 != "" {
   193  		if tag1 == tag2 {
   194  			return true
   195  		}
   196  		parts1 := strings.Split(tag1, ",")
   197  		parts2 := strings.Split(tag2, ",")
   198  		if len(parts1) >= 2 && len(parts2) >= 2 &&
   199  			parts1[0] == parts2[0] && // parts[0] is the field's name
   200  			parts1[1] == parts2[1] { // parts[1] is the field's id number
   201  			return true
   202  		}
   203  	}
   204  	// Be compatible with github.com/cloudwego/frugal.
   205  	tag1 = f1.Tag.Get("frugal")
   206  	tag2 = f2.Tag.Get("frugal")
   207  	if tag1 != "" && tag2 != "" {
   208  		if tag1 == tag2 {
   209  			return true
   210  		}
   211  		parts1 := strings.Split(tag1, ",")
   212  		parts2 := strings.Split(tag2, ",")
   213  		if len(parts1) >= 3 && len(parts2) >= 3 &&
   214  			parts1[0] == parts2[0] && // parts[0] is the field's id number
   215  			parts1[2] == parts2[2] { // parts[2] is the field's type
   216  			return true
   217  		}
   218  	}
   219  	return false
   220  }
   221  
   222  func (p *typecmp) isEqualSlice(typ1, typ2 reflect.Type) (bool, string) {
   223  	if typ1.Kind() != reflect.Slice || typ2.Kind() != reflect.Slice {
   224  		return false, fmt.Sprintf("type is not slice: %s, %s", _t(typ1), _t(typ2))
   225  	}
   226  	typ1 = typ1.Elem()
   227  	typ2 = typ2.Elem()
   228  	return p.isEqualType(typ1, typ2)
   229  }
   230  
   231  func (p *typecmp) isEqualMap(typ1, typ2 reflect.Type) (bool, string) {
   232  	if typ1.Kind() != reflect.Map || typ2.Kind() != reflect.Map {
   233  		return false, fmt.Sprintf("type is not map: %s, %s", _t(typ1), _t(typ2))
   234  	}
   235  
   236  	keyTyp1 := typ1.Key()
   237  	keyTyp2 := typ2.Key()
   238  	if equal, diff := p.isEqualType(keyTyp1, keyTyp2); !equal {
   239  		return false, fmt.Sprintf("map key: %s", diff)
   240  	}
   241  
   242  	elemTyp1 := typ1.Elem()
   243  	elemTyp2 := typ2.Elem()
   244  	if equal, diff := p.isEqualType(elemTyp1, elemTyp2); !equal {
   245  		return false, fmt.Sprintf("map value: %s", diff)
   246  	}
   247  
   248  	return true, ""
   249  }
   250  
   251  func _t(typ reflect.Type) string {
   252  	return fmt.Sprintf("%s.%s", typ.PkgPath(), typ.Name())
   253  }
   254  
   255  func _f(typ reflect.Type, field reflect.StructField) string {
   256  	return fmt.Sprintf("%s.%s", _t(typ), field.Name)
   257  }