github.com/m3db/m3@v1.5.0/src/dbnode/encoding/proto/equal.go (about)

     1  // Copyright (c) 2019 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package proto
    22  
    23  import (
    24  	"bytes"
    25  	"fmt"
    26  	"reflect"
    27  
    28  	"github.com/golang/protobuf/proto"
    29  	dpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
    30  	"github.com/jhump/protoreflect/desc"
    31  	"github.com/jhump/protoreflect/dynamic"
    32  )
    33  
    34  // isDefaultValue returns whether the provided value is the same as the default value for
    35  // a given field. For the most part we can rely on the fieldsEqual function and the
    36  // GetDefaultValue() method on the field descriptor, but repeated, map and nested message
    37  // fields require slightly more care.
    38  func isDefaultValue(field *desc.FieldDescriptor, curVal interface{}) (bool, error) {
    39  	if field.IsMap() {
    40  		// If its a repeated field then its a default value if it looks like a zero-length slice.
    41  		mapVal, ok := curVal.(map[interface{}]interface{})
    42  		if !ok {
    43  			// Should never happen.
    44  			return false, fmt.Errorf("current value for repeated field: %s wasn't a slice", field.String())
    45  		}
    46  
    47  		return len(mapVal) == 0, nil
    48  	}
    49  
    50  	if field.IsRepeated() {
    51  		// If its a repeated field then its a default value if it looks like a zero-length slice.
    52  		sliceVal, ok := curVal.([]interface{})
    53  		if !ok {
    54  			// Should never happen.
    55  			return false, fmt.Errorf("current value for repeated field: %s wasn't a slice", field.String())
    56  		}
    57  
    58  		return len(sliceVal) == 0, nil
    59  	}
    60  
    61  	if field.GetType() == dpb.FieldDescriptorProto_TYPE_MESSAGE {
    62  		// If its a nested message then its a default value if it looks the same as a new
    63  		// empty message with the same schema.
    64  		messageSchema := field.GetMessageType()
    65  		// TODO(rartoul): Don't allocate new message.
    66  		return fieldsEqual(dynamic.NewMessage(messageSchema), curVal), nil
    67  	}
    68  
    69  	return fieldsEqual(field.GetDefaultValue(), curVal), nil
    70  }
    71  
    72  // Mostly copy-pasta of a non-exported helper method from the protoreflect
    73  // library.
    74  // https://github.com/jhump/protoreflect/blob/87f824e0b908132b2501fe5652f8ee75a2e8cf06/dynamic/equal.go#L60
    75  func fieldsEqual(aVal, bVal interface{}) bool {
    76  	// Handle nil cases first since reflect.ValueOf will not handle untyped
    77  	// nils gracefully.
    78  	if aVal == nil && bVal == nil {
    79  		return true
    80  	}
    81  	if aVal == nil || bVal == nil {
    82  		return false
    83  	}
    84  
    85  	arv := reflect.ValueOf(aVal)
    86  	brv := reflect.ValueOf(bVal)
    87  	if arv.Type() != brv.Type() {
    88  		// it is possible that one is a dynamic message and one is not
    89  		apm, ok := aVal.(proto.Message)
    90  		if !ok {
    91  			return false
    92  		}
    93  		bpm, ok := bVal.(proto.Message)
    94  		if !ok {
    95  			return false
    96  		}
    97  		if !dynamic.MessagesEqual(apm, bpm) {
    98  			return false
    99  		}
   100  	} else {
   101  		switch arv.Kind() {
   102  		case reflect.Ptr:
   103  			apm, ok := aVal.(proto.Message)
   104  			if !ok {
   105  				// Don't know how to compare pointer values that aren't messages!
   106  				// Maybe this should panic?
   107  				return false
   108  			}
   109  			bpm := bVal.(proto.Message) // we know it will succeed because we know a and b have same type
   110  			if !dynamic.MessagesEqual(apm, bpm) {
   111  				return false
   112  			}
   113  		case reflect.Map:
   114  			if !mapsEqual(arv, brv) {
   115  				return false
   116  			}
   117  
   118  		case reflect.Slice:
   119  			if arv.Type() == typeOfBytes {
   120  				if !bytes.Equal(aVal.([]byte), bVal.([]byte)) {
   121  					return false
   122  				}
   123  			} else {
   124  				if !slicesEqual(arv, brv) {
   125  					return false
   126  				}
   127  			}
   128  
   129  		default:
   130  			if aVal != bVal {
   131  				return false
   132  			}
   133  		}
   134  	}
   135  
   136  	return true
   137  }
   138  
   139  func mapsEqual(a, b reflect.Value) bool {
   140  	if a.Len() != b.Len() {
   141  		return false
   142  	}
   143  
   144  	if a.Len() == 0 && b.Len() == 0 {
   145  		// Optimize the case where maps are frequently empty because MapKeys()
   146  		// function allocates heavily.
   147  		return true
   148  	}
   149  
   150  	for _, k := range a.MapKeys() {
   151  		av := a.MapIndex(k)
   152  		bv := b.MapIndex(k)
   153  		if !bv.IsValid() {
   154  			return false
   155  		}
   156  		if !fieldsEqual(av.Interface(), bv.Interface()) {
   157  			return false
   158  		}
   159  	}
   160  	return true
   161  }
   162  
   163  func slicesEqual(a, b reflect.Value) bool {
   164  	if a.Len() != b.Len() {
   165  		return false
   166  	}
   167  	for i := 0; i < a.Len(); i++ {
   168  		ai := a.Index(i)
   169  		bi := b.Index(i)
   170  		if !fieldsEqual(ai.Interface(), bi.Interface()) {
   171  			return false
   172  		}
   173  	}
   174  	return true
   175  }