github.com/jhump/protocompile@v0.0.0-20221021153901-4f6f732835e8/linker/validate.go (about)

     1  package linker
     2  
     3  import (
     4  	"google.golang.org/protobuf/reflect/protoreflect"
     5  	"google.golang.org/protobuf/types/descriptorpb"
     6  
     7  	"github.com/jhump/protocompile/internal"
     8  	"github.com/jhump/protocompile/reporter"
     9  )
    10  
    11  // ValidateExtensions runs some validation checks on extensions that can only
    12  // be done after files are linked and options are interpreted.
    13  func (r *result) ValidateExtensions(handler *reporter.Handler) error {
    14  	return r.validateExtensions(r, handler)
    15  }
    16  
    17  func (r *result) validateExtensions(d hasExtensionsAndMessages, handler *reporter.Handler) error {
    18  	for i := 0; i < d.Extensions().Len(); i++ {
    19  		if err := r.validateExtension(d.Extensions().Get(i), handler); err != nil {
    20  			return err
    21  		}
    22  	}
    23  	for i := 0; i < d.Messages().Len(); i++ {
    24  		if err := r.validateExtensions(d.Messages().Get(i), handler); err != nil {
    25  			return err
    26  		}
    27  	}
    28  	return nil
    29  }
    30  
    31  func (r *result) validateExtension(fld protoreflect.FieldDescriptor, handler *reporter.Handler) error {
    32  	// NB: It's a little gross that we don't enforce these in validateBasic().
    33  	// But it requires linking to resolve the extendee, so we can interrogate
    34  	// its descriptor.
    35  	if xtd, ok := fld.(protoreflect.ExtensionTypeDescriptor); ok {
    36  		fld = xtd.Descriptor()
    37  	}
    38  	fd := fld.(*fldDescriptor)
    39  	if fld.ContainingMessage().Options().(*descriptorpb.MessageOptions).GetMessageSetWireFormat() {
    40  		// Message set wire format requires that all extensions be messages
    41  		// themselves (no scalar extensions)
    42  		if fld.Kind() != protoreflect.MessageKind {
    43  			file := r.FileNode()
    44  			pos := file.NodeInfo(r.FieldNode(fd.proto).FieldType()).Start()
    45  			return handler.HandleErrorf(pos, "messages with message-set wire format cannot contain scalar extensions, only messages")
    46  		}
    47  		if fld.Cardinality() == protoreflect.Repeated {
    48  			file := r.FileNode()
    49  			pos := file.NodeInfo(r.FieldNode(fd.proto).FieldLabel()).Start()
    50  			return handler.HandleErrorf(pos, "messages with message-set wire format cannot contain repeated extensions, only optional")
    51  		}
    52  	} else {
    53  		// In validateBasic() we just made sure these were within bounds for any message. But
    54  		// now that things are linked, we can check if the extendee is messageset wire format
    55  		// and, if not, enforce tighter limit.
    56  		if fld.Number() > internal.MaxNormalTag {
    57  			file := r.FileNode()
    58  			pos := file.NodeInfo(r.FieldNode(fd.proto).FieldTag()).Start()
    59  			return handler.HandleErrorf(pos, "tag number %d is higher than max allowed tag number (%d)", fld.Number(), internal.MaxNormalTag)
    60  		}
    61  	}
    62  
    63  	return nil
    64  }