go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/common/proto/msgpackpb/deteministic.go (about)

     1  // Copyright 2022 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package msgpackpb
    16  
    17  import (
    18  	"bytes"
    19  	"reflect"
    20  	"sort"
    21  
    22  	"github.com/vmihailenco/msgpack/v5"
    23  	"go.chromium.org/luci/common/errors"
    24  )
    25  
    26  func sortedKeys(mapValue reflect.Value) (keys []reflect.Value, arrayLike bool, err error) {
    27  	keys = mapValue.MapKeys()
    28  
    29  	var sortFn func(i, j int) bool
    30  	checkArray := func() {}
    31  	switch mapValue.Type().Key().Kind() {
    32  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
    33  		sortFn = func(i, j int) bool { return keys[i].Uint() < keys[j].Uint() }
    34  		checkArray = func() {
    35  			if keys[0].Uint() == 1 && keys[len(keys)-1].Uint() == uint64(len(keys)) {
    36  				arrayLike = true
    37  			}
    38  		}
    39  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
    40  		sortFn = func(i, j int) bool { return keys[i].Int() < keys[j].Int() }
    41  		checkArray = func() {
    42  			if keys[0].Int() == 1 && keys[len(keys)-1].Int() == int64(len(keys)) {
    43  				arrayLike = true
    44  			}
    45  		}
    46  	case reflect.String:
    47  		sortFn = func(i, j int) bool { return keys[i].String() < keys[j].String() }
    48  	case reflect.Bool:
    49  		sortFn = func(i, j int) bool {
    50  			a, b := keys[i].Bool(), keys[j].Bool()
    51  			return !a && b
    52  		}
    53  	default:
    54  		err = errors.Reason("cannot sort keys of type %s", keys[0].Type()).Err()
    55  		return
    56  	}
    57  
    58  	if len(keys) > 1 {
    59  		sort.Slice(keys, sortFn)
    60  		checkArray()
    61  	}
    62  
    63  	return
    64  }
    65  
    66  // Unfortunately, the Go msgpack doesn't support deterministic map encoding for
    67  // all map key types :(.
    68  //
    69  // Fortunately, such an encoding for the subset of msgpack we use is relatively
    70  // easy.
    71  func msgpackpbDeterministicEncode(val reflect.Value) (msgpack.RawMessage, error) {
    72  	buf := bytes.Buffer{}
    73  	enc := msgpack.GetEncoder()
    74  	enc.Reset(&buf)
    75  	enc.UseCompactInts(true)
    76  	enc.UseCompactFloats(true)
    77  
    78  	must := func(err error) {
    79  		if err != nil {
    80  			panic(err)
    81  		}
    82  	}
    83  
    84  	var process func(val reflect.Value) error
    85  
    86  	process = func(val reflect.Value) error {
    87  		if val.Kind() == reflect.Interface && !val.IsNil() {
    88  			val = val.Elem()
    89  		}
    90  
    91  		if val.Kind() == reflect.Slice {
    92  			sliceLen := val.Len()
    93  			must(enc.EncodeArrayLen(sliceLen))
    94  			for i := 0; i < sliceLen; i++ {
    95  				if err := process(val.Index(i)); err != nil {
    96  					return err
    97  				}
    98  			}
    99  			return nil
   100  		}
   101  
   102  		if val.Kind() == reflect.Map {
   103  			keys, arrayLike, err := sortedKeys(val)
   104  			if err != nil {
   105  				return err
   106  			}
   107  			if arrayLike {
   108  				must(enc.EncodeArrayLen(len(keys)))
   109  			} else {
   110  				must(enc.EncodeMapLen(len(keys)))
   111  			}
   112  			for _, k := range keys {
   113  				if !arrayLike {
   114  					if err := enc.Encode(k.Interface()); err != nil {
   115  						return err
   116  					}
   117  				}
   118  				if err := process(val.MapIndex(k)); err != nil {
   119  					return err
   120  				}
   121  			}
   122  			return nil
   123  		}
   124  
   125  		return enc.Encode(val.Interface())
   126  	}
   127  
   128  	if err := process(val); err != nil {
   129  		return nil, err
   130  	}
   131  	return buf.Bytes(), nil
   132  }