github.com/segmentio/kafka-go@v0.4.48-0.20240318174348-3f6244eb34fd/protocol/prototest/prototest.go (about)

     1  package prototest
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"io"
     7  	"reflect"
     8  	"time"
     9  
    10  	"github.com/segmentio/kafka-go/protocol"
    11  )
    12  
    13  func deepEqual(x1, x2 interface{}) bool {
    14  	if x1 == nil {
    15  		return x2 == nil
    16  	}
    17  	if r1, ok := x1.(protocol.RecordReader); ok {
    18  		if r2, ok := x2.(protocol.RecordReader); ok {
    19  			return deepEqualRecords(r1, r2)
    20  		}
    21  		return false
    22  	}
    23  	if b1, ok := x1.(protocol.Bytes); ok {
    24  		if b2, ok := x2.(protocol.Bytes); ok {
    25  			return deepEqualBytes(b1, b2)
    26  		}
    27  		return false
    28  	}
    29  	if t1, ok := x1.(time.Time); ok {
    30  		if t2, ok := x2.(time.Time); ok {
    31  			return t1.Equal(t2)
    32  		}
    33  		return false
    34  	}
    35  	return deepEqualValue(reflect.ValueOf(x1), reflect.ValueOf(x2))
    36  }
    37  
    38  func deepEqualValue(v1, v2 reflect.Value) bool {
    39  	t1 := v1.Type()
    40  	t2 := v2.Type()
    41  
    42  	if t1 != t2 {
    43  		return false
    44  	}
    45  
    46  	switch v1.Kind() {
    47  	case reflect.Bool:
    48  		return v1.Bool() == v2.Bool()
    49  	case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
    50  		return v1.Int() == v2.Int()
    51  	case reflect.Float64:
    52  		return v1.Float() == v2.Float()
    53  	case reflect.String:
    54  		return v1.String() == v2.String()
    55  	case reflect.Struct:
    56  		return deepEqualStruct(v1, v2)
    57  	case reflect.Ptr:
    58  		return deepEqualPtr(v1, v2)
    59  	case reflect.Slice:
    60  		return deepEqualSlice(v1, v2)
    61  	default:
    62  		panic("comparing values of unsupported type: " + v1.Type().String())
    63  	}
    64  }
    65  
    66  func deepEqualPtr(v1, v2 reflect.Value) bool {
    67  	if v1.IsNil() {
    68  		return v2.IsNil()
    69  	}
    70  	return deepEqual(v1.Elem().Interface(), v2.Elem().Interface())
    71  }
    72  
    73  func deepEqualStruct(v1, v2 reflect.Value) bool {
    74  	t := v1.Type()
    75  	n := t.NumField()
    76  
    77  	for i := 0; i < n; i++ {
    78  		f := t.Field(i)
    79  
    80  		if f.PkgPath != "" { // ignore unexported fields
    81  			continue
    82  		}
    83  
    84  		f1 := v1.Field(i)
    85  		f2 := v2.Field(i)
    86  
    87  		if !deepEqual(f1.Interface(), f2.Interface()) {
    88  			return false
    89  		}
    90  	}
    91  
    92  	return true
    93  }
    94  
    95  func deepEqualSlice(v1, v2 reflect.Value) bool {
    96  	t := v1.Type()
    97  	e := t.Elem()
    98  
    99  	if e.Kind() == reflect.Uint8 { // []byte
   100  		return bytes.Equal(v1.Bytes(), v2.Bytes())
   101  	}
   102  
   103  	n1 := v1.Len()
   104  	n2 := v2.Len()
   105  
   106  	if n1 != n2 {
   107  		return false
   108  	}
   109  
   110  	for i := 0; i < n1; i++ {
   111  		f1 := v1.Index(i)
   112  		f2 := v2.Index(i)
   113  
   114  		if !deepEqual(f1.Interface(), f2.Interface()) {
   115  			return false
   116  		}
   117  	}
   118  
   119  	return true
   120  }
   121  
   122  func deepEqualBytes(s1, s2 protocol.Bytes) bool {
   123  	if s1 == nil {
   124  		return s2 == nil
   125  	}
   126  
   127  	if s2 == nil {
   128  		return false
   129  	}
   130  
   131  	n1 := s1.Len()
   132  	n2 := s2.Len()
   133  
   134  	if n1 != n2 {
   135  		return false
   136  	}
   137  
   138  	b1 := make([]byte, n1)
   139  	b2 := make([]byte, n2)
   140  
   141  	if _, err := s1.(io.ReaderAt).ReadAt(b1, 0); err != nil {
   142  		panic(err)
   143  	}
   144  
   145  	if _, err := s2.(io.ReaderAt).ReadAt(b2, 0); err != nil {
   146  		panic(err)
   147  	}
   148  
   149  	return bytes.Equal(b1, b2)
   150  }
   151  
   152  func deepEqualRecords(r1, r2 protocol.RecordReader) bool {
   153  	for {
   154  		rec1, err1 := r1.ReadRecord()
   155  		rec2, err2 := r2.ReadRecord()
   156  
   157  		if err1 != nil || err2 != nil {
   158  			return errors.Is(err1, err2)
   159  		}
   160  
   161  		if !deepEqualRecord(rec1, rec2) {
   162  			return false
   163  		}
   164  	}
   165  }
   166  
   167  func deepEqualRecord(r1, r2 *protocol.Record) bool {
   168  	if r1.Offset != r2.Offset {
   169  		return false
   170  	}
   171  
   172  	if !r1.Time.Equal(r2.Time) {
   173  		return false
   174  	}
   175  
   176  	if !deepEqualBytes(r1.Key, r2.Key) {
   177  		return false
   178  	}
   179  
   180  	if !deepEqualBytes(r1.Value, r2.Value) {
   181  		return false
   182  	}
   183  
   184  	return deepEqual(r1.Headers, r2.Headers)
   185  }