github.com/syumai/protoreflect@v1.7.1-0.20200810020253-2ac7e3b3a321/desc/protoparse/validate.go (about)

     1  package protoparse
     2  
     3  import (
     4  	"fmt"
     5  	"sort"
     6  
     7  	"github.com/golang/protobuf/proto"
     8  
     9  	dpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
    10  )
    11  
    12  func validateBasic(res *parseResult, containsErrors bool) {
    13  	fd := res.fd
    14  	isProto3 := fd.GetSyntax() == "proto3"
    15  
    16  	for _, md := range fd.MessageType {
    17  		if validateMessage(res, isProto3, "", md, containsErrors) != nil {
    18  			return
    19  		}
    20  	}
    21  
    22  	for _, ed := range fd.EnumType {
    23  		if validateEnum(res, isProto3, "", ed, containsErrors) != nil {
    24  			return
    25  		}
    26  	}
    27  
    28  	for _, fld := range fd.Extension {
    29  		if validateField(res, isProto3, "", fld) != nil {
    30  			return
    31  		}
    32  	}
    33  }
    34  
    35  func validateMessage(res *parseResult, isProto3 bool, prefix string, md *dpb.DescriptorProto, containsErrors bool) error {
    36  	nextPrefix := md.GetName() + "."
    37  
    38  	for _, fld := range md.Field {
    39  		if err := validateField(res, isProto3, nextPrefix, fld); err != nil {
    40  			return err
    41  		}
    42  	}
    43  	for _, fld := range md.Extension {
    44  		if err := validateField(res, isProto3, nextPrefix, fld); err != nil {
    45  			return err
    46  		}
    47  	}
    48  	for _, ed := range md.EnumType {
    49  		if err := validateEnum(res, isProto3, nextPrefix, ed, containsErrors); err != nil {
    50  			return err
    51  		}
    52  	}
    53  	for _, nmd := range md.NestedType {
    54  		if err := validateMessage(res, isProto3, nextPrefix, nmd, containsErrors); err != nil {
    55  			return err
    56  		}
    57  	}
    58  
    59  	scope := fmt.Sprintf("message %s%s", prefix, md.GetName())
    60  
    61  	if isProto3 && len(md.ExtensionRange) > 0 {
    62  		n := res.getExtensionRangeNode(md.ExtensionRange[0])
    63  		if err := res.errs.handleErrorWithPos(n.start(), "%s: extension ranges are not allowed in proto3", scope); err != nil {
    64  			return err
    65  		}
    66  	}
    67  
    68  	if index, err := findOption(res, scope, md.Options.GetUninterpretedOption(), "map_entry"); err != nil {
    69  		return err
    70  	} else if index >= 0 {
    71  		opt := md.Options.UninterpretedOption[index]
    72  		optn := res.getOptionNode(opt)
    73  		md.Options.UninterpretedOption = removeOption(md.Options.UninterpretedOption, index)
    74  		valid := false
    75  		if opt.IdentifierValue != nil {
    76  			if opt.GetIdentifierValue() == "true" {
    77  				valid = true
    78  				if err := res.errs.handleErrorWithPos(optn.getValue().start(), "%s: map_entry option should not be set explicitly; use map type instead", scope); err != nil {
    79  					return err
    80  				}
    81  			} else if opt.GetIdentifierValue() == "false" {
    82  				valid = true
    83  				md.Options.MapEntry = proto.Bool(false)
    84  			}
    85  		}
    86  		if !valid {
    87  			if err := res.errs.handleErrorWithPos(optn.getValue().start(), "%s: expecting bool value for map_entry option", scope); err != nil {
    88  				return err
    89  			}
    90  		}
    91  	}
    92  
    93  	// reserved ranges should not overlap
    94  	rsvd := make(tagRanges, len(md.ReservedRange))
    95  	for i, r := range md.ReservedRange {
    96  		n := res.getMessageReservedRangeNode(r)
    97  		rsvd[i] = tagRange{start: r.GetStart(), end: r.GetEnd(), node: n}
    98  
    99  	}
   100  	sort.Sort(rsvd)
   101  	for i := 1; i < len(rsvd); i++ {
   102  		if rsvd[i].start < rsvd[i-1].end {
   103  			if err := res.errs.handleErrorWithPos(rsvd[i].node.start(), "%s: reserved ranges overlap: %d to %d and %d to %d", scope, rsvd[i-1].start, rsvd[i-1].end-1, rsvd[i].start, rsvd[i].end-1); err != nil {
   104  				return err
   105  			}
   106  		}
   107  	}
   108  
   109  	// extensions ranges should not overlap
   110  	exts := make(tagRanges, len(md.ExtensionRange))
   111  	for i, r := range md.ExtensionRange {
   112  		n := res.getExtensionRangeNode(r)
   113  		exts[i] = tagRange{start: r.GetStart(), end: r.GetEnd(), node: n}
   114  	}
   115  	sort.Sort(exts)
   116  	for i := 1; i < len(exts); i++ {
   117  		if exts[i].start < exts[i-1].end {
   118  			if err := res.errs.handleErrorWithPos(exts[i].node.start(), "%s: extension ranges overlap: %d to %d and %d to %d", scope, exts[i-1].start, exts[i-1].end-1, exts[i].start, exts[i].end-1); err != nil {
   119  				return err
   120  			}
   121  		}
   122  	}
   123  
   124  	// see if any extension range overlaps any reserved range
   125  	var i, j int // i indexes rsvd; j indexes exts
   126  	for i < len(rsvd) && j < len(exts) {
   127  		if rsvd[i].start >= exts[j].start && rsvd[i].start < exts[j].end ||
   128  			exts[j].start >= rsvd[i].start && exts[j].start < rsvd[i].end {
   129  
   130  			var pos *SourcePos
   131  			if rsvd[i].start >= exts[j].start && rsvd[i].start < exts[j].end {
   132  				pos = rsvd[i].node.start()
   133  			} else {
   134  				pos = exts[j].node.start()
   135  			}
   136  			// ranges overlap
   137  			if err := res.errs.handleErrorWithPos(pos, "%s: extension range %d to %d overlaps reserved range %d to %d", scope, exts[j].start, exts[j].end-1, rsvd[i].start, rsvd[i].end-1); err != nil {
   138  				return err
   139  			}
   140  		}
   141  		if rsvd[i].start < exts[j].start {
   142  			i++
   143  		} else {
   144  			j++
   145  		}
   146  	}
   147  
   148  	// now, check that fields don't re-use tags and don't try to use extension
   149  	// or reserved ranges or reserved names
   150  	rsvdNames := map[string]struct{}{}
   151  	for _, n := range md.ReservedName {
   152  		rsvdNames[n] = struct{}{}
   153  	}
   154  	fieldTags := map[int32]string{}
   155  	for _, fld := range md.Field {
   156  		fn := res.getFieldNode(fld)
   157  		if _, ok := rsvdNames[fld.GetName()]; ok {
   158  			if err := res.errs.handleErrorWithPos(fn.fieldName().start(), "%s: field %s is using a reserved name", scope, fld.GetName()); err != nil {
   159  				return err
   160  			}
   161  		}
   162  		if existing := fieldTags[fld.GetNumber()]; existing != "" {
   163  			if err := res.errs.handleErrorWithPos(fn.fieldTag().start(), "%s: fields %s and %s both have the same tag %d", scope, existing, fld.GetName(), fld.GetNumber()); err != nil {
   164  				return err
   165  			}
   166  		}
   167  		fieldTags[fld.GetNumber()] = fld.GetName()
   168  		// check reserved ranges
   169  		r := sort.Search(len(rsvd), func(index int) bool { return rsvd[index].end > fld.GetNumber() })
   170  		if r < len(rsvd) && rsvd[r].start <= fld.GetNumber() {
   171  			if err := res.errs.handleErrorWithPos(fn.fieldTag().start(), "%s: field %s is using tag %d which is in reserved range %d to %d", scope, fld.GetName(), fld.GetNumber(), rsvd[r].start, rsvd[r].end-1); err != nil {
   172  				return err
   173  			}
   174  		}
   175  		// and check extension ranges
   176  		e := sort.Search(len(exts), func(index int) bool { return exts[index].end > fld.GetNumber() })
   177  		if e < len(exts) && exts[e].start <= fld.GetNumber() {
   178  			if err := res.errs.handleErrorWithPos(fn.fieldTag().start(), "%s: field %s is using tag %d which is in extension range %d to %d", scope, fld.GetName(), fld.GetNumber(), exts[e].start, exts[e].end-1); err != nil {
   179  				return err
   180  			}
   181  		}
   182  	}
   183  
   184  	return nil
   185  }
   186  
   187  func validateEnum(res *parseResult, isProto3 bool, prefix string, ed *dpb.EnumDescriptorProto, containsErrors bool) error {
   188  	scope := fmt.Sprintf("enum %s%s", prefix, ed.GetName())
   189  
   190  	if !containsErrors && len(ed.Value) == 0 {
   191  		// we only check this if file parsing had no errors; otherwise, the file may have
   192  		// had an enum value, but the parser encountered an error processing it, in which
   193  		// case the value would be absent from the descriptor. In such a case, this error
   194  		// would be confusing and incorrect, so we just skip this check.
   195  		enNode := res.getEnumNode(ed)
   196  		if err := res.errs.handleErrorWithPos(enNode.start(), "%s: enums must define at least one value", scope); err != nil {
   197  			return err
   198  		}
   199  	}
   200  
   201  	allowAlias := false
   202  	if index, err := findOption(res, scope, ed.Options.GetUninterpretedOption(), "allow_alias"); err != nil {
   203  		return err
   204  	} else if index >= 0 {
   205  		opt := ed.Options.UninterpretedOption[index]
   206  		valid := false
   207  		if opt.IdentifierValue != nil {
   208  			if opt.GetIdentifierValue() == "true" {
   209  				allowAlias = true
   210  				valid = true
   211  			} else if opt.GetIdentifierValue() == "false" {
   212  				valid = true
   213  			}
   214  		}
   215  		if !valid {
   216  			optNode := res.getOptionNode(opt)
   217  			if err := res.errs.handleErrorWithPos(optNode.getValue().start(), "%s: expecting bool value for allow_alias option", scope); err != nil {
   218  				return err
   219  			}
   220  		}
   221  	}
   222  
   223  	if isProto3 && len(ed.Value) > 0 && ed.Value[0].GetNumber() != 0 {
   224  		evNode := res.getEnumValueNode(ed.Value[0])
   225  		if err := res.errs.handleErrorWithPos(evNode.getNumber().start(), "%s: proto3 requires that first value in enum have numeric value of 0", scope); err != nil {
   226  			return err
   227  		}
   228  	}
   229  
   230  	if !allowAlias {
   231  		// make sure all value numbers are distinct
   232  		vals := map[int32]string{}
   233  		for _, evd := range ed.Value {
   234  			if existing := vals[evd.GetNumber()]; existing != "" {
   235  				evNode := res.getEnumValueNode(evd)
   236  				if err := res.errs.handleErrorWithPos(evNode.getNumber().start(), "%s: values %s and %s both have the same numeric value %d; use allow_alias option if intentional", scope, existing, evd.GetName(), evd.GetNumber()); err != nil {
   237  					return err
   238  				}
   239  			}
   240  			vals[evd.GetNumber()] = evd.GetName()
   241  		}
   242  	}
   243  
   244  	// reserved ranges should not overlap
   245  	rsvd := make(tagRanges, len(ed.ReservedRange))
   246  	for i, r := range ed.ReservedRange {
   247  		n := res.getEnumReservedRangeNode(r)
   248  		rsvd[i] = tagRange{start: r.GetStart(), end: r.GetEnd(), node: n}
   249  	}
   250  	sort.Sort(rsvd)
   251  	for i := 1; i < len(rsvd); i++ {
   252  		if rsvd[i].start <= rsvd[i-1].end {
   253  			if err := res.errs.handleErrorWithPos(rsvd[i].node.start(), "%s: reserved ranges overlap: %d to %d and %d to %d", scope, rsvd[i-1].start, rsvd[i-1].end, rsvd[i].start, rsvd[i].end); err != nil {
   254  				return err
   255  			}
   256  		}
   257  	}
   258  
   259  	// now, check that fields don't re-use tags and don't try to use extension
   260  	// or reserved ranges or reserved names
   261  	rsvdNames := map[string]struct{}{}
   262  	for _, n := range ed.ReservedName {
   263  		rsvdNames[n] = struct{}{}
   264  	}
   265  	for _, ev := range ed.Value {
   266  		evn := res.getEnumValueNode(ev)
   267  		if _, ok := rsvdNames[ev.GetName()]; ok {
   268  			if err := res.errs.handleErrorWithPos(evn.getName().start(), "%s: value %s is using a reserved name", scope, ev.GetName()); err != nil {
   269  				return err
   270  			}
   271  		}
   272  		// check reserved ranges
   273  		r := sort.Search(len(rsvd), func(index int) bool { return rsvd[index].end >= ev.GetNumber() })
   274  		if r < len(rsvd) && rsvd[r].start <= ev.GetNumber() {
   275  			if err := res.errs.handleErrorWithPos(evn.getNumber().start(), "%s: value %s is using number %d which is in reserved range %d to %d", scope, ev.GetName(), ev.GetNumber(), rsvd[r].start, rsvd[r].end); err != nil {
   276  				return err
   277  			}
   278  		}
   279  	}
   280  
   281  	return nil
   282  }
   283  
   284  func validateField(res *parseResult, isProto3 bool, prefix string, fld *dpb.FieldDescriptorProto) error {
   285  	scope := fmt.Sprintf("field %s%s", prefix, fld.GetName())
   286  
   287  	node := res.getFieldNode(fld)
   288  	if isProto3 {
   289  		if fld.GetType() == dpb.FieldDescriptorProto_TYPE_GROUP {
   290  			n := node.(*groupNode)
   291  			if err := res.errs.handleErrorWithPos(n.groupKeyword.start(), "%s: groups are not allowed in proto3", scope); err != nil {
   292  				return err
   293  			}
   294  		} else if fld.Label != nil && fld.GetLabel() == dpb.FieldDescriptorProto_LABEL_REQUIRED {
   295  			if err := res.errs.handleErrorWithPos(node.fieldLabel().start(), "%s: label 'required' is not allowed in proto3", scope); err != nil {
   296  				return err
   297  			}
   298  		} else if fld.Extendee != nil && fld.Label != nil && fld.GetLabel() == dpb.FieldDescriptorProto_LABEL_OPTIONAL {
   299  			if err := res.errs.handleErrorWithPos(node.fieldLabel().start(), "%s: label 'optional' is not allowed on extensions in proto3", scope); err != nil {
   300  				return err
   301  			}
   302  		}
   303  		if index, err := findOption(res, scope, fld.Options.GetUninterpretedOption(), "default"); err != nil {
   304  			return err
   305  		} else if index >= 0 {
   306  			optNode := res.getOptionNode(fld.Options.GetUninterpretedOption()[index])
   307  			if err := res.errs.handleErrorWithPos(optNode.getName().start(), "%s: default values are not allowed in proto3", scope); err != nil {
   308  				return err
   309  			}
   310  		}
   311  	} else {
   312  		if fld.Label == nil && fld.OneofIndex == nil {
   313  			if err := res.errs.handleErrorWithPos(node.fieldName().start(), "%s: field has no label; proto2 requires explicit 'optional' label", scope); err != nil {
   314  				return err
   315  			}
   316  		}
   317  		if fld.GetExtendee() != "" && fld.Label != nil && fld.GetLabel() == dpb.FieldDescriptorProto_LABEL_REQUIRED {
   318  			if err := res.errs.handleErrorWithPos(node.fieldLabel().start(), "%s: extension fields cannot be 'required'", scope); err != nil {
   319  				return err
   320  			}
   321  		}
   322  	}
   323  
   324  	// finally, set any missing label to optional
   325  	if fld.Label == nil {
   326  		fld.Label = dpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum()
   327  	}
   328  
   329  	return nil
   330  }
   331  
   332  type tagRange struct {
   333  	start int32
   334  	end   int32
   335  	node  rangeDecl
   336  }
   337  
   338  type tagRanges []tagRange
   339  
   340  func (r tagRanges) Len() int {
   341  	return len(r)
   342  }
   343  
   344  func (r tagRanges) Less(i, j int) bool {
   345  	return r[i].start < r[j].start ||
   346  		(r[i].start == r[j].start && r[i].end < r[j].end)
   347  }
   348  
   349  func (r tagRanges) Swap(i, j int) {
   350  	r[i], r[j] = r[j], r[i]
   351  }