go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/cv/internal/cvtesting/saferesemble.go (about)

     1  // Copyright 2021 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package cvtesting
    16  
    17  import (
    18  	"encoding/json"
    19  	"fmt"
    20  	"reflect"
    21  	"strings"
    22  	"unsafe"
    23  
    24  	"google.golang.org/protobuf/proto"
    25  
    26  	"go.chromium.org/luci/common/testing/assertions"
    27  
    28  	. "github.com/smartystreets/goconvey/convey"
    29  )
    30  
    31  // SafeShouldResemble compares 2 structs recursively, which may include proto
    32  // fields.
    33  //
    34  // Inner struct or slice of struct property is okay, but should not include
    35  // any pointer to struct or slice of pointer to struct property.
    36  //
    37  // This should work in place of GoConvey's ShouldResemble on most of CV's
    38  // structs which may contain protos.
    39  func SafeShouldResemble(actual any, expected ...any) string {
    40  	if len(expected) != 1 {
    41  		return fmt.Sprintf("expected 1 value, got %d", len(expected))
    42  	}
    43  	if diff := ShouldHaveSameTypeAs(actual, expected[0]); diff != "" {
    44  		return diff
    45  	}
    46  
    47  	rA, rE := reflect.ValueOf(actual), reflect.ValueOf(expected[0])
    48  	switch rA.Kind() {
    49  	case reflect.Struct:
    50  	case reflect.Ptr:
    51  		switch {
    52  		case rA.IsNil() && rE.IsNil():
    53  			return ""
    54  		case !rA.IsNil() && rE.IsNil():
    55  			return "actual is not nil, but nil is expected"
    56  		case rA.IsNil() && !rE.IsNil():
    57  			return "actual is nil, but not nil is expected"
    58  		}
    59  
    60  		if rA.Elem().Kind() == reflect.Struct {
    61  			// Neither is nil at this point, so can dereference both.
    62  			rA, rE = rA.Elem(), rE.Elem()
    63  			break
    64  		}
    65  		fallthrough
    66  	default:
    67  		return fmt.Sprintf("Wrong type %T, must be a pointer to struct or a struct", actual)
    68  	}
    69  
    70  	// Copy the *values* before passing to `compareStructRecursive` because
    71  	// it will reset proto fields to `nil`.
    72  	typ := rA.Type()
    73  	copyA, copyE := reflect.New(typ).Elem(), reflect.New(typ).Elem()
    74  	copyA.Set(rA) // shallow-copy
    75  	copyE.Set(rE) // shallow-copy
    76  	// Because GoConvey's ShouldResemble may hang when comparing protos,
    77  	// first compare proto fields with ShouldResembleProto,
    78  	// then compare the remaining fields with ShouldResemble.
    79  	buf := &strings.Builder{}
    80  	p := &protoFieldsComparator{
    81  		actual:   copyA,
    82  		expected: copyE,
    83  		diffBuf:  buf,
    84  	}
    85  	p.compareRecursiveAndNilify()
    86  	// OK, now compare all non-proto fields.
    87  	if diff := ShouldResemble(copyA.Interface(), copyE.Interface()); diff != "" {
    88  		buf.WriteRune('\n')
    89  		addWithIndent(buf, "non-proto fields differ:\n", poorifyIfConveyJSON(diff))
    90  	}
    91  	return strings.TrimSpace(buf.String())
    92  }
    93  
    94  type protoFieldsComparator struct {
    95  	actual, expected reflect.Value
    96  
    97  	parentFields []string
    98  	diffBuf      *strings.Builder
    99  }
   100  
   101  var protoMessageType = reflect.TypeOf((*proto.Message)(nil)).Elem()
   102  
   103  func (p *protoFieldsComparator) compareRecursiveAndNilify() {
   104  	structType := p.actual.Type()
   105  	for i := 0; i < structType.NumField(); i++ {
   106  		field := structType.Field(i)
   107  		fieldKind := field.Type.Kind()
   108  		fA, fE := p.actual.Field(i), p.expected.Field(i)
   109  		fullPath := strings.Join(append(p.parentFields, field.Name), ".")
   110  		switch {
   111  		case field.Type.Implements(protoMessageType):
   112  			fallthrough
   113  		case fieldKind == reflect.Slice && field.Type.Elem().Implements(protoMessageType):
   114  			switch {
   115  			case fA.CanInterface() != fE.CanInterface():
   116  				panic(fmt.Errorf("type %s field %s CanInterface behaves differently in actual (%t) and expected (%t)",
   117  					structType.Name(), fullPath, fA.CanInterface(), fE.CanInterface()))
   118  			case fA.CanSet() != fE.CanSet():
   119  				panic(fmt.Errorf("type %s field %s CanSet behaves differently in actual (%t) and expected (%t)",
   120  					structType.Name(), fullPath, fA.CanSet(), fE.CanSet()))
   121  			case !fA.CanInterface():
   122  				// HACK to make private fields interface-able.
   123  				fA = reflect.NewAt(field.Type, unsafe.Pointer(fA.UnsafeAddr())).Elem()
   124  				fE = reflect.NewAt(field.Type, unsafe.Pointer(fE.UnsafeAddr())).Elem()
   125  			}
   126  			if diff := assertions.ShouldResembleProto(fA.Interface(), fE.Interface()); diff != "" {
   127  				addWithIndent(p.diffBuf, "field ."+fullPath+" differs:\n", poorifyIfConveyJSON(diff))
   128  			}
   129  			// Reset proto field to nil.
   130  			zeroOutValue(fA)
   131  			zeroOutValue(fE)
   132  		case fieldKind == reflect.Struct:
   133  			p := &protoFieldsComparator{
   134  				actual:       fA,
   135  				expected:     fE,
   136  				diffBuf:      p.diffBuf,
   137  				parentFields: append(p.parentFields, field.Name),
   138  			}
   139  			p.compareRecursiveAndNilify()
   140  		case fieldKind == reflect.Slice && field.Type.Elem().Kind() == reflect.Struct:
   141  			if fA.Len() != fE.Len() {
   142  				addWithIndent(p.diffBuf, "field ."+fullPath+" differs in length:\n", fmt.Sprintf("expected %d, got %d", fE.Len(), fA.Len()))
   143  				// the element may contain proto fields inside.
   144  				zeroOutValue(fA)
   145  				zeroOutValue(fE)
   146  			} else {
   147  				for i := 0; i < fA.Len(); i++ {
   148  					p := &protoFieldsComparator{
   149  						actual:       fA.Index(i),
   150  						expected:     fE.Index(i),
   151  						diffBuf:      p.diffBuf,
   152  						parentFields: append(p.parentFields, fmt.Sprintf("%s[%d]", field.Name, i)),
   153  					}
   154  					p.compareRecursiveAndNilify()
   155  				}
   156  			}
   157  		default:
   158  			// In practice, this can be a ptr to a struct with a proto inside.
   159  			// Detecting and bailing in such a case is left as future work if it
   160  			// becomes really necessary.
   161  		}
   162  	}
   163  }
   164  
   165  func zeroOutValue(val reflect.Value) {
   166  	valType := val.Type()
   167  	if !val.CanSet() {
   168  		// HACK to workaround setting private fields.
   169  		val = reflect.NewAt(val.Type(), unsafe.Pointer(val.UnsafeAddr())).Elem()
   170  	}
   171  	val.Set(reflect.New(valType).Elem())
   172  }
   173  
   174  func addWithIndent(buf *strings.Builder, section, text string) {
   175  	buf.WriteString(section)
   176  	for _, line := range strings.Split(text, "\n") {
   177  		buf.WriteString("  ")
   178  		buf.WriteString(line)
   179  		buf.WriteRune('\n')
   180  	}
   181  	buf.WriteRune('\n')
   182  }
   183  
   184  // poorifyIfConveyJSON detects rich Convey JSON mascarading as string and
   185  // returns only its "poor" component.
   186  //
   187  // Depending on Convey's config, ShouldResemble-like response may be JSON
   188  // in the following format:
   189  //
   190  //	 {
   191  //		 "Message": ...
   192  //		 "Actual": ...
   193  //		 "Expected": ...
   194  //	 }
   195  //
   196  // If so, we want just the value of the "Message" part.
   197  func poorifyIfConveyJSON(msg string) string {
   198  	out := map[string]any{}
   199  	if err := json.Unmarshal([]byte(msg), &out); err == nil {
   200  		return out["Message"].(string)
   201  	}
   202  	return msg
   203  }