github.com/jhump/protocompile@v0.0.0-20221021153901-4f6f732835e8/linker/symbols.go (about)

     1  package linker
     2  
     3  import (
     4  	"strings"
     5  	"sync"
     6  
     7  	"google.golang.org/protobuf/proto"
     8  	"google.golang.org/protobuf/reflect/protoreflect"
     9  	"google.golang.org/protobuf/types/descriptorpb"
    10  
    11  	"github.com/jhump/protocompile/ast"
    12  	"github.com/jhump/protocompile/reporter"
    13  	"github.com/jhump/protocompile/walk"
    14  )
    15  
    16  // Symbols is a symbol table that maps names for all program elements to their
    17  // location in source. It also tracks extension tag numbers. This can be used
    18  // to enforce uniqueness for symbol names and tag numbers across many files and
    19  // many link operations.
    20  //
    21  // This type is thread-safe.
    22  type Symbols struct {
    23  	mu      sync.Mutex
    24  	files   map[protoreflect.FileDescriptor]struct{}
    25  	symbols map[protoreflect.FullName]symbolEntry
    26  	exts    map[protoreflect.FullName]map[protoreflect.FieldNumber]ast.SourcePos
    27  }
    28  
    29  type symbolEntry struct {
    30  	pos         ast.SourcePos
    31  	isEnumValue bool
    32  }
    33  
    34  // Import populates the symbol table with all symbols/elements and extension
    35  // tags present in the given file descriptor. If s is nil or if fd has already
    36  // been imported into s, this returns immediately without doing anything. If any
    37  // collisions in symbol names or extension tags are identified, an error will be
    38  // returned and the symbol table will not be updated.
    39  func (s *Symbols) Import(fd protoreflect.FileDescriptor, handler *reporter.Handler) error {
    40  	if s == nil {
    41  		return nil
    42  	}
    43  
    44  	if f, ok := fd.(file); ok {
    45  		// unwrap any file instance
    46  		fd = f.FileDescriptor
    47  	}
    48  
    49  	s.mu.Lock()
    50  	defer s.mu.Unlock()
    51  
    52  	return s.importLocked(fd, handler)
    53  }
    54  
    55  func (s *Symbols) importLocked(fd protoreflect.FileDescriptor, handler *reporter.Handler) error {
    56  	if _, ok := s.files[fd]; ok {
    57  		// already imported
    58  		return nil
    59  	}
    60  
    61  	// make sure deps are imported
    62  	for i := 0; i < fd.Imports().Len(); i++ {
    63  		imp := fd.Imports().Get(i)
    64  		if err := s.importLocked(imp.FileDescriptor, handler); err != nil {
    65  			return err
    66  		}
    67  	}
    68  
    69  	if res, ok := fd.(*result); ok {
    70  		return s.importResultLocked(res, false, true, handler)
    71  	}
    72  
    73  	// first pass: check for conflicts
    74  	if err := s.checkFileLocked(fd, handler); err != nil {
    75  		return err
    76  	}
    77  	if err := handler.Error(); err != nil {
    78  		return err
    79  	}
    80  
    81  	// second pass: commit all symbols
    82  	s.commitFileLocked(fd)
    83  
    84  	return nil
    85  }
    86  
    87  func reportSymbolCollision(pos ast.SourcePos, fqn protoreflect.FullName, additionIsEnumVal bool, existing symbolEntry, handler *reporter.Handler) error {
    88  	// because of weird scoping for enum values, provide more context in error message
    89  	// if this conflict is with an enum value
    90  	var suffix string
    91  	if additionIsEnumVal || existing.isEnumValue {
    92  		suffix = "; protobuf uses C++ scoping rules for enum values, so they exist in the scope enclosing the enum"
    93  	}
    94  	return handler.HandleErrorf(pos, "symbol %q already defined at %v%s", fqn, existing.pos, suffix)
    95  }
    96  
    97  func (s *Symbols) checkFileLocked(f protoreflect.FileDescriptor, handler *reporter.Handler) error {
    98  	return walk.Descriptors(f, func(d protoreflect.Descriptor) error {
    99  		pos := sourcePositionFor(d)
   100  		if existing, ok := s.symbols[d.FullName()]; ok {
   101  			_, isEnumVal := d.(protoreflect.EnumValueDescriptor)
   102  			if err := reportSymbolCollision(pos, d.FullName(), isEnumVal, existing, handler); err != nil {
   103  				return err
   104  			}
   105  		}
   106  
   107  		fld, ok := d.(protoreflect.FieldDescriptor)
   108  		if !ok || !fld.IsExtension() {
   109  			return nil
   110  		}
   111  
   112  		extendee := fld.ContainingMessage().FullName()
   113  		if tags, ok := s.exts[extendee]; ok {
   114  			if existing, ok := tags[fld.Number()]; ok {
   115  				if err := handler.HandleErrorf(pos, "extension with tag %d for message %s already defined at %v", fld.Number(), extendee, existing); err != nil {
   116  					return err
   117  				}
   118  			}
   119  		}
   120  
   121  		return nil
   122  	})
   123  }
   124  
   125  func sourcePositionFor(d protoreflect.Descriptor) ast.SourcePos {
   126  	loc := d.ParentFile().SourceLocations().ByDescriptor(d)
   127  	if isZeroLoc(loc) {
   128  		return ast.UnknownPos(d.ParentFile().Path())
   129  	}
   130  	return ast.SourcePos{
   131  		Filename: d.ParentFile().Path(),
   132  		Line:     loc.StartLine,
   133  		Col:      loc.StartColumn,
   134  	}
   135  }
   136  
   137  func isZeroLoc(loc protoreflect.SourceLocation) bool {
   138  	return loc.Path == nil &&
   139  		loc.StartLine == 0 &&
   140  		loc.StartColumn == 0 &&
   141  		loc.EndLine == 0 &&
   142  		loc.EndColumn == 0
   143  }
   144  
   145  func (s *Symbols) commitFileLocked(f protoreflect.FileDescriptor) {
   146  	if s.symbols == nil {
   147  		s.symbols = map[protoreflect.FullName]symbolEntry{}
   148  	}
   149  	if s.exts == nil {
   150  		s.exts = map[protoreflect.FullName]map[protoreflect.FieldNumber]ast.SourcePos{}
   151  	}
   152  	_ = walk.Descriptors(f, func(d protoreflect.Descriptor) error {
   153  		pos := sourcePositionFor(d)
   154  		name := d.FullName()
   155  		_, isEnumValue := d.(protoreflect.EnumValueDescriptor)
   156  		s.symbols[name] = symbolEntry{pos: pos, isEnumValue: isEnumValue}
   157  
   158  		fld, ok := d.(protoreflect.FieldDescriptor)
   159  		if !ok || !fld.IsExtension() {
   160  			return nil
   161  		}
   162  
   163  		extendee := fld.ContainingMessage().FullName()
   164  		tags := s.exts[extendee]
   165  		if tags == nil {
   166  			tags = map[protoreflect.FieldNumber]ast.SourcePos{}
   167  			s.exts[extendee] = tags
   168  		}
   169  		tags[fld.Number()] = pos
   170  
   171  		return nil
   172  	})
   173  
   174  	if s.files == nil {
   175  		s.files = map[protoreflect.FileDescriptor]struct{}{}
   176  	}
   177  	s.files[f] = struct{}{}
   178  }
   179  
   180  func (s *Symbols) importResult(r *result, populatePool bool, checkExts bool, handler *reporter.Handler) error {
   181  	s.mu.Lock()
   182  	defer s.mu.Unlock()
   183  
   184  	if _, ok := s.files[r]; ok {
   185  		// already imported
   186  		return nil
   187  	}
   188  
   189  	return s.importResultLocked(r, populatePool, checkExts, handler)
   190  }
   191  
   192  func (s *Symbols) importResultLocked(r *result, populatePool bool, checkExts bool, handler *reporter.Handler) error {
   193  	// first pass: check for conflicts
   194  	if err := s.checkResultLocked(r, checkExts, handler); err != nil {
   195  		return err
   196  	}
   197  	if err := handler.Error(); err != nil {
   198  		return err
   199  	}
   200  
   201  	// second pass: commit all symbols
   202  	s.commitResultLocked(r, populatePool)
   203  
   204  	return nil
   205  }
   206  
   207  func (s *Symbols) checkResultLocked(r *result, checkExts bool, handler *reporter.Handler) error {
   208  	resultSyms := map[protoreflect.FullName]symbolEntry{}
   209  	return walk.DescriptorProtos(r.Proto(), func(fqn protoreflect.FullName, d proto.Message) error {
   210  		_, isEnumVal := d.(*descriptorpb.EnumValueDescriptorProto)
   211  		file := r.FileNode()
   212  		node := r.Node(d)
   213  		pos := nameStart(file, node)
   214  		// check symbols already in this symbol table
   215  		if existing, ok := s.symbols[fqn]; ok {
   216  			if err := reportSymbolCollision(pos, fqn, isEnumVal, existing, handler); err != nil {
   217  				return err
   218  			}
   219  		}
   220  
   221  		// also check symbols from this result (that are not yet in symbol table)
   222  		if existing, ok := resultSyms[fqn]; ok {
   223  			if err := reportSymbolCollision(pos, fqn, isEnumVal, existing, handler); err != nil {
   224  				return err
   225  			}
   226  		}
   227  		resultSyms[fqn] = symbolEntry{
   228  			pos:         pos,
   229  			isEnumValue: isEnumVal,
   230  		}
   231  
   232  		if !checkExts {
   233  			return nil
   234  		}
   235  
   236  		fld, ok := d.(*descriptorpb.FieldDescriptorProto)
   237  		if !ok {
   238  			return nil
   239  		}
   240  		extendee := fld.GetExtendee()
   241  		if extendee == "" {
   242  			return nil
   243  		}
   244  
   245  		extendeeFqn := protoreflect.FullName(strings.TrimPrefix(extendee, "."))
   246  		if tags, ok := s.exts[extendeeFqn]; ok {
   247  			if existing, ok := tags[protoreflect.FieldNumber(fld.GetNumber())]; ok {
   248  				pos := file.NodeInfo(node.(ast.FieldDeclNode).FieldTag()).Start()
   249  				if err := handler.HandleErrorf(pos, "extension with tag %d for message %s already defined at %v", fld.GetNumber(), extendeeFqn, existing); err != nil {
   250  					return err
   251  				}
   252  			}
   253  		}
   254  
   255  		return nil
   256  	})
   257  }
   258  
   259  func nameStart(file ast.FileDeclNode, n ast.Node) ast.SourcePos {
   260  	// TODO: maybe ast package needs a NamedNode interface to simplify this?
   261  	switch n := n.(type) {
   262  	case ast.FieldDeclNode:
   263  		return file.NodeInfo(n.FieldName()).Start()
   264  	case ast.MessageDeclNode:
   265  		return file.NodeInfo(n.MessageName()).Start()
   266  	case ast.EnumValueDeclNode:
   267  		return file.NodeInfo(n.GetName()).Start()
   268  	case *ast.EnumNode:
   269  		return file.NodeInfo(n.Name).Start()
   270  	case *ast.ServiceNode:
   271  		return file.NodeInfo(n.Name).Start()
   272  	case ast.RPCDeclNode:
   273  		return file.NodeInfo(n.GetName()).Start()
   274  	default:
   275  		return file.NodeInfo(n).Start()
   276  	}
   277  }
   278  
   279  func (s *Symbols) commitResultLocked(r *result, populatePool bool) {
   280  	if s.symbols == nil {
   281  		s.symbols = map[protoreflect.FullName]symbolEntry{}
   282  	}
   283  	if s.exts == nil {
   284  		s.exts = map[protoreflect.FullName]map[protoreflect.FieldNumber]ast.SourcePos{}
   285  	}
   286  	_ = walk.DescriptorProtos(r.Proto(), func(fqn protoreflect.FullName, d proto.Message) error {
   287  		pos := nameStart(r.FileNode(), r.Node(d))
   288  		_, isEnumValue := d.(protoreflect.EnumValueDescriptor)
   289  		s.symbols[fqn] = symbolEntry{pos: pos, isEnumValue: isEnumValue}
   290  		if populatePool {
   291  			r.descriptorPool[string(fqn)] = d
   292  		}
   293  		return nil
   294  	})
   295  
   296  	if s.files == nil {
   297  		s.files = map[protoreflect.FileDescriptor]struct{}{}
   298  	}
   299  	s.files[r] = struct{}{}
   300  }
   301  
   302  func (s *Symbols) addExtension(extendee protoreflect.FullName, tag protoreflect.FieldNumber, pos ast.SourcePos, handler *reporter.Handler) error {
   303  	s.mu.Lock()
   304  	defer s.mu.Unlock()
   305  
   306  	if s.exts == nil {
   307  		s.exts = map[protoreflect.FullName]map[protoreflect.FieldNumber]ast.SourcePos{}
   308  	}
   309  
   310  	usedExtTags := s.exts[extendee]
   311  	if usedExtTags == nil {
   312  		usedExtTags = map[protoreflect.FieldNumber]ast.SourcePos{}
   313  		s.exts[extendee] = usedExtTags
   314  	}
   315  	if existing, ok := usedExtTags[tag]; ok {
   316  		if err := handler.HandleErrorf(pos, "extension with tag %d for message %s already defined at %v", tag, extendee, existing); err != nil {
   317  			return err
   318  		}
   319  	} else {
   320  		usedExtTags[tag] = pos
   321  	}
   322  	return nil
   323  }