go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/starlark/starlarkproto/loader.go (about)

     1  // Copyright 2019 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package starlarkproto
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  	"sync"
    21  	"sync/atomic"
    22  
    23  	"go.starlark.net/starlark"
    24  	"go.starlark.net/starlarkstruct"
    25  
    26  	"google.golang.org/protobuf/reflect/protodesc"
    27  	"google.golang.org/protobuf/reflect/protoreflect"
    28  	"google.golang.org/protobuf/reflect/protoregistry"
    29  	"google.golang.org/protobuf/types/descriptorpb"
    30  )
    31  
    32  // Loader can instantiate Starlark values that correspond to proto messages.
    33  //
    34  // Holds a pool of descriptors that describe all available proto types. Use
    35  // AddDescriptorSet to seed it. Once seeded, use Module to get a Starlark module
    36  // with symbols defined in some registered `*.proto` file.
    37  //
    38  // Loader is also a Starlark value itself, with the following methods:
    39  //   - add_descriptor_set(ds) - see AddDescriptorSet.
    40  //   - module(path) - see Module.
    41  //
    42  // Can be used concurrently. Non-freezable.
    43  type Loader struct {
    44  	m sync.RWMutex
    45  
    46  	files *protoregistry.Files
    47  	types *protoregistry.Types
    48  
    49  	dsets   map[*DescriptorSet]struct{}
    50  	mtypes  map[protoreflect.MessageDescriptor]*MessageType
    51  	modules map[string]*starlarkstruct.Module // *.proto file => its top-level symbols
    52  
    53  	hash uint32 // unique (within the process) value, used by Hash()
    54  }
    55  
    56  // loaderHash is used to give each instance of *Loader its own unique non-reused
    57  // hash value for Hash() method.
    58  var loaderHash uint32 = 1000
    59  
    60  // NewLoader instantiates a new loader with empty proto registry.
    61  func NewLoader() *Loader {
    62  	return &Loader{
    63  		files:   &protoregistry.Files{},
    64  		types:   &protoregistry.Types{},
    65  		dsets:   make(map[*DescriptorSet]struct{}, 0),
    66  		mtypes:  make(map[protoreflect.MessageDescriptor]*MessageType, 0),
    67  		modules: make(map[string]*starlarkstruct.Module, 0),
    68  		hash:    atomic.AddUint32(&loaderHash, 1),
    69  	}
    70  }
    71  
    72  // Key of the default *Loader in starlark.Thread local store.
    73  const threadLoaderKey = "starlarkproto.Loader"
    74  
    75  // DefaultLoader returns a loader installed in the thread via SetDefaultLoader.
    76  //
    77  // Returns nil if there's no default loader.
    78  func DefaultLoader(th *starlark.Thread) *Loader {
    79  	l, _ := th.Local(threadLoaderKey).(*Loader)
    80  	return l
    81  }
    82  
    83  // SetDefaultLoader installs the given loader as default in the thread.
    84  //
    85  // It can be obtained via DefaultLoader or proto.default_loader() from Starlark.
    86  // Note that Starlark code has no way of changing the default loader. It's
    87  // responsibility of the hosting environment to prepare the default loader
    88  // (just like it prepares starlark.Thread itself).
    89  func SetDefaultLoader(th *starlark.Thread, l *Loader) {
    90  	th.SetLocal(threadLoaderKey, l)
    91  }
    92  
    93  // Types returns a registry for looking up or iterating over descriptor types.
    94  func (l *Loader) Types() *protoregistry.Types {
    95  	return l.types
    96  }
    97  
    98  // AddDescriptorSet makes all *.proto files defined in the given descriptor set
    99  // and all its dependencies available for use from Starlark.
   100  //
   101  // AddDescriptorSet is idempotent in a sense that calling AddDescriptorSet(ds)
   102  // multiple times with the exact same 'ds' is not an error. But trying to
   103  // register a proto file through multiple different descriptor sets is an error.
   104  func (l *Loader) AddDescriptorSet(ds *DescriptorSet) error {
   105  	l.m.Lock()
   106  	defer l.m.Unlock()
   107  	return l.addDescriptorSetLocked(ds)
   108  }
   109  
   110  // addDescriptorSetLocked implements AddDescriptorSet.
   111  func (l *Loader) addDescriptorSetLocked(ds *DescriptorSet) error {
   112  	if _, ok := l.dsets[ds]; ok {
   113  		return nil
   114  	}
   115  	for _, dep := range ds.deps {
   116  		if err := l.addDescriptorSetLocked(dep); err != nil {
   117  			return fmt.Errorf("%q: %s", ds.name, err)
   118  		}
   119  	}
   120  	for _, fd := range ds.fdps {
   121  		if err := l.addDescriptorLocked(fd); err != nil {
   122  			return fmt.Errorf("%q: %s", ds.name, err)
   123  		}
   124  	}
   125  	l.dsets[ds] = struct{}{}
   126  	return nil
   127  }
   128  
   129  // addDescriptor adds a single deserialized FileDescriptorProto.
   130  func (l *Loader) addDescriptorLocked(fd *descriptorpb.FileDescriptorProto) error {
   131  	// Load the file descriptor, resolving all references through 'res' which
   132  	// will capture unresolved ones. Note that per comments in protodesc/desc.go,
   133  	// there would be an option to tell protodesc.NewFile to make this check
   134  	// natively.
   135  	res := &resolver{r: l.files}
   136  	f, err := protodesc.NewFile(fd, res)
   137  	if err != nil {
   138  		return fmt.Errorf("resolving imports in %s: %s", fd.GetName(), err)
   139  	}
   140  
   141  	switch {
   142  	case len(res.files) != 0:
   143  		return fmt.Errorf(
   144  			"compiled proto file %s refers to undefined files: %s",
   145  			fd.GetName(), strings.Join(res.files, ", "))
   146  	case len(res.descs) != 0:
   147  		return fmt.Errorf(
   148  			"compiled proto file %s refers to undefined descriptors: %s",
   149  			fd.GetName(), strings.Join(res.descs, ", "))
   150  	}
   151  
   152  	if err := l.files.RegisterFile(f); err != nil {
   153  		return fmt.Errorf("registering %s: %s", fd.GetName(), err)
   154  	}
   155  
   156  	// TODO(vadimsh): Populate l.types somehow. It is used by encoders/decoders
   157  	// to handle google.protobuf.Any fields (which we currently do not support).
   158  
   159  	return nil
   160  }
   161  
   162  // resolver wraps protodesc.Resolver by capturing unresolved references.
   163  type resolver struct {
   164  	r protodesc.Resolver
   165  
   166  	files []string // unresolvable files
   167  	descs []string // unresolvable descriptors
   168  }
   169  
   170  func (r *resolver) FindFileByPath(p string) (protoreflect.FileDescriptor, error) {
   171  	d, err := r.r.FindFileByPath(p)
   172  	if err == protoregistry.NotFound {
   173  		r.files = append(r.files, p)
   174  	}
   175  	return d, err
   176  }
   177  
   178  func (r *resolver) FindDescriptorByName(n protoreflect.FullName) (protoreflect.Descriptor, error) {
   179  	d, err := r.r.FindDescriptorByName(n)
   180  	if err == protoregistry.NotFound {
   181  		r.descs = append(r.descs, string(n))
   182  	}
   183  	return d, err
   184  }
   185  
   186  // Module returns a module with top-level definitions from some *.proto file.
   187  //
   188  // The descriptor of this proto file should be registered already via
   189  // AddDescriptorSet. 'path' here is matched to what's in the descriptor, which
   190  // is a path to *.proto EXACTLY as it was given to 'protoc'.
   191  //
   192  // The name of the module matches the proto package name (per 'package ...'
   193  // statement in the proto file).
   194  func (l *Loader) Module(path string) (*starlarkstruct.Module, error) {
   195  	// Lookup in the cache under the reader lock.
   196  	mod, desc, err := func() (*starlarkstruct.Module, protoreflect.FileDescriptor, error) {
   197  		l.m.RLock()
   198  		defer l.m.RUnlock()
   199  		if mod := l.modules[path]; mod != nil {
   200  			return mod, nil, nil
   201  		}
   202  		desc, err := l.files.FindFileByPath(path)
   203  		if err != nil {
   204  			return nil, nil, fmt.Errorf("loading %s: %s", path, err)
   205  		}
   206  		return nil, desc, nil
   207  	}()
   208  	if mod != nil || err != nil {
   209  		return mod, err
   210  	}
   211  
   212  	l.m.Lock()
   213  	defer l.m.Unlock()
   214  
   215  	// Populate the module dict with top-level symbols in the file.
   216  	mod = &starlarkstruct.Module{
   217  		Name:    string(desc.Package()),
   218  		Members: starlark.StringDict{},
   219  	}
   220  	l.injectMessageTypesLocked(mod.Members, desc.Messages())
   221  	l.injectEnumValuesLocked(mod.Members, desc.Enums())
   222  
   223  	l.modules[path] = mod
   224  	return mod, nil
   225  }
   226  
   227  // MessageType creates new (or returns existing) MessageType.
   228  //
   229  // The return value can be used to instantiate Starlark values via Message() or
   230  // MessageFromProto(m).
   231  func (l *Loader) MessageType(desc protoreflect.MessageDescriptor) *MessageType {
   232  	l.m.RLock()
   233  	mt := l.mtypes[desc]
   234  	l.m.RUnlock()
   235  	if mt != nil {
   236  		return mt
   237  	}
   238  
   239  	l.m.Lock()
   240  	defer l.m.Unlock()
   241  	return l.initMessageTypeLocked(desc)
   242  }
   243  
   244  // initMessageTypeLocked creates *MessageType if it didn't exist before.
   245  func (l *Loader) initMessageTypeLocked(desc protoreflect.MessageDescriptor) *MessageType {
   246  	if typ := l.mtypes[desc]; typ != nil {
   247  		return typ
   248  	}
   249  
   250  	typ := &MessageType{
   251  		loader: l,
   252  		desc:   desc,
   253  		attrs:  starlark.StringDict{},
   254  	}
   255  	typ.initLocked()
   256  
   257  	// Constructor function that uses `typ` to instantiate messages.
   258  	typ.Builtin = starlark.NewBuiltin(typ.Type(), func(_ *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
   259  		if len(args) != 0 {
   260  			return nil, fmt.Errorf("proto message constructors accept only keyword arguments")
   261  		}
   262  		msg := typ.Message()
   263  		for _, kv := range kwargs {
   264  			if err := msg.SetField(string(kv[0].(starlark.String)), kv[1]); err != nil {
   265  				return nil, err
   266  			}
   267  		}
   268  		return msg, nil
   269  	})
   270  
   271  	// Inject nested symbols.
   272  	l.injectMessageTypesLocked(typ.attrs, desc.Messages())
   273  	l.injectEnumValuesLocked(typ.attrs, desc.Enums())
   274  
   275  	l.mtypes[desc] = typ
   276  	return typ
   277  }
   278  
   279  // injectMessageTypesLocked instantiates constructors for messages in 'msgs' and
   280  // adds them to the dict 'd'.
   281  func (l *Loader) injectMessageTypesLocked(d starlark.StringDict, msgs protoreflect.MessageDescriptors) {
   282  	for i := 0; i < msgs.Len(); i++ {
   283  		desc := msgs.Get(i)
   284  		// map<...> fields are represented by magical map message types. We do not
   285  		// expose them on Starlark level and represent maps as dicts instead.
   286  		if !desc.IsMapEntry() {
   287  			d[string(desc.Name())] = l.initMessageTypeLocked(desc)
   288  		}
   289  	}
   290  }
   291  
   292  // injectEnumValuesLocked takes enum constants defined in 'enums' and puts them
   293  // directly into the given dict as integers.
   294  func (l *Loader) injectEnumValuesLocked(d starlark.StringDict, enums protoreflect.EnumDescriptors) {
   295  	for i := 0; i < enums.Len(); i++ {
   296  		vals := enums.Get(i).Values()
   297  		for j := 0; j < vals.Len(); j++ {
   298  			val := vals.Get(j)
   299  			d[string(val.Name())] = starlark.MakeInt(int(val.Number()))
   300  		}
   301  	}
   302  }
   303  
   304  // Implementation of starlark.Value and starlark.HasAttrs.
   305  
   306  // String returns str(...) representation of the loader.
   307  func (l *Loader) String() string {
   308  	return fmt.Sprintf("proto.Loader(0x%x)", l.hash)
   309  }
   310  
   311  // Type returns "proto.Loader".
   312  func (l *Loader) Type() string {
   313  	return "proto.Loader"
   314  }
   315  
   316  // Freeze is noop for now.
   317  func (l *Loader) Freeze() {}
   318  
   319  // Truth returns True.
   320  func (l *Loader) Truth() starlark.Bool { return starlark.True }
   321  
   322  // Hash returns an integer assigned to this loader when it was created.
   323  func (l *Loader) Hash() (uint32, error) { return l.hash, nil }
   324  
   325  // AtrrNames lists available attributes.
   326  func (l *Loader) AttrNames() []string {
   327  	return []string{
   328  		"add_descriptor_set",
   329  		"module",
   330  	}
   331  }
   332  
   333  // Attr returns an attribute given its name (or nil if not present).
   334  func (l *Loader) Attr(name string) (starlark.Value, error) {
   335  	switch name {
   336  	case "add_descriptor_set":
   337  		return addDescSetBuiltin.BindReceiver(l), nil
   338  	case "module":
   339  		return moduleBuiltin.BindReceiver(l), nil
   340  	default:
   341  		return nil, nil
   342  	}
   343  }
   344  
   345  // Shims for calling Loader methods from Starlark.
   346  
   347  var addDescSetBuiltin = starlark.NewBuiltin("add_descriptor_set", func(_ *starlark.Thread, b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
   348  	var ds *DescriptorSet
   349  	if err := starlark.UnpackPositionalArgs("add_descriptor_set", args, kwargs, 1, &ds); err != nil {
   350  		return nil, err
   351  	}
   352  	return starlark.None, b.Receiver().(*Loader).AddDescriptorSet(ds)
   353  })
   354  
   355  var moduleBuiltin = starlark.NewBuiltin("module", func(_ *starlark.Thread, b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
   356  	var path string
   357  	if err := starlark.UnpackPositionalArgs("module", args, kwargs, 1, &path); err != nil {
   358  		return nil, err
   359  	}
   360  	return b.Receiver().(*Loader).Module(path)
   361  })