github.com/ManabuSeki/goa-v1@v1.4.3/goagen/codegen/validation.go (about)

     1  package codegen
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"math"
     8  	"strings"
     9  	"text/template"
    10  
    11  	"github.com/goadesign/goa/design"
    12  )
    13  
    14  var (
    15  	enumValT     *template.Template
    16  	formatValT   *template.Template
    17  	patternValT  *template.Template
    18  	minMaxValT   *template.Template
    19  	lengthValT   *template.Template
    20  	requiredValT *template.Template
    21  )
    22  
    23  //  init instantiates the templates.
    24  func init() {
    25  	var err error
    26  	fm := template.FuncMap{
    27  		"tabs":     Tabs,
    28  		"slice":    toSlice,
    29  		"oneof":    oneof,
    30  		"constant": constant,
    31  		"goifyAtt": GoifyAtt,
    32  		"add":      Add,
    33  	}
    34  	if enumValT, err = template.New("enum").Funcs(fm).Parse(enumValTmpl); err != nil {
    35  		panic(err)
    36  	}
    37  	if formatValT, err = template.New("format").Funcs(fm).Parse(formatValTmpl); err != nil {
    38  		panic(err)
    39  	}
    40  	if patternValT, err = template.New("pattern").Funcs(fm).Parse(patternValTmpl); err != nil {
    41  		panic(err)
    42  	}
    43  	if minMaxValT, err = template.New("minMax").Funcs(fm).Parse(minMaxValTmpl); err != nil {
    44  		panic(err)
    45  	}
    46  	if lengthValT, err = template.New("length").Funcs(fm).Parse(lengthValTmpl); err != nil {
    47  		panic(err)
    48  	}
    49  	if requiredValT, err = template.New("required").Funcs(fm).Parse(requiredValTmpl); err != nil {
    50  		panic(err)
    51  	}
    52  }
    53  
    54  // Validator is the code generator for the 'Validate' type methods.
    55  type Validator struct {
    56  	arrayValT *template.Template
    57  	hashValT  *template.Template
    58  	userValT  *template.Template
    59  	seen      map[string]*bytes.Buffer
    60  }
    61  
    62  // NewValidator instantiates a validate code generator.
    63  func NewValidator() *Validator {
    64  	var (
    65  		v   = &Validator{seen: make(map[string]*bytes.Buffer)}
    66  		err error
    67  	)
    68  	fm := template.FuncMap{
    69  		"tabs":             Tabs,
    70  		"slice":            toSlice,
    71  		"oneof":            oneof,
    72  		"constant":         constant,
    73  		"goifyAtt":         GoifyAtt,
    74  		"add":              Add,
    75  		"recurseAttribute": v.recurseAttribute,
    76  	}
    77  	v.arrayValT, err = template.New("array").Funcs(fm).Parse(arrayValTmpl)
    78  	if err != nil {
    79  		panic(err)
    80  	}
    81  	v.hashValT, err = template.New("hash").Funcs(fm).Parse(hashValTmpl)
    82  	if err != nil {
    83  		panic(err)
    84  	}
    85  	v.userValT, err = template.New("user").Funcs(fm).Parse(userValTmpl)
    86  	if err != nil {
    87  		panic(err)
    88  	}
    89  	return v
    90  }
    91  
    92  // Code produces Go code that runs the validation checks recursively over the given attribute.
    93  func (v *Validator) Code(att *design.AttributeDefinition, nonzero, required, hasDefault bool, target, context string, depth int, private bool) string {
    94  	buf := v.recurse(att, nonzero, required, hasDefault, target, context, depth, private)
    95  	return buf.String()
    96  }
    97  
    98  func (v *Validator) arrayValCode(att *design.AttributeDefinition, nonzero, required, hasDefault bool, target, context string, depth int, private bool) []byte {
    99  	a := att.Type.ToArray()
   100  	if a == nil {
   101  		return nil
   102  	}
   103  
   104  	var buf bytes.Buffer
   105  
   106  	// Perform any validation on the array type such as MinLength, MaxLength, etc.
   107  	validation := ValidationChecker(att, nonzero, required, hasDefault, target, context, depth, private)
   108  	first := true
   109  	if validation != "" {
   110  		buf.WriteString(validation)
   111  		first = false
   112  	}
   113  	val := v.Code(a.ElemType, true, false, false, "e", context+"[*]", depth+1, false)
   114  	if val != "" {
   115  		switch a.ElemType.Type.(type) {
   116  		case *design.UserTypeDefinition, *design.MediaTypeDefinition:
   117  			// For user and media types, call the Validate method
   118  			val = RunTemplate(v.userValT, map[string]interface{}{
   119  				"depth":  depth + 2,
   120  				"target": "e",
   121  			})
   122  			val = fmt.Sprintf("%sif e != nil {\n%s\n%s}", Tabs(depth+1), val, Tabs(depth+1))
   123  		}
   124  		data := map[string]interface{}{
   125  			"elemType":   a.ElemType,
   126  			"context":    context,
   127  			"target":     target,
   128  			"depth":      1,
   129  			"private":    private,
   130  			"validation": val,
   131  		}
   132  		validation = RunTemplate(v.arrayValT, data)
   133  		if !first {
   134  			buf.WriteByte('\n')
   135  		} else {
   136  			first = false
   137  		}
   138  		buf.WriteString(validation)
   139  	}
   140  	return buf.Bytes()
   141  }
   142  
   143  func (v *Validator) hashValCode(att *design.AttributeDefinition, nonzero, required, hasDefault bool, target, context string, depth int, private bool) []byte {
   144  	h := att.Type.ToHash()
   145  	if h == nil {
   146  		return nil
   147  	}
   148  
   149  	var buf bytes.Buffer
   150  
   151  	// Perform any validation on the hash type such as MinLength, MaxLength, etc.
   152  	validation := ValidationChecker(att, nonzero, required, hasDefault, target, context, depth, private)
   153  	first := true
   154  	if validation != "" {
   155  		buf.WriteString(validation)
   156  		first = false
   157  	}
   158  	keyVal := v.Code(h.KeyType, true, false, false, "k", context+"[*]", depth+1, false)
   159  	if keyVal != "" {
   160  		switch h.KeyType.Type.(type) {
   161  		case *design.UserTypeDefinition, *design.MediaTypeDefinition:
   162  			// For user and media types, call the Validate method
   163  			keyVal = RunTemplate(v.userValT, map[string]interface{}{
   164  				"depth":  depth + 2,
   165  				"target": "k",
   166  			})
   167  			keyVal = fmt.Sprintf("%sif e != nil {\n%s\n%s}", Tabs(depth+1), keyVal, Tabs(depth+1))
   168  		}
   169  	}
   170  	elemVal := v.Code(h.ElemType, true, false, false, "e", context+"[*]", depth+1, false)
   171  	if elemVal != "" {
   172  		switch h.ElemType.Type.(type) {
   173  		case *design.UserTypeDefinition, *design.MediaTypeDefinition:
   174  			// For user and media types, call the Validate method
   175  			elemVal = RunTemplate(v.userValT, map[string]interface{}{
   176  				"depth":  depth + 2,
   177  				"target": "e",
   178  			})
   179  			elemVal = fmt.Sprintf("%sif e != nil {\n%s\n%s}", Tabs(depth+1), elemVal, Tabs(depth+1))
   180  		}
   181  	}
   182  	if keyVal != "" || elemVal != "" {
   183  		data := map[string]interface{}{
   184  			"depth":          1,
   185  			"target":         target,
   186  			"keyValidation":  keyVal,
   187  			"elemValidation": elemVal,
   188  		}
   189  		validation = RunTemplate(v.hashValT, data)
   190  		if !first {
   191  			buf.WriteByte('\n')
   192  		} else {
   193  			first = false
   194  		}
   195  		buf.WriteString(validation)
   196  	}
   197  	return buf.Bytes()
   198  }
   199  
   200  func (v *Validator) recurse(att *design.AttributeDefinition, nonzero, required, hasDefault bool, target, context string, depth int, private bool) *bytes.Buffer {
   201  	var (
   202  		buf   = new(bytes.Buffer)
   203  		first = true
   204  	)
   205  
   206  	// Break infinite recursions
   207  	switch dt := att.Type.(type) {
   208  	case *design.MediaTypeDefinition:
   209  		if buf, ok := v.seen[dt.TypeName]; ok {
   210  			return buf
   211  		}
   212  		v.seen[dt.TypeName] = buf
   213  	case *design.UserTypeDefinition:
   214  		if buf, ok := v.seen[dt.TypeName]; ok {
   215  			return buf
   216  		}
   217  		v.seen[dt.TypeName] = buf
   218  	}
   219  
   220  	if o := att.Type.ToObject(); o != nil {
   221  		if ds, ok := att.Type.(design.DataStructure); ok {
   222  			att = ds.Definition()
   223  		}
   224  		validation := ValidationChecker(att, nonzero, required, hasDefault, target, context, depth, private)
   225  		if validation != "" {
   226  			buf.WriteString(validation)
   227  			first = false
   228  		}
   229  		o.IterateAttributes(func(n string, catt *design.AttributeDefinition) error {
   230  			validation := v.recurseAttribute(att, catt, n, target, context, depth, private)
   231  			if validation != "" {
   232  				if !first {
   233  					buf.WriteByte('\n')
   234  				} else {
   235  					first = false
   236  				}
   237  				buf.WriteString(validation)
   238  			}
   239  			return nil
   240  		})
   241  	} else if a := att.Type.ToArray(); a != nil {
   242  		buf.Write(v.arrayValCode(att, nonzero, required, hasDefault, target, context, depth, private))
   243  	} else if h := att.Type.ToHash(); h != nil {
   244  		buf.Write(v.hashValCode(att, nonzero, required, hasDefault, target, context, depth, private))
   245  	} else {
   246  		validation := ValidationChecker(att, nonzero, required, hasDefault, target, context, depth, private)
   247  		if validation != "" {
   248  			buf.WriteString(validation)
   249  		}
   250  	}
   251  	return buf
   252  }
   253  
   254  func (v *Validator) recurseAttribute(att, catt *design.AttributeDefinition, n, target, context string, depth int, private bool) string {
   255  	var validation string
   256  	if ds, ok := catt.Type.(design.DataStructure); ok {
   257  		// We need to check empirically whether there are validations to be
   258  		// generated, we can't just generate and check whether something was
   259  		// generated to avoid infinite recursions.
   260  		hasValidations := false
   261  		done := errors.New("done")
   262  		ds.Walk(func(a *design.AttributeDefinition) error {
   263  			if a.Validation != nil {
   264  				if private {
   265  					hasValidations = true
   266  					return done
   267  				}
   268  				// For public data structures there is a case where
   269  				// there is validation but no actual validation
   270  				// code: if the validation is a required validation
   271  				// that applies to attributes that cannot be nil or
   272  				// empty string i.e. primitive types other than
   273  				// string.
   274  				if !a.Validation.HasRequiredOnly() {
   275  					hasValidations = true
   276  					return done
   277  				}
   278  				for _, name := range a.Validation.Required {
   279  					att := a.Type.ToObject()[name]
   280  					if att != nil && (!att.Type.IsPrimitive() || att.Type.Kind() == design.StringKind) {
   281  						hasValidations = true
   282  						return done
   283  					}
   284  				}
   285  			}
   286  			return nil
   287  		})
   288  		if hasValidations {
   289  			validation = RunTemplate(v.userValT, map[string]interface{}{
   290  				"depth":  depth,
   291  				"target": fmt.Sprintf("%s.%s", target, GoifyAtt(catt, n, true)),
   292  			})
   293  		}
   294  	} else {
   295  		dp := depth
   296  		if catt.Type.IsObject() {
   297  			dp++
   298  		}
   299  		validation = v.recurse(
   300  			catt,
   301  			att.IsNonZero(n),
   302  			att.IsRequired(n),
   303  			att.HasDefaultValue(n),
   304  			fmt.Sprintf("%s.%s", target, GoifyAtt(catt, n, true)),
   305  			fmt.Sprintf("%s.%s", context, n),
   306  			dp,
   307  			private,
   308  		).String()
   309  	}
   310  	if validation != "" {
   311  		if catt.Type.IsObject() {
   312  			validation = fmt.Sprintf("%sif %s.%s != nil {\n%s\n%s}",
   313  				Tabs(depth), target, GoifyAtt(catt, n, true), validation, Tabs(depth))
   314  		}
   315  	}
   316  	return validation
   317  }
   318  
   319  // ValidationChecker produces Go code that runs the validation defined in the given attribute
   320  // definition against the content of the variable named target recursively.
   321  // context is used to keep track of recursion to produce helpful error messages in case of type
   322  // validation error.
   323  // The generated code assumes that there is a pre-existing "err" variable of type
   324  // error. It initializes that variable in case a validation fails.
   325  // Note: we do not want to recurse here, recursion is done by the marshaler/unmarshaler code.
   326  func ValidationChecker(att *design.AttributeDefinition, nonzero, required, hasDefault bool, target, context string, depth int, private bool) string {
   327  	if att.Validation == nil {
   328  		return ""
   329  	}
   330  	t := target
   331  	isPointer := private || (!required && !hasDefault && !nonzero)
   332  	if isPointer && att.Type.IsPrimitive() {
   333  		t = "*" + t
   334  	}
   335  	data := map[string]interface{}{
   336  		"attribute": att,
   337  		"isPointer": private || isPointer,
   338  		"nonzero":   nonzero,
   339  		"context":   context,
   340  		"target":    target,
   341  		"targetVal": t,
   342  		"string":    att.Type.Kind() == design.StringKind,
   343  		"array":     att.Type.IsArray(),
   344  		"hash":      att.Type.IsHash(),
   345  		"depth":     depth,
   346  		"private":   private,
   347  	}
   348  	res := validationsCode(att, data)
   349  	return strings.Join(res, "\n")
   350  }
   351  
   352  func validationsCode(att *design.AttributeDefinition, data map[string]interface{}) (res []string) {
   353  	validation := att.Validation
   354  	if values := validation.Values; values != nil {
   355  		data["values"] = values
   356  		if val := RunTemplate(enumValT, data); val != "" {
   357  			res = append(res, val)
   358  		}
   359  	}
   360  	if format := validation.Format; format != "" {
   361  		data["format"] = format
   362  		if val := RunTemplate(formatValT, data); val != "" {
   363  			res = append(res, val)
   364  		}
   365  	}
   366  	if pattern := validation.Pattern; pattern != "" {
   367  		data["pattern"] = pattern
   368  		if val := RunTemplate(patternValT, data); val != "" {
   369  			res = append(res, val)
   370  		}
   371  	}
   372  	if min := validation.Minimum; min != nil {
   373  		if att.Type == design.Integer {
   374  			data["min"] = renderInteger(*min)
   375  		} else {
   376  			data["min"] = fmt.Sprintf("%f", *min)
   377  		}
   378  		data["isMin"] = true
   379  		delete(data, "max")
   380  		if val := RunTemplate(minMaxValT, data); val != "" {
   381  			res = append(res, val)
   382  		}
   383  	}
   384  	if max := validation.Maximum; max != nil {
   385  		if att.Type == design.Integer {
   386  			data["max"] = renderInteger(*max)
   387  		} else {
   388  			data["max"] = fmt.Sprintf("%f", *max)
   389  		}
   390  		data["isMin"] = false
   391  		delete(data, "min")
   392  		if val := RunTemplate(minMaxValT, data); val != "" {
   393  			res = append(res, val)
   394  		}
   395  	}
   396  	if minLength := validation.MinLength; minLength != nil {
   397  		data["minLength"] = minLength
   398  		data["isMinLength"] = true
   399  		delete(data, "maxLength")
   400  		if val := RunTemplate(lengthValT, data); val != "" {
   401  			res = append(res, val)
   402  		}
   403  	}
   404  	if maxLength := validation.MaxLength; maxLength != nil {
   405  		data["maxLength"] = maxLength
   406  		data["isMinLength"] = false
   407  		delete(data, "minLength")
   408  		if val := RunTemplate(lengthValT, data); val != "" {
   409  			res = append(res, val)
   410  		}
   411  	}
   412  	if required := validation.Required; len(required) > 0 {
   413  		var val string
   414  		for i, r := range required {
   415  			if i > 0 {
   416  				val += "\n"
   417  			}
   418  			data["required"] = r
   419  			val += RunTemplate(requiredValT, data)
   420  		}
   421  		res = append(res, val)
   422  	}
   423  	return
   424  }
   425  
   426  // renderInteger renders a max or min value properly, taking into account
   427  // overflows due to casting from a float value.
   428  func renderInteger(f float64) string {
   429  	if f > math.Nextafter(float64(math.MaxInt64), 0) {
   430  		return fmt.Sprintf("%d", int64(math.MaxInt64))
   431  	}
   432  	if f < math.Nextafter(float64(math.MinInt64), 0) {
   433  		return fmt.Sprintf("%d", int64(math.MinInt64))
   434  	}
   435  	return fmt.Sprintf("%d", int64(f))
   436  }
   437  
   438  // oneof produces code that compares target with each element of vals and ORs
   439  // the result, e.g. "target == 1 || target == 2".
   440  func oneof(target string, vals []interface{}) string {
   441  	elems := make([]string, len(vals))
   442  	for i, v := range vals {
   443  		elems[i] = fmt.Sprintf("%s == %#v", target, v)
   444  	}
   445  	return strings.Join(elems, " || ")
   446  }
   447  
   448  // constant returns the Go constant name of the format with the given value.
   449  func constant(formatName string) string {
   450  	switch formatName {
   451  	case "date":
   452  		return "goa.FormatDate"
   453  	case "date-time":
   454  		return "goa.FormatDateTime"
   455  	case "email":
   456  		return "goa.FormatEmail"
   457  	case "hostname":
   458  		return "goa.FormatHostname"
   459  	case "ipv4":
   460  		return "goa.FormatIPv4"
   461  	case "ipv6":
   462  		return "goa.FormatIPv6"
   463  	case "ip":
   464  		return "goa.FormatIP"
   465  	case "uri":
   466  		return "goa.FormatURI"
   467  	case "mac":
   468  		return "goa.FormatMAC"
   469  	case "cidr":
   470  		return "goa.FormatCIDR"
   471  	case "regexp":
   472  		return "goa.FormatRegexp"
   473  	case "rfc1123":
   474  		return "goa.FormatRFC1123"
   475  	}
   476  	panic("unknown format") // bug
   477  }
   478  
   479  const (
   480  	arrayValTmpl = `{{ tabs .depth }}for _, e := range {{ .target }} {
   481  {{ .validation }}
   482  {{ tabs .depth }}}`
   483  
   484  	hashValTmpl = `{{ tabs .depth }}for {{ if .keyValidation }}k{{ else }}_{{ end }}, {{ if .elemValidation }}e{{ else }}_{{ end }} := range {{ .target }} {
   485  {{- if .keyValidation }}
   486  {{ .keyValidation }}{{ end }}{{ if .elemValidation }}
   487  {{ .elemValidation }}{{ end }}
   488  {{ tabs .depth }}}`
   489  
   490  	userValTmpl = `{{ tabs .depth }}if err2 := {{ .target }}.Validate(); err2 != nil {
   491  {{ tabs .depth }}	err = goa.MergeErrors(err, err2)
   492  {{ tabs .depth }}}`
   493  
   494  	enumValTmpl = `{{ $depth := or (and .isPointer (add .depth 1)) .depth }}{{/*
   495  */}}{{ if .isPointer }}{{ tabs .depth }}if {{ .target }} != nil {
   496  {{ end }}{{ tabs $depth }}if !({{ oneof .targetVal .values }}) {
   497  {{ tabs $depth }}	err = goa.MergeErrors(err, goa.InvalidEnumValueError(` + "`" + `{{ .context }}` + "`" + `, {{ .targetVal }}, {{ slice .values }}))
   498  {{ if .isPointer }}{{ tabs $depth }}}
   499  {{ end }}{{ tabs .depth }}}`
   500  
   501  	patternValTmpl = `{{ $depth := or (and .isPointer (add .depth 1)) .depth }}{{/*
   502  */}}{{ if .isPointer }}{{ tabs .depth }}if {{ .target }} != nil {
   503  {{ end }}{{ tabs $depth }}if ok := goa.ValidatePattern(` + "`{{ .pattern }}`" + `, {{ .targetVal }}); !ok {
   504  {{ tabs $depth }}	err = goa.MergeErrors(err, goa.InvalidPatternError(` + "`" + `{{ .context }}` + "`" + `, {{ .targetVal }}, ` + "`{{ .pattern }}`" + `))
   505  {{ tabs $depth }}}{{ if .isPointer }}
   506  {{ tabs .depth }}}{{ end }}`
   507  
   508  	formatValTmpl = `{{ $depth := or (and .isPointer (add .depth 1)) .depth }}{{/*
   509  */}}{{ if .isPointer }}{{ tabs .depth }}if {{ .target }} != nil {
   510  {{ end }}{{ tabs $depth }}if err2 := goa.ValidateFormat({{ constant .format }}, {{ .targetVal }}); err2 != nil {
   511  {{ tabs $depth }}		err = goa.MergeErrors(err, goa.InvalidFormatError(` + "`" + `{{ .context }}` + "`" + `, {{ .targetVal }}, {{ constant .format }}, err2))
   512  {{ if .isPointer }}{{ tabs $depth }}}
   513  {{ end }}{{ tabs .depth }}}`
   514  
   515  	minMaxValTmpl = `{{ $depth := or (and .isPointer (add .depth 1)) .depth }}{{/*
   516  */}}{{ if .isPointer }}{{ tabs .depth }}if {{ .target }} != nil {
   517  {{ end }}{{ tabs .depth }}	if {{ .targetVal }} {{ if .isMin }}<{{ else }}>{{ end }} {{ if .isMin }}{{ .min }}{{ else }}{{ .max }}{{ end }} {
   518  {{ tabs $depth }}	err = goa.MergeErrors(err, goa.InvalidRangeError(` + "`" + `{{ .context }}` + "`" + `, {{ .targetVal }}, {{ if .isMin }}{{ .min }}, true{{ else }}{{ .max }}, false{{ end }}))
   519  {{ if .isPointer }}{{ tabs $depth }}}
   520  {{ end }}{{ tabs .depth }}}`
   521  
   522  	lengthValTmpl = `{{$depth := or (and .isPointer (add .depth 1)) .depth}}{{/*
   523  */}}{{$target := or (and (or (or .array .hash) .nonzero) .target) .targetVal}}{{/*
   524  */}}{{if .isPointer}}{{tabs .depth}}if {{.target}} != nil {
   525  {{end}}{{tabs .depth}}	if {{if .string}}utf8.RuneCountInString({{$target}}){{else}}len({{$target}}){{end}} {{if .isMinLength}}<{{else}}>{{end}} {{if .isMinLength}}{{.minLength}}{{else}}{{.maxLength}}{{end}} {
   526  {{tabs $depth}}	err = goa.MergeErrors(err, goa.InvalidLengthError(` + "`" + `{{.context}}` + "`" + `, {{$target}}, {{if .string}}utf8.RuneCountInString({{$target}}){{else}}len({{$target}}){{end}}, {{if .isMinLength}}{{.minLength}}, true{{else}}{{.maxLength}}, false{{end}}))
   527  {{if .isPointer}}{{tabs $depth}}}
   528  {{end}}{{tabs .depth}}}`
   529  
   530  	requiredValTmpl = `{{ $att := index $.attribute.Type.ToObject .required }}{{/*
   531  */}}{{ if and (not $.private) (eq $att.Type.Kind 4) }}{{ tabs $.depth }}if {{ $.target }}.{{ goifyAtt $att .required true }} == "" {
   532  {{ tabs $.depth }}	err = goa.MergeErrors(err, goa.MissingAttributeError(` + "`" + `{{ $.context }}` + "`" + `, "{{  .required  }}"))
   533  {{ tabs $.depth }}}{{ else if or $.private (not $att.Type.IsPrimitive) }}{{ tabs $.depth }}if {{ $.target }}.{{ goifyAtt $att .required true }} == nil {
   534  {{ tabs $.depth }}	err = goa.MergeErrors(err, goa.MissingAttributeError(` + "`" + `{{ $.context }}` + "`" + `, "{{ .required }}"))
   535  {{ tabs $.depth }}}{{ end }}`
   536  )