go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/starlark/starlarkproto/functions.go (about)

     1  // Copyright 2019 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package starlarkproto
    16  
    17  import (
    18  	"bytes"
    19  	"encoding/json"
    20  	"fmt"
    21  
    22  	"go.starlark.net/starlark"
    23  	"go.starlark.net/starlarkstruct"
    24  
    25  	"google.golang.org/protobuf/encoding/protojson"
    26  	"google.golang.org/protobuf/encoding/prototext"
    27  	"google.golang.org/protobuf/proto"
    28  	"google.golang.org/protobuf/types/descriptorpb"
    29  	"google.golang.org/protobuf/types/dynamicpb"
    30  
    31  	"go.chromium.org/luci/common/errors"
    32  	"go.chromium.org/luci/common/proto/textpb"
    33  )
    34  
    35  // ToTextPB serializes a protobuf message to text proto.
    36  func ToTextPB(msg *Message) ([]byte, error) {
    37  	opts := prototext.MarshalOptions{
    38  		AllowPartial: true,
    39  		Indent:       " ",
    40  		Resolver:     msg.typ.loader.types, // used for google.protobuf.Any fields
    41  	}
    42  	blob, err := opts.Marshal(msg.ToProto())
    43  	if err != nil {
    44  		return nil, err
    45  	}
    46  	// prototext randomly injects spaces into the generate output. Pass it through
    47  	// a formatter to get rid of them.
    48  	return textpb.Format(blob, msg.MessageType().Descriptor())
    49  }
    50  
    51  // ToJSONPB serializes a protobuf message to JSONPB string.
    52  func ToJSONPB(msg *Message, useProtoNames bool) ([]byte, error) {
    53  	opts := protojson.MarshalOptions{
    54  		AllowPartial:  true,
    55  		Resolver:      msg.typ.loader.types, // used for google.protobuf.Any fields
    56  		UseProtoNames: useProtoNames,
    57  	}
    58  	blob, err := opts.Marshal(msg.ToProto())
    59  	if err != nil {
    60  		return nil, err
    61  	}
    62  	// protojson randomly injects spaces into the generate output. Pass it through
    63  	// a formatter to get rid of them.
    64  	var out bytes.Buffer
    65  	if err := json.Indent(&out, blob, "", "\t"); err != nil {
    66  		return nil, err
    67  	}
    68  	return bytes.TrimSpace(out.Bytes()), nil
    69  }
    70  
    71  // ToWirePB serializes a protobuf message to binary wire format.
    72  func ToWirePB(msg *Message) ([]byte, error) {
    73  	opts := proto.MarshalOptions{
    74  		AllowPartial:  true,
    75  		Deterministic: true,
    76  	}
    77  	return opts.Marshal(msg.ToProto())
    78  }
    79  
    80  // FromTextPB deserializes a protobuf message given in text proto form.
    81  //
    82  // Unlike the equivalent Starlark proto.from_textpb(...), this low-level native
    83  // function doesn't freeze returned messages, but also doesn't use the message
    84  // cache.
    85  func FromTextPB(typ *MessageType, blob []byte, discardUnknown bool) (*Message, error) {
    86  	pb := dynamicpb.NewMessage(typ.desc)
    87  	opts := prototext.UnmarshalOptions{
    88  		AllowPartial:   true,
    89  		DiscardUnknown: discardUnknown,
    90  		Resolver:       typ.loader.types, // used for google.protobuf.Any fields
    91  	}
    92  	if err := opts.Unmarshal(blob, pb); err != nil {
    93  		return nil, err
    94  	}
    95  	return typ.MessageFromProto(pb), nil
    96  }
    97  
    98  // FromJSONPB deserializes a protobuf message given as JBONPB string.
    99  //
   100  // Unlike the equivalent Starlark proto.from_jsonpb(...), this low-level native
   101  // function doesn't freeze returned messages, but also doesn't use the message
   102  // cache.
   103  func FromJSONPB(typ *MessageType, blob []byte, discardUnknown bool) (*Message, error) {
   104  	pb := dynamicpb.NewMessage(typ.desc)
   105  	opts := protojson.UnmarshalOptions{
   106  		AllowPartial:   true,
   107  		DiscardUnknown: discardUnknown,
   108  		Resolver:       typ.loader.types, // used for google.protobuf.Any fields
   109  	}
   110  	if err := opts.Unmarshal(blob, pb); err != nil {
   111  		return nil, err
   112  	}
   113  	return typ.MessageFromProto(pb), nil
   114  }
   115  
   116  // FromWirePB deserializes a protobuf message given as a wire-encoded blob.
   117  //
   118  // Unlike the equivalent Starlark proto.from_wirepb(...), this low-level native
   119  // function doesn't freeze returned messages, but also doesn't use the message
   120  // cache.
   121  func FromWirePB(typ *MessageType, blob []byte, discardUnknown bool) (*Message, error) {
   122  	pb := dynamicpb.NewMessage(typ.desc)
   123  	opts := proto.UnmarshalOptions{
   124  		AllowPartial:   true,
   125  		DiscardUnknown: discardUnknown,
   126  		Resolver:       typ.loader.types, // used for google.protobuf.Any fields
   127  	}
   128  	if err := opts.Unmarshal(blob, pb); err != nil {
   129  		return nil, err
   130  	}
   131  	return typ.MessageFromProto(pb), nil
   132  }
   133  
   134  // ProtoLib returns a dict with single struct named "proto" that holds public
   135  // Starlark API for working with proto messages.
   136  //
   137  // Exported functions:
   138  //
   139  //	def new_descriptor_set(name=None, blob=None, deps=None):
   140  //	  """Returns a new DescriptorSet.
   141  //
   142  //	  Args:
   143  //	    name: name of this set for debug and error messages, default is '???'.
   144  //	    blob: raw serialized FileDescriptorSet, if any.
   145  //	    deps: an iterable of DescriptorSet's with dependencies, if any.
   146  //
   147  //	  Returns:
   148  //	    New DescriptorSet.
   149  //	  """
   150  //
   151  //	def new_loader(*descriptor_sets):
   152  //	  """Returns a new proto loader."""
   153  //
   154  //	def default_loader():
   155  //	  """Returns a loader used by default when registering descriptor sets."""
   156  //
   157  //	def message_type(msg):
   158  //	  """Returns proto.MessageType of the given message."""
   159  //
   160  //	def to_textpb(msg):
   161  //	  """Serializes a protobuf message to text proto.
   162  //
   163  //	  Args:
   164  //	    msg: a *Message to serialize.
   165  //
   166  //	  Returns:
   167  //	    A str representing msg in text format.
   168  //	  """
   169  //
   170  //	def to_jsonpb(msg, use_proto_names = None):
   171  //	  """Serializes a protobuf message to JSONPB string.
   172  //
   173  //	  Args:
   174  //	    msg: a *Message to serialize.
   175  //	    use_proto_names: boolean, whether to use snake_case in field names
   176  //	      instead of camelCase. The default is False.
   177  //
   178  //	  Returns:
   179  //	    A str representing msg in JSONPB format.
   180  //	  """
   181  //
   182  //	def to_wirepb(msg):
   183  //	  """Serializes a protobuf message to a string using binary wire encoding.
   184  //
   185  //	  Args:
   186  //	    msg: a *Message to serialize.
   187  //
   188  //	  Returns:
   189  //	    A str representing msg in binary wire format.
   190  //	  """
   191  //
   192  //	def from_textpb(ctor, body):
   193  //	  """Deserializes a protobuf message given in text proto form.
   194  //
   195  //	  Unknown fields are not allowed.
   196  //
   197  //	  Args:
   198  //	    ctor: a message constructor function.
   199  //	    body: a string with serialized message.
   200  //	    discard_unknown: boolean, whether to discard unrecognized fields. The
   201  //	      default is False.
   202  //
   203  //	  Returns:
   204  //	    Deserialized frozen message constructed via `ctor`.
   205  //	  """
   206  //
   207  //	def from_jsonpb(ctor, body):
   208  //	  """Deserializes a protobuf message given as JBONPB string.
   209  //
   210  //	  Unknown fields are silently skipped.
   211  //
   212  //	  Args:
   213  //	    ctor: a message constructor function.
   214  //	    body: a string with serialized message.
   215  //	    discard_unknown: boolean, whether to discard unrecognized fields. The
   216  //	      default is True.
   217  //
   218  //	  Returns:
   219  //	    Deserialized frozen message constructed via `ctor`.
   220  //	  """
   221  //
   222  //	def from_wirepb(ctor, body):
   223  //	  """Deserializes a protobuf message given its wire serialization.
   224  //
   225  //	  Unknown fields are silently skipped.
   226  //
   227  //	  Args:
   228  //	    ctor: a message constructor function.
   229  //	    body: a string with serialized message.
   230  //	    discard_unknown: boolean, whether to discard unrecognized fields. The
   231  //	      default is True.
   232  //
   233  //	  Returns:
   234  //	    Deserialized frozen message constructed via `ctor`.
   235  //	  """
   236  //
   237  //	def struct_to_textpb(s):
   238  //	  """Converts a struct to a text proto string.
   239  //
   240  //	  Args:
   241  //	    s: a struct object. May not contain dicts.
   242  //
   243  //	  Returns:
   244  //	    A str containing a text format protocol buffer message.
   245  //	  """
   246  //
   247  //	def clone(msg):
   248  //	  """Returns a deep copy of a given proto message.
   249  //
   250  //	  Args:
   251  //	    msg: a proto message to make a copy of.
   252  //
   253  //	  Returns:
   254  //	    A deep copy of the message
   255  //	  """
   256  //
   257  //	def has(msg, field):
   258  //	  """Checks if a proto message has the given optional field set.
   259  //
   260  //	  Args:
   261  //	    msg: a message to check.
   262  //	    field: a string name of the field to check.
   263  //
   264  //	  Returns:
   265  //	    True if the message has the field set.
   266  //	  """
   267  func ProtoLib() starlark.StringDict {
   268  	return starlark.StringDict{
   269  		"proto": starlarkstruct.FromStringDict(starlark.String("proto"), starlark.StringDict{
   270  			"new_descriptor_set": starlark.NewBuiltin("new_descriptor_set", newDescriptorSet),
   271  			"new_loader":         starlark.NewBuiltin("new_loader", newLoader),
   272  			"default_loader":     starlark.NewBuiltin("default_loader", defaultLoader),
   273  			"message_type":       starlark.NewBuiltin("message_type", messageType),
   274  			"to_textpb":          marshallerBuiltin("to_textpb", ToTextPB),
   275  			"to_jsonpb":          toJSONPBBuiltin("to_jsonpb"),
   276  			"to_wirepb":          marshallerBuiltin("to_wirepb", ToWirePB),
   277  			"from_textpb":        unmarshallerBuiltin("from_textpb", FromTextPB, false),
   278  			"from_jsonpb":        unmarshallerBuiltin("from_jsonpb", FromJSONPB, true),
   279  			"from_wirepb":        unmarshallerBuiltin("from_wirepb", FromWirePB, true),
   280  			"struct_to_textpb":   starlark.NewBuiltin("struct_to_textpb", structToTextPb),
   281  			"clone":              starlark.NewBuiltin("clone", clone),
   282  			"has":                starlark.NewBuiltin("has", has),
   283  		}),
   284  	}
   285  }
   286  
   287  // newDescriptorSet constructs *DescriptorSet.
   288  func newDescriptorSet(_ *starlark.Thread, fn *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
   289  	var name string
   290  	var blob string
   291  	var deps starlark.Value
   292  	err := starlark.UnpackArgs("new_descriptor_set", args, kwargs,
   293  		"name?", &name,
   294  		"blob?", &blob,
   295  		"deps?", &deps,
   296  	)
   297  	if err != nil {
   298  		return nil, err
   299  	}
   300  
   301  	// Name is optional.
   302  	if name == "" {
   303  		name = "???"
   304  	}
   305  
   306  	// Blob is also optional. If given, it is a serialized FileDescriptorSet.
   307  	var fdps []*descriptorpb.FileDescriptorProto
   308  	if blob != "" {
   309  		fds := &descriptorpb.FileDescriptorSet{}
   310  		if err := proto.Unmarshal([]byte(blob), fds); err != nil {
   311  			return nil, fmt.Errorf("new_descriptor_set: for parameter \"blob\": %s", err)
   312  		}
   313  		fdps = fds.GetFile()
   314  	}
   315  
   316  	// Collect []*DescriptorSet from 'deps'.
   317  	var sets []*DescriptorSet
   318  	if deps != nil && deps != starlark.None {
   319  		iter := starlark.Iterate(deps)
   320  		if iter == nil {
   321  			return nil, fmt.Errorf("new_descriptor_set: for parameter \"deps\": got %s, want an iterable", deps.Type())
   322  		}
   323  		defer iter.Done()
   324  		var x starlark.Value
   325  		for iter.Next(&x) {
   326  			ds, ok := x.(*DescriptorSet)
   327  			if !ok {
   328  				return nil, fmt.Errorf("new_descriptor_set: for parameter \"deps\" #%d: got %s, want proto.DescriptorSet", len(sets), x.Type())
   329  			}
   330  			sets = append(sets, ds)
   331  		}
   332  	}
   333  
   334  	// Checks all imports can be resolved.
   335  	ds, err := NewDescriptorSet(name, fdps, sets)
   336  	if err != nil {
   337  		return nil, fmt.Errorf("new_descriptor_set: %s", err)
   338  	}
   339  	return ds, nil
   340  }
   341  
   342  // newLoader constructs *Loader and populates it with given descriptor sets.
   343  func newLoader(_ *starlark.Thread, fn *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
   344  	if len(kwargs) > 0 {
   345  		return nil, errors.New("new_loader: unexpected keyword arguments")
   346  	}
   347  	sets := make([]*DescriptorSet, len(args))
   348  	for i, v := range args {
   349  		ds, ok := v.(*DescriptorSet)
   350  		if !ok {
   351  			return nil, fmt.Errorf("new_loader: for parameter %d: got %s, want proto.DescriptorSet", i+1, v.Type())
   352  		}
   353  		sets[i] = ds
   354  	}
   355  	l := NewLoader()
   356  	for _, ds := range sets {
   357  		if err := l.AddDescriptorSet(ds); err != nil {
   358  			return nil, fmt.Errorf("new_loader: %s", err)
   359  		}
   360  	}
   361  	return l, nil
   362  }
   363  
   364  // defaultLoader returns *Loader installed in the thread via SetDefaultLoader.
   365  func defaultLoader(th *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
   366  	if err := starlark.UnpackArgs("default_loader", args, kwargs); err != nil {
   367  		return nil, err
   368  	}
   369  	if l := DefaultLoader(th); l != nil {
   370  		return l, nil
   371  	}
   372  	return starlark.None, nil
   373  }
   374  
   375  // messageType returns MessageType of the given message.
   376  func messageType(th *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
   377  	var msg *Message
   378  	if err := starlark.UnpackArgs("message_type", args, kwargs, "msg", &msg); err != nil {
   379  		return nil, err
   380  	}
   381  	return msg.MessageType(), nil
   382  }
   383  
   384  // marshallerBuiltin implements Starlark shim for To*PB() functions.
   385  func marshallerBuiltin(name string, impl func(*Message) ([]byte, error)) *starlark.Builtin {
   386  	return starlark.NewBuiltin(name, func(_ *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
   387  		var msg *Message
   388  		if err := starlark.UnpackArgs(name, args, kwargs, "msg", &msg); err != nil {
   389  			return nil, err
   390  		}
   391  		blob, err := impl(msg)
   392  		if err != nil {
   393  			return nil, fmt.Errorf("%s: %s", name, err)
   394  		}
   395  		return starlark.String(blob), nil
   396  	})
   397  }
   398  
   399  // toJSONPBBuiltin implements Starlark shim for the ToJSONPB function.
   400  func toJSONPBBuiltin(name string) *starlark.Builtin {
   401  	return starlark.NewBuiltin(name, func(_ *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
   402  		var msg *Message
   403  		var useProtoNames starlark.Bool
   404  		if err := starlark.UnpackArgs(name, args, kwargs, "msg", &msg, "use_proto_names?", &useProtoNames); err != nil {
   405  			return nil, err
   406  		}
   407  		blob, err := ToJSONPB(msg, bool(useProtoNames))
   408  		if err != nil {
   409  			return nil, fmt.Errorf("%s: %s", name, err)
   410  		}
   411  		return starlark.String(blob), nil
   412  	})
   413  }
   414  
   415  // unmarshallerBuiltin implements Starlark shim for From*PB() functions.
   416  //
   417  // It also knows how to use the message cache in the thread to cache
   418  // deserialized messages.
   419  func unmarshallerBuiltin(name string, impl func(*MessageType, []byte, bool) (*Message, error), discardUnknownDefault bool) *starlark.Builtin {
   420  	return starlark.NewBuiltin(name, func(th *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
   421  		var ctor starlark.Value
   422  		var body string
   423  		discardUnknown := starlark.Bool(discardUnknownDefault)
   424  		if err := starlark.UnpackArgs(name, args, kwargs, "ctor", &ctor, "body", &body, "discard_unknown?", &discardUnknown); err != nil {
   425  			return nil, err
   426  		}
   427  		typ, ok := ctor.(*MessageType)
   428  		if !ok {
   429  			return nil, fmt.Errorf("%s: got %s, expecting a proto message constructor", name, ctor.Type())
   430  		}
   431  
   432  		cache := messageCache(th)
   433  		cacheName := fmt.Sprintf("%s:%s", name, discardUnknown)
   434  		if cache != nil {
   435  			cached, err := cache.Fetch(th, cacheName, body, typ)
   436  			if err != nil {
   437  				return nil, fmt.Errorf("%s: internal message cache error when fetching: %s", name, err)
   438  			}
   439  			if cached != nil {
   440  				if cached.MessageType() != typ {
   441  					panic(fmt.Sprintf("the message cache returned message of type %s, but %s was expected", cached.MessageType(), typ))
   442  				}
   443  				if !cached.IsFrozen() {
   444  					panic("the message cache returned non-frozen message")
   445  				}
   446  				return cached, nil
   447  			}
   448  		}
   449  
   450  		msg, err := impl(typ, []byte(body), bool(discardUnknown))
   451  		if err != nil {
   452  			return nil, fmt.Errorf("%s: %s", name, err)
   453  		}
   454  		msg.Freeze()
   455  
   456  		if cache != nil {
   457  			if err := cache.Store(th, cacheName, body, msg); err != nil {
   458  				return nil, fmt.Errorf("%s: internal message cache error when storing: %s", name, err)
   459  			}
   460  		}
   461  
   462  		return msg, nil
   463  	})
   464  }
   465  
   466  // clone returns a copy of a given message.
   467  func clone(th *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
   468  	var msg *Message
   469  	if err := starlark.UnpackArgs("clone", args, kwargs, "msg", &msg); err != nil {
   470  		return nil, err
   471  	}
   472  	return msg.MessageType().MessageFromProto(proto.Clone(msg.ToProto())), nil
   473  }
   474  
   475  // has checks a presence of an optional field.
   476  func has(th *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
   477  	var msg *Message
   478  	var field string
   479  	if err := starlark.UnpackArgs("has", args, kwargs, "msg", &msg, "field", &field); err != nil {
   480  		return nil, err
   481  	}
   482  	return starlark.Bool(msg.HasProtoField(field)), nil
   483  }
   484  
   485  // TODO(vadimsh): Remove once users switch to protos.
   486  
   487  // structToTextPb takes a struct and returns a string containing a text format
   488  // protocol buffer.
   489  func structToTextPb(_ *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
   490  	var val starlark.Value
   491  	if err := starlark.UnpackArgs("struct_to_textpb", args, kwargs, "struct", &val); err != nil {
   492  		return nil, err
   493  	}
   494  	s, ok := val.(*starlarkstruct.Struct)
   495  	if !ok {
   496  		return nil, fmt.Errorf("struct_to_textpb: got %s, expecting a struct", val.Type())
   497  	}
   498  	var buf bytes.Buffer
   499  	err := writeProtoStruct(&buf, 0, s)
   500  	if err != nil {
   501  		return nil, err
   502  	}
   503  	return starlark.String(buf.String()), nil
   504  }
   505  
   506  // Based on
   507  // https://github.com/google/starlark-go/blob/32ce6ec36500ded2e2340a430fae42bc43da8467/starlarkstruct/struct.go
   508  func writeProtoStruct(out *bytes.Buffer, depth int, s *starlarkstruct.Struct) error {
   509  	for _, name := range s.AttrNames() {
   510  		val, err := s.Attr(name)
   511  		if err != nil {
   512  			return err
   513  		}
   514  		if err = writeProtoField(out, depth, name, val); err != nil {
   515  			return err
   516  		}
   517  	}
   518  	return nil
   519  }
   520  
   521  func writeProtoField(out *bytes.Buffer, depth int, field string, v starlark.Value) error {
   522  	if depth > 16 {
   523  		return fmt.Errorf("struct_to_textpb: depth limit exceeded")
   524  	}
   525  
   526  	switch v := v.(type) {
   527  	case *starlarkstruct.Struct:
   528  		fmt.Fprintf(out, "%*s%s: <\n", 2*depth, "", field)
   529  		if err := writeProtoStruct(out, depth+1, v); err != nil {
   530  			return err
   531  		}
   532  		fmt.Fprintf(out, "%*s>\n", 2*depth, "")
   533  		return nil
   534  
   535  	case *starlark.List, starlark.Tuple:
   536  		iter := starlark.Iterate(v)
   537  		defer iter.Done()
   538  		var elem starlark.Value
   539  		for iter.Next(&elem) {
   540  			if err := writeProtoField(out, depth, field, elem); err != nil {
   541  				return err
   542  			}
   543  		}
   544  		return nil
   545  	}
   546  
   547  	// scalars
   548  	fmt.Fprintf(out, "%*s%s: ", 2*depth, "", field)
   549  	switch v := v.(type) {
   550  	case starlark.Bool:
   551  		fmt.Fprintf(out, "%t", v)
   552  
   553  	case starlark.Int:
   554  		out.WriteString(v.String())
   555  
   556  	case starlark.Float:
   557  		fmt.Fprintf(out, "%g", v)
   558  
   559  	case starlark.String:
   560  		fmt.Fprintf(out, "%q", string(v))
   561  
   562  	default:
   563  		return fmt.Errorf("struct_to_textpb: cannot convert %s to proto", v.Type())
   564  	}
   565  	out.WriteByte('\n')
   566  	return nil
   567  }