github.com/ManabuSeki/goa-v1@v1.4.3/goagen/gen_app/test_generator.go (about)

     1  package genapp
     2  
     3  import (
     4  	"fmt"
     5  	"net/http"
     6  	"os"
     7  	"path/filepath"
     8  	"sort"
     9  	"strconv"
    10  	"strings"
    11  	"text/template"
    12  
    13  	"github.com/goadesign/goa/design"
    14  	"github.com/goadesign/goa/goagen/codegen"
    15  )
    16  
    17  func makeTestDir(g *Generator, apiName string) (outDir string, err error) {
    18  	outDir = filepath.Join(g.OutDir, "test")
    19  	if err = os.RemoveAll(outDir); err != nil {
    20  		return
    21  	}
    22  	if err = os.MkdirAll(outDir, 0755); err != nil {
    23  		return
    24  	}
    25  	g.genfiles = append(g.genfiles, outDir)
    26  	return
    27  }
    28  
    29  // TestMethod structure
    30  type TestMethod struct {
    31  	Name              string
    32  	Comment           string
    33  	ResourceName      string
    34  	ActionName        string
    35  	ControllerName    string
    36  	ContextVarName    string
    37  	ContextType       string
    38  	RouteVerb         string
    39  	FullPath          string
    40  	Status            int
    41  	ReturnType        *ObjectType
    42  	ReturnsErrorMedia bool
    43  	Params            []*ObjectType
    44  	QueryParams       []*ObjectType
    45  	Headers           []*ObjectType
    46  	Payload           *ObjectType
    47  	reservedNames     map[string]bool
    48  }
    49  
    50  // Escape escapes given string.
    51  func (t *TestMethod) Escape(s string) string {
    52  	if ok := t.reservedNames[s]; ok {
    53  		s = t.Escape("_" + s)
    54  	}
    55  	t.reservedNames[s] = true
    56  	return s
    57  }
    58  
    59  // ObjectType structure
    60  type ObjectType struct {
    61  	Label       string
    62  	Name        string
    63  	Type        string
    64  	Pointer     string
    65  	Validatable bool
    66  }
    67  
    68  func (g *Generator) generateResourceTest() error {
    69  	if len(g.API.Resources) == 0 {
    70  		return nil
    71  	}
    72  	funcs := template.FuncMap{
    73  		"isSlice": isSlice,
    74  	}
    75  	testTmpl := template.Must(template.New("test").Funcs(funcs).Parse(testTmpl))
    76  	outDir, err := makeTestDir(g, g.API.Name)
    77  	if err != nil {
    78  		return err
    79  	}
    80  	appPkg, err := codegen.PackagePath(g.OutDir)
    81  	if err != nil {
    82  		return err
    83  	}
    84  	imports := []*codegen.ImportSpec{
    85  		codegen.SimpleImport("bytes"),
    86  		codegen.SimpleImport("fmt"),
    87  		codegen.SimpleImport("io"),
    88  		codegen.SimpleImport("log"),
    89  		codegen.SimpleImport("net/http"),
    90  		codegen.SimpleImport("net/http/httptest"),
    91  		codegen.SimpleImport("net/url"),
    92  		codegen.SimpleImport("strconv"),
    93  		codegen.SimpleImport("strings"),
    94  		codegen.SimpleImport("time"),
    95  		codegen.SimpleImport(appPkg),
    96  		codegen.SimpleImport("github.com/goadesign/goa"),
    97  		codegen.SimpleImport("github.com/goadesign/goa/goatest"),
    98  		codegen.SimpleImport("context"),
    99  		codegen.NewImport("uuid", "github.com/gofrs/uuid"),
   100  	}
   101  
   102  	return g.API.IterateResources(func(res *design.ResourceDefinition) (err error) {
   103  		filename := filepath.Join(outDir, codegen.SnakeCase(res.Name)+"_testing.go")
   104  		var file *codegen.SourceFile
   105  		file, err = codegen.SourceFileFor(filename)
   106  		if err != nil {
   107  			return err
   108  		}
   109  		defer func() {
   110  			file.Close()
   111  			if err == nil {
   112  				err = file.FormatCode()
   113  			}
   114  		}()
   115  		title := fmt.Sprintf("%s: %s TestHelpers", g.API.Context(), res.Name)
   116  		if err = file.WriteHeader(title, "test", imports); err != nil {
   117  			return err
   118  		}
   119  
   120  		var methods []*TestMethod
   121  
   122  		if err = res.IterateActions(func(action *design.ActionDefinition) error {
   123  			if err := action.IterateResponses(func(response *design.ResponseDefinition) error {
   124  				if response.Status == 101 { // SwitchingProtocols, Don't currently handle WebSocket endpoints
   125  					return nil
   126  				}
   127  				for routeIndex, route := range action.Routes {
   128  					mediaType := design.Design.MediaTypeWithIdentifier(response.MediaType)
   129  					if mediaType == nil {
   130  						methods = append(methods, g.createTestMethod(res, action, response, route, routeIndex, nil, nil))
   131  					} else {
   132  						if err := mediaType.IterateViews(func(view *design.ViewDefinition) error {
   133  							methods = append(methods, g.createTestMethod(res, action, response, route, routeIndex, mediaType, view))
   134  							return nil
   135  						}); err != nil {
   136  							return err
   137  						}
   138  					}
   139  				}
   140  				return nil
   141  			}); err != nil {
   142  				return err
   143  			}
   144  			return nil
   145  		}); err != nil {
   146  			return err
   147  		}
   148  		g.genfiles = append(g.genfiles, filename)
   149  		err = testTmpl.Execute(file, methods)
   150  		return
   151  	})
   152  }
   153  
   154  func (g *Generator) createTestMethod(resource *design.ResourceDefinition, action *design.ActionDefinition,
   155  	response *design.ResponseDefinition, route *design.RouteDefinition, routeIndex int,
   156  	mediaType *design.MediaTypeDefinition, view *design.ViewDefinition) *TestMethod {
   157  
   158  	var (
   159  		actionName, ctrlName, varName                string
   160  		routeQualifier, viewQualifier, respQualifier string
   161  		comment                                      string
   162  		path                                         []*ObjectType
   163  		query                                        []*ObjectType
   164  		header                                       []*ObjectType
   165  		returnType                                   *ObjectType
   166  		payload                                      *ObjectType
   167  	)
   168  
   169  	actionName = codegen.Goify(action.Name, true)
   170  	ctrlName = codegen.Goify(resource.Name, true)
   171  	varName = codegen.Goify(action.Name, false)
   172  	routeQualifier = suffixRoute(action.Routes, routeIndex)
   173  	if view != nil && view.Name != "default" {
   174  		viewQualifier = codegen.Goify(view.Name, true)
   175  	}
   176  	respQualifier = codegen.Goify(response.Name, true)
   177  	hasReturnValue := view != nil && mediaType != nil
   178  
   179  	if hasReturnValue {
   180  		p, _, err := mediaType.Project(view.Name)
   181  		if err != nil {
   182  			panic(err) // bug
   183  		}
   184  		tmp := codegen.GoTypeName(p, nil, 0, false)
   185  		if !p.IsError() {
   186  			tmp = fmt.Sprintf("%s.%s", g.Target, tmp)
   187  		}
   188  		validate := g.validator.Code(p.AttributeDefinition, false, false, false, "payload", "raw", 1, false)
   189  		returnType = &ObjectType{}
   190  		returnType.Type = tmp
   191  		if p.IsObject() && !p.IsError() {
   192  			returnType.Pointer = "*"
   193  		}
   194  		returnType.Validatable = validate != ""
   195  	}
   196  
   197  	comment = "runs the method " + actionName + " of the given controller with the given parameters"
   198  	if action.Payload != nil {
   199  		comment += " and payload"
   200  	}
   201  	comment += ".\n// It returns the response writer so it's possible to inspect the response headers"
   202  	if hasReturnValue {
   203  		comment += " and the media type struct written to the response"
   204  	}
   205  	comment += "."
   206  
   207  	path = pathParams(action, route)
   208  	query = queryParams(action)
   209  	header = headers(action, resource.Headers)
   210  
   211  	if action.Payload != nil {
   212  		payload = &ObjectType{}
   213  		payload.Name = "payload"
   214  		payload.Type = fmt.Sprintf("%s.%s", g.Target, codegen.Goify(action.Payload.TypeName, true))
   215  		if !action.Payload.IsPrimitive() && !action.Payload.IsArray() && !action.Payload.IsHash() {
   216  			payload.Pointer = "*"
   217  		}
   218  
   219  		validate := g.validator.Code(action.Payload.AttributeDefinition, false, false, false, "payload", "raw", 1, false)
   220  		if validate != "" {
   221  			payload.Validatable = true
   222  		}
   223  	}
   224  
   225  	return &TestMethod{
   226  		Name:              fmt.Sprintf("%s%s%s%s%s", actionName, ctrlName, respQualifier, routeQualifier, viewQualifier),
   227  		ActionName:        actionName,
   228  		ResourceName:      ctrlName,
   229  		Comment:           comment,
   230  		Params:            path,
   231  		QueryParams:       query,
   232  		Headers:           header,
   233  		Payload:           payload,
   234  		ReturnType:        returnType,
   235  		ReturnsErrorMedia: mediaType == design.ErrorMedia,
   236  		ControllerName:    fmt.Sprintf("%s.%sController", g.Target, ctrlName),
   237  		ContextVarName:    fmt.Sprintf("%sCtx", varName),
   238  		ContextType:       fmt.Sprintf("%s.New%s%sContext", g.Target, actionName, ctrlName),
   239  		RouteVerb:         route.Verb,
   240  		Status:            response.Status,
   241  		FullPath:          goPathFormat(route.FullPath()),
   242  		reservedNames:     reservedNames(path, query, header, payload, returnType),
   243  	}
   244  }
   245  
   246  // pathParams returns the path params for the given action and route.
   247  func pathParams(action *design.ActionDefinition, route *design.RouteDefinition) []*ObjectType {
   248  	return paramFromNames(action, route.Params())
   249  }
   250  
   251  // headers builds the template data structure needed to proprely render the code
   252  // for setting the headers for the given action.
   253  func headers(action *design.ActionDefinition, headers *design.AttributeDefinition) []*ObjectType {
   254  	hds := &design.AttributeDefinition{
   255  		Type: design.Object{},
   256  	}
   257  	if headers != nil {
   258  		hds.Merge(headers)
   259  		hds.Validation = headers.Validation
   260  	}
   261  	if action.Headers != nil {
   262  		hds.Merge(action.Headers)
   263  		hds.Validation = action.Headers.Validation
   264  	}
   265  
   266  	if hds == nil {
   267  		return nil
   268  	}
   269  	var headrs []string
   270  	for header := range hds.Type.ToObject() {
   271  		headrs = append(headrs, header)
   272  	}
   273  	sort.Strings(headrs)
   274  	objs := make([]*ObjectType, len(headrs))
   275  	for i, name := range headrs {
   276  		objs[i] = attToObject(name, hds, hds.Type.ToObject()[name])
   277  		objs[i].Label = http.CanonicalHeaderKey(objs[i].Label)
   278  	}
   279  	return objs
   280  }
   281  
   282  // queryParams returns the query string params for the given action.
   283  func queryParams(action *design.ActionDefinition) []*ObjectType {
   284  	var qparams []string
   285  	if qps := action.QueryParams; qps != nil {
   286  		for pname := range qps.Type.ToObject() {
   287  			qparams = append(qparams, pname)
   288  		}
   289  	}
   290  	sort.Strings(qparams)
   291  	return paramFromNames(action, qparams)
   292  }
   293  
   294  func paramFromNames(action *design.ActionDefinition, names []string) (params []*ObjectType) {
   295  	obj := action.Params.Type.ToObject()
   296  	for _, name := range names {
   297  		params = append(params, attToObject(name, action.Params, obj[name]))
   298  	}
   299  	return
   300  }
   301  
   302  func reservedNames(params, queryParams, headers []*ObjectType, payload, returnType *ObjectType) map[string]bool {
   303  	var names = make(map[string]bool)
   304  	for _, param := range params {
   305  		names[param.Name] = true
   306  	}
   307  	for _, param := range queryParams {
   308  		names[param.Name] = true
   309  	}
   310  	for _, header := range headers {
   311  		names[header.Name] = true
   312  	}
   313  	if payload != nil {
   314  		names[payload.Name] = true
   315  	}
   316  	if returnType != nil {
   317  		names[returnType.Name] = true
   318  	}
   319  	return names
   320  }
   321  
   322  func attToObject(name string, parent, att *design.AttributeDefinition) *ObjectType {
   323  	obj := &ObjectType{}
   324  	obj.Label = name
   325  	obj.Name = codegen.Goify(name, false)
   326  	obj.Type = codegen.GoTypeRef(att.Type, nil, 0, false)
   327  	if att.Type.IsPrimitive() && parent.IsPrimitivePointer(name) {
   328  		obj.Pointer = "*"
   329  	}
   330  	return obj
   331  }
   332  
   333  func goPathFormat(path string) string {
   334  	return design.WildcardRegex.ReplaceAllLiteralString(path, "/%v")
   335  }
   336  
   337  func suffixRoute(routes []*design.RouteDefinition, currIndex int) string {
   338  	if len(routes) > 1 && currIndex > 0 {
   339  		return strconv.Itoa(currIndex)
   340  	}
   341  	return ""
   342  }
   343  
   344  func isSlice(typeName string) bool {
   345  	return strings.HasPrefix(typeName, "[]")
   346  }
   347  
   348  var convertParamTmpl = `{{ if eq .Type "string" }}		sliceVal := []string{ {{ if .Pointer }}*{{ end }}{{ .Name }}}{{/*
   349  */}}{{ else if eq .Type "int" }}		sliceVal := []string{strconv.Itoa({{ if .Pointer }}*{{ end }}{{ .Name }})}{{/*
   350  */}}{{ else if eq .Type "[]string" }}		sliceVal := {{ .Name }}{{/*
   351  */}}{{ else if (isSlice .Type) }}		sliceVal := make([]string, len({{ .Name }}))
   352  		for i, v := range {{ .Name }} {
   353  			sliceVal[i] = fmt.Sprintf("%v", v)
   354  		}{{/*
   355  */}}{{ else if eq .Type "time.Time" }}		sliceVal := []string{ {{ if .Pointer }}(*{{ end }}{{ .Name }}{{ if .Pointer }}){{ end }}.Format(time.RFC3339)}{{/*
   356  */}}{{ else }}		sliceVal := []string{fmt.Sprintf("%v", {{ if .Pointer }}*{{ end }}{{ .Name }})}{{ end }}`
   357  
   358  var testTmpl = `{{ define "convertParam" }}` + convertParamTmpl + `{{ end }}` + `
   359  {{ range $test := . }}
   360  // {{ $test.Name }} {{ $test.Comment }}
   361  // If ctx is nil then context.Background() is used.
   362  // If service is nil then a default service is created.
   363  func {{ $test.Name }}(t goatest.TInterface, ctx context.Context, service *goa.Service, ctrl {{ $test.ControllerName}}{{/*
   364  */}}{{ range $param := $test.Params }}, {{ $param.Name }} {{ $param.Pointer }}{{ $param.Type }}{{ end }}{{/*
   365  */}}{{ range $param := $test.QueryParams }}, {{ $param.Name }} {{ $param.Pointer }}{{ $param.Type }}{{ end }}{{/*
   366  */}}{{ range $header := $test.Headers }}, {{ $header.Name }} {{ $header.Pointer }}{{ $header.Type }}{{ end }}{{/*
   367  */}}{{ if $test.Payload }}, {{ $test.Payload.Name }} {{ $test.Payload.Pointer }}{{ $test.Payload.Type }}{{ end }}){{/*
   368  */}} (http.ResponseWriter{{ if $test.ReturnType }}, {{ $test.ReturnType.Pointer }}{{ $test.ReturnType.Type }}{{ end }}) {
   369  	// Setup service
   370  	var (
   371  		{{ $logBuf := $test.Escape "logBuf" }}{{ $logBuf }} bytes.Buffer
   372  		{{ $resp := $test.Escape "resp" }}{{ if $test.ReturnType }}{{ $resp }}   interface{}{{ end }}
   373  
   374  		{{ $respSetter := $test.Escape "respSetter" }}{{ $respSetter }} goatest.ResponseSetterFunc = func(r interface{}) { {{ if $test.ReturnType }}{{ $resp }} = r{{ end }} }
   375  	)
   376  	if service == nil {
   377  		service = goatest.Service(&{{ $logBuf }}, {{ $respSetter }})
   378  	} else {
   379  		{{ $logger := $test.Escape "logger" }}{{ $logger }} := log.New(&{{ $logBuf }}, "", log.Ltime)
   380  		service.WithLogger(goa.NewLogger({{ $logger }}))
   381  		{{ $newEncoder := $test.Escape "newEncoder" }}{{ $newEncoder }} := func(io.Writer) goa.Encoder { return  {{ $respSetter }} }
   382  		service.Encoder = goa.NewHTTPEncoder() // Make sure the code ends up using this decoder
   383  		service.Encoder.Register({{ $newEncoder }}, "*/*")
   384  	}
   385  {{ if $test.Payload }}{{ if $test.Payload.Validatable }}
   386  	// Validate payload
   387  	{{ $err := $test.Escape "err" }}{{ $err }} := {{ $test.Payload.Name }}.Validate()
   388  	if {{ $err }} != nil {
   389  		{{ $e := $test.Escape "e" }}{{ $e }}, {{ $ok := $test.Escape "ok" }}{{ $ok }} := {{ $err }}.(goa.ServiceError)
   390  		if !{{ $ok }} {
   391  			panic({{ $err }}) // bug
   392  		}
   393  {{ if not $test.ReturnsErrorMedia }}		t.Errorf("unexpected payload validation error: %+v", {{ $e }})
   394  {{ end }}{{ if $test.ReturnType }}		return nil, {{ if $test.ReturnsErrorMedia }}{{ $e }}{{ else }}nil{{ end }}{{ else }}return nil{{ end }}
   395  	}
   396  {{ end }}{{ end }}
   397  	// Setup request context
   398  	{{ $rw := $test.Escape "rw" }}{{ $rw }} := httptest.NewRecorder()
   399  {{ $query := $test.Escape "query" }}{{ if $test.QueryParams}}	{{ $query }} := url.Values{}
   400  {{ range $param := $test.QueryParams }}{{ if $param.Pointer }}	if {{ $param.Name }} != nil {{ end }}{
   401  {{ template "convertParam" $param }}
   402  		{{ $query }}[{{ printf "%q" $param.Label }}] = sliceVal
   403  	}
   404  {{ end }}{{ end }}	{{ $u := $test.Escape "u" }}{{ $u }}:= &url.URL{
   405  		Path: fmt.Sprintf({{ printf "%q" $test.FullPath }}{{ range $param := $test.Params }}, {{ $param.Name }}{{ end }}),
   406  {{ if $test.QueryParams }}		RawQuery: {{ $query }}.Encode(),
   407  {{ end }}	}
   408  	{{ $req := $test.Escape "req" }}{{ $req }}, {{ $err := $test.Escape "err" }}{{ $err }}:= http.NewRequest("{{ $test.RouteVerb }}", {{ $u }}.String(), nil)
   409  	if {{ $err }} != nil {
   410  		panic("invalid test " + {{ $err }}.Error()) // bug
   411  	}
   412  {{ range $header := $test.Headers }}{{ if $header.Pointer }}	if {{ $header.Name }} != nil {{ end }}{
   413  {{ template "convertParam" $header }}
   414  		{{ $req }}.Header[{{ printf "%q" $header.Label }}] = sliceVal
   415  	}
   416  {{ end }} {{ $prms := $test.Escape "prms" }}{{ $prms }} := url.Values{}
   417  {{ range $param := $test.Params }}	{{ $prms }}["{{ $param.Label }}"] = []string{fmt.Sprintf("%v",{{ $param.Name}})}
   418  {{ end }}{{ range $param := $test.QueryParams }}{{ if $param.Pointer }} if {{ $param.Name }} != nil {{ end }} {
   419  {{ template "convertParam" $param }}
   420  		{{ $prms }}[{{ printf "%q" $param.Label }}] = sliceVal
   421  	}
   422  {{ end }}	if ctx == nil {
   423  		ctx = context.Background()
   424  	}
   425  	{{ $goaCtx := $test.Escape "goaCtx" }}{{ $goaCtx }} := goa.NewContext(goa.WithAction(ctx, "{{ $test.ResourceName }}Test"), {{ $rw }}, {{ $req }}, {{ $prms }})
   426  	{{ $test.ContextVarName }}, {{ $err := $test.Escape "err" }}{{ $err }} := {{ $test.ContextType }}({{ $goaCtx }}, {{ $req }}, service)
   427  	if {{ $err }} != nil {
   428  		{{ $e := $test.Escape "e" }}{{ $e }}, {{ $ok := $test.Escape "ok" }}{{ $ok }} := {{ $err }}.(goa.ServiceError)
   429  		if !{{ $ok }} {
   430  			panic("invalid test data " + {{ $err }}.Error()) // bug
   431  		}
   432  {{ if not $test.ReturnsErrorMedia }}		t.Errorf("unexpected parameter validation error: %+v", {{ $e }})
   433  {{ end }}{{ if $test.ReturnType }}		return nil, {{ if $test.ReturnsErrorMedia }}{{ $e }}{{ else }}nil{{ end }}{{ else }}return nil{{ end }}
   434  	}
   435  	{{ if $test.Payload }}{{ $test.ContextVarName }}.Payload = {{ $test.Payload.Name }}{{ end }}
   436  
   437  	// Perform action
   438  	{{ $err }} = ctrl.{{ $test.ActionName}}({{ $test.ContextVarName }})
   439  
   440  	// Validate response
   441  	if {{ $err }} != nil {
   442  		t.Fatalf("controller returned %+v, logs:\n%s", {{ $err }}, {{ $logBuf }}.String())
   443  	}
   444  	if {{ $rw }}.Code != {{ $test.Status }} {
   445  		t.Errorf("invalid response status code: got %+v, expected {{ $test.Status }}", {{ $rw }}.Code)
   446  	}
   447  {{ if $test.ReturnType }}	var mt {{ $test.ReturnType.Pointer }}{{ $test.ReturnType.Type }}
   448  	if {{ $resp }} != nil {
   449  		var {{ $ok := $test.Escape "ok" }}{{ $ok }} bool
   450  		mt, {{ $ok }} = {{ $resp }}.({{ $test.ReturnType.Pointer }}{{ $test.ReturnType.Type }})
   451  		if !{{ $ok }} {
   452  			t.Fatalf("invalid response media: got variable of type %T, value %+v, expected instance of {{ $test.ReturnType.Type }}", {{ $resp }}, {{ $resp }})
   453  		}
   454  {{ if $test.ReturnType.Validatable }}		{{ $err }} = mt.Validate()
   455  		if {{ $err }} != nil {
   456  			t.Errorf("invalid response media type: %s", {{ $err }})
   457  		}
   458  {{ end }}	}
   459  {{ end }}
   460  	// Return results
   461  	return {{ $rw }}{{ if $test.ReturnType }}, mt{{ end }}
   462  }
   463  {{ end }}`