github.com/jhump/protocompile@v0.0.0-20221021153901-4f6f732835e8/parser/validate.go (about) 1 package parser 2 3 import ( 4 "fmt" 5 "sort" 6 7 "google.golang.org/protobuf/proto" 8 "google.golang.org/protobuf/reflect/protoreflect" 9 "google.golang.org/protobuf/types/descriptorpb" 10 11 "github.com/jhump/protocompile/ast" 12 "github.com/jhump/protocompile/internal" 13 "github.com/jhump/protocompile/reporter" 14 "github.com/jhump/protocompile/walk" 15 ) 16 17 func validateBasic(res *result, handler *reporter.Handler) { 18 fd := res.proto 19 isProto3 := fd.GetSyntax() == "proto3" 20 21 _ = walk.DescriptorProtos(fd, func(name protoreflect.FullName, d proto.Message) error { 22 switch d := d.(type) { 23 case *descriptorpb.DescriptorProto: 24 if err := validateMessage(res, isProto3, name, d, handler); err != nil { 25 return err 26 } 27 case *descriptorpb.EnumDescriptorProto: 28 if err := validateEnum(res, isProto3, name, d, handler); err != nil { 29 return err 30 } 31 case *descriptorpb.FieldDescriptorProto: 32 if err := validateField(res, isProto3, name, d, handler); err != nil { 33 return err 34 } 35 } 36 return nil 37 }) 38 } 39 40 func validateMessage(res *result, isProto3 bool, name protoreflect.FullName, md *descriptorpb.DescriptorProto, handler *reporter.Handler) error { 41 scope := fmt.Sprintf("message %s", name) 42 43 if isProto3 && len(md.ExtensionRange) > 0 { 44 n := res.ExtensionRangeNode(md.ExtensionRange[0]) 45 nInfo := res.file.NodeInfo(n) 46 if err := handler.HandleErrorf(nInfo.Start(), "%s: extension ranges are not allowed in proto3", scope); err != nil { 47 return err 48 } 49 } 50 51 if index, err := internal.FindOption(res, handler, scope, md.Options.GetUninterpretedOption(), "map_entry"); err != nil { 52 return err 53 } else if index >= 0 { 54 opt := md.Options.UninterpretedOption[index] 55 optn := res.OptionNode(opt) 56 md.Options.UninterpretedOption = internal.RemoveOption(md.Options.UninterpretedOption, index) 57 valid := false 58 if opt.IdentifierValue != nil { 59 if opt.GetIdentifierValue() == "true" { 60 valid = true 61 optionNodeInfo := res.file.NodeInfo(optn.GetValue()) 62 if err := handler.HandleErrorf(optionNodeInfo.Start(), "%s: map_entry option should not be set explicitly; use map type instead", scope); err != nil { 63 return err 64 } 65 } else if opt.GetIdentifierValue() == "false" { 66 valid = true 67 md.Options.MapEntry = proto.Bool(false) 68 } 69 } 70 if !valid { 71 optionNodeInfo := res.file.NodeInfo(optn.GetValue()) 72 if err := handler.HandleErrorf(optionNodeInfo.Start(), "%s: expecting bool value for map_entry option", scope); err != nil { 73 return err 74 } 75 } 76 } 77 78 // reserved ranges should not overlap 79 rsvd := make(tagRanges, len(md.ReservedRange)) 80 for i, r := range md.ReservedRange { 81 n := res.MessageReservedRangeNode(r) 82 rsvd[i] = tagRange{start: r.GetStart(), end: r.GetEnd(), node: n} 83 84 } 85 sort.Sort(rsvd) 86 for i := 1; i < len(rsvd); i++ { 87 if rsvd[i].start < rsvd[i-1].end { 88 rangeNodeInfo := res.file.NodeInfo(rsvd[i].node) 89 if err := handler.HandleErrorf(rangeNodeInfo.Start(), "%s: reserved ranges overlap: %d to %d and %d to %d", scope, rsvd[i-1].start, rsvd[i-1].end-1, rsvd[i].start, rsvd[i].end-1); err != nil { 90 return err 91 } 92 } 93 } 94 95 // extensions ranges should not overlap 96 exts := make(tagRanges, len(md.ExtensionRange)) 97 for i, r := range md.ExtensionRange { 98 n := res.ExtensionRangeNode(r) 99 exts[i] = tagRange{start: r.GetStart(), end: r.GetEnd(), node: n} 100 } 101 sort.Sort(exts) 102 for i := 1; i < len(exts); i++ { 103 if exts[i].start < exts[i-1].end { 104 rangeNodeInfo := res.file.NodeInfo(exts[i].node) 105 if err := handler.HandleErrorf(rangeNodeInfo.Start(), "%s: extension ranges overlap: %d to %d and %d to %d", scope, exts[i-1].start, exts[i-1].end-1, exts[i].start, exts[i].end-1); err != nil { 106 return err 107 } 108 } 109 } 110 111 // see if any extension range overlaps any reserved range 112 var i, j int // i indexes rsvd; j indexes exts 113 for i < len(rsvd) && j < len(exts) { 114 if rsvd[i].start >= exts[j].start && rsvd[i].start < exts[j].end || 115 exts[j].start >= rsvd[i].start && exts[j].start < rsvd[i].end { 116 117 var pos ast.SourcePos 118 if rsvd[i].start >= exts[j].start && rsvd[i].start < exts[j].end { 119 rangeNodeInfo := res.file.NodeInfo(rsvd[i].node) 120 pos = rangeNodeInfo.Start() 121 } else { 122 rangeNodeInfo := res.file.NodeInfo(exts[j].node) 123 pos = rangeNodeInfo.Start() 124 } 125 // ranges overlap 126 if err := handler.HandleErrorf(pos, "%s: extension range %d to %d overlaps reserved range %d to %d", scope, exts[j].start, exts[j].end-1, rsvd[i].start, rsvd[i].end-1); err != nil { 127 return err 128 } 129 } 130 if rsvd[i].start < exts[j].start { 131 i++ 132 } else { 133 j++ 134 } 135 } 136 137 // now, check that fields don't re-use tags and don't try to use extension 138 // or reserved ranges or reserved names 139 rsvdNames := map[string]struct{}{} 140 for _, n := range md.ReservedName { 141 rsvdNames[n] = struct{}{} 142 } 143 fieldTags := map[int32]string{} 144 for _, fld := range md.Field { 145 fn := res.FieldNode(fld) 146 if _, ok := rsvdNames[fld.GetName()]; ok { 147 fieldNameNodeInfo := res.file.NodeInfo(fn.FieldName()) 148 if err := handler.HandleErrorf(fieldNameNodeInfo.Start(), "%s: field %s is using a reserved name", scope, fld.GetName()); err != nil { 149 return err 150 } 151 } 152 if existing := fieldTags[fld.GetNumber()]; existing != "" { 153 fieldTagNodeInfo := res.file.NodeInfo(fn.FieldTag()) 154 if err := handler.HandleErrorf(fieldTagNodeInfo.Start(), "%s: fields %s and %s both have the same tag %d", scope, existing, fld.GetName(), fld.GetNumber()); err != nil { 155 return err 156 } 157 } 158 fieldTags[fld.GetNumber()] = fld.GetName() 159 // check reserved ranges 160 r := sort.Search(len(rsvd), func(index int) bool { return rsvd[index].end > fld.GetNumber() }) 161 if r < len(rsvd) && rsvd[r].start <= fld.GetNumber() { 162 fieldTagNodeInfo := res.file.NodeInfo(fn.FieldTag()) 163 if err := handler.HandleErrorf(fieldTagNodeInfo.Start(), "%s: field %s is using tag %d which is in reserved range %d to %d", scope, fld.GetName(), fld.GetNumber(), rsvd[r].start, rsvd[r].end-1); err != nil { 164 return err 165 } 166 } 167 // and check extension ranges 168 e := sort.Search(len(exts), func(index int) bool { return exts[index].end > fld.GetNumber() }) 169 if e < len(exts) && exts[e].start <= fld.GetNumber() { 170 fieldTagNodeInfo := res.file.NodeInfo(fn.FieldTag()) 171 if err := handler.HandleErrorf(fieldTagNodeInfo.Start(), "%s: field %s is using tag %d which is in extension range %d to %d", scope, fld.GetName(), fld.GetNumber(), exts[e].start, exts[e].end-1); err != nil { 172 return err 173 } 174 } 175 } 176 177 return nil 178 } 179 180 func validateEnum(res *result, isProto3 bool, name protoreflect.FullName, ed *descriptorpb.EnumDescriptorProto, handler *reporter.Handler) error { 181 scope := fmt.Sprintf("enum %s", name) 182 183 if len(ed.Value) == 0 { 184 enNode := res.EnumNode(ed) 185 enNodeInfo := res.file.NodeInfo(enNode) 186 if err := handler.HandleErrorf(enNodeInfo.Start(), "%s: enums must define at least one value", scope); err != nil { 187 return err 188 } 189 } 190 191 allowAlias := false 192 var allowAliasOpt *descriptorpb.UninterpretedOption 193 if index, err := internal.FindOption(res, handler, scope, ed.Options.GetUninterpretedOption(), "allow_alias"); err != nil { 194 return err 195 } else if index >= 0 { 196 allowAliasOpt = ed.Options.UninterpretedOption[index] 197 valid := false 198 if allowAliasOpt.IdentifierValue != nil { 199 if allowAliasOpt.GetIdentifierValue() == "true" { 200 allowAlias = true 201 valid = true 202 } else if allowAliasOpt.GetIdentifierValue() == "false" { 203 valid = true 204 } 205 } 206 if !valid { 207 optNode := res.OptionNode(allowAliasOpt) 208 optNodeInfo := res.file.NodeInfo(optNode.GetValue()) 209 if err := handler.HandleErrorf(optNodeInfo.Start(), "%s: expecting bool value for allow_alias option", scope); err != nil { 210 return err 211 } 212 } 213 } 214 215 if isProto3 && len(ed.Value) > 0 && ed.Value[0].GetNumber() != 0 { 216 evNode := res.EnumValueNode(ed.Value[0]) 217 evNodeInfo := res.file.NodeInfo(evNode.GetNumber()) 218 if err := handler.HandleErrorf(evNodeInfo.Start(), "%s: proto3 requires that first value in enum have numeric value of 0", scope); err != nil { 219 return err 220 } 221 } 222 223 // check for aliases 224 vals := map[int32]string{} 225 hasAlias := false 226 for _, evd := range ed.Value { 227 existing := vals[evd.GetNumber()] 228 if existing != "" { 229 if allowAlias { 230 hasAlias = true 231 } else { 232 evNode := res.EnumValueNode(evd) 233 evNodeInfo := res.file.NodeInfo(evNode.GetNumber()) 234 if err := handler.HandleErrorf(evNodeInfo.Start(), "%s: values %s and %s both have the same numeric value %d; use allow_alias option if intentional", scope, existing, evd.GetName(), evd.GetNumber()); err != nil { 235 return err 236 } 237 } 238 } 239 vals[evd.GetNumber()] = evd.GetName() 240 } 241 if allowAlias && !hasAlias { 242 optNode := res.OptionNode(allowAliasOpt) 243 optNodeInfo := res.file.NodeInfo(optNode.GetValue()) 244 if err := handler.HandleErrorf(optNodeInfo.Start(), "%s: allow_alias is true but no values are aliases", scope); err != nil { 245 return err 246 } 247 } 248 249 // reserved ranges should not overlap 250 rsvd := make(tagRanges, len(ed.ReservedRange)) 251 for i, r := range ed.ReservedRange { 252 n := res.EnumReservedRangeNode(r) 253 rsvd[i] = tagRange{start: r.GetStart(), end: r.GetEnd(), node: n} 254 } 255 sort.Sort(rsvd) 256 for i := 1; i < len(rsvd); i++ { 257 if rsvd[i].start <= rsvd[i-1].end { 258 rangeNodeInfo := res.file.NodeInfo(rsvd[i].node) 259 if err := handler.HandleErrorf(rangeNodeInfo.Start(), "%s: reserved ranges overlap: %d to %d and %d to %d", scope, rsvd[i-1].start, rsvd[i-1].end, rsvd[i].start, rsvd[i].end); err != nil { 260 return err 261 } 262 } 263 } 264 265 // now, check that fields don't re-use tags and don't try to use extension 266 // or reserved ranges or reserved names 267 rsvdNames := map[string]struct{}{} 268 for _, n := range ed.ReservedName { 269 rsvdNames[n] = struct{}{} 270 } 271 for _, ev := range ed.Value { 272 evn := res.EnumValueNode(ev) 273 if _, ok := rsvdNames[ev.GetName()]; ok { 274 enumValNodeInfo := res.file.NodeInfo(evn.GetName()) 275 if err := handler.HandleErrorf(enumValNodeInfo.Start(), "%s: value %s is using a reserved name", scope, ev.GetName()); err != nil { 276 return err 277 } 278 } 279 // check reserved ranges 280 r := sort.Search(len(rsvd), func(index int) bool { return rsvd[index].end >= ev.GetNumber() }) 281 if r < len(rsvd) && rsvd[r].start <= ev.GetNumber() { 282 enumValNodeInfo := res.file.NodeInfo(evn.GetNumber()) 283 if err := handler.HandleErrorf(enumValNodeInfo.Start(), "%s: value %s is using number %d which is in reserved range %d to %d", scope, ev.GetName(), ev.GetNumber(), rsvd[r].start, rsvd[r].end); err != nil { 284 return err 285 } 286 } 287 } 288 289 return nil 290 } 291 292 func validateField(res *result, isProto3 bool, name protoreflect.FullName, fld *descriptorpb.FieldDescriptorProto, handler *reporter.Handler) error { 293 scope := fmt.Sprintf("field %s", name) 294 295 node := res.FieldNode(fld) 296 if isProto3 { 297 if fld.GetType() == descriptorpb.FieldDescriptorProto_TYPE_GROUP { 298 groupNodeInfo := res.file.NodeInfo(node.GetGroupKeyword()) 299 if err := handler.HandleErrorf(groupNodeInfo.Start(), "%s: groups are not allowed in proto3", scope); err != nil { 300 return err 301 } 302 } else if fld.Label != nil && fld.GetLabel() == descriptorpb.FieldDescriptorProto_LABEL_REQUIRED { 303 fieldLabelNodeInfo := res.file.NodeInfo(node.FieldLabel()) 304 if err := handler.HandleErrorf(fieldLabelNodeInfo.Start(), "%s: label 'required' is not allowed in proto3", scope); err != nil { 305 return err 306 } 307 } 308 if index, err := internal.FindOption(res, handler, scope, fld.Options.GetUninterpretedOption(), "default"); err != nil { 309 return err 310 } else if index >= 0 { 311 optNode := res.OptionNode(fld.Options.GetUninterpretedOption()[index]) 312 optNameNodeInfo := res.file.NodeInfo(optNode.GetName()) 313 if err := handler.HandleErrorf(optNameNodeInfo.Start(), "%s: default values are not allowed in proto3", scope); err != nil { 314 return err 315 } 316 } 317 } else { 318 if fld.Label == nil && fld.OneofIndex == nil { 319 fieldNameNodeInfo := res.file.NodeInfo(node.FieldName()) 320 if err := handler.HandleErrorf(fieldNameNodeInfo.Start(), "%s: field has no label; proto2 requires explicit 'optional' label", scope); err != nil { 321 return err 322 } 323 } 324 if fld.GetExtendee() != "" && fld.Label != nil && fld.GetLabel() == descriptorpb.FieldDescriptorProto_LABEL_REQUIRED { 325 fieldLabelNodeInfo := res.file.NodeInfo(node.FieldLabel()) 326 if err := handler.HandleErrorf(fieldLabelNodeInfo.Start(), "%s: extension fields cannot be 'required'", scope); err != nil { 327 return err 328 } 329 } 330 } 331 332 // finally, set any missing label to optional 333 if fld.Label == nil { 334 fld.Label = descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum() 335 } 336 337 return nil 338 } 339 340 type tagRange struct { 341 start int32 342 end int32 343 node ast.RangeDeclNode 344 } 345 346 type tagRanges []tagRange 347 348 func (r tagRanges) Len() int { 349 return len(r) 350 } 351 352 func (r tagRanges) Less(i, j int) bool { 353 return r[i].start < r[j].start || 354 (r[i].start == r[j].start && r[i].end < r[j].end) 355 } 356 357 func (r tagRanges) Swap(i, j int) { 358 r[i], r[j] = r[j], r[i] 359 }