go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/common/proto/msgpackpb/marshal.go (about)

     1  // Copyright 2022 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package msgpackpb
    16  
    17  import (
    18  	"bytes"
    19  	"io"
    20  	"math"
    21  	"reflect"
    22  	"sort"
    23  
    24  	"github.com/vmihailenco/msgpack/v5"
    25  	"go.chromium.org/luci/common/errors"
    26  	"go.chromium.org/luci/common/proto/reflectutil"
    27  	"google.golang.org/protobuf/proto"
    28  	"google.golang.org/protobuf/reflect/protoreflect"
    29  )
    30  
    31  // internal type for marshalMessage
    32  //
    33  // This maps a field number to a value for that field.
    34  //
    35  // A fieldVal can either be `known` (i.e. is defined in the proto)
    36  // or `unknown` (i.e. field was present in an unmarshaled msgpackpb
    37  // message). Exactly one of `(fd, v)` or `raw` will be set.
    38  type fieldVal struct {
    39  	n int32 // the proto field tag number
    40  
    41  	// set if field was `known`
    42  	fd protoreflect.FieldDescriptor
    43  	v  protoreflect.Value
    44  
    45  	// set if field was `unknown`
    46  	raw msgpack.RawMessage
    47  }
    48  
    49  func (o *options) marshalValue(enc *msgpack.Encoder, fd protoreflect.FieldDescriptor, val protoreflect.Value) error {
    50  	kind := fd.Kind()
    51  	if fd.IsMap() {
    52  		kind = fd.MapValue().Kind()
    53  	}
    54  
    55  	switch kind {
    56  	case protoreflect.BoolKind:
    57  		// note: this should only ever encode `true`, because proto range should
    58  		// skip it if it's false.
    59  		return enc.EncodeBool(val.Bool())
    60  
    61  	case protoreflect.Int32Kind, protoreflect.Int64Kind:
    62  		return enc.EncodeInt(val.Int())
    63  
    64  	case protoreflect.EnumKind:
    65  		return enc.EncodeInt(int64(val.Enum()))
    66  
    67  	case protoreflect.Uint32Kind, protoreflect.Uint64Kind:
    68  		return enc.EncodeUint(val.Uint())
    69  
    70  	case protoreflect.FloatKind:
    71  		// this mimics lua's handling of floats-containing-integers
    72  
    73  		// convert to float32 here is potentially lossy, so we do it before
    74  		// math.Floor. Conversion from float32 to float64 is NOT lossy.
    75  		f := float32(val.Float())
    76  		if math.Floor(float64(f)) == float64(f) {
    77  			return enc.EncodeInt(int64(f))
    78  		}
    79  		return enc.EncodeFloat32(f)
    80  
    81  	case protoreflect.DoubleKind:
    82  		// this mimics lua's handling of floats-containing-integers
    83  		f := val.Float()
    84  		if math.Floor(f) == f {
    85  			return enc.EncodeInt(int64(f))
    86  		}
    87  		return enc.EncodeFloat64(f)
    88  
    89  	case protoreflect.StringKind:
    90  		sVal := val.String()
    91  		if ival, ok := o.internMarshalTable[sVal]; ok {
    92  			return enc.EncodeUint(uint64(ival))
    93  		}
    94  
    95  		return enc.EncodeString(sVal)
    96  
    97  	case protoreflect.MessageKind:
    98  		return o.marshalMessage(enc, val.Message())
    99  	}
   100  	return errors.Reason("marshalValue: invalid kind %q", kind).Err()
   101  }
   102  
   103  func (o *options) appendRawMsgpackMsg(raw []byte, to *[]fieldVal, tf takenFields) error {
   104  	dec := msgpack.GetDecoder()
   105  	defer func() {
   106  		if dec != nil {
   107  			msgpack.PutDecoder(dec)
   108  		}
   109  	}()
   110  
   111  	dec.Reset(bytes.NewReader(raw))
   112  	dec.SetMapDecoder((*msgpack.Decoder).DecodeTypedMap)
   113  
   114  	msgItemLen, nextKey, err := getMapLen(dec)
   115  	if err != nil {
   116  		return errors.Annotate(err, "expected message length").Err()
   117  	}
   118  
   119  	for i := 0; i < msgItemLen; i++ {
   120  		tag, err := getNextMsgTag(dec, nextKey)
   121  		if err != nil {
   122  			return errors.Annotate(err, "reading message %d'th tag", i).Err()
   123  		}
   124  		if err = tf.add(tag); err != nil {
   125  			return errors.Annotate(err, "reading message %d'th tag", i).Err()
   126  		}
   127  
   128  		var rawVal msgpack.RawMessage
   129  		if o.deterministic {
   130  			var valI any
   131  			valI, err = dec.DecodeInterfaceLoose()
   132  			if err == nil {
   133  				rawVal, err = msgpackpbDeterministicEncode(reflect.ValueOf(valI))
   134  			}
   135  		} else {
   136  			rawVal, err = dec.DecodeRaw()
   137  		}
   138  		if err != nil {
   139  			return errors.Annotate(err, "reading message %d't field", i).Err()
   140  		}
   141  
   142  		*to = append(*to, fieldVal{
   143  			n:   tag,
   144  			raw: rawVal,
   145  		})
   146  	}
   147  
   148  	return nil
   149  }
   150  
   151  type takenFields map[int32]struct{}
   152  
   153  func (t takenFields) add(tag int32) error {
   154  	if tag == 0 {
   155  		return errors.New("invalid tag 0")
   156  	}
   157  
   158  	if _, ok := t[tag]; ok {
   159  		return errors.Reason("duplicate tag %d", tag).Err()
   160  	}
   161  	t[tag] = struct{}{}
   162  
   163  	return nil
   164  }
   165  
   166  func (o *options) marshalMessage(enc *msgpack.Encoder, msg protoreflect.Message) (err error) {
   167  	tf := takenFields{}
   168  	populatedFields := make([]fieldVal, 0, msg.Descriptor().Fields().Len())
   169  	msg.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
   170  		fv := fieldVal{fd: fd, v: v}
   171  		fv.n = int32(fd.Number())
   172  		if err := tf.add(fv.n); err != nil {
   173  			panic(errors.Annotate(err, "impossible").Err())
   174  		}
   175  		populatedFields = append(populatedFields, fv)
   176  		return true
   177  	})
   178  
   179  	unknownFieldsRaw := msg.GetUnknown()
   180  	if len(unknownFieldsRaw) > 0 {
   181  		if o.unknownFieldBehavior == disallowUnknownFields {
   182  			return errors.Reason("message has unknown fields").Err()
   183  		}
   184  
   185  		var uf UnknownFields
   186  		if err := proto.Unmarshal(unknownFieldsRaw, &uf); err != nil {
   187  			return errors.Reason("unmarshaling unknown msgpack fields").Err()
   188  		}
   189  		if len(uf.ProtoReflect().GetUnknown()) > 0 {
   190  			return errors.Reason("unknown non-msgpack fields unsupported").Err()
   191  		}
   192  
   193  		if o.unknownFieldBehavior == preserveUnknownFields {
   194  			if err := o.appendRawMsgpackMsg(uf.MsgpackpbData, &populatedFields, tf); err != nil {
   195  				return errors.Reason("parsing unknown fields").Err()
   196  			}
   197  		}
   198  	}
   199  
   200  	encodeLen := func() error {
   201  		return enc.EncodeMapLen(len(populatedFields))
   202  	}
   203  	encodeKey := func(fv *fieldVal) error {
   204  		return enc.EncodeInt(int64(fv.n))
   205  	}
   206  
   207  	if o.deterministic {
   208  		sort.Slice(populatedFields, func(i, j int) bool { return populatedFields[i].n < populatedFields[j].n })
   209  		count := int32(len(populatedFields))
   210  		if count > 0 && populatedFields[0].n == 1 && populatedFields[len(populatedFields)-1].n == count {
   211  			encodeLen = func() error {
   212  				return enc.EncodeArrayLen(int(count))
   213  			}
   214  			encodeKey = func(fv *fieldVal) error { return nil }
   215  		}
   216  	}
   217  
   218  	if err := encodeLen(); err != nil {
   219  		return err
   220  	}
   221  	for _, fv := range populatedFields {
   222  		if err := encodeKey(&fv); err != nil {
   223  			return err
   224  		}
   225  
   226  		if len(fv.raw) > 0 {
   227  			if err := enc.Encode(fv.raw); err != nil {
   228  				return err
   229  			}
   230  			continue
   231  		}
   232  
   233  		fd := fv.fd
   234  		name := fd.Name()
   235  
   236  		// list[*]
   237  		if fd.IsList() {
   238  			lst := fv.v.List()
   239  			if err := enc.EncodeArrayLen(lst.Len()); err != nil {
   240  				return err
   241  			}
   242  			for i := 0; i < lst.Len(); i++ {
   243  				if err := o.marshalValue(enc, fd, lst.Get(i)); err != nil {
   244  					return errors.Annotate(err, "%s[%d]", name, i).Err()
   245  				}
   246  			}
   247  			continue
   248  		}
   249  
   250  		// map[simple]*
   251  		if fd.IsMap() {
   252  			m := fv.v.Map()
   253  			if err := enc.EncodeMapLen(m.Len()); err != nil {
   254  				return err
   255  			}
   256  			rangeFn := m.Range
   257  			if o.deterministic {
   258  				rangeFn = func(f func(protoreflect.MapKey, protoreflect.Value) bool) {
   259  					reflectutil.MapRangeSorted(m, fd.MapKey().Kind(), f)
   260  				}
   261  			}
   262  			var encodeKey func(protoreflect.MapKey) error
   263  			if len(o.internMarshalTable) > 0 && fd.MapKey().Kind() == protoreflect.StringKind {
   264  				encodeKey = func(mk protoreflect.MapKey) error {
   265  					sval := mk.String()
   266  					if ival, ok := o.internMarshalTable[sval]; ok {
   267  						if err := enc.EncodeUint(uint64(ival)); err != nil {
   268  							return err
   269  						}
   270  						return nil
   271  					}
   272  					return enc.EncodeString(sval)
   273  				}
   274  			} else {
   275  				encodeKey = func(mk protoreflect.MapKey) error {
   276  					return enc.Encode(mk.Interface())
   277  				}
   278  			}
   279  			rangeFn(func(mk protoreflect.MapKey, v protoreflect.Value) bool {
   280  				if err = encodeKey(mk); err == nil {
   281  					err = o.marshalValue(enc, fd, v)
   282  				}
   283  				err = errors.Annotate(err, "%s[%s]", name, mk).Err()
   284  				return err == nil
   285  			})
   286  			if err != nil {
   287  				return err
   288  			}
   289  			continue
   290  		}
   291  
   292  		if err := o.marshalValue(enc, fd, fv.v); err != nil {
   293  			return errors.Annotate(err, "%s", name).Err()
   294  		}
   295  	}
   296  
   297  	return
   298  }
   299  
   300  // MarshalStream is like Marshal but outputs to an io.Writer instead of
   301  // returning a string.
   302  func MarshalStream(writer io.Writer, msg proto.Message, opts ...Option) error {
   303  	o := &options{}
   304  	for _, fn := range opts {
   305  		fn(o)
   306  	}
   307  
   308  	enc := msgpack.GetEncoder()
   309  	defer msgpack.PutEncoder(enc)
   310  
   311  	enc.Reset(writer)
   312  	enc.UseCompactInts(true)
   313  	enc.UseCompactFloats(true)
   314  	err := o.marshalMessage(enc, msg.ProtoReflect())
   315  
   316  	return err
   317  }
   318  
   319  // Marshal encodes all the known fields in msg to a msgpack string.
   320  //
   321  // By default, this will emit any unknown msgpack fields (generated by the
   322  // Unmarshal method in this package) back to the serialized message. Pass
   323  // IgnoreUnknownFields or DisallowUnknownFields to affect this behavior.
   324  //
   325  // This can also produce a deterministic encoding if Deterministic is passed as
   326  // an option. Otherwise this will do a faster non-determnistic encoding without
   327  // trying to sort field tags or map keys.
   328  //
   329  // Returns an error if `msg` contains unknown fields.
   330  func Marshal(msg proto.Message, opts ...Option) (msgpack.RawMessage, error) {
   331  	ret := bytes.Buffer{}
   332  	err := MarshalStream(&ret, msg, opts...)
   333  	return ret.Bytes(), err
   334  }