go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/starlark/starlarkproto/message.go (about)

     1  // Copyright 2019 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package starlarkproto
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  
    21  	"go.starlark.net/starlark"
    22  	"go.starlark.net/syntax"
    23  
    24  	"google.golang.org/protobuf/proto"
    25  	"google.golang.org/protobuf/reflect/protoreflect"
    26  	"google.golang.org/protobuf/types/dynamicpb"
    27  
    28  	"github.com/protocolbuffers/txtpbfmt/parser"
    29  )
    30  
    31  // Message is a Starlark value that implements a struct-like type structured
    32  // like a protobuf message.
    33  //
    34  // Implements starlark.Value, starlark.HasAttrs, starlark.HasSetField and
    35  // starlark.Comparable interfaces.
    36  //
    37  // Can be instantiated through Loader as loader.MessageType(...).Message() or
    38  // loader.MessageType(...).MessageFromProto(p).
    39  //
    40  // TODO(vadimsh): Currently not safe for a cross-goroutine use without external
    41  // locking, even when frozen, due to lazy initialization of default fields on
    42  // first access.
    43  type Message struct {
    44  	typ    *MessageType        // type information
    45  	fields starlark.StringDict // populated fields, keyed by proto field name
    46  	frozen bool                // true after Freeze()
    47  }
    48  
    49  var (
    50  	_ starlark.Value       = (*Message)(nil)
    51  	_ starlark.HasAttrs    = (*Message)(nil)
    52  	_ starlark.HasSetField = (*Message)(nil)
    53  	_ starlark.Comparable  = (*Message)(nil)
    54  )
    55  
    56  // Public API used by the hosting environment.
    57  
    58  // MessageType returns type information about this message.
    59  func (m *Message) MessageType() *MessageType {
    60  	return m.typ
    61  }
    62  
    63  // IsFrozen returns true if this message was frozen already.
    64  func (m *Message) IsFrozen() bool {
    65  	return m.frozen
    66  }
    67  
    68  // ToProto returns a new populated proto message of an appropriate type.
    69  func (m *Message) ToProto() proto.Message {
    70  	msg := dynamicpb.NewMessage(m.typ.desc)
    71  	for k, v := range m.fields {
    72  		assign(msg, m.typ.fields[k], v)
    73  	}
    74  	return msg
    75  }
    76  
    77  // FromDict populates fields of this message based on values in an iterable
    78  // mapping (usually a starlark.Dict).
    79  //
    80  // Doesn't reset the message. Basically does this:
    81  //
    82  //	for k in d:
    83  //	  setattr(msg, k, d[k])
    84  //
    85  // Returns an error on type mismatch.
    86  func (m *Message) FromDict(d starlark.IterableMapping) error {
    87  	iter := d.Iterate()
    88  	defer iter.Done()
    89  
    90  	var k starlark.Value
    91  	for iter.Next(&k) {
    92  		key, ok := k.(starlark.String)
    93  		if !ok {
    94  			return fmt.Errorf("got %s dict key, want string", k.Type())
    95  		}
    96  		v, _, _ := d.Get(k)
    97  		if err := m.SetField(key.GoString(), v); err != nil {
    98  			return err
    99  		}
   100  	}
   101  
   102  	return nil
   103  }
   104  
   105  // HasProtoField returns true if the message has the given field initialized.
   106  func (m *Message) HasProtoField(name string) bool {
   107  	// If the field was set through Starlark already, it exists. This also covers
   108  	// "selected" oneof alternatives.
   109  	if _, ok := m.fields[name]; ok {
   110  		return true
   111  	}
   112  
   113  	// Check we have this field defined in the schema at all.
   114  	fd, err := m.fieldDesc(name)
   115  	if err != nil {
   116  		return false
   117  	}
   118  
   119  	// If this is a part of some oneof set, the field is assumed set only if it
   120  	// was explicitly initialized in Starlark (already checked above). So if we
   121  	// are here, then this particular oneof alternative wasn't used.
   122  	if fd.ContainingOneof() != nil {
   123  		return false
   124  	}
   125  
   126  	// Repeated and map fields are assumed to be always preset. They just may be
   127  	// empty.
   128  	if fd.IsList() || fd.IsMap() {
   129  		return true
   130  	}
   131  
   132  	// Singular message-typed fields are set only if they were explicitly
   133  	// initialized (and if we are here, they were not).
   134  	if kind := fd.Kind(); kind == protoreflect.MessageKind || kind == protoreflect.GroupKind {
   135  		return false
   136  	}
   137  
   138  	// Singular fields of primitive types are always set, since there's no way to
   139  	// distinguish fields initialized with a default value from unset fields.
   140  	return true
   141  }
   142  
   143  // Basic starlark.Value interface.
   144  
   145  // String returns compact text serialization of this message.
   146  func (m *Message) String() string {
   147  	raw := m.ToProto().(interface{ String() string }).String()
   148  	formatted, err := parser.FormatWithConfig([]byte(raw), parser.Config{
   149  		SkipAllColons: true,
   150  	})
   151  	if err != nil {
   152  		return fmt.Sprintf("<bad proto: %q>", err)
   153  	}
   154  	return strings.TrimSpace(string(formatted))
   155  }
   156  
   157  // Type returns full proto message name.
   158  func (m *Message) Type() string {
   159  	// The receiver is nil when doing type checks with starlark.UnpackArgs. It
   160  	// asks the nil message for its type for the error message.
   161  	if m == nil {
   162  		return "proto.Message"
   163  	}
   164  	return fmt.Sprintf("proto.Message<%s>", m.typ.desc.FullName())
   165  }
   166  
   167  // Freeze makes this message immutable.
   168  func (m *Message) Freeze() {
   169  	if !m.frozen {
   170  		m.fields.Freeze()
   171  		m.frozen = true
   172  	}
   173  }
   174  
   175  // Truth always returns True.
   176  func (m *Message) Truth() starlark.Bool { return starlark.True }
   177  
   178  // Hash returns an error, indicating proto messages are not hashable.
   179  func (m *Message) Hash() (uint32, error) {
   180  	return 0, fmt.Errorf("proto messages (and %s in particular) are not hashable", m.Type())
   181  }
   182  
   183  // HasAttrs and HasSetField interfaces that make the message look like a struct.
   184  
   185  // Attr is called when a field is read from Starlark code.
   186  func (m *Message) Attr(name string) (starlark.Value, error) {
   187  	return m.attrImpl(name, true)
   188  }
   189  
   190  // attrImpl is the actual implementation of Attr.
   191  //
   192  // 'mut' controls how attrImpl behaves if 'name' field is not set.
   193  //
   194  // If 'mut' is true, the field will be set to its default value and this value
   195  // is returned. This updates 'm' as a side effect.
   196  //
   197  // If 'mut' is false, and the field is not message-valued, its default value
   198  // is returned (but 'm' itself is not updated). This applies to repeated fields
   199  // and maps as well. But if the field is message-valued, None is returned.
   200  func (m *Message) attrImpl(name string, mut bool) (starlark.Value, error) {
   201  	// If the field was set through Starlark already, return its value right away.
   202  	val, ok := m.fields[name]
   203  	if ok {
   204  		return val, nil
   205  	}
   206  
   207  	// Check we have this field at all.
   208  	fd, err := m.fieldDesc(name)
   209  	if err != nil {
   210  		return nil, err
   211  	}
   212  
   213  	// If this is one alternative of some oneof field, do NOT instantiate it.
   214  	// This is needed to make sure callers are explicitly picking a oneof
   215  	// alternative by assigning a value to it, rather than have it be picked
   216  	// implicitly be reading an attribute (which is weird).
   217  	if fd.ContainingOneof() != nil {
   218  		return starlark.None, nil
   219  	}
   220  
   221  	// Don't auto-initialize message-valued fields if 'mut' is false. This case is
   222  	// special because 'm.msg = Msg{}' and 'm.msg = None' lead to observably
   223  	// different outcomes and we should account for that in 'messagesEqual'.
   224  	if !mut && !fd.IsList() && !fd.IsMap() &&
   225  		(fd.Kind() == protoreflect.MessageKind || fd.Kind() == protoreflect.GroupKind) {
   226  		return starlark.None, nil
   227  	}
   228  
   229  	// If this is not a oneof field, auto-initialize it to its default value. In
   230  	// particular this is important when chaining through fields `a.b.c.d`. We
   231  	// want intermediates to be silently auto-initialized.
   232  	//
   233  	// Note that lazy initialization of fields is an implementation detail. This
   234  	// is significant when considering frozen messages. From the caller's point of
   235  	// view, all fields had had their default values even before the object was
   236  	// frozen. So we lazy-initialize the field, even if the message is frozen, but
   237  	// make sure the new field is frozen itself too.
   238  	//
   239  	// TODO(vadimsh): This is not thread safe and should be improved if a frozen
   240  	// *Message is shared between goroutines. Generally frozen values are
   241  	// assumed to be safe for cross-goroutine use, which is not the case here.
   242  	// If this becomes important, we can force-initialize and freeze all default
   243  	// fields in Freeze(), but this is generally expensive.
   244  	def := toStarlark(m.typ.loader, fd, fd.Default())
   245  	if mut {
   246  		if m.frozen {
   247  			def.Freeze()
   248  		}
   249  		m.fields[name] = def
   250  	}
   251  	return def, nil
   252  }
   253  
   254  // AttrNames lists available attributes.
   255  func (m *Message) AttrNames() []string {
   256  	return m.typ.keys
   257  }
   258  
   259  // SetField is called when a field is assigned to from Starlark code.
   260  func (m *Message) SetField(name string, val starlark.Value) error {
   261  	// Check we have this field defined in the message.
   262  	fd, err := m.fieldDesc(name)
   263  	if err != nil {
   264  		return err
   265  	}
   266  
   267  	// Setting a field to None removes it completely.
   268  	if val == starlark.None {
   269  		if err := m.checkMutable(); err != nil {
   270  			return err
   271  		}
   272  		delete(m.fields, name)
   273  		return nil
   274  	}
   275  
   276  	// Check the type, do implicit type casts.
   277  	rhs, err := prepareRHS(m.typ.loader, fd, val)
   278  	if err != nil {
   279  		return fmt.Errorf("can't assign %s to field %q in %s: %s", val.Type(), name, m.Type(), err)
   280  	}
   281  	if err := m.checkMutable(); err != nil {
   282  		return err
   283  	}
   284  	m.fields[name] = rhs
   285  
   286  	// When assigning to a oneof alternative, clear its all other alternatives.
   287  	if oneof := fd.ContainingOneof(); oneof != nil {
   288  		alts := oneof.Fields()
   289  		for i := 0; i < alts.Len(); i++ {
   290  			if altfd := alts.Get(i); altfd != fd {
   291  				delete(m.fields, string(altfd.Name()))
   292  			}
   293  		}
   294  	}
   295  
   296  	return nil
   297  }
   298  
   299  // Comparable interface to implement '==' and '!='.
   300  
   301  // CompareSameType does 'm <op> y' comparison.
   302  func (m *Message) CompareSameType(op syntax.Token, y starlark.Value, depth int) (bool, error) {
   303  	switch op {
   304  	case syntax.EQL:
   305  		return messagesEqual(m, y.(*Message), depth)
   306  	case syntax.NEQ:
   307  		eq, err := messagesEqual(m, y.(*Message), depth)
   308  		return !eq, err
   309  	default:
   310  		return false, fmt.Errorf("%q is not implemented for %s", op, m.Type())
   311  	}
   312  }
   313  
   314  // messagesEqual compares two messages by value, recursively.
   315  func messagesEqual(l, r *Message, depth int) (bool, error) {
   316  	switch {
   317  	case l == r:
   318  		return true, nil // equal by identity
   319  	case l.typ != r.typ:
   320  		return false, nil // messages of different types are never equal
   321  	}
   322  	// We go through attrImpl(...) to correctly handle default values and oneof's.
   323  	// We don't want to mutate messages though.
   324  	for _, key := range l.typ.keys {
   325  		lv, err := l.attrImpl(key, false)
   326  		if err != nil {
   327  			return false, err
   328  		}
   329  		rv, err := r.attrImpl(key, false)
   330  		if err != nil {
   331  			return false, err
   332  		}
   333  		if eq, err := starlark.EqualDepth(lv, rv, depth-1); !eq || err != nil {
   334  			return false, err
   335  		}
   336  	}
   337  	return true, nil
   338  }
   339  
   340  // fieldDesc returns FieldDescriptor of the corresponding field or an error
   341  // message if there's no such field defined in the proto schema.
   342  func (m *Message) fieldDesc(name string) (protoreflect.FieldDescriptor, error) {
   343  	switch fd := m.typ.fields[name]; {
   344  	case fd != nil:
   345  		return fd, nil
   346  	case m.typ.desc.IsPlaceholder():
   347  		// This happens if 'm' lacks type information because its descriptor wasn't
   348  		// in any of the sets passed to loader.AddDescriptorSet(...). This should
   349  		// not really be happening since AddDescriptorSet(...) checks that all
   350  		// references are resolved. But handle this case anyway for a clearer error
   351  		// message if some unnoticed edge case pops up.
   352  		return nil, fmt.Errorf("internal error: descriptor of %s is not available, can't use this type", m.Type())
   353  	default:
   354  		return nil, fmt.Errorf("%s has no field %q", m.Type(), name)
   355  	}
   356  }
   357  
   358  // checkMutable returns an error if the message is frozen.
   359  func (m *Message) checkMutable() error {
   360  	if m.frozen {
   361  		return fmt.Errorf("cannot modify frozen %s", m.Type())
   362  	}
   363  	return nil
   364  }