github.com/hoveychen/protoreflect@v1.4.7-0.20221103114119-0b4b3385ec76/desc/builder/resolver.go (about)

     1  package builder
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"strings"
     7  
     8  	"github.com/golang/protobuf/proto"
     9  
    10  	"github.com/hoveychen/protoreflect/desc"
    11  	"github.com/hoveychen/protoreflect/dynamic"
    12  )
    13  
    14  type dependencies struct {
    15  	descs map[*desc.FileDescriptor]struct{}
    16  	exts  dynamic.ExtensionRegistry
    17  }
    18  
    19  func newDependencies() *dependencies {
    20  	return &dependencies{
    21  		descs: map[*desc.FileDescriptor]struct{}{},
    22  	}
    23  }
    24  
    25  func (d *dependencies) add(fd *desc.FileDescriptor) {
    26  	if _, ok := d.descs[fd]; ok {
    27  		// already added
    28  		return
    29  	}
    30  	d.descs[fd] = struct{}{}
    31  	d.exts.AddExtensionsFromFile(fd)
    32  }
    33  
    34  // dependencyResolver is the work-horse for converting a tree of builders into a
    35  // tree of descriptors. It scans a root (usually a file builder) and recursively
    36  // resolves all dependencies (references to builders in other trees as well as
    37  // references to other already-built descriptors). The result of resolution is a
    38  // file descriptor (or an error).
    39  type dependencyResolver struct {
    40  	resolvedRoots map[Builder]*desc.FileDescriptor
    41  	seen          map[Builder]struct{}
    42  	opts          BuilderOptions
    43  }
    44  
    45  func newResolver(opts BuilderOptions) *dependencyResolver {
    46  	return &dependencyResolver{
    47  		resolvedRoots: map[Builder]*desc.FileDescriptor{},
    48  		seen:          map[Builder]struct{}{},
    49  		opts:          opts,
    50  	}
    51  }
    52  
    53  func (r *dependencyResolver) resolveElement(b Builder, seen []Builder) (*desc.FileDescriptor, error) {
    54  	b = getRoot(b)
    55  
    56  	if fd, ok := r.resolvedRoots[b]; ok {
    57  		return fd, nil
    58  	}
    59  
    60  	for _, s := range seen {
    61  		if s == b {
    62  			names := make([]string, len(seen)+1)
    63  			for i, s := range seen {
    64  				names[i] = s.GetName()
    65  			}
    66  			names[len(seen)] = b.GetName()
    67  			return nil, fmt.Errorf("descriptors have cyclic dependency: %s", strings.Join(names, " ->  "))
    68  		}
    69  	}
    70  	seen = append(seen, b)
    71  
    72  	var fd *desc.FileDescriptor
    73  	var err error
    74  	switch b := b.(type) {
    75  	case *FileBuilder:
    76  		fd, err = r.resolveFile(b, b, seen)
    77  	default:
    78  		fd, err = r.resolveSyntheticFile(b, seen)
    79  	}
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  	r.resolvedRoots[b] = fd
    84  	return fd, nil
    85  }
    86  
    87  func (r *dependencyResolver) resolveFile(fb *FileBuilder, root Builder, seen []Builder) (*desc.FileDescriptor, error) {
    88  	deps := newDependencies()
    89  	// add explicit imports first
    90  	for fd := range fb.explicitImports {
    91  		deps.add(fd)
    92  	}
    93  	for dep := range fb.explicitDeps {
    94  		if dep == fb {
    95  			// ignore erroneous self references
    96  			continue
    97  		}
    98  		fd, err := r.resolveElement(dep, seen)
    99  		if err != nil {
   100  			return nil, err
   101  		}
   102  		deps.add(fd)
   103  	}
   104  	// now accumulate implicit dependencies based on other types referenced
   105  	for _, mb := range fb.messages {
   106  		if err := r.resolveTypesInMessage(root, seen, deps, mb); err != nil {
   107  			return nil, err
   108  		}
   109  	}
   110  	for _, exb := range fb.extensions {
   111  		if err := r.resolveTypesInExtension(root, seen, deps, exb); err != nil {
   112  			return nil, err
   113  		}
   114  	}
   115  	for _, sb := range fb.services {
   116  		if err := r.resolveTypesInService(root, seen, deps, sb); err != nil {
   117  			return nil, err
   118  		}
   119  	}
   120  
   121  	// finally, resolve custom options (which may refer to deps already
   122  	// computed above)
   123  	if err := r.resolveTypesInFileOptions(root, deps, fb); err != nil {
   124  		return nil, err
   125  	}
   126  
   127  	depSlice := make([]*desc.FileDescriptor, 0, len(deps.descs))
   128  	for dep := range deps.descs {
   129  		depSlice = append(depSlice, dep)
   130  	}
   131  
   132  	fp, err := fb.buildProto(depSlice)
   133  	if err != nil {
   134  		return nil, err
   135  	}
   136  
   137  	// make sure this file name doesn't collide with any of its dependencies
   138  	fileNames := map[string]struct{}{}
   139  	for _, d := range depSlice {
   140  		addFileNames(d, fileNames)
   141  		fileNames[d.GetName()] = struct{}{}
   142  	}
   143  	unique := makeUnique(fp.GetName(), fileNames)
   144  	if unique != fp.GetName() {
   145  		fp.Name = proto.String(unique)
   146  	}
   147  
   148  	return desc.CreateFileDescriptor(fp, depSlice...)
   149  }
   150  
   151  func addFileNames(fd *desc.FileDescriptor, files map[string]struct{}) {
   152  	if _, ok := files[fd.GetName()]; ok {
   153  		// already added
   154  		return
   155  	}
   156  	files[fd.GetName()] = struct{}{}
   157  	for _, d := range fd.GetDependencies() {
   158  		addFileNames(d, files)
   159  	}
   160  }
   161  
   162  func (r *dependencyResolver) resolveSyntheticFile(b Builder, seen []Builder) (*desc.FileDescriptor, error) {
   163  	// find ancestor to temporarily attach to new file
   164  	curr := b
   165  	for curr.GetParent() != nil {
   166  		curr = curr.GetParent()
   167  	}
   168  	f := NewFile("")
   169  	switch curr := curr.(type) {
   170  	case *MessageBuilder:
   171  		f.messages = append(f.messages, curr)
   172  	case *EnumBuilder:
   173  		f.enums = append(f.enums, curr)
   174  	case *ServiceBuilder:
   175  		f.services = append(f.services, curr)
   176  	case *FieldBuilder:
   177  		if curr.IsExtension() {
   178  			f.extensions = append(f.extensions, curr)
   179  		} else {
   180  			panic("field must be added to message before calling Build()")
   181  		}
   182  	case *OneOfBuilder:
   183  		if _, ok := b.(*OneOfBuilder); ok {
   184  			panic("one-of must be added to message before calling Build()")
   185  		} else {
   186  			// b was a child of one-of which means it must have been a field
   187  			panic("field must be added to message before calling Build()")
   188  		}
   189  	case *MethodBuilder:
   190  		panic("method must be added to service before calling Build()")
   191  	case *EnumValueBuilder:
   192  		panic("enum value must be added to enum before calling Build()")
   193  	default:
   194  		panic(fmt.Sprintf("Unrecognized kind of builder: %T", b))
   195  	}
   196  	curr.setParent(f)
   197  
   198  	// don't forget to reset when done
   199  	defer func() {
   200  		curr.setParent(nil)
   201  	}()
   202  
   203  	return r.resolveFile(f, b, seen)
   204  }
   205  
   206  func (r *dependencyResolver) resolveTypesInMessage(root Builder, seen []Builder, deps *dependencies, mb *MessageBuilder) error {
   207  	for _, b := range mb.fieldsAndOneOfs {
   208  		if flb, ok := b.(*FieldBuilder); ok {
   209  			if err := r.resolveTypesInField(root, seen, deps, flb); err != nil {
   210  				return err
   211  			}
   212  		} else {
   213  			oob := b.(*OneOfBuilder)
   214  			for _, flb := range oob.choices {
   215  				if err := r.resolveTypesInField(root, seen, deps, flb); err != nil {
   216  					return err
   217  				}
   218  			}
   219  		}
   220  	}
   221  	for _, nmb := range mb.nestedMessages {
   222  		if err := r.resolveTypesInMessage(root, seen, deps, nmb); err != nil {
   223  			return err
   224  		}
   225  	}
   226  	for _, exb := range mb.nestedExtensions {
   227  		if err := r.resolveTypesInExtension(root, seen, deps, exb); err != nil {
   228  			return err
   229  		}
   230  	}
   231  	return nil
   232  }
   233  
   234  func (r *dependencyResolver) resolveTypesInExtension(root Builder, seen []Builder, deps *dependencies, exb *FieldBuilder) error {
   235  	if err := r.resolveTypesInField(root, seen, deps, exb); err != nil {
   236  		return err
   237  	}
   238  	if exb.foreignExtendee != nil {
   239  		deps.add(exb.foreignExtendee.GetFile())
   240  	} else if err := r.resolveType(root, seen, exb.localExtendee, deps); err != nil {
   241  		return err
   242  	}
   243  	return nil
   244  }
   245  
   246  func (r *dependencyResolver) resolveTypesInService(root Builder, seen []Builder, deps *dependencies, sb *ServiceBuilder) error {
   247  	for _, mtb := range sb.methods {
   248  		if err := r.resolveRpcType(root, seen, mtb.ReqType, deps); err != nil {
   249  			return err
   250  		}
   251  		if err := r.resolveRpcType(root, seen, mtb.RespType, deps); err != nil {
   252  			return err
   253  		}
   254  	}
   255  	return nil
   256  }
   257  
   258  func (r *dependencyResolver) resolveRpcType(root Builder, seen []Builder, t *RpcType, deps *dependencies) error {
   259  	if t.foreignType != nil {
   260  		deps.add(t.foreignType.GetFile())
   261  	} else {
   262  		return r.resolveType(root, seen, t.localType, deps)
   263  	}
   264  	return nil
   265  }
   266  
   267  func (r *dependencyResolver) resolveTypesInField(root Builder, seen []Builder, deps *dependencies, flb *FieldBuilder) error {
   268  	if flb.fieldType.foreignMsgType != nil {
   269  		deps.add(flb.fieldType.foreignMsgType.GetFile())
   270  	} else if flb.fieldType.foreignEnumType != nil {
   271  		deps.add(flb.fieldType.foreignEnumType.GetFile())
   272  	} else if flb.fieldType.localMsgType != nil {
   273  		if flb.fieldType.localMsgType == flb.msgType {
   274  			return r.resolveTypesInMessage(root, seen, deps, flb.msgType)
   275  		} else {
   276  			return r.resolveType(root, seen, flb.fieldType.localMsgType, deps)
   277  		}
   278  	} else if flb.fieldType.localEnumType != nil {
   279  		return r.resolveType(root, seen, flb.fieldType.localEnumType, deps)
   280  	}
   281  	return nil
   282  }
   283  
   284  func (r *dependencyResolver) resolveType(root Builder, seen []Builder, typeBuilder Builder, deps *dependencies) error {
   285  	otherRoot := getRoot(typeBuilder)
   286  	if root == otherRoot {
   287  		// local reference, so it will get resolved when we finish resolving this root
   288  		return nil
   289  	}
   290  	fd, err := r.resolveElement(otherRoot, seen)
   291  	if err != nil {
   292  		return err
   293  	}
   294  	deps.add(fd)
   295  	return nil
   296  }
   297  
   298  func (r *dependencyResolver) resolveTypesInFileOptions(root Builder, deps *dependencies, fb *FileBuilder) error {
   299  	for _, mb := range fb.messages {
   300  		if err := r.resolveTypesInMessageOptions(root, deps, mb); err != nil {
   301  			return err
   302  		}
   303  	}
   304  	for _, eb := range fb.enums {
   305  		if err := r.resolveTypesInEnumOptions(root, deps, eb); err != nil {
   306  			return err
   307  		}
   308  	}
   309  	for _, exb := range fb.extensions {
   310  		if err := r.resolveTypesInOptions(root, deps, exb.Options); err != nil {
   311  			return err
   312  		}
   313  	}
   314  	for _, sb := range fb.services {
   315  		for _, mtb := range sb.methods {
   316  			if err := r.resolveTypesInOptions(root, deps, mtb.Options); err != nil {
   317  				return err
   318  			}
   319  		}
   320  		if err := r.resolveTypesInOptions(root, deps, sb.Options); err != nil {
   321  			return err
   322  		}
   323  	}
   324  	return r.resolveTypesInOptions(root, deps, fb.Options)
   325  }
   326  
   327  func (r *dependencyResolver) resolveTypesInMessageOptions(root Builder, deps *dependencies, mb *MessageBuilder) error {
   328  	for _, b := range mb.fieldsAndOneOfs {
   329  		if flb, ok := b.(*FieldBuilder); ok {
   330  			if err := r.resolveTypesInOptions(root, deps, flb.Options); err != nil {
   331  				return err
   332  			}
   333  		} else {
   334  			oob := b.(*OneOfBuilder)
   335  			for _, flb := range oob.choices {
   336  				if err := r.resolveTypesInOptions(root, deps, flb.Options); err != nil {
   337  					return err
   338  				}
   339  			}
   340  			if err := r.resolveTypesInOptions(root, deps, oob.Options); err != nil {
   341  				return err
   342  			}
   343  		}
   344  	}
   345  	for _, extr := range mb.ExtensionRanges {
   346  		if err := r.resolveTypesInOptions(root, deps, extr.Options); err != nil {
   347  			return err
   348  		}
   349  	}
   350  	for _, eb := range mb.nestedEnums {
   351  		if err := r.resolveTypesInEnumOptions(root, deps, eb); err != nil {
   352  			return err
   353  		}
   354  	}
   355  	for _, nmb := range mb.nestedMessages {
   356  		if err := r.resolveTypesInMessageOptions(root, deps, nmb); err != nil {
   357  			return err
   358  		}
   359  	}
   360  	for _, exb := range mb.nestedExtensions {
   361  		if err := r.resolveTypesInOptions(root, deps, exb.Options); err != nil {
   362  			return err
   363  		}
   364  	}
   365  	if err := r.resolveTypesInOptions(root, deps, mb.Options); err != nil {
   366  		return err
   367  	}
   368  	return nil
   369  }
   370  
   371  func (r *dependencyResolver) resolveTypesInEnumOptions(root Builder, deps *dependencies, eb *EnumBuilder) error {
   372  	for _, evb := range eb.values {
   373  		if err := r.resolveTypesInOptions(root, deps, evb.Options); err != nil {
   374  			return err
   375  		}
   376  	}
   377  	if err := r.resolveTypesInOptions(root, deps, eb.Options); err != nil {
   378  		return err
   379  	}
   380  	return nil
   381  }
   382  
   383  func (r *dependencyResolver) resolveTypesInOptions(root Builder, deps *dependencies, opts proto.Message) error {
   384  	// nothing to see if opts is nil
   385  	if opts == nil {
   386  		return nil
   387  	}
   388  	if rv := reflect.ValueOf(opts); rv.Kind() == reflect.Ptr && rv.IsNil() {
   389  		return nil
   390  	}
   391  
   392  	msgName := proto.MessageName(opts)
   393  
   394  	extdescs, err := proto.ExtensionDescs(opts)
   395  	if err != nil {
   396  		return err
   397  	}
   398  	for _, ed := range extdescs {
   399  		// see if known dependencies have this option
   400  		extd := deps.exts.FindExtension(msgName, ed.Field)
   401  		if extd != nil {
   402  			// yep! nothing else to do
   403  			continue
   404  		}
   405  		// see if this extension is defined in *this* builder
   406  		if findExtension(root, msgName, ed.Field) {
   407  			// yep!
   408  			continue
   409  		}
   410  		// see if configured extension registry knows about it
   411  		extd = r.opts.Extensions.FindExtension(msgName, ed.Field)
   412  		if extd != nil {
   413  			// extension registry recognized it!
   414  			deps.add(extd.GetFile())
   415  		} else if ed.Filename != "" {
   416  			// if filename is filled in, then this extension is a known
   417  			// extension (registered in proto package)
   418  			fd, err := desc.LoadFileDescriptor(ed.Filename)
   419  			if err != nil {
   420  				return err
   421  			}
   422  			deps.add(fd)
   423  		} else if r.opts.RequireInterpretedOptions {
   424  			// we require options to be interpreted but are not able to!
   425  			return fmt.Errorf("could not interpret custom option for %s, tag %d", msgName, ed.Field)
   426  		}
   427  	}
   428  	return nil
   429  }
   430  
   431  func findExtension(b Builder, messageName string, extTag int32) bool {
   432  	if fb, ok := b.(*FileBuilder); ok && findExtensionInFile(fb, messageName, extTag) {
   433  		return true
   434  	}
   435  	if mb, ok := b.(*MessageBuilder); ok && findExtensionInMessage(mb, messageName, extTag) {
   436  		return true
   437  	}
   438  	return false
   439  }
   440  
   441  func findExtensionInFile(fb *FileBuilder, messageName string, extTag int32) bool {
   442  	for _, extb := range fb.extensions {
   443  		if extb.GetExtendeeTypeName() == messageName && extb.number == extTag {
   444  			return true
   445  		}
   446  	}
   447  	for _, mb := range fb.messages {
   448  		if findExtensionInMessage(mb, messageName, extTag) {
   449  			return true
   450  		}
   451  	}
   452  	return false
   453  }
   454  
   455  func findExtensionInMessage(mb *MessageBuilder, messageName string, extTag int32) bool {
   456  	for _, extb := range mb.nestedExtensions {
   457  		if extb.GetExtendeeTypeName() == messageName && extb.number == extTag {
   458  			return true
   459  		}
   460  	}
   461  	for _, mb := range mb.nestedMessages {
   462  		if findExtensionInMessage(mb, messageName, extTag) {
   463  			return true
   464  		}
   465  	}
   466  	return false
   467  }