github.com/shogo82148/goa-v1@v1.6.2/goagen/codegen/validation.go (about) 1 package codegen 2 3 import ( 4 "bytes" 5 "fmt" 6 "math" 7 "strings" 8 "text/template" 9 10 "github.com/shogo82148/goa-v1/design" 11 ) 12 13 var ( 14 validationFuncs = template.FuncMap{ 15 "tabs": Tabs, 16 "slice": toSlice, 17 "oneof": oneof, 18 "constant": constant, 19 "goifyAtt": GoifyAtt, 20 "add": Add, 21 } 22 enumValT = template.Must(template.New("enum").Funcs(validationFuncs).Parse(enumValTmpl)) 23 formatValT = template.Must(template.New("format").Funcs(validationFuncs).Parse(formatValTmpl)) 24 patternValT = template.Must(template.New("pattern").Funcs(validationFuncs).Parse(patternValTmpl)) 25 minMaxValT = template.Must(template.New("minMax").Funcs(validationFuncs).Parse(minMaxValTmpl)) 26 lengthValT = template.Must(template.New("length").Funcs(validationFuncs).Parse(lengthValTmpl)) 27 requiredValT = template.Must(template.New("required").Funcs(validationFuncs).Parse(requiredValTmpl)) 28 ) 29 30 // Validator is the code generator for the 'Validate' type methods. 31 type Validator struct { 32 arrayValT *template.Template 33 hashValT *template.Template 34 userValT *template.Template 35 seen map[string]*bytes.Buffer 36 } 37 38 // NewValidator instantiates a validate code generator. 39 func NewValidator() *Validator { 40 v := &Validator{seen: map[string]*bytes.Buffer{}} 41 fm := template.FuncMap{ 42 "tabs": Tabs, 43 "slice": toSlice, 44 "oneof": oneof, 45 "constant": constant, 46 "goifyAtt": GoifyAtt, 47 "add": Add, 48 "recurseAttribute": v.recurseAttribute, 49 } 50 v.arrayValT = template.Must(template.New("array").Funcs(fm).Parse(arrayValTmpl)) 51 v.hashValT = template.Must(template.New("hash").Funcs(fm).Parse(hashValTmpl)) 52 v.userValT = template.Must(template.New("user").Funcs(fm).Parse(userValTmpl)) 53 return v 54 } 55 56 // Code produces Go code that runs the validation checks recursively over the given attribute. 57 func (v *Validator) Code(att *design.AttributeDefinition, nonzero, required, hasDefault bool, target, context string, depth int, private bool) string { 58 buf := v.recurse(att, nonzero, required, hasDefault, target, context, depth, private) 59 return buf.String() 60 } 61 62 func (v *Validator) arrayValCode(att *design.AttributeDefinition, nonzero, required, hasDefault bool, target, context string, depth int, private bool) []byte { 63 a := att.Type.ToArray() 64 if a == nil { 65 return nil 66 } 67 68 var buf bytes.Buffer 69 70 // Perform any validation on the array type such as MinLength, MaxLength, etc. 71 validation := ValidationChecker(att, nonzero, required, hasDefault, target, context, depth, private) 72 first := true 73 if validation != "" { 74 buf.WriteString(validation) 75 first = false 76 } 77 val := v.Code(a.ElemType, true, false, false, "e", context+"[*]", depth+1, false) 78 if val != "" { 79 switch a.ElemType.Type.(type) { 80 case *design.UserTypeDefinition, *design.MediaTypeDefinition: 81 // For user and media types, call the Validate method 82 val = RunTemplate(v.userValT, map[string]interface{}{ 83 "depth": depth + 2, 84 "target": "e", 85 }) 86 val = fmt.Sprintf("%sif e != nil {\n%s\n%s}", Tabs(depth+1), val, Tabs(depth+1)) 87 } 88 data := map[string]interface{}{ 89 "elemType": a.ElemType, 90 "context": context, 91 "target": target, 92 "depth": 1, 93 "private": private, 94 "validation": val, 95 } 96 validation = RunTemplate(v.arrayValT, data) 97 if !first { 98 buf.WriteByte('\n') 99 } else { 100 first = false 101 } 102 buf.WriteString(validation) 103 } 104 _ = first // suppress: ineffectual assignment to first (ineffassign) 105 return buf.Bytes() 106 } 107 108 func (v *Validator) hashValCode(att *design.AttributeDefinition, nonzero, required, hasDefault bool, target, context string, depth int, private bool) []byte { 109 h := att.Type.ToHash() 110 if h == nil { 111 return nil 112 } 113 114 var buf bytes.Buffer 115 116 // Perform any validation on the hash type such as MinLength, MaxLength, etc. 117 validation := ValidationChecker(att, nonzero, required, hasDefault, target, context, depth, private) 118 first := true 119 if validation != "" { 120 buf.WriteString(validation) 121 first = false 122 } 123 keyVal := v.Code(h.KeyType, true, false, false, "k", context+"[*]", depth+1, false) 124 if keyVal != "" { 125 switch h.KeyType.Type.(type) { 126 case *design.UserTypeDefinition, *design.MediaTypeDefinition: 127 // For user and media types, call the Validate method 128 keyVal = RunTemplate(v.userValT, map[string]interface{}{ 129 "depth": depth + 2, 130 "target": "k", 131 }) 132 keyVal = fmt.Sprintf("%sif e != nil {\n%s\n%s}", Tabs(depth+1), keyVal, Tabs(depth+1)) 133 } 134 } 135 elemVal := v.Code(h.ElemType, true, false, false, "e", context+"[*]", depth+1, false) 136 if elemVal != "" { 137 switch h.ElemType.Type.(type) { 138 case *design.UserTypeDefinition, *design.MediaTypeDefinition: 139 // For user and media types, call the Validate method 140 elemVal = RunTemplate(v.userValT, map[string]interface{}{ 141 "depth": depth + 2, 142 "target": "e", 143 }) 144 elemVal = fmt.Sprintf("%sif e != nil {\n%s\n%s}", Tabs(depth+1), elemVal, Tabs(depth+1)) 145 } 146 } 147 if keyVal != "" || elemVal != "" { 148 data := map[string]interface{}{ 149 "depth": 1, 150 "target": target, 151 "keyValidation": keyVal, 152 "elemValidation": elemVal, 153 } 154 validation = RunTemplate(v.hashValT, data) 155 if !first { 156 buf.WriteByte('\n') 157 } else { 158 first = false 159 } 160 buf.WriteString(validation) 161 } 162 _ = first // suppress: ineffectual assignment to first (ineffassign) 163 return buf.Bytes() 164 } 165 166 func (v *Validator) recurse(att *design.AttributeDefinition, nonzero, required, hasDefault bool, target, context string, depth int, private bool) *bytes.Buffer { 167 var ( 168 buf = new(bytes.Buffer) 169 first = true 170 ) 171 172 // Break infinite recursions 173 switch dt := att.Type.(type) { 174 case *design.MediaTypeDefinition: 175 if buf, ok := v.seen[dt.TypeName]; ok { 176 return buf 177 } 178 v.seen[dt.TypeName] = buf 179 case *design.UserTypeDefinition: 180 if buf, ok := v.seen[dt.TypeName]; ok { 181 return buf 182 } 183 v.seen[dt.TypeName] = buf 184 } 185 186 if o := att.Type.ToObject(); o != nil { 187 if ds, ok := att.Type.(design.DataStructure); ok { 188 att = ds.Definition() 189 } 190 validation := ValidationChecker(att, nonzero, required, hasDefault, target, context, depth, private) 191 if validation != "" { 192 buf.WriteString(validation) 193 first = false 194 } 195 o.IterateAttributes(func(n string, catt *design.AttributeDefinition) error { 196 validation := v.recurseAttribute(att, catt, n, target, context, depth, private) 197 if validation != "" { 198 if !first { 199 buf.WriteByte('\n') 200 } else { 201 first = false 202 } 203 buf.WriteString(validation) 204 } 205 return nil 206 }) 207 } else if a := att.Type.ToArray(); a != nil { 208 buf.Write(v.arrayValCode(att, nonzero, required, hasDefault, target, context, depth, private)) 209 } else if h := att.Type.ToHash(); h != nil { 210 buf.Write(v.hashValCode(att, nonzero, required, hasDefault, target, context, depth, private)) 211 } else { 212 validation := ValidationChecker(att, nonzero, required, hasDefault, target, context, depth, private) 213 if validation != "" { 214 buf.WriteString(validation) 215 } 216 } 217 return buf 218 } 219 220 func (v *Validator) recurseAttribute(att, catt *design.AttributeDefinition, n, target, context string, depth int, private bool) string { 221 var validation string 222 if _, ok := catt.Type.(design.DataStructure); ok { 223 validation = RunTemplate(v.userValT, map[string]interface{}{ 224 "depth": depth, 225 "target": fmt.Sprintf("%s.%s", target, GoifyAtt(catt, n, true)), 226 }) 227 } else { 228 dp := depth 229 if catt.Type.IsObject() { 230 dp++ 231 } 232 validation = v.recurse( 233 catt, 234 att.IsNonZero(n), 235 att.IsRequired(n), 236 att.HasDefaultValue(n), 237 fmt.Sprintf("%s.%s", target, GoifyAtt(catt, n, true)), 238 fmt.Sprintf("%s.%s", context, n), 239 dp, 240 private, 241 ).String() 242 } 243 if validation != "" { 244 if catt.Type.IsObject() { 245 validation = fmt.Sprintf("%sif %s.%s != nil {\n%s\n%s}", 246 Tabs(depth), target, GoifyAtt(catt, n, true), validation, Tabs(depth)) 247 } 248 } 249 return validation 250 } 251 252 // ValidationChecker produces Go code that runs the validation defined in the given attribute 253 // definition against the content of the variable named target recursively. 254 // context is used to keep track of recursion to produce helpful error messages in case of type 255 // validation error. 256 // The generated code assumes that there is a pre-existing "err" variable of type 257 // error. It initializes that variable in case a validation fails. 258 // Note: we do not want to recurse here, recursion is done by the marshaler/unmarshaler code. 259 func ValidationChecker(att *design.AttributeDefinition, nonzero, required, hasDefault bool, target, context string, depth int, private bool) string { 260 if att.Validation == nil { 261 return "" 262 } 263 t := target 264 isPointer := private || (!required && !hasDefault && !nonzero) 265 if isPointer && att.Type.IsPrimitive() { 266 t = "*" + t 267 } 268 data := map[string]interface{}{ 269 "attribute": att, 270 "isPointer": private || isPointer, 271 "nonzero": nonzero, 272 "context": context, 273 "target": target, 274 "targetVal": t, 275 "string": att.Type.Kind() == design.StringKind, 276 "array": att.Type.IsArray(), 277 "hash": att.Type.IsHash(), 278 "depth": depth, 279 "private": private, 280 } 281 res := validationsCode(att, data) 282 return strings.Join(res, "\n") 283 } 284 285 func validationsCode(att *design.AttributeDefinition, data map[string]interface{}) (res []string) { 286 validation := att.Validation 287 if values := validation.Values; values != nil { 288 data["values"] = values 289 if val := RunTemplate(enumValT, data); val != "" { 290 res = append(res, val) 291 } 292 } 293 if format := validation.Format; format != "" { 294 data["format"] = format 295 if val := RunTemplate(formatValT, data); val != "" { 296 res = append(res, val) 297 } 298 } 299 if pattern := validation.Pattern; pattern != "" { 300 data["pattern"] = pattern 301 if val := RunTemplate(patternValT, data); val != "" { 302 res = append(res, val) 303 } 304 } 305 if min := validation.Minimum; min != nil { 306 if att.Type == design.Integer { 307 data["min"] = renderInteger(*min) 308 } else { 309 data["min"] = fmt.Sprintf("%f", *min) 310 } 311 data["isMin"] = true 312 delete(data, "max") 313 if val := RunTemplate(minMaxValT, data); val != "" { 314 res = append(res, val) 315 } 316 } 317 if max := validation.Maximum; max != nil { 318 if att.Type == design.Integer { 319 data["max"] = renderInteger(*max) 320 } else { 321 data["max"] = fmt.Sprintf("%f", *max) 322 } 323 data["isMin"] = false 324 delete(data, "min") 325 if val := RunTemplate(minMaxValT, data); val != "" { 326 res = append(res, val) 327 } 328 } 329 if minLength := validation.MinLength; minLength != nil { 330 data["minLength"] = minLength 331 data["isMinLength"] = true 332 delete(data, "maxLength") 333 if val := RunTemplate(lengthValT, data); val != "" { 334 res = append(res, val) 335 } 336 } 337 if maxLength := validation.MaxLength; maxLength != nil { 338 data["maxLength"] = maxLength 339 data["isMinLength"] = false 340 delete(data, "minLength") 341 if val := RunTemplate(lengthValT, data); val != "" { 342 res = append(res, val) 343 } 344 } 345 if required := validation.Required; len(required) > 0 { 346 var val string 347 for i, r := range required { 348 if i > 0 { 349 val += "\n" 350 } 351 data["required"] = r 352 val += RunTemplate(requiredValT, data) 353 } 354 res = append(res, val) 355 } 356 return 357 } 358 359 // renderInteger renders a max or min value properly, taking into account 360 // overflows due to casting from a float value. 361 func renderInteger(f float64) string { 362 if f > math.Nextafter(float64(math.MaxInt64), 0) { 363 return fmt.Sprintf("%d", int64(math.MaxInt64)) 364 } 365 if f < math.Nextafter(float64(math.MinInt64), 0) { 366 return fmt.Sprintf("%d", int64(math.MinInt64)) 367 } 368 return fmt.Sprintf("%d", int64(f)) 369 } 370 371 // oneof produces code that compares target with each element of vals and ORs 372 // the result, e.g. "target == 1 || target == 2". 373 func oneof(target string, vals []interface{}) string { 374 elems := make([]string, len(vals)) 375 for i, v := range vals { 376 elems[i] = fmt.Sprintf("%s == %#v", target, v) 377 } 378 return strings.Join(elems, " || ") 379 } 380 381 // constant returns the Go constant name of the format with the given value. 382 func constant(formatName string) string { 383 switch formatName { 384 case "date": 385 return "goa.FormatDate" 386 case "date-time": 387 return "goa.FormatDateTime" 388 case "email": 389 return "goa.FormatEmail" 390 case "hostname": 391 return "goa.FormatHostname" 392 case "ipv4": 393 return "goa.FormatIPv4" 394 case "ipv6": 395 return "goa.FormatIPv6" 396 case "ip": 397 return "goa.FormatIP" 398 case "uri": 399 return "goa.FormatURI" 400 case "mac": 401 return "goa.FormatMAC" 402 case "cidr": 403 return "goa.FormatCIDR" 404 case "regexp": 405 return "goa.FormatRegexp" 406 case "rfc1123": 407 return "goa.FormatRFC1123" 408 } 409 panic("unknown format") // bug 410 } 411 412 const ( 413 arrayValTmpl = `{{ tabs .depth }}for _, e := range {{ .target }} { 414 {{ .validation }} 415 {{ tabs .depth }}}` 416 417 hashValTmpl = `{{ tabs .depth }}for {{ if .keyValidation }}k{{ else }}_{{ end }}, {{ if .elemValidation }}e{{ else }}_{{ end }} := range {{ .target }} { 418 {{- if .keyValidation }} 419 {{ .keyValidation }}{{ end }}{{ if .elemValidation }} 420 {{ .elemValidation }}{{ end }} 421 {{ tabs .depth }}}` 422 423 userValTmpl = `{{ tabs .depth }}if err2 := {{ .target }}.Validate(); err2 != nil { 424 {{ tabs .depth }} err = goa.MergeErrors(err, err2) 425 {{ tabs .depth }}}` 426 427 enumValTmpl = `{{ $depth := or (and .isPointer (add .depth 1)) .depth }}{{/* 428 */}}{{ if .isPointer }}{{ tabs .depth }}if {{ .target }} != nil { 429 {{ end }}{{ tabs $depth }}if !({{ oneof .targetVal .values }}) { 430 {{ tabs $depth }} err = goa.MergeErrors(err, goa.InvalidEnumValueError(` + "`" + `{{ .context }}` + "`" + `, {{ .targetVal }}, {{ slice .values }})) 431 {{ if .isPointer }}{{ tabs $depth }}} 432 {{ end }}{{ tabs .depth }}}` 433 434 patternValTmpl = `{{ $depth := or (and .isPointer (add .depth 1)) .depth }}{{/* 435 */}}{{ if .isPointer }}{{ tabs .depth }}if {{ .target }} != nil { 436 {{ end }}{{ tabs $depth }}if ok := goa.ValidatePattern(` + "`{{ .pattern }}`" + `, {{ .targetVal }}); !ok { 437 {{ tabs $depth }} err = goa.MergeErrors(err, goa.InvalidPatternError(` + "`" + `{{ .context }}` + "`" + `, {{ .targetVal }}, ` + "`{{ .pattern }}`" + `)) 438 {{ tabs $depth }}}{{ if .isPointer }} 439 {{ tabs .depth }}}{{ end }}` 440 441 formatValTmpl = `{{ $depth := or (and .isPointer (add .depth 1)) .depth }}{{/* 442 */}}{{ if .isPointer }}{{ tabs .depth }}if {{ .target }} != nil { 443 {{ end }}{{ tabs $depth }}if err2 := goa.ValidateFormat({{ constant .format }}, {{ .targetVal }}); err2 != nil { 444 {{ tabs $depth }} err = goa.MergeErrors(err, goa.InvalidFormatError(` + "`" + `{{ .context }}` + "`" + `, {{ .targetVal }}, {{ constant .format }}, err2)) 445 {{ if .isPointer }}{{ tabs $depth }}} 446 {{ end }}{{ tabs .depth }}}` 447 448 minMaxValTmpl = `{{ $depth := or (and .isPointer (add .depth 1)) .depth }}{{/* 449 */}}{{ if .isPointer }}{{ tabs .depth }}if {{ .target }} != nil { 450 {{ end }}{{ tabs .depth }} if {{ .targetVal }} {{ if .isMin }}<{{ else }}>{{ end }} {{ if .isMin }}{{ .min }}{{ else }}{{ .max }}{{ end }} { 451 {{ tabs $depth }} err = goa.MergeErrors(err, goa.InvalidRangeError(` + "`" + `{{ .context }}` + "`" + `, {{ .targetVal }}, {{ if .isMin }}{{ .min }}, true{{ else }}{{ .max }}, false{{ end }})) 452 {{ if .isPointer }}{{ tabs $depth }}} 453 {{ end }}{{ tabs .depth }}}` 454 455 lengthValTmpl = `{{$depth := or (and .isPointer (add .depth 1)) .depth}}{{/* 456 */}}{{$target := or (and (or (or .array .hash) .nonzero) .target) .targetVal}}{{/* 457 */}}{{if .isPointer}}{{tabs .depth}}if {{.target}} != nil { 458 {{end}}{{tabs .depth}} if {{if .string}}utf8.RuneCountInString({{$target}}){{else}}len({{$target}}){{end}} {{if .isMinLength}}<{{else}}>{{end}} {{if .isMinLength}}{{.minLength}}{{else}}{{.maxLength}}{{end}} { 459 {{tabs $depth}} err = goa.MergeErrors(err, goa.InvalidLengthError(` + "`" + `{{.context}}` + "`" + `, {{$target}}, {{if .string}}utf8.RuneCountInString({{$target}}){{else}}len({{$target}}){{end}}, {{if .isMinLength}}{{.minLength}}, true{{else}}{{.maxLength}}, false{{end}})) 460 {{if .isPointer}}{{tabs $depth}}} 461 {{end}}{{tabs .depth}}}` 462 463 requiredValTmpl = `{{ $att := index $.attribute.Type.ToObject .required }}{{/* 464 */}}{{ if or $.private (not $att.Type.IsPrimitive) }}{{ tabs $.depth }}if {{ $.target }}.{{ goifyAtt $att .required true }} == nil { 465 {{ tabs $.depth }} err = goa.MergeErrors(err, goa.MissingAttributeError(` + "`" + `{{ $.context }}` + "`" + `, "{{ .required }}")) 466 {{ tabs $.depth }}}{{ end }}` 467 )