go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/common/proto/msgpackpb/marshal.go (about) 1 // Copyright 2022 The LUCI Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package msgpackpb 16 17 import ( 18 "bytes" 19 "io" 20 "math" 21 "reflect" 22 "sort" 23 24 "github.com/vmihailenco/msgpack/v5" 25 "go.chromium.org/luci/common/errors" 26 "go.chromium.org/luci/common/proto/reflectutil" 27 "google.golang.org/protobuf/proto" 28 "google.golang.org/protobuf/reflect/protoreflect" 29 ) 30 31 // internal type for marshalMessage 32 // 33 // This maps a field number to a value for that field. 34 // 35 // A fieldVal can either be `known` (i.e. is defined in the proto) 36 // or `unknown` (i.e. field was present in an unmarshaled msgpackpb 37 // message). Exactly one of `(fd, v)` or `raw` will be set. 38 type fieldVal struct { 39 n int32 // the proto field tag number 40 41 // set if field was `known` 42 fd protoreflect.FieldDescriptor 43 v protoreflect.Value 44 45 // set if field was `unknown` 46 raw msgpack.RawMessage 47 } 48 49 func (o *options) marshalValue(enc *msgpack.Encoder, fd protoreflect.FieldDescriptor, val protoreflect.Value) error { 50 kind := fd.Kind() 51 if fd.IsMap() { 52 kind = fd.MapValue().Kind() 53 } 54 55 switch kind { 56 case protoreflect.BoolKind: 57 // note: this should only ever encode `true`, because proto range should 58 // skip it if it's false. 59 return enc.EncodeBool(val.Bool()) 60 61 case protoreflect.Int32Kind, protoreflect.Int64Kind: 62 return enc.EncodeInt(val.Int()) 63 64 case protoreflect.EnumKind: 65 return enc.EncodeInt(int64(val.Enum())) 66 67 case protoreflect.Uint32Kind, protoreflect.Uint64Kind: 68 return enc.EncodeUint(val.Uint()) 69 70 case protoreflect.FloatKind: 71 // this mimics lua's handling of floats-containing-integers 72 73 // convert to float32 here is potentially lossy, so we do it before 74 // math.Floor. Conversion from float32 to float64 is NOT lossy. 75 f := float32(val.Float()) 76 if math.Floor(float64(f)) == float64(f) { 77 return enc.EncodeInt(int64(f)) 78 } 79 return enc.EncodeFloat32(f) 80 81 case protoreflect.DoubleKind: 82 // this mimics lua's handling of floats-containing-integers 83 f := val.Float() 84 if math.Floor(f) == f { 85 return enc.EncodeInt(int64(f)) 86 } 87 return enc.EncodeFloat64(f) 88 89 case protoreflect.StringKind: 90 sVal := val.String() 91 if ival, ok := o.internMarshalTable[sVal]; ok { 92 return enc.EncodeUint(uint64(ival)) 93 } 94 95 return enc.EncodeString(sVal) 96 97 case protoreflect.MessageKind: 98 return o.marshalMessage(enc, val.Message()) 99 } 100 return errors.Reason("marshalValue: invalid kind %q", kind).Err() 101 } 102 103 func (o *options) appendRawMsgpackMsg(raw []byte, to *[]fieldVal, tf takenFields) error { 104 dec := msgpack.GetDecoder() 105 defer func() { 106 if dec != nil { 107 msgpack.PutDecoder(dec) 108 } 109 }() 110 111 dec.Reset(bytes.NewReader(raw)) 112 dec.SetMapDecoder((*msgpack.Decoder).DecodeTypedMap) 113 114 msgItemLen, nextKey, err := getMapLen(dec) 115 if err != nil { 116 return errors.Annotate(err, "expected message length").Err() 117 } 118 119 for i := 0; i < msgItemLen; i++ { 120 tag, err := getNextMsgTag(dec, nextKey) 121 if err != nil { 122 return errors.Annotate(err, "reading message %d'th tag", i).Err() 123 } 124 if err = tf.add(tag); err != nil { 125 return errors.Annotate(err, "reading message %d'th tag", i).Err() 126 } 127 128 var rawVal msgpack.RawMessage 129 if o.deterministic { 130 var valI any 131 valI, err = dec.DecodeInterfaceLoose() 132 if err == nil { 133 rawVal, err = msgpackpbDeterministicEncode(reflect.ValueOf(valI)) 134 } 135 } else { 136 rawVal, err = dec.DecodeRaw() 137 } 138 if err != nil { 139 return errors.Annotate(err, "reading message %d't field", i).Err() 140 } 141 142 *to = append(*to, fieldVal{ 143 n: tag, 144 raw: rawVal, 145 }) 146 } 147 148 return nil 149 } 150 151 type takenFields map[int32]struct{} 152 153 func (t takenFields) add(tag int32) error { 154 if tag == 0 { 155 return errors.New("invalid tag 0") 156 } 157 158 if _, ok := t[tag]; ok { 159 return errors.Reason("duplicate tag %d", tag).Err() 160 } 161 t[tag] = struct{}{} 162 163 return nil 164 } 165 166 func (o *options) marshalMessage(enc *msgpack.Encoder, msg protoreflect.Message) (err error) { 167 tf := takenFields{} 168 populatedFields := make([]fieldVal, 0, msg.Descriptor().Fields().Len()) 169 msg.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { 170 fv := fieldVal{fd: fd, v: v} 171 fv.n = int32(fd.Number()) 172 if err := tf.add(fv.n); err != nil { 173 panic(errors.Annotate(err, "impossible").Err()) 174 } 175 populatedFields = append(populatedFields, fv) 176 return true 177 }) 178 179 unknownFieldsRaw := msg.GetUnknown() 180 if len(unknownFieldsRaw) > 0 { 181 if o.unknownFieldBehavior == disallowUnknownFields { 182 return errors.Reason("message has unknown fields").Err() 183 } 184 185 var uf UnknownFields 186 if err := proto.Unmarshal(unknownFieldsRaw, &uf); err != nil { 187 return errors.Reason("unmarshaling unknown msgpack fields").Err() 188 } 189 if len(uf.ProtoReflect().GetUnknown()) > 0 { 190 return errors.Reason("unknown non-msgpack fields unsupported").Err() 191 } 192 193 if o.unknownFieldBehavior == preserveUnknownFields { 194 if err := o.appendRawMsgpackMsg(uf.MsgpackpbData, &populatedFields, tf); err != nil { 195 return errors.Reason("parsing unknown fields").Err() 196 } 197 } 198 } 199 200 encodeLen := func() error { 201 return enc.EncodeMapLen(len(populatedFields)) 202 } 203 encodeKey := func(fv *fieldVal) error { 204 return enc.EncodeInt(int64(fv.n)) 205 } 206 207 if o.deterministic { 208 sort.Slice(populatedFields, func(i, j int) bool { return populatedFields[i].n < populatedFields[j].n }) 209 count := int32(len(populatedFields)) 210 if count > 0 && populatedFields[0].n == 1 && populatedFields[len(populatedFields)-1].n == count { 211 encodeLen = func() error { 212 return enc.EncodeArrayLen(int(count)) 213 } 214 encodeKey = func(fv *fieldVal) error { return nil } 215 } 216 } 217 218 if err := encodeLen(); err != nil { 219 return err 220 } 221 for _, fv := range populatedFields { 222 if err := encodeKey(&fv); err != nil { 223 return err 224 } 225 226 if len(fv.raw) > 0 { 227 if err := enc.Encode(fv.raw); err != nil { 228 return err 229 } 230 continue 231 } 232 233 fd := fv.fd 234 name := fd.Name() 235 236 // list[*] 237 if fd.IsList() { 238 lst := fv.v.List() 239 if err := enc.EncodeArrayLen(lst.Len()); err != nil { 240 return err 241 } 242 for i := 0; i < lst.Len(); i++ { 243 if err := o.marshalValue(enc, fd, lst.Get(i)); err != nil { 244 return errors.Annotate(err, "%s[%d]", name, i).Err() 245 } 246 } 247 continue 248 } 249 250 // map[simple]* 251 if fd.IsMap() { 252 m := fv.v.Map() 253 if err := enc.EncodeMapLen(m.Len()); err != nil { 254 return err 255 } 256 rangeFn := m.Range 257 if o.deterministic { 258 rangeFn = func(f func(protoreflect.MapKey, protoreflect.Value) bool) { 259 reflectutil.MapRangeSorted(m, fd.MapKey().Kind(), f) 260 } 261 } 262 var encodeKey func(protoreflect.MapKey) error 263 if len(o.internMarshalTable) > 0 && fd.MapKey().Kind() == protoreflect.StringKind { 264 encodeKey = func(mk protoreflect.MapKey) error { 265 sval := mk.String() 266 if ival, ok := o.internMarshalTable[sval]; ok { 267 if err := enc.EncodeUint(uint64(ival)); err != nil { 268 return err 269 } 270 return nil 271 } 272 return enc.EncodeString(sval) 273 } 274 } else { 275 encodeKey = func(mk protoreflect.MapKey) error { 276 return enc.Encode(mk.Interface()) 277 } 278 } 279 rangeFn(func(mk protoreflect.MapKey, v protoreflect.Value) bool { 280 if err = encodeKey(mk); err == nil { 281 err = o.marshalValue(enc, fd, v) 282 } 283 err = errors.Annotate(err, "%s[%s]", name, mk).Err() 284 return err == nil 285 }) 286 if err != nil { 287 return err 288 } 289 continue 290 } 291 292 if err := o.marshalValue(enc, fd, fv.v); err != nil { 293 return errors.Annotate(err, "%s", name).Err() 294 } 295 } 296 297 return 298 } 299 300 // MarshalStream is like Marshal but outputs to an io.Writer instead of 301 // returning a string. 302 func MarshalStream(writer io.Writer, msg proto.Message, opts ...Option) error { 303 o := &options{} 304 for _, fn := range opts { 305 fn(o) 306 } 307 308 enc := msgpack.GetEncoder() 309 defer msgpack.PutEncoder(enc) 310 311 enc.Reset(writer) 312 enc.UseCompactInts(true) 313 enc.UseCompactFloats(true) 314 err := o.marshalMessage(enc, msg.ProtoReflect()) 315 316 return err 317 } 318 319 // Marshal encodes all the known fields in msg to a msgpack string. 320 // 321 // By default, this will emit any unknown msgpack fields (generated by the 322 // Unmarshal method in this package) back to the serialized message. Pass 323 // IgnoreUnknownFields or DisallowUnknownFields to affect this behavior. 324 // 325 // This can also produce a deterministic encoding if Deterministic is passed as 326 // an option. Otherwise this will do a faster non-determnistic encoding without 327 // trying to sort field tags or map keys. 328 // 329 // Returns an error if `msg` contains unknown fields. 330 func Marshal(msg proto.Message, opts ...Option) (msgpack.RawMessage, error) { 331 ret := bytes.Buffer{} 332 err := MarshalStream(&ret, msg, opts...) 333 return ret.Bytes(), err 334 }