go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/starlark/starlarkproto/functions.go (about) 1 // Copyright 2019 The LUCI Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package starlarkproto 16 17 import ( 18 "bytes" 19 "encoding/json" 20 "fmt" 21 22 "go.starlark.net/starlark" 23 "go.starlark.net/starlarkstruct" 24 25 "google.golang.org/protobuf/encoding/protojson" 26 "google.golang.org/protobuf/encoding/prototext" 27 "google.golang.org/protobuf/proto" 28 "google.golang.org/protobuf/types/descriptorpb" 29 "google.golang.org/protobuf/types/dynamicpb" 30 31 "go.chromium.org/luci/common/errors" 32 "go.chromium.org/luci/common/proto/textpb" 33 ) 34 35 // ToTextPB serializes a protobuf message to text proto. 36 func ToTextPB(msg *Message) ([]byte, error) { 37 opts := prototext.MarshalOptions{ 38 AllowPartial: true, 39 Indent: " ", 40 Resolver: msg.typ.loader.types, // used for google.protobuf.Any fields 41 } 42 blob, err := opts.Marshal(msg.ToProto()) 43 if err != nil { 44 return nil, err 45 } 46 // prototext randomly injects spaces into the generate output. Pass it through 47 // a formatter to get rid of them. 48 return textpb.Format(blob, msg.MessageType().Descriptor()) 49 } 50 51 // ToJSONPB serializes a protobuf message to JSONPB string. 52 func ToJSONPB(msg *Message, useProtoNames bool) ([]byte, error) { 53 opts := protojson.MarshalOptions{ 54 AllowPartial: true, 55 Resolver: msg.typ.loader.types, // used for google.protobuf.Any fields 56 UseProtoNames: useProtoNames, 57 } 58 blob, err := opts.Marshal(msg.ToProto()) 59 if err != nil { 60 return nil, err 61 } 62 // protojson randomly injects spaces into the generate output. Pass it through 63 // a formatter to get rid of them. 64 var out bytes.Buffer 65 if err := json.Indent(&out, blob, "", "\t"); err != nil { 66 return nil, err 67 } 68 return bytes.TrimSpace(out.Bytes()), nil 69 } 70 71 // ToWirePB serializes a protobuf message to binary wire format. 72 func ToWirePB(msg *Message) ([]byte, error) { 73 opts := proto.MarshalOptions{ 74 AllowPartial: true, 75 Deterministic: true, 76 } 77 return opts.Marshal(msg.ToProto()) 78 } 79 80 // FromTextPB deserializes a protobuf message given in text proto form. 81 // 82 // Unlike the equivalent Starlark proto.from_textpb(...), this low-level native 83 // function doesn't freeze returned messages, but also doesn't use the message 84 // cache. 85 func FromTextPB(typ *MessageType, blob []byte, discardUnknown bool) (*Message, error) { 86 pb := dynamicpb.NewMessage(typ.desc) 87 opts := prototext.UnmarshalOptions{ 88 AllowPartial: true, 89 DiscardUnknown: discardUnknown, 90 Resolver: typ.loader.types, // used for google.protobuf.Any fields 91 } 92 if err := opts.Unmarshal(blob, pb); err != nil { 93 return nil, err 94 } 95 return typ.MessageFromProto(pb), nil 96 } 97 98 // FromJSONPB deserializes a protobuf message given as JBONPB string. 99 // 100 // Unlike the equivalent Starlark proto.from_jsonpb(...), this low-level native 101 // function doesn't freeze returned messages, but also doesn't use the message 102 // cache. 103 func FromJSONPB(typ *MessageType, blob []byte, discardUnknown bool) (*Message, error) { 104 pb := dynamicpb.NewMessage(typ.desc) 105 opts := protojson.UnmarshalOptions{ 106 AllowPartial: true, 107 DiscardUnknown: discardUnknown, 108 Resolver: typ.loader.types, // used for google.protobuf.Any fields 109 } 110 if err := opts.Unmarshal(blob, pb); err != nil { 111 return nil, err 112 } 113 return typ.MessageFromProto(pb), nil 114 } 115 116 // FromWirePB deserializes a protobuf message given as a wire-encoded blob. 117 // 118 // Unlike the equivalent Starlark proto.from_wirepb(...), this low-level native 119 // function doesn't freeze returned messages, but also doesn't use the message 120 // cache. 121 func FromWirePB(typ *MessageType, blob []byte, discardUnknown bool) (*Message, error) { 122 pb := dynamicpb.NewMessage(typ.desc) 123 opts := proto.UnmarshalOptions{ 124 AllowPartial: true, 125 DiscardUnknown: discardUnknown, 126 Resolver: typ.loader.types, // used for google.protobuf.Any fields 127 } 128 if err := opts.Unmarshal(blob, pb); err != nil { 129 return nil, err 130 } 131 return typ.MessageFromProto(pb), nil 132 } 133 134 // ProtoLib returns a dict with single struct named "proto" that holds public 135 // Starlark API for working with proto messages. 136 // 137 // Exported functions: 138 // 139 // def new_descriptor_set(name=None, blob=None, deps=None): 140 // """Returns a new DescriptorSet. 141 // 142 // Args: 143 // name: name of this set for debug and error messages, default is '???'. 144 // blob: raw serialized FileDescriptorSet, if any. 145 // deps: an iterable of DescriptorSet's with dependencies, if any. 146 // 147 // Returns: 148 // New DescriptorSet. 149 // """ 150 // 151 // def new_loader(*descriptor_sets): 152 // """Returns a new proto loader.""" 153 // 154 // def default_loader(): 155 // """Returns a loader used by default when registering descriptor sets.""" 156 // 157 // def message_type(msg): 158 // """Returns proto.MessageType of the given message.""" 159 // 160 // def to_textpb(msg): 161 // """Serializes a protobuf message to text proto. 162 // 163 // Args: 164 // msg: a *Message to serialize. 165 // 166 // Returns: 167 // A str representing msg in text format. 168 // """ 169 // 170 // def to_jsonpb(msg, use_proto_names = None): 171 // """Serializes a protobuf message to JSONPB string. 172 // 173 // Args: 174 // msg: a *Message to serialize. 175 // use_proto_names: boolean, whether to use snake_case in field names 176 // instead of camelCase. The default is False. 177 // 178 // Returns: 179 // A str representing msg in JSONPB format. 180 // """ 181 // 182 // def to_wirepb(msg): 183 // """Serializes a protobuf message to a string using binary wire encoding. 184 // 185 // Args: 186 // msg: a *Message to serialize. 187 // 188 // Returns: 189 // A str representing msg in binary wire format. 190 // """ 191 // 192 // def from_textpb(ctor, body): 193 // """Deserializes a protobuf message given in text proto form. 194 // 195 // Unknown fields are not allowed. 196 // 197 // Args: 198 // ctor: a message constructor function. 199 // body: a string with serialized message. 200 // discard_unknown: boolean, whether to discard unrecognized fields. The 201 // default is False. 202 // 203 // Returns: 204 // Deserialized frozen message constructed via `ctor`. 205 // """ 206 // 207 // def from_jsonpb(ctor, body): 208 // """Deserializes a protobuf message given as JBONPB string. 209 // 210 // Unknown fields are silently skipped. 211 // 212 // Args: 213 // ctor: a message constructor function. 214 // body: a string with serialized message. 215 // discard_unknown: boolean, whether to discard unrecognized fields. The 216 // default is True. 217 // 218 // Returns: 219 // Deserialized frozen message constructed via `ctor`. 220 // """ 221 // 222 // def from_wirepb(ctor, body): 223 // """Deserializes a protobuf message given its wire serialization. 224 // 225 // Unknown fields are silently skipped. 226 // 227 // Args: 228 // ctor: a message constructor function. 229 // body: a string with serialized message. 230 // discard_unknown: boolean, whether to discard unrecognized fields. The 231 // default is True. 232 // 233 // Returns: 234 // Deserialized frozen message constructed via `ctor`. 235 // """ 236 // 237 // def struct_to_textpb(s): 238 // """Converts a struct to a text proto string. 239 // 240 // Args: 241 // s: a struct object. May not contain dicts. 242 // 243 // Returns: 244 // A str containing a text format protocol buffer message. 245 // """ 246 // 247 // def clone(msg): 248 // """Returns a deep copy of a given proto message. 249 // 250 // Args: 251 // msg: a proto message to make a copy of. 252 // 253 // Returns: 254 // A deep copy of the message 255 // """ 256 // 257 // def has(msg, field): 258 // """Checks if a proto message has the given optional field set. 259 // 260 // Args: 261 // msg: a message to check. 262 // field: a string name of the field to check. 263 // 264 // Returns: 265 // True if the message has the field set. 266 // """ 267 func ProtoLib() starlark.StringDict { 268 return starlark.StringDict{ 269 "proto": starlarkstruct.FromStringDict(starlark.String("proto"), starlark.StringDict{ 270 "new_descriptor_set": starlark.NewBuiltin("new_descriptor_set", newDescriptorSet), 271 "new_loader": starlark.NewBuiltin("new_loader", newLoader), 272 "default_loader": starlark.NewBuiltin("default_loader", defaultLoader), 273 "message_type": starlark.NewBuiltin("message_type", messageType), 274 "to_textpb": marshallerBuiltin("to_textpb", ToTextPB), 275 "to_jsonpb": toJSONPBBuiltin("to_jsonpb"), 276 "to_wirepb": marshallerBuiltin("to_wirepb", ToWirePB), 277 "from_textpb": unmarshallerBuiltin("from_textpb", FromTextPB, false), 278 "from_jsonpb": unmarshallerBuiltin("from_jsonpb", FromJSONPB, true), 279 "from_wirepb": unmarshallerBuiltin("from_wirepb", FromWirePB, true), 280 "struct_to_textpb": starlark.NewBuiltin("struct_to_textpb", structToTextPb), 281 "clone": starlark.NewBuiltin("clone", clone), 282 "has": starlark.NewBuiltin("has", has), 283 }), 284 } 285 } 286 287 // newDescriptorSet constructs *DescriptorSet. 288 func newDescriptorSet(_ *starlark.Thread, fn *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { 289 var name string 290 var blob string 291 var deps starlark.Value 292 err := starlark.UnpackArgs("new_descriptor_set", args, kwargs, 293 "name?", &name, 294 "blob?", &blob, 295 "deps?", &deps, 296 ) 297 if err != nil { 298 return nil, err 299 } 300 301 // Name is optional. 302 if name == "" { 303 name = "???" 304 } 305 306 // Blob is also optional. If given, it is a serialized FileDescriptorSet. 307 var fdps []*descriptorpb.FileDescriptorProto 308 if blob != "" { 309 fds := &descriptorpb.FileDescriptorSet{} 310 if err := proto.Unmarshal([]byte(blob), fds); err != nil { 311 return nil, fmt.Errorf("new_descriptor_set: for parameter \"blob\": %s", err) 312 } 313 fdps = fds.GetFile() 314 } 315 316 // Collect []*DescriptorSet from 'deps'. 317 var sets []*DescriptorSet 318 if deps != nil && deps != starlark.None { 319 iter := starlark.Iterate(deps) 320 if iter == nil { 321 return nil, fmt.Errorf("new_descriptor_set: for parameter \"deps\": got %s, want an iterable", deps.Type()) 322 } 323 defer iter.Done() 324 var x starlark.Value 325 for iter.Next(&x) { 326 ds, ok := x.(*DescriptorSet) 327 if !ok { 328 return nil, fmt.Errorf("new_descriptor_set: for parameter \"deps\" #%d: got %s, want proto.DescriptorSet", len(sets), x.Type()) 329 } 330 sets = append(sets, ds) 331 } 332 } 333 334 // Checks all imports can be resolved. 335 ds, err := NewDescriptorSet(name, fdps, sets) 336 if err != nil { 337 return nil, fmt.Errorf("new_descriptor_set: %s", err) 338 } 339 return ds, nil 340 } 341 342 // newLoader constructs *Loader and populates it with given descriptor sets. 343 func newLoader(_ *starlark.Thread, fn *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { 344 if len(kwargs) > 0 { 345 return nil, errors.New("new_loader: unexpected keyword arguments") 346 } 347 sets := make([]*DescriptorSet, len(args)) 348 for i, v := range args { 349 ds, ok := v.(*DescriptorSet) 350 if !ok { 351 return nil, fmt.Errorf("new_loader: for parameter %d: got %s, want proto.DescriptorSet", i+1, v.Type()) 352 } 353 sets[i] = ds 354 } 355 l := NewLoader() 356 for _, ds := range sets { 357 if err := l.AddDescriptorSet(ds); err != nil { 358 return nil, fmt.Errorf("new_loader: %s", err) 359 } 360 } 361 return l, nil 362 } 363 364 // defaultLoader returns *Loader installed in the thread via SetDefaultLoader. 365 func defaultLoader(th *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { 366 if err := starlark.UnpackArgs("default_loader", args, kwargs); err != nil { 367 return nil, err 368 } 369 if l := DefaultLoader(th); l != nil { 370 return l, nil 371 } 372 return starlark.None, nil 373 } 374 375 // messageType returns MessageType of the given message. 376 func messageType(th *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { 377 var msg *Message 378 if err := starlark.UnpackArgs("message_type", args, kwargs, "msg", &msg); err != nil { 379 return nil, err 380 } 381 return msg.MessageType(), nil 382 } 383 384 // marshallerBuiltin implements Starlark shim for To*PB() functions. 385 func marshallerBuiltin(name string, impl func(*Message) ([]byte, error)) *starlark.Builtin { 386 return starlark.NewBuiltin(name, func(_ *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { 387 var msg *Message 388 if err := starlark.UnpackArgs(name, args, kwargs, "msg", &msg); err != nil { 389 return nil, err 390 } 391 blob, err := impl(msg) 392 if err != nil { 393 return nil, fmt.Errorf("%s: %s", name, err) 394 } 395 return starlark.String(blob), nil 396 }) 397 } 398 399 // toJSONPBBuiltin implements Starlark shim for the ToJSONPB function. 400 func toJSONPBBuiltin(name string) *starlark.Builtin { 401 return starlark.NewBuiltin(name, func(_ *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { 402 var msg *Message 403 var useProtoNames starlark.Bool 404 if err := starlark.UnpackArgs(name, args, kwargs, "msg", &msg, "use_proto_names?", &useProtoNames); err != nil { 405 return nil, err 406 } 407 blob, err := ToJSONPB(msg, bool(useProtoNames)) 408 if err != nil { 409 return nil, fmt.Errorf("%s: %s", name, err) 410 } 411 return starlark.String(blob), nil 412 }) 413 } 414 415 // unmarshallerBuiltin implements Starlark shim for From*PB() functions. 416 // 417 // It also knows how to use the message cache in the thread to cache 418 // deserialized messages. 419 func unmarshallerBuiltin(name string, impl func(*MessageType, []byte, bool) (*Message, error), discardUnknownDefault bool) *starlark.Builtin { 420 return starlark.NewBuiltin(name, func(th *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { 421 var ctor starlark.Value 422 var body string 423 discardUnknown := starlark.Bool(discardUnknownDefault) 424 if err := starlark.UnpackArgs(name, args, kwargs, "ctor", &ctor, "body", &body, "discard_unknown?", &discardUnknown); err != nil { 425 return nil, err 426 } 427 typ, ok := ctor.(*MessageType) 428 if !ok { 429 return nil, fmt.Errorf("%s: got %s, expecting a proto message constructor", name, ctor.Type()) 430 } 431 432 cache := messageCache(th) 433 cacheName := fmt.Sprintf("%s:%s", name, discardUnknown) 434 if cache != nil { 435 cached, err := cache.Fetch(th, cacheName, body, typ) 436 if err != nil { 437 return nil, fmt.Errorf("%s: internal message cache error when fetching: %s", name, err) 438 } 439 if cached != nil { 440 if cached.MessageType() != typ { 441 panic(fmt.Sprintf("the message cache returned message of type %s, but %s was expected", cached.MessageType(), typ)) 442 } 443 if !cached.IsFrozen() { 444 panic("the message cache returned non-frozen message") 445 } 446 return cached, nil 447 } 448 } 449 450 msg, err := impl(typ, []byte(body), bool(discardUnknown)) 451 if err != nil { 452 return nil, fmt.Errorf("%s: %s", name, err) 453 } 454 msg.Freeze() 455 456 if cache != nil { 457 if err := cache.Store(th, cacheName, body, msg); err != nil { 458 return nil, fmt.Errorf("%s: internal message cache error when storing: %s", name, err) 459 } 460 } 461 462 return msg, nil 463 }) 464 } 465 466 // clone returns a copy of a given message. 467 func clone(th *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { 468 var msg *Message 469 if err := starlark.UnpackArgs("clone", args, kwargs, "msg", &msg); err != nil { 470 return nil, err 471 } 472 return msg.MessageType().MessageFromProto(proto.Clone(msg.ToProto())), nil 473 } 474 475 // has checks a presence of an optional field. 476 func has(th *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { 477 var msg *Message 478 var field string 479 if err := starlark.UnpackArgs("has", args, kwargs, "msg", &msg, "field", &field); err != nil { 480 return nil, err 481 } 482 return starlark.Bool(msg.HasProtoField(field)), nil 483 } 484 485 // TODO(vadimsh): Remove once users switch to protos. 486 487 // structToTextPb takes a struct and returns a string containing a text format 488 // protocol buffer. 489 func structToTextPb(_ *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { 490 var val starlark.Value 491 if err := starlark.UnpackArgs("struct_to_textpb", args, kwargs, "struct", &val); err != nil { 492 return nil, err 493 } 494 s, ok := val.(*starlarkstruct.Struct) 495 if !ok { 496 return nil, fmt.Errorf("struct_to_textpb: got %s, expecting a struct", val.Type()) 497 } 498 var buf bytes.Buffer 499 err := writeProtoStruct(&buf, 0, s) 500 if err != nil { 501 return nil, err 502 } 503 return starlark.String(buf.String()), nil 504 } 505 506 // Based on 507 // https://github.com/google/starlark-go/blob/32ce6ec36500ded2e2340a430fae42bc43da8467/starlarkstruct/struct.go 508 func writeProtoStruct(out *bytes.Buffer, depth int, s *starlarkstruct.Struct) error { 509 for _, name := range s.AttrNames() { 510 val, err := s.Attr(name) 511 if err != nil { 512 return err 513 } 514 if err = writeProtoField(out, depth, name, val); err != nil { 515 return err 516 } 517 } 518 return nil 519 } 520 521 func writeProtoField(out *bytes.Buffer, depth int, field string, v starlark.Value) error { 522 if depth > 16 { 523 return fmt.Errorf("struct_to_textpb: depth limit exceeded") 524 } 525 526 switch v := v.(type) { 527 case *starlarkstruct.Struct: 528 fmt.Fprintf(out, "%*s%s: <\n", 2*depth, "", field) 529 if err := writeProtoStruct(out, depth+1, v); err != nil { 530 return err 531 } 532 fmt.Fprintf(out, "%*s>\n", 2*depth, "") 533 return nil 534 535 case *starlark.List, starlark.Tuple: 536 iter := starlark.Iterate(v) 537 defer iter.Done() 538 var elem starlark.Value 539 for iter.Next(&elem) { 540 if err := writeProtoField(out, depth, field, elem); err != nil { 541 return err 542 } 543 } 544 return nil 545 } 546 547 // scalars 548 fmt.Fprintf(out, "%*s%s: ", 2*depth, "", field) 549 switch v := v.(type) { 550 case starlark.Bool: 551 fmt.Fprintf(out, "%t", v) 552 553 case starlark.Int: 554 out.WriteString(v.String()) 555 556 case starlark.Float: 557 fmt.Fprintf(out, "%g", v) 558 559 case starlark.String: 560 fmt.Fprintf(out, "%q", string(v)) 561 562 default: 563 return fmt.Errorf("struct_to_textpb: cannot convert %s to proto", v.Type()) 564 } 565 out.WriteByte('\n') 566 return nil 567 }