github.com/cilium/ebpf@v0.16.0/internal/testutils/checkers.go (about)

     1  package testutils
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"reflect"
     7  
     8  	"github.com/go-quicktest/qt"
     9  )
    10  
    11  // IsDeepCopy checks that got is a deep copy of want.
    12  //
    13  // All primitive values must be equal, but pointers must be distinct.
    14  // This is different from [reflect.DeepEqual] which will accept equal pointer values.
    15  // That is, reflect.DeepEqual(a, a) is true, while IsDeepCopy(a, a) is false.
    16  func IsDeepCopy[T any](got, want T) qt.Checker {
    17  	return &deepCopyChecker[T]{got, want, make(map[pair]struct{})}
    18  }
    19  
    20  type pair struct {
    21  	got, want reflect.Value
    22  }
    23  
    24  type deepCopyChecker[T any] struct {
    25  	got, want T
    26  	visited   map[pair]struct{}
    27  }
    28  
    29  func (dcc *deepCopyChecker[T]) Check(_ func(key string, value any)) error {
    30  	return dcc.check(reflect.ValueOf(dcc.got), reflect.ValueOf(dcc.want))
    31  }
    32  
    33  func (dcc *deepCopyChecker[T]) check(got, want reflect.Value) error {
    34  	switch want.Kind() {
    35  	case reflect.Interface:
    36  		return dcc.check(got.Elem(), want.Elem())
    37  
    38  	case reflect.Pointer:
    39  		if got.IsNil() && want.IsNil() {
    40  			return nil
    41  		}
    42  
    43  		if got.IsNil() {
    44  			return fmt.Errorf("expected non-nil pointer")
    45  		}
    46  
    47  		if want.IsNil() {
    48  			return fmt.Errorf("expected nil pointer")
    49  		}
    50  
    51  		if got.UnsafePointer() == want.UnsafePointer() {
    52  			return fmt.Errorf("equal pointer values")
    53  		}
    54  
    55  		switch want.Type() {
    56  		case reflect.TypeOf((*bytes.Reader)(nil)):
    57  			// bytes.Reader doesn't allow modifying it's contents, so we
    58  			// allow a shallow copy.
    59  			return nil
    60  		}
    61  
    62  		if _, ok := dcc.visited[pair{got, want}]; ok {
    63  			// Deal with recursive types.
    64  			return nil
    65  		}
    66  
    67  		dcc.visited[pair{got, want}] = struct{}{}
    68  		return dcc.check(got.Elem(), want.Elem())
    69  
    70  	case reflect.Slice:
    71  		if got.IsNil() && want.IsNil() {
    72  			return nil
    73  		}
    74  
    75  		if got.IsNil() {
    76  			return fmt.Errorf("expected non-nil slice")
    77  		}
    78  
    79  		if want.IsNil() {
    80  			return fmt.Errorf("expected nil slice")
    81  		}
    82  
    83  		if got.Len() != want.Len() {
    84  			return fmt.Errorf("expected %d elements, got %d", want.Len(), got.Len())
    85  		}
    86  
    87  		if want.Len() == 0 {
    88  			return nil
    89  		}
    90  
    91  		if got.UnsafePointer() == want.UnsafePointer() {
    92  			return fmt.Errorf("equal backing memory")
    93  		}
    94  
    95  		fallthrough
    96  
    97  	case reflect.Array:
    98  		for i := 0; i < want.Len(); i++ {
    99  			if err := dcc.check(got.Index(i), want.Index(i)); err != nil {
   100  				return fmt.Errorf("index %d: %w", i, err)
   101  			}
   102  		}
   103  
   104  		return nil
   105  
   106  	case reflect.Struct:
   107  		for i := 0; i < want.NumField(); i++ {
   108  			if err := dcc.check(got.Field(i), want.Field(i)); err != nil {
   109  				return fmt.Errorf("%q: %w", want.Type().Field(i).Name, err)
   110  			}
   111  		}
   112  
   113  		return nil
   114  
   115  	case reflect.Map:
   116  		if got.Len() != want.Len() {
   117  			return fmt.Errorf("expected %d items, got %d", want.Len(), got.Len())
   118  		}
   119  
   120  		if got.UnsafePointer() == want.UnsafePointer() {
   121  			return fmt.Errorf("maps are equal")
   122  		}
   123  
   124  		iter := want.MapRange()
   125  		for iter.Next() {
   126  			key := iter.Key()
   127  			got := got.MapIndex(iter.Key())
   128  			if !got.IsValid() {
   129  				return fmt.Errorf("key %v is missing", key)
   130  			}
   131  
   132  			want := iter.Value()
   133  			if err := dcc.check(got, want); err != nil {
   134  				return fmt.Errorf("key %v: %w", key, err)
   135  			}
   136  		}
   137  
   138  		return nil
   139  
   140  	case reflect.Chan, reflect.UnsafePointer:
   141  		return fmt.Errorf("%s is not supported", want.Type())
   142  
   143  	default:
   144  		// Compare by value as usual.
   145  		if !got.Equal(want) {
   146  			return fmt.Errorf("%#v is not equal to %#v", got, want)
   147  		}
   148  
   149  		return nil
   150  	}
   151  }
   152  
   153  func (dcc *deepCopyChecker[T]) Args() []qt.Arg {
   154  	return []qt.Arg{
   155  		{Name: "got", Value: dcc.got},
   156  		{Name: "want", Value: dcc.want},
   157  	}
   158  }