go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/common/proto/stable_hash.go (about)

     1  // Copyright 2023 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package proto
    16  
    17  import (
    18  	"encoding/binary"
    19  	"errors"
    20  	"fmt"
    21  	"hash"
    22  	"math"
    23  	"sort"
    24  
    25  	"google.golang.org/protobuf/proto"
    26  	"google.golang.org/protobuf/reflect/protoreflect"
    27  	"google.golang.org/protobuf/types/known/anypb"
    28  )
    29  
    30  const anyName = "google.protobuf.Any"
    31  
    32  // StableHash does a deterministic and ordered walk over the proto message `m`,
    33  // feeding each field into `h`.
    34  //
    35  // This is useful to produce stable, deterministic hashes of proto messages
    36  // where comparison of messages generated from different sources or runtimes is
    37  // important.
    38  //
    39  // Because of this, `m` may not contain any unknown fields, since there is no
    40  // way to canonicalize them without the message definition. If the message has
    41  // any unknown fields, this function returns an error.
    42  // NOTE:
    43  // - The hash value can ONLY be used for backward compatible protobuf change to
    44  // determine if the protobuf message is different. If two protobuf specs are
    45  // incompatible, their value MUST NOT be compared with each other and not
    46  // promised to be different even their messages are clearly different.
    47  // - google.protobuf.Any is supported via Any.UnmarshalNew: if the message is
    48  // not registered, this returns an error.
    49  // - 0-valued scalar fields are not distinguished from absent fields.
    50  func StableHash(h hash.Hash, m proto.Message) error {
    51  	return hashMessage(h, m.ProtoReflect())
    52  }
    53  
    54  func hashValue(h hash.Hash, v protoreflect.Value) error {
    55  	switch v := v.Interface().(type) {
    56  	case int32:
    57  		return hashNumber(h, uint64(v))
    58  	case int64:
    59  		return hashNumber(h, uint64(v))
    60  	case uint32:
    61  		return hashNumber(h, uint64(v))
    62  	case uint64:
    63  		return hashNumber(h, v)
    64  	case float32:
    65  		return hashNumber(h, uint64(math.Float32bits(v)))
    66  	case float64:
    67  		return hashNumber(h, math.Float64bits(v))
    68  	case string:
    69  		return hashBytes(h, []byte(v))
    70  	case []byte:
    71  		return hashBytes(h, v)
    72  	case protoreflect.EnumNumber:
    73  		return hashNumber(h, uint64(v))
    74  	case protoreflect.Message:
    75  		return hashMessage(h, v)
    76  	case protoreflect.List:
    77  		return hashList(h, v)
    78  	case protoreflect.Map:
    79  		return hashMap(h, v)
    80  	case bool:
    81  		var b uint64
    82  		if v {
    83  			b = 1
    84  		}
    85  		return hashNumber(h, b)
    86  	default:
    87  		return fmt.Errorf("unknown type: %T", v)
    88  	}
    89  }
    90  
    91  func hashMessage(h hash.Hash, m protoreflect.Message) error {
    92  	if m.Descriptor().FullName() == anyName {
    93  		a, err := m.Interface().(*anypb.Any).UnmarshalNew()
    94  		if err != nil {
    95  			return err
    96  		}
    97  		return hashMessage(h, a.ProtoReflect())
    98  	}
    99  	if m.GetUnknown() != nil {
   100  		return fmt.Errorf("unknown fields cannot be hashed")
   101  	}
   102  	// Collect a sorted list of populated message fields.
   103  	var fds []protoreflect.FieldDescriptor
   104  	m.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
   105  		fds = append(fds, fd)
   106  		return true
   107  	})
   108  	sort.Slice(fds, func(i, j int) bool { return fds[i].Number() < fds[j].Number() })
   109  	// Iterate over message fields.
   110  	for _, fd := range fds {
   111  		if err := hashNumber(h, uint64(fd.Number())); err != nil {
   112  			return err
   113  		}
   114  		if err := hashValue(h, m.Get(fd)); err != nil {
   115  			return err
   116  		}
   117  	}
   118  	return nil
   119  }
   120  
   121  func hashList(h hash.Hash, lv protoreflect.List) error {
   122  	if err := hashNumber(h, uint64(lv.Len())); err != nil {
   123  		return err
   124  	}
   125  	for i := 0; i < lv.Len(); i++ {
   126  		if err := hashValue(h, lv.Get(i)); err != nil {
   127  			return err
   128  		}
   129  	}
   130  	return nil
   131  }
   132  
   133  func hashMap(h hash.Hash, mv protoreflect.Map) error {
   134  	if err := hashNumber(h, uint64(mv.Len())); err != nil {
   135  		return err
   136  	}
   137  	if mv.Len() == 0 {
   138  		return nil
   139  	}
   140  
   141  	// Collect a sorted list of populated map entries.
   142  	var ks []protoreflect.MapKey
   143  	mv.Range(func(k protoreflect.MapKey, _ protoreflect.Value) bool {
   144  		ks = append(ks, k)
   145  		return true
   146  	})
   147  
   148  	var sortFn func(i, j int) bool
   149  	switch ks[0].Interface().(type) {
   150  	case bool:
   151  		sortFn = func(i, j int) bool { return !ks[i].Bool() && ks[j].Bool() }
   152  	case int32, int64:
   153  		sortFn = func(i, j int) bool { return ks[i].Int() < ks[j].Int() }
   154  	case uint32, uint64:
   155  		sortFn = func(i, j int) bool { return ks[i].Uint() < ks[j].Uint() }
   156  	case string:
   157  		sortFn = func(i, j int) bool { return ks[i].String() < ks[j].String() }
   158  	default:
   159  		return errors.New("invalid map key type")
   160  	}
   161  	sort.Slice(ks, sortFn)
   162  
   163  	// Iterate over map entries.
   164  	for _, k := range ks {
   165  		if err := hashValue(h, k.Value()); err != nil {
   166  			return err
   167  		}
   168  		if err := hashValue(h, mv.Get(k)); err != nil {
   169  			return err
   170  		}
   171  	}
   172  	return nil
   173  }
   174  
   175  func hashNumber(h hash.Hash, v uint64) error {
   176  	var b [8]byte
   177  	binary.LittleEndian.PutUint64(b[:], v)
   178  	if _, err := h.Write(b[:]); err != nil {
   179  		return err
   180  	}
   181  	return nil
   182  }
   183  
   184  func hashBytes(h hash.Hash, v []byte) error {
   185  	if err := hashNumber(h, uint64(len(v))); err != nil {
   186  		return err
   187  	}
   188  	if _, err := h.Write([]byte(v)); err != nil {
   189  		return err
   190  	}
   191  	return nil
   192  }