github.com/bakjos/protoreflect@v1.9.2/desc/protoparse/validate.go (about)

     1  package protoparse
     2  
     3  import (
     4  	"fmt"
     5  	"sort"
     6  
     7  	"github.com/golang/protobuf/proto"
     8  	dpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
     9  
    10  	"github.com/bakjos/protoreflect/desc/protoparse/ast"
    11  )
    12  
    13  func validateBasic(res *parseResult, containsErrors bool) {
    14  	fd := res.fd
    15  	isProto3 := fd.GetSyntax() == "proto3"
    16  
    17  	for _, md := range fd.MessageType {
    18  		if validateMessage(res, isProto3, "", md, containsErrors) != nil {
    19  			return
    20  		}
    21  	}
    22  
    23  	for _, ed := range fd.EnumType {
    24  		if validateEnum(res, isProto3, "", ed, containsErrors) != nil {
    25  			return
    26  		}
    27  	}
    28  
    29  	for _, fld := range fd.Extension {
    30  		if validateField(res, isProto3, "", fld) != nil {
    31  			return
    32  		}
    33  	}
    34  }
    35  
    36  func validateMessage(res *parseResult, isProto3 bool, prefix string, md *dpb.DescriptorProto, containsErrors bool) error {
    37  	nextPrefix := md.GetName() + "."
    38  
    39  	for _, fld := range md.Field {
    40  		if err := validateField(res, isProto3, nextPrefix, fld); err != nil {
    41  			return err
    42  		}
    43  	}
    44  	for _, fld := range md.Extension {
    45  		if err := validateField(res, isProto3, nextPrefix, fld); err != nil {
    46  			return err
    47  		}
    48  	}
    49  	for _, ed := range md.EnumType {
    50  		if err := validateEnum(res, isProto3, nextPrefix, ed, containsErrors); err != nil {
    51  			return err
    52  		}
    53  	}
    54  	for _, nmd := range md.NestedType {
    55  		if err := validateMessage(res, isProto3, nextPrefix, nmd, containsErrors); err != nil {
    56  			return err
    57  		}
    58  	}
    59  
    60  	scope := fmt.Sprintf("message %s%s", prefix, md.GetName())
    61  
    62  	if isProto3 && len(md.ExtensionRange) > 0 {
    63  		n := res.getExtensionRangeNode(md.ExtensionRange[0])
    64  		if err := res.errs.handleErrorWithPos(n.Start(), "%s: extension ranges are not allowed in proto3", scope); err != nil {
    65  			return err
    66  		}
    67  	}
    68  
    69  	if index, err := findOption(res, scope, md.Options.GetUninterpretedOption(), "map_entry"); err != nil {
    70  		return err
    71  	} else if index >= 0 {
    72  		opt := md.Options.UninterpretedOption[index]
    73  		optn := res.getOptionNode(opt)
    74  		md.Options.UninterpretedOption = removeOption(md.Options.UninterpretedOption, index)
    75  		valid := false
    76  		if opt.IdentifierValue != nil {
    77  			if opt.GetIdentifierValue() == "true" {
    78  				valid = true
    79  				if err := res.errs.handleErrorWithPos(optn.GetValue().Start(), "%s: map_entry option should not be set explicitly; use map type instead", scope); err != nil {
    80  					return err
    81  				}
    82  			} else if opt.GetIdentifierValue() == "false" {
    83  				valid = true
    84  				md.Options.MapEntry = proto.Bool(false)
    85  			}
    86  		}
    87  		if !valid {
    88  			if err := res.errs.handleErrorWithPos(optn.GetValue().Start(), "%s: expecting bool value for map_entry option", scope); err != nil {
    89  				return err
    90  			}
    91  		}
    92  	}
    93  
    94  	// reserved ranges should not overlap
    95  	rsvd := make(tagRanges, len(md.ReservedRange))
    96  	for i, r := range md.ReservedRange {
    97  		n := res.getMessageReservedRangeNode(r)
    98  		rsvd[i] = tagRange{start: r.GetStart(), end: r.GetEnd(), node: n}
    99  
   100  	}
   101  	sort.Sort(rsvd)
   102  	for i := 1; i < len(rsvd); i++ {
   103  		if rsvd[i].start < rsvd[i-1].end {
   104  			if err := res.errs.handleErrorWithPos(rsvd[i].node.Start(), "%s: reserved ranges overlap: %d to %d and %d to %d", scope, rsvd[i-1].start, rsvd[i-1].end-1, rsvd[i].start, rsvd[i].end-1); err != nil {
   105  				return err
   106  			}
   107  		}
   108  	}
   109  
   110  	// extensions ranges should not overlap
   111  	exts := make(tagRanges, len(md.ExtensionRange))
   112  	for i, r := range md.ExtensionRange {
   113  		n := res.getExtensionRangeNode(r)
   114  		exts[i] = tagRange{start: r.GetStart(), end: r.GetEnd(), node: n}
   115  	}
   116  	sort.Sort(exts)
   117  	for i := 1; i < len(exts); i++ {
   118  		if exts[i].start < exts[i-1].end {
   119  			if err := res.errs.handleErrorWithPos(exts[i].node.Start(), "%s: extension ranges overlap: %d to %d and %d to %d", scope, exts[i-1].start, exts[i-1].end-1, exts[i].start, exts[i].end-1); err != nil {
   120  				return err
   121  			}
   122  		}
   123  	}
   124  
   125  	// see if any extension range overlaps any reserved range
   126  	var i, j int // i indexes rsvd; j indexes exts
   127  	for i < len(rsvd) && j < len(exts) {
   128  		if rsvd[i].start >= exts[j].start && rsvd[i].start < exts[j].end ||
   129  			exts[j].start >= rsvd[i].start && exts[j].start < rsvd[i].end {
   130  
   131  			var pos *SourcePos
   132  			if rsvd[i].start >= exts[j].start && rsvd[i].start < exts[j].end {
   133  				pos = rsvd[i].node.Start()
   134  			} else {
   135  				pos = exts[j].node.Start()
   136  			}
   137  			// ranges overlap
   138  			if err := res.errs.handleErrorWithPos(pos, "%s: extension range %d to %d overlaps reserved range %d to %d", scope, exts[j].start, exts[j].end-1, rsvd[i].start, rsvd[i].end-1); err != nil {
   139  				return err
   140  			}
   141  		}
   142  		if rsvd[i].start < exts[j].start {
   143  			i++
   144  		} else {
   145  			j++
   146  		}
   147  	}
   148  
   149  	// now, check that fields don't re-use tags and don't try to use extension
   150  	// or reserved ranges or reserved names
   151  	rsvdNames := map[string]struct{}{}
   152  	for _, n := range md.ReservedName {
   153  		rsvdNames[n] = struct{}{}
   154  	}
   155  	fieldTags := map[int32]string{}
   156  	for _, fld := range md.Field {
   157  		fn := res.getFieldNode(fld)
   158  		if _, ok := rsvdNames[fld.GetName()]; ok {
   159  			if err := res.errs.handleErrorWithPos(fn.FieldName().Start(), "%s: field %s is using a reserved name", scope, fld.GetName()); err != nil {
   160  				return err
   161  			}
   162  		}
   163  		if existing := fieldTags[fld.GetNumber()]; existing != "" {
   164  			if err := res.errs.handleErrorWithPos(fn.FieldTag().Start(), "%s: fields %s and %s both have the same tag %d", scope, existing, fld.GetName(), fld.GetNumber()); err != nil {
   165  				return err
   166  			}
   167  		}
   168  		fieldTags[fld.GetNumber()] = fld.GetName()
   169  		// check reserved ranges
   170  		r := sort.Search(len(rsvd), func(index int) bool { return rsvd[index].end > fld.GetNumber() })
   171  		if r < len(rsvd) && rsvd[r].start <= fld.GetNumber() {
   172  			if err := res.errs.handleErrorWithPos(fn.FieldTag().Start(), "%s: field %s is using tag %d which is in reserved range %d to %d", scope, fld.GetName(), fld.GetNumber(), rsvd[r].start, rsvd[r].end-1); err != nil {
   173  				return err
   174  			}
   175  		}
   176  		// and check extension ranges
   177  		e := sort.Search(len(exts), func(index int) bool { return exts[index].end > fld.GetNumber() })
   178  		if e < len(exts) && exts[e].start <= fld.GetNumber() {
   179  			if err := res.errs.handleErrorWithPos(fn.FieldTag().Start(), "%s: field %s is using tag %d which is in extension range %d to %d", scope, fld.GetName(), fld.GetNumber(), exts[e].start, exts[e].end-1); err != nil {
   180  				return err
   181  			}
   182  		}
   183  	}
   184  
   185  	return nil
   186  }
   187  
   188  func validateEnum(res *parseResult, isProto3 bool, prefix string, ed *dpb.EnumDescriptorProto, containsErrors bool) error {
   189  	scope := fmt.Sprintf("enum %s%s", prefix, ed.GetName())
   190  
   191  	if !containsErrors && len(ed.Value) == 0 {
   192  		// we only check this if file parsing had no errors; otherwise, the file may have
   193  		// had an enum value, but the parser encountered an error processing it, in which
   194  		// case the value would be absent from the descriptor. In such a case, this error
   195  		// would be confusing and incorrect, so we just skip this check.
   196  		enNode := res.getEnumNode(ed)
   197  		if err := res.errs.handleErrorWithPos(enNode.Start(), "%s: enums must define at least one value", scope); err != nil {
   198  			return err
   199  		}
   200  	}
   201  
   202  	allowAlias := false
   203  	var allowAliasOpt *dpb.UninterpretedOption
   204  	if index, err := findOption(res, scope, ed.Options.GetUninterpretedOption(), "allow_alias"); err != nil {
   205  		return err
   206  	} else if index >= 0 {
   207  		allowAliasOpt = ed.Options.UninterpretedOption[index]
   208  		valid := false
   209  		if allowAliasOpt.IdentifierValue != nil {
   210  			if allowAliasOpt.GetIdentifierValue() == "true" {
   211  				allowAlias = true
   212  				valid = true
   213  			} else if allowAliasOpt.GetIdentifierValue() == "false" {
   214  				valid = true
   215  			}
   216  		}
   217  		if !valid {
   218  			optNode := res.getOptionNode(allowAliasOpt)
   219  			if err := res.errs.handleErrorWithPos(optNode.GetValue().Start(), "%s: expecting bool value for allow_alias option", scope); err != nil {
   220  				return err
   221  			}
   222  		}
   223  	}
   224  
   225  	if isProto3 && len(ed.Value) > 0 && ed.Value[0].GetNumber() != 0 {
   226  		evNode := res.getEnumValueNode(ed.Value[0])
   227  		if err := res.errs.handleErrorWithPos(evNode.GetNumber().Start(), "%s: proto3 requires that first value in enum have numeric value of 0", scope); err != nil {
   228  			return err
   229  		}
   230  	}
   231  
   232  	// check for aliases
   233  	vals := map[int32]string{}
   234  	hasAlias := false
   235  	for _, evd := range ed.Value {
   236  		existing := vals[evd.GetNumber()]
   237  		if existing != "" {
   238  			if allowAlias {
   239  				hasAlias = true
   240  			} else {
   241  				evNode := res.getEnumValueNode(evd)
   242  				if err := res.errs.handleErrorWithPos(evNode.GetNumber().Start(), "%s: values %s and %s both have the same numeric value %d; use allow_alias option if intentional", scope, existing, evd.GetName(), evd.GetNumber()); err != nil {
   243  					return err
   244  				}
   245  			}
   246  		}
   247  		vals[evd.GetNumber()] = evd.GetName()
   248  	}
   249  	if allowAlias && !hasAlias {
   250  		optNode := res.getOptionNode(allowAliasOpt)
   251  		if err := res.errs.handleErrorWithPos(optNode.GetValue().Start(), "%s: allow_alias is true but no values are aliases", scope); err != nil {
   252  			return err
   253  		}
   254  	}
   255  
   256  	// reserved ranges should not overlap
   257  	rsvd := make(tagRanges, len(ed.ReservedRange))
   258  	for i, r := range ed.ReservedRange {
   259  		n := res.getEnumReservedRangeNode(r)
   260  		rsvd[i] = tagRange{start: r.GetStart(), end: r.GetEnd(), node: n}
   261  	}
   262  	sort.Sort(rsvd)
   263  	for i := 1; i < len(rsvd); i++ {
   264  		if rsvd[i].start <= rsvd[i-1].end {
   265  			if err := res.errs.handleErrorWithPos(rsvd[i].node.Start(), "%s: reserved ranges overlap: %d to %d and %d to %d", scope, rsvd[i-1].start, rsvd[i-1].end, rsvd[i].start, rsvd[i].end); err != nil {
   266  				return err
   267  			}
   268  		}
   269  	}
   270  
   271  	// now, check that fields don't re-use tags and don't try to use extension
   272  	// or reserved ranges or reserved names
   273  	rsvdNames := map[string]struct{}{}
   274  	for _, n := range ed.ReservedName {
   275  		rsvdNames[n] = struct{}{}
   276  	}
   277  	for _, ev := range ed.Value {
   278  		evn := res.getEnumValueNode(ev)
   279  		if _, ok := rsvdNames[ev.GetName()]; ok {
   280  			if err := res.errs.handleErrorWithPos(evn.GetName().Start(), "%s: value %s is using a reserved name", scope, ev.GetName()); err != nil {
   281  				return err
   282  			}
   283  		}
   284  		// check reserved ranges
   285  		r := sort.Search(len(rsvd), func(index int) bool { return rsvd[index].end >= ev.GetNumber() })
   286  		if r < len(rsvd) && rsvd[r].start <= ev.GetNumber() {
   287  			if err := res.errs.handleErrorWithPos(evn.GetNumber().Start(), "%s: value %s is using number %d which is in reserved range %d to %d", scope, ev.GetName(), ev.GetNumber(), rsvd[r].start, rsvd[r].end); err != nil {
   288  				return err
   289  			}
   290  		}
   291  	}
   292  
   293  	return nil
   294  }
   295  
   296  func validateField(res *parseResult, isProto3 bool, prefix string, fld *dpb.FieldDescriptorProto) error {
   297  	scope := fmt.Sprintf("field %s%s", prefix, fld.GetName())
   298  
   299  	node := res.getFieldNode(fld)
   300  	if isProto3 {
   301  		if fld.GetType() == dpb.FieldDescriptorProto_TYPE_GROUP {
   302  			if err := res.errs.handleErrorWithPos(node.GetGroupKeyword().Start(), "%s: groups are not allowed in proto3", scope); err != nil {
   303  				return err
   304  			}
   305  		} else if fld.Label != nil && fld.GetLabel() == dpb.FieldDescriptorProto_LABEL_REQUIRED {
   306  			if err := res.errs.handleErrorWithPos(node.FieldLabel().Start(), "%s: label 'required' is not allowed in proto3", scope); err != nil {
   307  				return err
   308  			}
   309  		}
   310  		if index, err := findOption(res, scope, fld.Options.GetUninterpretedOption(), "default"); err != nil {
   311  			return err
   312  		} else if index >= 0 {
   313  			optNode := res.getOptionNode(fld.Options.GetUninterpretedOption()[index])
   314  			if err := res.errs.handleErrorWithPos(optNode.GetName().Start(), "%s: default values are not allowed in proto3", scope); err != nil {
   315  				return err
   316  			}
   317  		}
   318  	} else {
   319  		if fld.Label == nil && fld.OneofIndex == nil {
   320  			if err := res.errs.handleErrorWithPos(node.FieldName().Start(), "%s: field has no label; proto2 requires explicit 'optional' label", scope); err != nil {
   321  				return err
   322  			}
   323  		}
   324  		if fld.GetExtendee() != "" && fld.Label != nil && fld.GetLabel() == dpb.FieldDescriptorProto_LABEL_REQUIRED {
   325  			if err := res.errs.handleErrorWithPos(node.FieldLabel().Start(), "%s: extension fields cannot be 'required'", scope); err != nil {
   326  				return err
   327  			}
   328  		}
   329  	}
   330  
   331  	// finally, set any missing label to optional
   332  	if fld.Label == nil {
   333  		fld.Label = dpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum()
   334  	}
   335  
   336  	return nil
   337  }
   338  
   339  type tagRange struct {
   340  	start int32
   341  	end   int32
   342  	node  ast.RangeDeclNode
   343  }
   344  
   345  type tagRanges []tagRange
   346  
   347  func (r tagRanges) Len() int {
   348  	return len(r)
   349  }
   350  
   351  func (r tagRanges) Less(i, j int) bool {
   352  	return r[i].start < r[j].start ||
   353  		(r[i].start == r[j].start && r[i].end < r[j].end)
   354  }
   355  
   356  func (r tagRanges) Swap(i, j int) {
   357  	r[i], r[j] = r[j], r[i]
   358  }