github.com/shogo82148/goa-v1@v1.6.2/goagen/codegen/validation.go (about)

     1  package codegen
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"math"
     7  	"strings"
     8  	"text/template"
     9  
    10  	"github.com/shogo82148/goa-v1/design"
    11  )
    12  
    13  var (
    14  	validationFuncs = template.FuncMap{
    15  		"tabs":     Tabs,
    16  		"slice":    toSlice,
    17  		"oneof":    oneof,
    18  		"constant": constant,
    19  		"goifyAtt": GoifyAtt,
    20  		"add":      Add,
    21  	}
    22  	enumValT     = template.Must(template.New("enum").Funcs(validationFuncs).Parse(enumValTmpl))
    23  	formatValT   = template.Must(template.New("format").Funcs(validationFuncs).Parse(formatValTmpl))
    24  	patternValT  = template.Must(template.New("pattern").Funcs(validationFuncs).Parse(patternValTmpl))
    25  	minMaxValT   = template.Must(template.New("minMax").Funcs(validationFuncs).Parse(minMaxValTmpl))
    26  	lengthValT   = template.Must(template.New("length").Funcs(validationFuncs).Parse(lengthValTmpl))
    27  	requiredValT = template.Must(template.New("required").Funcs(validationFuncs).Parse(requiredValTmpl))
    28  )
    29  
    30  // Validator is the code generator for the 'Validate' type methods.
    31  type Validator struct {
    32  	arrayValT *template.Template
    33  	hashValT  *template.Template
    34  	userValT  *template.Template
    35  	seen      map[string]*bytes.Buffer
    36  }
    37  
    38  // NewValidator instantiates a validate code generator.
    39  func NewValidator() *Validator {
    40  	v := &Validator{seen: map[string]*bytes.Buffer{}}
    41  	fm := template.FuncMap{
    42  		"tabs":             Tabs,
    43  		"slice":            toSlice,
    44  		"oneof":            oneof,
    45  		"constant":         constant,
    46  		"goifyAtt":         GoifyAtt,
    47  		"add":              Add,
    48  		"recurseAttribute": v.recurseAttribute,
    49  	}
    50  	v.arrayValT = template.Must(template.New("array").Funcs(fm).Parse(arrayValTmpl))
    51  	v.hashValT = template.Must(template.New("hash").Funcs(fm).Parse(hashValTmpl))
    52  	v.userValT = template.Must(template.New("user").Funcs(fm).Parse(userValTmpl))
    53  	return v
    54  }
    55  
    56  // Code produces Go code that runs the validation checks recursively over the given attribute.
    57  func (v *Validator) Code(att *design.AttributeDefinition, nonzero, required, hasDefault bool, target, context string, depth int, private bool) string {
    58  	buf := v.recurse(att, nonzero, required, hasDefault, target, context, depth, private)
    59  	return buf.String()
    60  }
    61  
    62  func (v *Validator) arrayValCode(att *design.AttributeDefinition, nonzero, required, hasDefault bool, target, context string, depth int, private bool) []byte {
    63  	a := att.Type.ToArray()
    64  	if a == nil {
    65  		return nil
    66  	}
    67  
    68  	var buf bytes.Buffer
    69  
    70  	// Perform any validation on the array type such as MinLength, MaxLength, etc.
    71  	validation := ValidationChecker(att, nonzero, required, hasDefault, target, context, depth, private)
    72  	first := true
    73  	if validation != "" {
    74  		buf.WriteString(validation)
    75  		first = false
    76  	}
    77  	val := v.Code(a.ElemType, true, false, false, "e", context+"[*]", depth+1, false)
    78  	if val != "" {
    79  		switch a.ElemType.Type.(type) {
    80  		case *design.UserTypeDefinition, *design.MediaTypeDefinition:
    81  			// For user and media types, call the Validate method
    82  			val = RunTemplate(v.userValT, map[string]interface{}{
    83  				"depth":  depth + 2,
    84  				"target": "e",
    85  			})
    86  			val = fmt.Sprintf("%sif e != nil {\n%s\n%s}", Tabs(depth+1), val, Tabs(depth+1))
    87  		}
    88  		data := map[string]interface{}{
    89  			"elemType":   a.ElemType,
    90  			"context":    context,
    91  			"target":     target,
    92  			"depth":      1,
    93  			"private":    private,
    94  			"validation": val,
    95  		}
    96  		validation = RunTemplate(v.arrayValT, data)
    97  		if !first {
    98  			buf.WriteByte('\n')
    99  		} else {
   100  			first = false
   101  		}
   102  		buf.WriteString(validation)
   103  	}
   104  	_ = first // suppress: ineffectual assignment to first (ineffassign)
   105  	return buf.Bytes()
   106  }
   107  
   108  func (v *Validator) hashValCode(att *design.AttributeDefinition, nonzero, required, hasDefault bool, target, context string, depth int, private bool) []byte {
   109  	h := att.Type.ToHash()
   110  	if h == nil {
   111  		return nil
   112  	}
   113  
   114  	var buf bytes.Buffer
   115  
   116  	// Perform any validation on the hash type such as MinLength, MaxLength, etc.
   117  	validation := ValidationChecker(att, nonzero, required, hasDefault, target, context, depth, private)
   118  	first := true
   119  	if validation != "" {
   120  		buf.WriteString(validation)
   121  		first = false
   122  	}
   123  	keyVal := v.Code(h.KeyType, true, false, false, "k", context+"[*]", depth+1, false)
   124  	if keyVal != "" {
   125  		switch h.KeyType.Type.(type) {
   126  		case *design.UserTypeDefinition, *design.MediaTypeDefinition:
   127  			// For user and media types, call the Validate method
   128  			keyVal = RunTemplate(v.userValT, map[string]interface{}{
   129  				"depth":  depth + 2,
   130  				"target": "k",
   131  			})
   132  			keyVal = fmt.Sprintf("%sif e != nil {\n%s\n%s}", Tabs(depth+1), keyVal, Tabs(depth+1))
   133  		}
   134  	}
   135  	elemVal := v.Code(h.ElemType, true, false, false, "e", context+"[*]", depth+1, false)
   136  	if elemVal != "" {
   137  		switch h.ElemType.Type.(type) {
   138  		case *design.UserTypeDefinition, *design.MediaTypeDefinition:
   139  			// For user and media types, call the Validate method
   140  			elemVal = RunTemplate(v.userValT, map[string]interface{}{
   141  				"depth":  depth + 2,
   142  				"target": "e",
   143  			})
   144  			elemVal = fmt.Sprintf("%sif e != nil {\n%s\n%s}", Tabs(depth+1), elemVal, Tabs(depth+1))
   145  		}
   146  	}
   147  	if keyVal != "" || elemVal != "" {
   148  		data := map[string]interface{}{
   149  			"depth":          1,
   150  			"target":         target,
   151  			"keyValidation":  keyVal,
   152  			"elemValidation": elemVal,
   153  		}
   154  		validation = RunTemplate(v.hashValT, data)
   155  		if !first {
   156  			buf.WriteByte('\n')
   157  		} else {
   158  			first = false
   159  		}
   160  		buf.WriteString(validation)
   161  	}
   162  	_ = first // suppress: ineffectual assignment to first (ineffassign)
   163  	return buf.Bytes()
   164  }
   165  
   166  func (v *Validator) recurse(att *design.AttributeDefinition, nonzero, required, hasDefault bool, target, context string, depth int, private bool) *bytes.Buffer {
   167  	var (
   168  		buf   = new(bytes.Buffer)
   169  		first = true
   170  	)
   171  
   172  	// Break infinite recursions
   173  	switch dt := att.Type.(type) {
   174  	case *design.MediaTypeDefinition:
   175  		if buf, ok := v.seen[dt.TypeName]; ok {
   176  			return buf
   177  		}
   178  		v.seen[dt.TypeName] = buf
   179  	case *design.UserTypeDefinition:
   180  		if buf, ok := v.seen[dt.TypeName]; ok {
   181  			return buf
   182  		}
   183  		v.seen[dt.TypeName] = buf
   184  	}
   185  
   186  	if o := att.Type.ToObject(); o != nil {
   187  		if ds, ok := att.Type.(design.DataStructure); ok {
   188  			att = ds.Definition()
   189  		}
   190  		validation := ValidationChecker(att, nonzero, required, hasDefault, target, context, depth, private)
   191  		if validation != "" {
   192  			buf.WriteString(validation)
   193  			first = false
   194  		}
   195  		o.IterateAttributes(func(n string, catt *design.AttributeDefinition) error {
   196  			validation := v.recurseAttribute(att, catt, n, target, context, depth, private)
   197  			if validation != "" {
   198  				if !first {
   199  					buf.WriteByte('\n')
   200  				} else {
   201  					first = false
   202  				}
   203  				buf.WriteString(validation)
   204  			}
   205  			return nil
   206  		})
   207  	} else if a := att.Type.ToArray(); a != nil {
   208  		buf.Write(v.arrayValCode(att, nonzero, required, hasDefault, target, context, depth, private))
   209  	} else if h := att.Type.ToHash(); h != nil {
   210  		buf.Write(v.hashValCode(att, nonzero, required, hasDefault, target, context, depth, private))
   211  	} else {
   212  		validation := ValidationChecker(att, nonzero, required, hasDefault, target, context, depth, private)
   213  		if validation != "" {
   214  			buf.WriteString(validation)
   215  		}
   216  	}
   217  	return buf
   218  }
   219  
   220  func (v *Validator) recurseAttribute(att, catt *design.AttributeDefinition, n, target, context string, depth int, private bool) string {
   221  	var validation string
   222  	if _, ok := catt.Type.(design.DataStructure); ok {
   223  		validation = RunTemplate(v.userValT, map[string]interface{}{
   224  			"depth":  depth,
   225  			"target": fmt.Sprintf("%s.%s", target, GoifyAtt(catt, n, true)),
   226  		})
   227  	} else {
   228  		dp := depth
   229  		if catt.Type.IsObject() {
   230  			dp++
   231  		}
   232  		validation = v.recurse(
   233  			catt,
   234  			att.IsNonZero(n),
   235  			att.IsRequired(n),
   236  			att.HasDefaultValue(n),
   237  			fmt.Sprintf("%s.%s", target, GoifyAtt(catt, n, true)),
   238  			fmt.Sprintf("%s.%s", context, n),
   239  			dp,
   240  			private,
   241  		).String()
   242  	}
   243  	if validation != "" {
   244  		if catt.Type.IsObject() {
   245  			validation = fmt.Sprintf("%sif %s.%s != nil {\n%s\n%s}",
   246  				Tabs(depth), target, GoifyAtt(catt, n, true), validation, Tabs(depth))
   247  		}
   248  	}
   249  	return validation
   250  }
   251  
   252  // ValidationChecker produces Go code that runs the validation defined in the given attribute
   253  // definition against the content of the variable named target recursively.
   254  // context is used to keep track of recursion to produce helpful error messages in case of type
   255  // validation error.
   256  // The generated code assumes that there is a pre-existing "err" variable of type
   257  // error. It initializes that variable in case a validation fails.
   258  // Note: we do not want to recurse here, recursion is done by the marshaler/unmarshaler code.
   259  func ValidationChecker(att *design.AttributeDefinition, nonzero, required, hasDefault bool, target, context string, depth int, private bool) string {
   260  	if att.Validation == nil {
   261  		return ""
   262  	}
   263  	t := target
   264  	isPointer := private || (!required && !hasDefault && !nonzero)
   265  	if isPointer && att.Type.IsPrimitive() {
   266  		t = "*" + t
   267  	}
   268  	data := map[string]interface{}{
   269  		"attribute": att,
   270  		"isPointer": private || isPointer,
   271  		"nonzero":   nonzero,
   272  		"context":   context,
   273  		"target":    target,
   274  		"targetVal": t,
   275  		"string":    att.Type.Kind() == design.StringKind,
   276  		"array":     att.Type.IsArray(),
   277  		"hash":      att.Type.IsHash(),
   278  		"depth":     depth,
   279  		"private":   private,
   280  	}
   281  	res := validationsCode(att, data)
   282  	return strings.Join(res, "\n")
   283  }
   284  
   285  func validationsCode(att *design.AttributeDefinition, data map[string]interface{}) (res []string) {
   286  	validation := att.Validation
   287  	if values := validation.Values; values != nil {
   288  		data["values"] = values
   289  		if val := RunTemplate(enumValT, data); val != "" {
   290  			res = append(res, val)
   291  		}
   292  	}
   293  	if format := validation.Format; format != "" {
   294  		data["format"] = format
   295  		if val := RunTemplate(formatValT, data); val != "" {
   296  			res = append(res, val)
   297  		}
   298  	}
   299  	if pattern := validation.Pattern; pattern != "" {
   300  		data["pattern"] = pattern
   301  		if val := RunTemplate(patternValT, data); val != "" {
   302  			res = append(res, val)
   303  		}
   304  	}
   305  	if min := validation.Minimum; min != nil {
   306  		if att.Type == design.Integer {
   307  			data["min"] = renderInteger(*min)
   308  		} else {
   309  			data["min"] = fmt.Sprintf("%f", *min)
   310  		}
   311  		data["isMin"] = true
   312  		delete(data, "max")
   313  		if val := RunTemplate(minMaxValT, data); val != "" {
   314  			res = append(res, val)
   315  		}
   316  	}
   317  	if max := validation.Maximum; max != nil {
   318  		if att.Type == design.Integer {
   319  			data["max"] = renderInteger(*max)
   320  		} else {
   321  			data["max"] = fmt.Sprintf("%f", *max)
   322  		}
   323  		data["isMin"] = false
   324  		delete(data, "min")
   325  		if val := RunTemplate(minMaxValT, data); val != "" {
   326  			res = append(res, val)
   327  		}
   328  	}
   329  	if minLength := validation.MinLength; minLength != nil {
   330  		data["minLength"] = minLength
   331  		data["isMinLength"] = true
   332  		delete(data, "maxLength")
   333  		if val := RunTemplate(lengthValT, data); val != "" {
   334  			res = append(res, val)
   335  		}
   336  	}
   337  	if maxLength := validation.MaxLength; maxLength != nil {
   338  		data["maxLength"] = maxLength
   339  		data["isMinLength"] = false
   340  		delete(data, "minLength")
   341  		if val := RunTemplate(lengthValT, data); val != "" {
   342  			res = append(res, val)
   343  		}
   344  	}
   345  	if required := validation.Required; len(required) > 0 {
   346  		var val string
   347  		for i, r := range required {
   348  			if i > 0 {
   349  				val += "\n"
   350  			}
   351  			data["required"] = r
   352  			val += RunTemplate(requiredValT, data)
   353  		}
   354  		res = append(res, val)
   355  	}
   356  	return
   357  }
   358  
   359  // renderInteger renders a max or min value properly, taking into account
   360  // overflows due to casting from a float value.
   361  func renderInteger(f float64) string {
   362  	if f > math.Nextafter(float64(math.MaxInt64), 0) {
   363  		return fmt.Sprintf("%d", int64(math.MaxInt64))
   364  	}
   365  	if f < math.Nextafter(float64(math.MinInt64), 0) {
   366  		return fmt.Sprintf("%d", int64(math.MinInt64))
   367  	}
   368  	return fmt.Sprintf("%d", int64(f))
   369  }
   370  
   371  // oneof produces code that compares target with each element of vals and ORs
   372  // the result, e.g. "target == 1 || target == 2".
   373  func oneof(target string, vals []interface{}) string {
   374  	elems := make([]string, len(vals))
   375  	for i, v := range vals {
   376  		elems[i] = fmt.Sprintf("%s == %#v", target, v)
   377  	}
   378  	return strings.Join(elems, " || ")
   379  }
   380  
   381  // constant returns the Go constant name of the format with the given value.
   382  func constant(formatName string) string {
   383  	switch formatName {
   384  	case "date":
   385  		return "goa.FormatDate"
   386  	case "date-time":
   387  		return "goa.FormatDateTime"
   388  	case "email":
   389  		return "goa.FormatEmail"
   390  	case "hostname":
   391  		return "goa.FormatHostname"
   392  	case "ipv4":
   393  		return "goa.FormatIPv4"
   394  	case "ipv6":
   395  		return "goa.FormatIPv6"
   396  	case "ip":
   397  		return "goa.FormatIP"
   398  	case "uri":
   399  		return "goa.FormatURI"
   400  	case "mac":
   401  		return "goa.FormatMAC"
   402  	case "cidr":
   403  		return "goa.FormatCIDR"
   404  	case "regexp":
   405  		return "goa.FormatRegexp"
   406  	case "rfc1123":
   407  		return "goa.FormatRFC1123"
   408  	}
   409  	panic("unknown format") // bug
   410  }
   411  
   412  const (
   413  	arrayValTmpl = `{{ tabs .depth }}for _, e := range {{ .target }} {
   414  {{ .validation }}
   415  {{ tabs .depth }}}`
   416  
   417  	hashValTmpl = `{{ tabs .depth }}for {{ if .keyValidation }}k{{ else }}_{{ end }}, {{ if .elemValidation }}e{{ else }}_{{ end }} := range {{ .target }} {
   418  {{- if .keyValidation }}
   419  {{ .keyValidation }}{{ end }}{{ if .elemValidation }}
   420  {{ .elemValidation }}{{ end }}
   421  {{ tabs .depth }}}`
   422  
   423  	userValTmpl = `{{ tabs .depth }}if err2 := {{ .target }}.Validate(); err2 != nil {
   424  {{ tabs .depth }}	err = goa.MergeErrors(err, err2)
   425  {{ tabs .depth }}}`
   426  
   427  	enumValTmpl = `{{ $depth := or (and .isPointer (add .depth 1)) .depth }}{{/*
   428  */}}{{ if .isPointer }}{{ tabs .depth }}if {{ .target }} != nil {
   429  {{ end }}{{ tabs $depth }}if !({{ oneof .targetVal .values }}) {
   430  {{ tabs $depth }}	err = goa.MergeErrors(err, goa.InvalidEnumValueError(` + "`" + `{{ .context }}` + "`" + `, {{ .targetVal }}, {{ slice .values }}))
   431  {{ if .isPointer }}{{ tabs $depth }}}
   432  {{ end }}{{ tabs .depth }}}`
   433  
   434  	patternValTmpl = `{{ $depth := or (and .isPointer (add .depth 1)) .depth }}{{/*
   435  */}}{{ if .isPointer }}{{ tabs .depth }}if {{ .target }} != nil {
   436  {{ end }}{{ tabs $depth }}if ok := goa.ValidatePattern(` + "`{{ .pattern }}`" + `, {{ .targetVal }}); !ok {
   437  {{ tabs $depth }}	err = goa.MergeErrors(err, goa.InvalidPatternError(` + "`" + `{{ .context }}` + "`" + `, {{ .targetVal }}, ` + "`{{ .pattern }}`" + `))
   438  {{ tabs $depth }}}{{ if .isPointer }}
   439  {{ tabs .depth }}}{{ end }}`
   440  
   441  	formatValTmpl = `{{ $depth := or (and .isPointer (add .depth 1)) .depth }}{{/*
   442  */}}{{ if .isPointer }}{{ tabs .depth }}if {{ .target }} != nil {
   443  {{ end }}{{ tabs $depth }}if err2 := goa.ValidateFormat({{ constant .format }}, {{ .targetVal }}); err2 != nil {
   444  {{ tabs $depth }}		err = goa.MergeErrors(err, goa.InvalidFormatError(` + "`" + `{{ .context }}` + "`" + `, {{ .targetVal }}, {{ constant .format }}, err2))
   445  {{ if .isPointer }}{{ tabs $depth }}}
   446  {{ end }}{{ tabs .depth }}}`
   447  
   448  	minMaxValTmpl = `{{ $depth := or (and .isPointer (add .depth 1)) .depth }}{{/*
   449  */}}{{ if .isPointer }}{{ tabs .depth }}if {{ .target }} != nil {
   450  {{ end }}{{ tabs .depth }}	if {{ .targetVal }} {{ if .isMin }}<{{ else }}>{{ end }} {{ if .isMin }}{{ .min }}{{ else }}{{ .max }}{{ end }} {
   451  {{ tabs $depth }}	err = goa.MergeErrors(err, goa.InvalidRangeError(` + "`" + `{{ .context }}` + "`" + `, {{ .targetVal }}, {{ if .isMin }}{{ .min }}, true{{ else }}{{ .max }}, false{{ end }}))
   452  {{ if .isPointer }}{{ tabs $depth }}}
   453  {{ end }}{{ tabs .depth }}}`
   454  
   455  	lengthValTmpl = `{{$depth := or (and .isPointer (add .depth 1)) .depth}}{{/*
   456  */}}{{$target := or (and (or (or .array .hash) .nonzero) .target) .targetVal}}{{/*
   457  */}}{{if .isPointer}}{{tabs .depth}}if {{.target}} != nil {
   458  {{end}}{{tabs .depth}}	if {{if .string}}utf8.RuneCountInString({{$target}}){{else}}len({{$target}}){{end}} {{if .isMinLength}}<{{else}}>{{end}} {{if .isMinLength}}{{.minLength}}{{else}}{{.maxLength}}{{end}} {
   459  {{tabs $depth}}	err = goa.MergeErrors(err, goa.InvalidLengthError(` + "`" + `{{.context}}` + "`" + `, {{$target}}, {{if .string}}utf8.RuneCountInString({{$target}}){{else}}len({{$target}}){{end}}, {{if .isMinLength}}{{.minLength}}, true{{else}}{{.maxLength}}, false{{end}}))
   460  {{if .isPointer}}{{tabs $depth}}}
   461  {{end}}{{tabs .depth}}}`
   462  
   463  	requiredValTmpl = `{{ $att := index $.attribute.Type.ToObject .required }}{{/*
   464  */}}{{ if or $.private (not $att.Type.IsPrimitive) }}{{ tabs $.depth }}if {{ $.target }}.{{ goifyAtt $att .required true }} == nil {
   465  {{ tabs $.depth }}	err = goa.MergeErrors(err, goa.MissingAttributeError(` + "`" + `{{ $.context }}` + "`" + `, "{{ .required }}"))
   466  {{ tabs $.depth }}}{{ end }}`
   467  )