github.com/jhump/protocompile@v0.0.0-20221021153901-4f6f732835e8/parser/validate.go (about)

     1  package parser
     2  
     3  import (
     4  	"fmt"
     5  	"sort"
     6  
     7  	"google.golang.org/protobuf/proto"
     8  	"google.golang.org/protobuf/reflect/protoreflect"
     9  	"google.golang.org/protobuf/types/descriptorpb"
    10  
    11  	"github.com/jhump/protocompile/ast"
    12  	"github.com/jhump/protocompile/internal"
    13  	"github.com/jhump/protocompile/reporter"
    14  	"github.com/jhump/protocompile/walk"
    15  )
    16  
    17  func validateBasic(res *result, handler *reporter.Handler) {
    18  	fd := res.proto
    19  	isProto3 := fd.GetSyntax() == "proto3"
    20  
    21  	_ = walk.DescriptorProtos(fd, func(name protoreflect.FullName, d proto.Message) error {
    22  		switch d := d.(type) {
    23  		case *descriptorpb.DescriptorProto:
    24  			if err := validateMessage(res, isProto3, name, d, handler); err != nil {
    25  				return err
    26  			}
    27  		case *descriptorpb.EnumDescriptorProto:
    28  			if err := validateEnum(res, isProto3, name, d, handler); err != nil {
    29  				return err
    30  			}
    31  		case *descriptorpb.FieldDescriptorProto:
    32  			if err := validateField(res, isProto3, name, d, handler); err != nil {
    33  				return err
    34  			}
    35  		}
    36  		return nil
    37  	})
    38  }
    39  
    40  func validateMessage(res *result, isProto3 bool, name protoreflect.FullName, md *descriptorpb.DescriptorProto, handler *reporter.Handler) error {
    41  	scope := fmt.Sprintf("message %s", name)
    42  
    43  	if isProto3 && len(md.ExtensionRange) > 0 {
    44  		n := res.ExtensionRangeNode(md.ExtensionRange[0])
    45  		nInfo := res.file.NodeInfo(n)
    46  		if err := handler.HandleErrorf(nInfo.Start(), "%s: extension ranges are not allowed in proto3", scope); err != nil {
    47  			return err
    48  		}
    49  	}
    50  
    51  	if index, err := internal.FindOption(res, handler, scope, md.Options.GetUninterpretedOption(), "map_entry"); err != nil {
    52  		return err
    53  	} else if index >= 0 {
    54  		opt := md.Options.UninterpretedOption[index]
    55  		optn := res.OptionNode(opt)
    56  		md.Options.UninterpretedOption = internal.RemoveOption(md.Options.UninterpretedOption, index)
    57  		valid := false
    58  		if opt.IdentifierValue != nil {
    59  			if opt.GetIdentifierValue() == "true" {
    60  				valid = true
    61  				optionNodeInfo := res.file.NodeInfo(optn.GetValue())
    62  				if err := handler.HandleErrorf(optionNodeInfo.Start(), "%s: map_entry option should not be set explicitly; use map type instead", scope); err != nil {
    63  					return err
    64  				}
    65  			} else if opt.GetIdentifierValue() == "false" {
    66  				valid = true
    67  				md.Options.MapEntry = proto.Bool(false)
    68  			}
    69  		}
    70  		if !valid {
    71  			optionNodeInfo := res.file.NodeInfo(optn.GetValue())
    72  			if err := handler.HandleErrorf(optionNodeInfo.Start(), "%s: expecting bool value for map_entry option", scope); err != nil {
    73  				return err
    74  			}
    75  		}
    76  	}
    77  
    78  	// reserved ranges should not overlap
    79  	rsvd := make(tagRanges, len(md.ReservedRange))
    80  	for i, r := range md.ReservedRange {
    81  		n := res.MessageReservedRangeNode(r)
    82  		rsvd[i] = tagRange{start: r.GetStart(), end: r.GetEnd(), node: n}
    83  
    84  	}
    85  	sort.Sort(rsvd)
    86  	for i := 1; i < len(rsvd); i++ {
    87  		if rsvd[i].start < rsvd[i-1].end {
    88  			rangeNodeInfo := res.file.NodeInfo(rsvd[i].node)
    89  			if err := handler.HandleErrorf(rangeNodeInfo.Start(), "%s: reserved ranges overlap: %d to %d and %d to %d", scope, rsvd[i-1].start, rsvd[i-1].end-1, rsvd[i].start, rsvd[i].end-1); err != nil {
    90  				return err
    91  			}
    92  		}
    93  	}
    94  
    95  	// extensions ranges should not overlap
    96  	exts := make(tagRanges, len(md.ExtensionRange))
    97  	for i, r := range md.ExtensionRange {
    98  		n := res.ExtensionRangeNode(r)
    99  		exts[i] = tagRange{start: r.GetStart(), end: r.GetEnd(), node: n}
   100  	}
   101  	sort.Sort(exts)
   102  	for i := 1; i < len(exts); i++ {
   103  		if exts[i].start < exts[i-1].end {
   104  			rangeNodeInfo := res.file.NodeInfo(exts[i].node)
   105  			if err := handler.HandleErrorf(rangeNodeInfo.Start(), "%s: extension ranges overlap: %d to %d and %d to %d", scope, exts[i-1].start, exts[i-1].end-1, exts[i].start, exts[i].end-1); err != nil {
   106  				return err
   107  			}
   108  		}
   109  	}
   110  
   111  	// see if any extension range overlaps any reserved range
   112  	var i, j int // i indexes rsvd; j indexes exts
   113  	for i < len(rsvd) && j < len(exts) {
   114  		if rsvd[i].start >= exts[j].start && rsvd[i].start < exts[j].end ||
   115  			exts[j].start >= rsvd[i].start && exts[j].start < rsvd[i].end {
   116  
   117  			var pos ast.SourcePos
   118  			if rsvd[i].start >= exts[j].start && rsvd[i].start < exts[j].end {
   119  				rangeNodeInfo := res.file.NodeInfo(rsvd[i].node)
   120  				pos = rangeNodeInfo.Start()
   121  			} else {
   122  				rangeNodeInfo := res.file.NodeInfo(exts[j].node)
   123  				pos = rangeNodeInfo.Start()
   124  			}
   125  			// ranges overlap
   126  			if err := handler.HandleErrorf(pos, "%s: extension range %d to %d overlaps reserved range %d to %d", scope, exts[j].start, exts[j].end-1, rsvd[i].start, rsvd[i].end-1); err != nil {
   127  				return err
   128  			}
   129  		}
   130  		if rsvd[i].start < exts[j].start {
   131  			i++
   132  		} else {
   133  			j++
   134  		}
   135  	}
   136  
   137  	// now, check that fields don't re-use tags and don't try to use extension
   138  	// or reserved ranges or reserved names
   139  	rsvdNames := map[string]struct{}{}
   140  	for _, n := range md.ReservedName {
   141  		rsvdNames[n] = struct{}{}
   142  	}
   143  	fieldTags := map[int32]string{}
   144  	for _, fld := range md.Field {
   145  		fn := res.FieldNode(fld)
   146  		if _, ok := rsvdNames[fld.GetName()]; ok {
   147  			fieldNameNodeInfo := res.file.NodeInfo(fn.FieldName())
   148  			if err := handler.HandleErrorf(fieldNameNodeInfo.Start(), "%s: field %s is using a reserved name", scope, fld.GetName()); err != nil {
   149  				return err
   150  			}
   151  		}
   152  		if existing := fieldTags[fld.GetNumber()]; existing != "" {
   153  			fieldTagNodeInfo := res.file.NodeInfo(fn.FieldTag())
   154  			if err := handler.HandleErrorf(fieldTagNodeInfo.Start(), "%s: fields %s and %s both have the same tag %d", scope, existing, fld.GetName(), fld.GetNumber()); err != nil {
   155  				return err
   156  			}
   157  		}
   158  		fieldTags[fld.GetNumber()] = fld.GetName()
   159  		// check reserved ranges
   160  		r := sort.Search(len(rsvd), func(index int) bool { return rsvd[index].end > fld.GetNumber() })
   161  		if r < len(rsvd) && rsvd[r].start <= fld.GetNumber() {
   162  			fieldTagNodeInfo := res.file.NodeInfo(fn.FieldTag())
   163  			if err := handler.HandleErrorf(fieldTagNodeInfo.Start(), "%s: field %s is using tag %d which is in reserved range %d to %d", scope, fld.GetName(), fld.GetNumber(), rsvd[r].start, rsvd[r].end-1); err != nil {
   164  				return err
   165  			}
   166  		}
   167  		// and check extension ranges
   168  		e := sort.Search(len(exts), func(index int) bool { return exts[index].end > fld.GetNumber() })
   169  		if e < len(exts) && exts[e].start <= fld.GetNumber() {
   170  			fieldTagNodeInfo := res.file.NodeInfo(fn.FieldTag())
   171  			if err := handler.HandleErrorf(fieldTagNodeInfo.Start(), "%s: field %s is using tag %d which is in extension range %d to %d", scope, fld.GetName(), fld.GetNumber(), exts[e].start, exts[e].end-1); err != nil {
   172  				return err
   173  			}
   174  		}
   175  	}
   176  
   177  	return nil
   178  }
   179  
   180  func validateEnum(res *result, isProto3 bool, name protoreflect.FullName, ed *descriptorpb.EnumDescriptorProto, handler *reporter.Handler) error {
   181  	scope := fmt.Sprintf("enum %s", name)
   182  
   183  	if len(ed.Value) == 0 {
   184  		enNode := res.EnumNode(ed)
   185  		enNodeInfo := res.file.NodeInfo(enNode)
   186  		if err := handler.HandleErrorf(enNodeInfo.Start(), "%s: enums must define at least one value", scope); err != nil {
   187  			return err
   188  		}
   189  	}
   190  
   191  	allowAlias := false
   192  	var allowAliasOpt *descriptorpb.UninterpretedOption
   193  	if index, err := internal.FindOption(res, handler, scope, ed.Options.GetUninterpretedOption(), "allow_alias"); err != nil {
   194  		return err
   195  	} else if index >= 0 {
   196  		allowAliasOpt = ed.Options.UninterpretedOption[index]
   197  		valid := false
   198  		if allowAliasOpt.IdentifierValue != nil {
   199  			if allowAliasOpt.GetIdentifierValue() == "true" {
   200  				allowAlias = true
   201  				valid = true
   202  			} else if allowAliasOpt.GetIdentifierValue() == "false" {
   203  				valid = true
   204  			}
   205  		}
   206  		if !valid {
   207  			optNode := res.OptionNode(allowAliasOpt)
   208  			optNodeInfo := res.file.NodeInfo(optNode.GetValue())
   209  			if err := handler.HandleErrorf(optNodeInfo.Start(), "%s: expecting bool value for allow_alias option", scope); err != nil {
   210  				return err
   211  			}
   212  		}
   213  	}
   214  
   215  	if isProto3 && len(ed.Value) > 0 && ed.Value[0].GetNumber() != 0 {
   216  		evNode := res.EnumValueNode(ed.Value[0])
   217  		evNodeInfo := res.file.NodeInfo(evNode.GetNumber())
   218  		if err := handler.HandleErrorf(evNodeInfo.Start(), "%s: proto3 requires that first value in enum have numeric value of 0", scope); err != nil {
   219  			return err
   220  		}
   221  	}
   222  
   223  	// check for aliases
   224  	vals := map[int32]string{}
   225  	hasAlias := false
   226  	for _, evd := range ed.Value {
   227  		existing := vals[evd.GetNumber()]
   228  		if existing != "" {
   229  			if allowAlias {
   230  				hasAlias = true
   231  			} else {
   232  				evNode := res.EnumValueNode(evd)
   233  				evNodeInfo := res.file.NodeInfo(evNode.GetNumber())
   234  				if err := handler.HandleErrorf(evNodeInfo.Start(), "%s: values %s and %s both have the same numeric value %d; use allow_alias option if intentional", scope, existing, evd.GetName(), evd.GetNumber()); err != nil {
   235  					return err
   236  				}
   237  			}
   238  		}
   239  		vals[evd.GetNumber()] = evd.GetName()
   240  	}
   241  	if allowAlias && !hasAlias {
   242  		optNode := res.OptionNode(allowAliasOpt)
   243  		optNodeInfo := res.file.NodeInfo(optNode.GetValue())
   244  		if err := handler.HandleErrorf(optNodeInfo.Start(), "%s: allow_alias is true but no values are aliases", scope); err != nil {
   245  			return err
   246  		}
   247  	}
   248  
   249  	// reserved ranges should not overlap
   250  	rsvd := make(tagRanges, len(ed.ReservedRange))
   251  	for i, r := range ed.ReservedRange {
   252  		n := res.EnumReservedRangeNode(r)
   253  		rsvd[i] = tagRange{start: r.GetStart(), end: r.GetEnd(), node: n}
   254  	}
   255  	sort.Sort(rsvd)
   256  	for i := 1; i < len(rsvd); i++ {
   257  		if rsvd[i].start <= rsvd[i-1].end {
   258  			rangeNodeInfo := res.file.NodeInfo(rsvd[i].node)
   259  			if err := handler.HandleErrorf(rangeNodeInfo.Start(), "%s: reserved ranges overlap: %d to %d and %d to %d", scope, rsvd[i-1].start, rsvd[i-1].end, rsvd[i].start, rsvd[i].end); err != nil {
   260  				return err
   261  			}
   262  		}
   263  	}
   264  
   265  	// now, check that fields don't re-use tags and don't try to use extension
   266  	// or reserved ranges or reserved names
   267  	rsvdNames := map[string]struct{}{}
   268  	for _, n := range ed.ReservedName {
   269  		rsvdNames[n] = struct{}{}
   270  	}
   271  	for _, ev := range ed.Value {
   272  		evn := res.EnumValueNode(ev)
   273  		if _, ok := rsvdNames[ev.GetName()]; ok {
   274  			enumValNodeInfo := res.file.NodeInfo(evn.GetName())
   275  			if err := handler.HandleErrorf(enumValNodeInfo.Start(), "%s: value %s is using a reserved name", scope, ev.GetName()); err != nil {
   276  				return err
   277  			}
   278  		}
   279  		// check reserved ranges
   280  		r := sort.Search(len(rsvd), func(index int) bool { return rsvd[index].end >= ev.GetNumber() })
   281  		if r < len(rsvd) && rsvd[r].start <= ev.GetNumber() {
   282  			enumValNodeInfo := res.file.NodeInfo(evn.GetNumber())
   283  			if err := handler.HandleErrorf(enumValNodeInfo.Start(), "%s: value %s is using number %d which is in reserved range %d to %d", scope, ev.GetName(), ev.GetNumber(), rsvd[r].start, rsvd[r].end); err != nil {
   284  				return err
   285  			}
   286  		}
   287  	}
   288  
   289  	return nil
   290  }
   291  
   292  func validateField(res *result, isProto3 bool, name protoreflect.FullName, fld *descriptorpb.FieldDescriptorProto, handler *reporter.Handler) error {
   293  	scope := fmt.Sprintf("field %s", name)
   294  
   295  	node := res.FieldNode(fld)
   296  	if isProto3 {
   297  		if fld.GetType() == descriptorpb.FieldDescriptorProto_TYPE_GROUP {
   298  			groupNodeInfo := res.file.NodeInfo(node.GetGroupKeyword())
   299  			if err := handler.HandleErrorf(groupNodeInfo.Start(), "%s: groups are not allowed in proto3", scope); err != nil {
   300  				return err
   301  			}
   302  		} else if fld.Label != nil && fld.GetLabel() == descriptorpb.FieldDescriptorProto_LABEL_REQUIRED {
   303  			fieldLabelNodeInfo := res.file.NodeInfo(node.FieldLabel())
   304  			if err := handler.HandleErrorf(fieldLabelNodeInfo.Start(), "%s: label 'required' is not allowed in proto3", scope); err != nil {
   305  				return err
   306  			}
   307  		}
   308  		if index, err := internal.FindOption(res, handler, scope, fld.Options.GetUninterpretedOption(), "default"); err != nil {
   309  			return err
   310  		} else if index >= 0 {
   311  			optNode := res.OptionNode(fld.Options.GetUninterpretedOption()[index])
   312  			optNameNodeInfo := res.file.NodeInfo(optNode.GetName())
   313  			if err := handler.HandleErrorf(optNameNodeInfo.Start(), "%s: default values are not allowed in proto3", scope); err != nil {
   314  				return err
   315  			}
   316  		}
   317  	} else {
   318  		if fld.Label == nil && fld.OneofIndex == nil {
   319  			fieldNameNodeInfo := res.file.NodeInfo(node.FieldName())
   320  			if err := handler.HandleErrorf(fieldNameNodeInfo.Start(), "%s: field has no label; proto2 requires explicit 'optional' label", scope); err != nil {
   321  				return err
   322  			}
   323  		}
   324  		if fld.GetExtendee() != "" && fld.Label != nil && fld.GetLabel() == descriptorpb.FieldDescriptorProto_LABEL_REQUIRED {
   325  			fieldLabelNodeInfo := res.file.NodeInfo(node.FieldLabel())
   326  			if err := handler.HandleErrorf(fieldLabelNodeInfo.Start(), "%s: extension fields cannot be 'required'", scope); err != nil {
   327  				return err
   328  			}
   329  		}
   330  	}
   331  
   332  	// finally, set any missing label to optional
   333  	if fld.Label == nil {
   334  		fld.Label = descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum()
   335  	}
   336  
   337  	return nil
   338  }
   339  
   340  type tagRange struct {
   341  	start int32
   342  	end   int32
   343  	node  ast.RangeDeclNode
   344  }
   345  
   346  type tagRanges []tagRange
   347  
   348  func (r tagRanges) Len() int {
   349  	return len(r)
   350  }
   351  
   352  func (r tagRanges) Less(i, j int) bool {
   353  	return r[i].start < r[j].start ||
   354  		(r[i].start == r[j].start && r[i].end < r[j].end)
   355  }
   356  
   357  func (r tagRanges) Swap(i, j int) {
   358  	r[i], r[j] = r[j], r[i]
   359  }