github.com/ManabuSeki/goa-v1@v1.4.3/goagen/codegen/validation.go (about) 1 package codegen 2 3 import ( 4 "bytes" 5 "errors" 6 "fmt" 7 "math" 8 "strings" 9 "text/template" 10 11 "github.com/goadesign/goa/design" 12 ) 13 14 var ( 15 enumValT *template.Template 16 formatValT *template.Template 17 patternValT *template.Template 18 minMaxValT *template.Template 19 lengthValT *template.Template 20 requiredValT *template.Template 21 ) 22 23 // init instantiates the templates. 24 func init() { 25 var err error 26 fm := template.FuncMap{ 27 "tabs": Tabs, 28 "slice": toSlice, 29 "oneof": oneof, 30 "constant": constant, 31 "goifyAtt": GoifyAtt, 32 "add": Add, 33 } 34 if enumValT, err = template.New("enum").Funcs(fm).Parse(enumValTmpl); err != nil { 35 panic(err) 36 } 37 if formatValT, err = template.New("format").Funcs(fm).Parse(formatValTmpl); err != nil { 38 panic(err) 39 } 40 if patternValT, err = template.New("pattern").Funcs(fm).Parse(patternValTmpl); err != nil { 41 panic(err) 42 } 43 if minMaxValT, err = template.New("minMax").Funcs(fm).Parse(minMaxValTmpl); err != nil { 44 panic(err) 45 } 46 if lengthValT, err = template.New("length").Funcs(fm).Parse(lengthValTmpl); err != nil { 47 panic(err) 48 } 49 if requiredValT, err = template.New("required").Funcs(fm).Parse(requiredValTmpl); err != nil { 50 panic(err) 51 } 52 } 53 54 // Validator is the code generator for the 'Validate' type methods. 55 type Validator struct { 56 arrayValT *template.Template 57 hashValT *template.Template 58 userValT *template.Template 59 seen map[string]*bytes.Buffer 60 } 61 62 // NewValidator instantiates a validate code generator. 63 func NewValidator() *Validator { 64 var ( 65 v = &Validator{seen: make(map[string]*bytes.Buffer)} 66 err error 67 ) 68 fm := template.FuncMap{ 69 "tabs": Tabs, 70 "slice": toSlice, 71 "oneof": oneof, 72 "constant": constant, 73 "goifyAtt": GoifyAtt, 74 "add": Add, 75 "recurseAttribute": v.recurseAttribute, 76 } 77 v.arrayValT, err = template.New("array").Funcs(fm).Parse(arrayValTmpl) 78 if err != nil { 79 panic(err) 80 } 81 v.hashValT, err = template.New("hash").Funcs(fm).Parse(hashValTmpl) 82 if err != nil { 83 panic(err) 84 } 85 v.userValT, err = template.New("user").Funcs(fm).Parse(userValTmpl) 86 if err != nil { 87 panic(err) 88 } 89 return v 90 } 91 92 // Code produces Go code that runs the validation checks recursively over the given attribute. 93 func (v *Validator) Code(att *design.AttributeDefinition, nonzero, required, hasDefault bool, target, context string, depth int, private bool) string { 94 buf := v.recurse(att, nonzero, required, hasDefault, target, context, depth, private) 95 return buf.String() 96 } 97 98 func (v *Validator) arrayValCode(att *design.AttributeDefinition, nonzero, required, hasDefault bool, target, context string, depth int, private bool) []byte { 99 a := att.Type.ToArray() 100 if a == nil { 101 return nil 102 } 103 104 var buf bytes.Buffer 105 106 // Perform any validation on the array type such as MinLength, MaxLength, etc. 107 validation := ValidationChecker(att, nonzero, required, hasDefault, target, context, depth, private) 108 first := true 109 if validation != "" { 110 buf.WriteString(validation) 111 first = false 112 } 113 val := v.Code(a.ElemType, true, false, false, "e", context+"[*]", depth+1, false) 114 if val != "" { 115 switch a.ElemType.Type.(type) { 116 case *design.UserTypeDefinition, *design.MediaTypeDefinition: 117 // For user and media types, call the Validate method 118 val = RunTemplate(v.userValT, map[string]interface{}{ 119 "depth": depth + 2, 120 "target": "e", 121 }) 122 val = fmt.Sprintf("%sif e != nil {\n%s\n%s}", Tabs(depth+1), val, Tabs(depth+1)) 123 } 124 data := map[string]interface{}{ 125 "elemType": a.ElemType, 126 "context": context, 127 "target": target, 128 "depth": 1, 129 "private": private, 130 "validation": val, 131 } 132 validation = RunTemplate(v.arrayValT, data) 133 if !first { 134 buf.WriteByte('\n') 135 } else { 136 first = false 137 } 138 buf.WriteString(validation) 139 } 140 return buf.Bytes() 141 } 142 143 func (v *Validator) hashValCode(att *design.AttributeDefinition, nonzero, required, hasDefault bool, target, context string, depth int, private bool) []byte { 144 h := att.Type.ToHash() 145 if h == nil { 146 return nil 147 } 148 149 var buf bytes.Buffer 150 151 // Perform any validation on the hash type such as MinLength, MaxLength, etc. 152 validation := ValidationChecker(att, nonzero, required, hasDefault, target, context, depth, private) 153 first := true 154 if validation != "" { 155 buf.WriteString(validation) 156 first = false 157 } 158 keyVal := v.Code(h.KeyType, true, false, false, "k", context+"[*]", depth+1, false) 159 if keyVal != "" { 160 switch h.KeyType.Type.(type) { 161 case *design.UserTypeDefinition, *design.MediaTypeDefinition: 162 // For user and media types, call the Validate method 163 keyVal = RunTemplate(v.userValT, map[string]interface{}{ 164 "depth": depth + 2, 165 "target": "k", 166 }) 167 keyVal = fmt.Sprintf("%sif e != nil {\n%s\n%s}", Tabs(depth+1), keyVal, Tabs(depth+1)) 168 } 169 } 170 elemVal := v.Code(h.ElemType, true, false, false, "e", context+"[*]", depth+1, false) 171 if elemVal != "" { 172 switch h.ElemType.Type.(type) { 173 case *design.UserTypeDefinition, *design.MediaTypeDefinition: 174 // For user and media types, call the Validate method 175 elemVal = RunTemplate(v.userValT, map[string]interface{}{ 176 "depth": depth + 2, 177 "target": "e", 178 }) 179 elemVal = fmt.Sprintf("%sif e != nil {\n%s\n%s}", Tabs(depth+1), elemVal, Tabs(depth+1)) 180 } 181 } 182 if keyVal != "" || elemVal != "" { 183 data := map[string]interface{}{ 184 "depth": 1, 185 "target": target, 186 "keyValidation": keyVal, 187 "elemValidation": elemVal, 188 } 189 validation = RunTemplate(v.hashValT, data) 190 if !first { 191 buf.WriteByte('\n') 192 } else { 193 first = false 194 } 195 buf.WriteString(validation) 196 } 197 return buf.Bytes() 198 } 199 200 func (v *Validator) recurse(att *design.AttributeDefinition, nonzero, required, hasDefault bool, target, context string, depth int, private bool) *bytes.Buffer { 201 var ( 202 buf = new(bytes.Buffer) 203 first = true 204 ) 205 206 // Break infinite recursions 207 switch dt := att.Type.(type) { 208 case *design.MediaTypeDefinition: 209 if buf, ok := v.seen[dt.TypeName]; ok { 210 return buf 211 } 212 v.seen[dt.TypeName] = buf 213 case *design.UserTypeDefinition: 214 if buf, ok := v.seen[dt.TypeName]; ok { 215 return buf 216 } 217 v.seen[dt.TypeName] = buf 218 } 219 220 if o := att.Type.ToObject(); o != nil { 221 if ds, ok := att.Type.(design.DataStructure); ok { 222 att = ds.Definition() 223 } 224 validation := ValidationChecker(att, nonzero, required, hasDefault, target, context, depth, private) 225 if validation != "" { 226 buf.WriteString(validation) 227 first = false 228 } 229 o.IterateAttributes(func(n string, catt *design.AttributeDefinition) error { 230 validation := v.recurseAttribute(att, catt, n, target, context, depth, private) 231 if validation != "" { 232 if !first { 233 buf.WriteByte('\n') 234 } else { 235 first = false 236 } 237 buf.WriteString(validation) 238 } 239 return nil 240 }) 241 } else if a := att.Type.ToArray(); a != nil { 242 buf.Write(v.arrayValCode(att, nonzero, required, hasDefault, target, context, depth, private)) 243 } else if h := att.Type.ToHash(); h != nil { 244 buf.Write(v.hashValCode(att, nonzero, required, hasDefault, target, context, depth, private)) 245 } else { 246 validation := ValidationChecker(att, nonzero, required, hasDefault, target, context, depth, private) 247 if validation != "" { 248 buf.WriteString(validation) 249 } 250 } 251 return buf 252 } 253 254 func (v *Validator) recurseAttribute(att, catt *design.AttributeDefinition, n, target, context string, depth int, private bool) string { 255 var validation string 256 if ds, ok := catt.Type.(design.DataStructure); ok { 257 // We need to check empirically whether there are validations to be 258 // generated, we can't just generate and check whether something was 259 // generated to avoid infinite recursions. 260 hasValidations := false 261 done := errors.New("done") 262 ds.Walk(func(a *design.AttributeDefinition) error { 263 if a.Validation != nil { 264 if private { 265 hasValidations = true 266 return done 267 } 268 // For public data structures there is a case where 269 // there is validation but no actual validation 270 // code: if the validation is a required validation 271 // that applies to attributes that cannot be nil or 272 // empty string i.e. primitive types other than 273 // string. 274 if !a.Validation.HasRequiredOnly() { 275 hasValidations = true 276 return done 277 } 278 for _, name := range a.Validation.Required { 279 att := a.Type.ToObject()[name] 280 if att != nil && (!att.Type.IsPrimitive() || att.Type.Kind() == design.StringKind) { 281 hasValidations = true 282 return done 283 } 284 } 285 } 286 return nil 287 }) 288 if hasValidations { 289 validation = RunTemplate(v.userValT, map[string]interface{}{ 290 "depth": depth, 291 "target": fmt.Sprintf("%s.%s", target, GoifyAtt(catt, n, true)), 292 }) 293 } 294 } else { 295 dp := depth 296 if catt.Type.IsObject() { 297 dp++ 298 } 299 validation = v.recurse( 300 catt, 301 att.IsNonZero(n), 302 att.IsRequired(n), 303 att.HasDefaultValue(n), 304 fmt.Sprintf("%s.%s", target, GoifyAtt(catt, n, true)), 305 fmt.Sprintf("%s.%s", context, n), 306 dp, 307 private, 308 ).String() 309 } 310 if validation != "" { 311 if catt.Type.IsObject() { 312 validation = fmt.Sprintf("%sif %s.%s != nil {\n%s\n%s}", 313 Tabs(depth), target, GoifyAtt(catt, n, true), validation, Tabs(depth)) 314 } 315 } 316 return validation 317 } 318 319 // ValidationChecker produces Go code that runs the validation defined in the given attribute 320 // definition against the content of the variable named target recursively. 321 // context is used to keep track of recursion to produce helpful error messages in case of type 322 // validation error. 323 // The generated code assumes that there is a pre-existing "err" variable of type 324 // error. It initializes that variable in case a validation fails. 325 // Note: we do not want to recurse here, recursion is done by the marshaler/unmarshaler code. 326 func ValidationChecker(att *design.AttributeDefinition, nonzero, required, hasDefault bool, target, context string, depth int, private bool) string { 327 if att.Validation == nil { 328 return "" 329 } 330 t := target 331 isPointer := private || (!required && !hasDefault && !nonzero) 332 if isPointer && att.Type.IsPrimitive() { 333 t = "*" + t 334 } 335 data := map[string]interface{}{ 336 "attribute": att, 337 "isPointer": private || isPointer, 338 "nonzero": nonzero, 339 "context": context, 340 "target": target, 341 "targetVal": t, 342 "string": att.Type.Kind() == design.StringKind, 343 "array": att.Type.IsArray(), 344 "hash": att.Type.IsHash(), 345 "depth": depth, 346 "private": private, 347 } 348 res := validationsCode(att, data) 349 return strings.Join(res, "\n") 350 } 351 352 func validationsCode(att *design.AttributeDefinition, data map[string]interface{}) (res []string) { 353 validation := att.Validation 354 if values := validation.Values; values != nil { 355 data["values"] = values 356 if val := RunTemplate(enumValT, data); val != "" { 357 res = append(res, val) 358 } 359 } 360 if format := validation.Format; format != "" { 361 data["format"] = format 362 if val := RunTemplate(formatValT, data); val != "" { 363 res = append(res, val) 364 } 365 } 366 if pattern := validation.Pattern; pattern != "" { 367 data["pattern"] = pattern 368 if val := RunTemplate(patternValT, data); val != "" { 369 res = append(res, val) 370 } 371 } 372 if min := validation.Minimum; min != nil { 373 if att.Type == design.Integer { 374 data["min"] = renderInteger(*min) 375 } else { 376 data["min"] = fmt.Sprintf("%f", *min) 377 } 378 data["isMin"] = true 379 delete(data, "max") 380 if val := RunTemplate(minMaxValT, data); val != "" { 381 res = append(res, val) 382 } 383 } 384 if max := validation.Maximum; max != nil { 385 if att.Type == design.Integer { 386 data["max"] = renderInteger(*max) 387 } else { 388 data["max"] = fmt.Sprintf("%f", *max) 389 } 390 data["isMin"] = false 391 delete(data, "min") 392 if val := RunTemplate(minMaxValT, data); val != "" { 393 res = append(res, val) 394 } 395 } 396 if minLength := validation.MinLength; minLength != nil { 397 data["minLength"] = minLength 398 data["isMinLength"] = true 399 delete(data, "maxLength") 400 if val := RunTemplate(lengthValT, data); val != "" { 401 res = append(res, val) 402 } 403 } 404 if maxLength := validation.MaxLength; maxLength != nil { 405 data["maxLength"] = maxLength 406 data["isMinLength"] = false 407 delete(data, "minLength") 408 if val := RunTemplate(lengthValT, data); val != "" { 409 res = append(res, val) 410 } 411 } 412 if required := validation.Required; len(required) > 0 { 413 var val string 414 for i, r := range required { 415 if i > 0 { 416 val += "\n" 417 } 418 data["required"] = r 419 val += RunTemplate(requiredValT, data) 420 } 421 res = append(res, val) 422 } 423 return 424 } 425 426 // renderInteger renders a max or min value properly, taking into account 427 // overflows due to casting from a float value. 428 func renderInteger(f float64) string { 429 if f > math.Nextafter(float64(math.MaxInt64), 0) { 430 return fmt.Sprintf("%d", int64(math.MaxInt64)) 431 } 432 if f < math.Nextafter(float64(math.MinInt64), 0) { 433 return fmt.Sprintf("%d", int64(math.MinInt64)) 434 } 435 return fmt.Sprintf("%d", int64(f)) 436 } 437 438 // oneof produces code that compares target with each element of vals and ORs 439 // the result, e.g. "target == 1 || target == 2". 440 func oneof(target string, vals []interface{}) string { 441 elems := make([]string, len(vals)) 442 for i, v := range vals { 443 elems[i] = fmt.Sprintf("%s == %#v", target, v) 444 } 445 return strings.Join(elems, " || ") 446 } 447 448 // constant returns the Go constant name of the format with the given value. 449 func constant(formatName string) string { 450 switch formatName { 451 case "date": 452 return "goa.FormatDate" 453 case "date-time": 454 return "goa.FormatDateTime" 455 case "email": 456 return "goa.FormatEmail" 457 case "hostname": 458 return "goa.FormatHostname" 459 case "ipv4": 460 return "goa.FormatIPv4" 461 case "ipv6": 462 return "goa.FormatIPv6" 463 case "ip": 464 return "goa.FormatIP" 465 case "uri": 466 return "goa.FormatURI" 467 case "mac": 468 return "goa.FormatMAC" 469 case "cidr": 470 return "goa.FormatCIDR" 471 case "regexp": 472 return "goa.FormatRegexp" 473 case "rfc1123": 474 return "goa.FormatRFC1123" 475 } 476 panic("unknown format") // bug 477 } 478 479 const ( 480 arrayValTmpl = `{{ tabs .depth }}for _, e := range {{ .target }} { 481 {{ .validation }} 482 {{ tabs .depth }}}` 483 484 hashValTmpl = `{{ tabs .depth }}for {{ if .keyValidation }}k{{ else }}_{{ end }}, {{ if .elemValidation }}e{{ else }}_{{ end }} := range {{ .target }} { 485 {{- if .keyValidation }} 486 {{ .keyValidation }}{{ end }}{{ if .elemValidation }} 487 {{ .elemValidation }}{{ end }} 488 {{ tabs .depth }}}` 489 490 userValTmpl = `{{ tabs .depth }}if err2 := {{ .target }}.Validate(); err2 != nil { 491 {{ tabs .depth }} err = goa.MergeErrors(err, err2) 492 {{ tabs .depth }}}` 493 494 enumValTmpl = `{{ $depth := or (and .isPointer (add .depth 1)) .depth }}{{/* 495 */}}{{ if .isPointer }}{{ tabs .depth }}if {{ .target }} != nil { 496 {{ end }}{{ tabs $depth }}if !({{ oneof .targetVal .values }}) { 497 {{ tabs $depth }} err = goa.MergeErrors(err, goa.InvalidEnumValueError(` + "`" + `{{ .context }}` + "`" + `, {{ .targetVal }}, {{ slice .values }})) 498 {{ if .isPointer }}{{ tabs $depth }}} 499 {{ end }}{{ tabs .depth }}}` 500 501 patternValTmpl = `{{ $depth := or (and .isPointer (add .depth 1)) .depth }}{{/* 502 */}}{{ if .isPointer }}{{ tabs .depth }}if {{ .target }} != nil { 503 {{ end }}{{ tabs $depth }}if ok := goa.ValidatePattern(` + "`{{ .pattern }}`" + `, {{ .targetVal }}); !ok { 504 {{ tabs $depth }} err = goa.MergeErrors(err, goa.InvalidPatternError(` + "`" + `{{ .context }}` + "`" + `, {{ .targetVal }}, ` + "`{{ .pattern }}`" + `)) 505 {{ tabs $depth }}}{{ if .isPointer }} 506 {{ tabs .depth }}}{{ end }}` 507 508 formatValTmpl = `{{ $depth := or (and .isPointer (add .depth 1)) .depth }}{{/* 509 */}}{{ if .isPointer }}{{ tabs .depth }}if {{ .target }} != nil { 510 {{ end }}{{ tabs $depth }}if err2 := goa.ValidateFormat({{ constant .format }}, {{ .targetVal }}); err2 != nil { 511 {{ tabs $depth }} err = goa.MergeErrors(err, goa.InvalidFormatError(` + "`" + `{{ .context }}` + "`" + `, {{ .targetVal }}, {{ constant .format }}, err2)) 512 {{ if .isPointer }}{{ tabs $depth }}} 513 {{ end }}{{ tabs .depth }}}` 514 515 minMaxValTmpl = `{{ $depth := or (and .isPointer (add .depth 1)) .depth }}{{/* 516 */}}{{ if .isPointer }}{{ tabs .depth }}if {{ .target }} != nil { 517 {{ end }}{{ tabs .depth }} if {{ .targetVal }} {{ if .isMin }}<{{ else }}>{{ end }} {{ if .isMin }}{{ .min }}{{ else }}{{ .max }}{{ end }} { 518 {{ tabs $depth }} err = goa.MergeErrors(err, goa.InvalidRangeError(` + "`" + `{{ .context }}` + "`" + `, {{ .targetVal }}, {{ if .isMin }}{{ .min }}, true{{ else }}{{ .max }}, false{{ end }})) 519 {{ if .isPointer }}{{ tabs $depth }}} 520 {{ end }}{{ tabs .depth }}}` 521 522 lengthValTmpl = `{{$depth := or (and .isPointer (add .depth 1)) .depth}}{{/* 523 */}}{{$target := or (and (or (or .array .hash) .nonzero) .target) .targetVal}}{{/* 524 */}}{{if .isPointer}}{{tabs .depth}}if {{.target}} != nil { 525 {{end}}{{tabs .depth}} if {{if .string}}utf8.RuneCountInString({{$target}}){{else}}len({{$target}}){{end}} {{if .isMinLength}}<{{else}}>{{end}} {{if .isMinLength}}{{.minLength}}{{else}}{{.maxLength}}{{end}} { 526 {{tabs $depth}} err = goa.MergeErrors(err, goa.InvalidLengthError(` + "`" + `{{.context}}` + "`" + `, {{$target}}, {{if .string}}utf8.RuneCountInString({{$target}}){{else}}len({{$target}}){{end}}, {{if .isMinLength}}{{.minLength}}, true{{else}}{{.maxLength}}, false{{end}})) 527 {{if .isPointer}}{{tabs $depth}}} 528 {{end}}{{tabs .depth}}}` 529 530 requiredValTmpl = `{{ $att := index $.attribute.Type.ToObject .required }}{{/* 531 */}}{{ if and (not $.private) (eq $att.Type.Kind 4) }}{{ tabs $.depth }}if {{ $.target }}.{{ goifyAtt $att .required true }} == "" { 532 {{ tabs $.depth }} err = goa.MergeErrors(err, goa.MissingAttributeError(` + "`" + `{{ $.context }}` + "`" + `, "{{ .required }}")) 533 {{ tabs $.depth }}}{{ else if or $.private (not $att.Type.IsPrimitive) }}{{ tabs $.depth }}if {{ $.target }}.{{ goifyAtt $att .required true }} == nil { 534 {{ tabs $.depth }} err = goa.MergeErrors(err, goa.MissingAttributeError(` + "`" + `{{ $.context }}` + "`" + `, "{{ .required }}")) 535 {{ tabs $.depth }}}{{ end }}` 536 )