github.com/ManabuSeki/goa-v1@v1.4.3/goagen/gen_app/test_generator.go (about) 1 package genapp 2 3 import ( 4 "fmt" 5 "net/http" 6 "os" 7 "path/filepath" 8 "sort" 9 "strconv" 10 "strings" 11 "text/template" 12 13 "github.com/goadesign/goa/design" 14 "github.com/goadesign/goa/goagen/codegen" 15 ) 16 17 func makeTestDir(g *Generator, apiName string) (outDir string, err error) { 18 outDir = filepath.Join(g.OutDir, "test") 19 if err = os.RemoveAll(outDir); err != nil { 20 return 21 } 22 if err = os.MkdirAll(outDir, 0755); err != nil { 23 return 24 } 25 g.genfiles = append(g.genfiles, outDir) 26 return 27 } 28 29 // TestMethod structure 30 type TestMethod struct { 31 Name string 32 Comment string 33 ResourceName string 34 ActionName string 35 ControllerName string 36 ContextVarName string 37 ContextType string 38 RouteVerb string 39 FullPath string 40 Status int 41 ReturnType *ObjectType 42 ReturnsErrorMedia bool 43 Params []*ObjectType 44 QueryParams []*ObjectType 45 Headers []*ObjectType 46 Payload *ObjectType 47 reservedNames map[string]bool 48 } 49 50 // Escape escapes given string. 51 func (t *TestMethod) Escape(s string) string { 52 if ok := t.reservedNames[s]; ok { 53 s = t.Escape("_" + s) 54 } 55 t.reservedNames[s] = true 56 return s 57 } 58 59 // ObjectType structure 60 type ObjectType struct { 61 Label string 62 Name string 63 Type string 64 Pointer string 65 Validatable bool 66 } 67 68 func (g *Generator) generateResourceTest() error { 69 if len(g.API.Resources) == 0 { 70 return nil 71 } 72 funcs := template.FuncMap{ 73 "isSlice": isSlice, 74 } 75 testTmpl := template.Must(template.New("test").Funcs(funcs).Parse(testTmpl)) 76 outDir, err := makeTestDir(g, g.API.Name) 77 if err != nil { 78 return err 79 } 80 appPkg, err := codegen.PackagePath(g.OutDir) 81 if err != nil { 82 return err 83 } 84 imports := []*codegen.ImportSpec{ 85 codegen.SimpleImport("bytes"), 86 codegen.SimpleImport("fmt"), 87 codegen.SimpleImport("io"), 88 codegen.SimpleImport("log"), 89 codegen.SimpleImport("net/http"), 90 codegen.SimpleImport("net/http/httptest"), 91 codegen.SimpleImport("net/url"), 92 codegen.SimpleImport("strconv"), 93 codegen.SimpleImport("strings"), 94 codegen.SimpleImport("time"), 95 codegen.SimpleImport(appPkg), 96 codegen.SimpleImport("github.com/goadesign/goa"), 97 codegen.SimpleImport("github.com/goadesign/goa/goatest"), 98 codegen.SimpleImport("context"), 99 codegen.NewImport("uuid", "github.com/gofrs/uuid"), 100 } 101 102 return g.API.IterateResources(func(res *design.ResourceDefinition) (err error) { 103 filename := filepath.Join(outDir, codegen.SnakeCase(res.Name)+"_testing.go") 104 var file *codegen.SourceFile 105 file, err = codegen.SourceFileFor(filename) 106 if err != nil { 107 return err 108 } 109 defer func() { 110 file.Close() 111 if err == nil { 112 err = file.FormatCode() 113 } 114 }() 115 title := fmt.Sprintf("%s: %s TestHelpers", g.API.Context(), res.Name) 116 if err = file.WriteHeader(title, "test", imports); err != nil { 117 return err 118 } 119 120 var methods []*TestMethod 121 122 if err = res.IterateActions(func(action *design.ActionDefinition) error { 123 if err := action.IterateResponses(func(response *design.ResponseDefinition) error { 124 if response.Status == 101 { // SwitchingProtocols, Don't currently handle WebSocket endpoints 125 return nil 126 } 127 for routeIndex, route := range action.Routes { 128 mediaType := design.Design.MediaTypeWithIdentifier(response.MediaType) 129 if mediaType == nil { 130 methods = append(methods, g.createTestMethod(res, action, response, route, routeIndex, nil, nil)) 131 } else { 132 if err := mediaType.IterateViews(func(view *design.ViewDefinition) error { 133 methods = append(methods, g.createTestMethod(res, action, response, route, routeIndex, mediaType, view)) 134 return nil 135 }); err != nil { 136 return err 137 } 138 } 139 } 140 return nil 141 }); err != nil { 142 return err 143 } 144 return nil 145 }); err != nil { 146 return err 147 } 148 g.genfiles = append(g.genfiles, filename) 149 err = testTmpl.Execute(file, methods) 150 return 151 }) 152 } 153 154 func (g *Generator) createTestMethod(resource *design.ResourceDefinition, action *design.ActionDefinition, 155 response *design.ResponseDefinition, route *design.RouteDefinition, routeIndex int, 156 mediaType *design.MediaTypeDefinition, view *design.ViewDefinition) *TestMethod { 157 158 var ( 159 actionName, ctrlName, varName string 160 routeQualifier, viewQualifier, respQualifier string 161 comment string 162 path []*ObjectType 163 query []*ObjectType 164 header []*ObjectType 165 returnType *ObjectType 166 payload *ObjectType 167 ) 168 169 actionName = codegen.Goify(action.Name, true) 170 ctrlName = codegen.Goify(resource.Name, true) 171 varName = codegen.Goify(action.Name, false) 172 routeQualifier = suffixRoute(action.Routes, routeIndex) 173 if view != nil && view.Name != "default" { 174 viewQualifier = codegen.Goify(view.Name, true) 175 } 176 respQualifier = codegen.Goify(response.Name, true) 177 hasReturnValue := view != nil && mediaType != nil 178 179 if hasReturnValue { 180 p, _, err := mediaType.Project(view.Name) 181 if err != nil { 182 panic(err) // bug 183 } 184 tmp := codegen.GoTypeName(p, nil, 0, false) 185 if !p.IsError() { 186 tmp = fmt.Sprintf("%s.%s", g.Target, tmp) 187 } 188 validate := g.validator.Code(p.AttributeDefinition, false, false, false, "payload", "raw", 1, false) 189 returnType = &ObjectType{} 190 returnType.Type = tmp 191 if p.IsObject() && !p.IsError() { 192 returnType.Pointer = "*" 193 } 194 returnType.Validatable = validate != "" 195 } 196 197 comment = "runs the method " + actionName + " of the given controller with the given parameters" 198 if action.Payload != nil { 199 comment += " and payload" 200 } 201 comment += ".\n// It returns the response writer so it's possible to inspect the response headers" 202 if hasReturnValue { 203 comment += " and the media type struct written to the response" 204 } 205 comment += "." 206 207 path = pathParams(action, route) 208 query = queryParams(action) 209 header = headers(action, resource.Headers) 210 211 if action.Payload != nil { 212 payload = &ObjectType{} 213 payload.Name = "payload" 214 payload.Type = fmt.Sprintf("%s.%s", g.Target, codegen.Goify(action.Payload.TypeName, true)) 215 if !action.Payload.IsPrimitive() && !action.Payload.IsArray() && !action.Payload.IsHash() { 216 payload.Pointer = "*" 217 } 218 219 validate := g.validator.Code(action.Payload.AttributeDefinition, false, false, false, "payload", "raw", 1, false) 220 if validate != "" { 221 payload.Validatable = true 222 } 223 } 224 225 return &TestMethod{ 226 Name: fmt.Sprintf("%s%s%s%s%s", actionName, ctrlName, respQualifier, routeQualifier, viewQualifier), 227 ActionName: actionName, 228 ResourceName: ctrlName, 229 Comment: comment, 230 Params: path, 231 QueryParams: query, 232 Headers: header, 233 Payload: payload, 234 ReturnType: returnType, 235 ReturnsErrorMedia: mediaType == design.ErrorMedia, 236 ControllerName: fmt.Sprintf("%s.%sController", g.Target, ctrlName), 237 ContextVarName: fmt.Sprintf("%sCtx", varName), 238 ContextType: fmt.Sprintf("%s.New%s%sContext", g.Target, actionName, ctrlName), 239 RouteVerb: route.Verb, 240 Status: response.Status, 241 FullPath: goPathFormat(route.FullPath()), 242 reservedNames: reservedNames(path, query, header, payload, returnType), 243 } 244 } 245 246 // pathParams returns the path params for the given action and route. 247 func pathParams(action *design.ActionDefinition, route *design.RouteDefinition) []*ObjectType { 248 return paramFromNames(action, route.Params()) 249 } 250 251 // headers builds the template data structure needed to proprely render the code 252 // for setting the headers for the given action. 253 func headers(action *design.ActionDefinition, headers *design.AttributeDefinition) []*ObjectType { 254 hds := &design.AttributeDefinition{ 255 Type: design.Object{}, 256 } 257 if headers != nil { 258 hds.Merge(headers) 259 hds.Validation = headers.Validation 260 } 261 if action.Headers != nil { 262 hds.Merge(action.Headers) 263 hds.Validation = action.Headers.Validation 264 } 265 266 if hds == nil { 267 return nil 268 } 269 var headrs []string 270 for header := range hds.Type.ToObject() { 271 headrs = append(headrs, header) 272 } 273 sort.Strings(headrs) 274 objs := make([]*ObjectType, len(headrs)) 275 for i, name := range headrs { 276 objs[i] = attToObject(name, hds, hds.Type.ToObject()[name]) 277 objs[i].Label = http.CanonicalHeaderKey(objs[i].Label) 278 } 279 return objs 280 } 281 282 // queryParams returns the query string params for the given action. 283 func queryParams(action *design.ActionDefinition) []*ObjectType { 284 var qparams []string 285 if qps := action.QueryParams; qps != nil { 286 for pname := range qps.Type.ToObject() { 287 qparams = append(qparams, pname) 288 } 289 } 290 sort.Strings(qparams) 291 return paramFromNames(action, qparams) 292 } 293 294 func paramFromNames(action *design.ActionDefinition, names []string) (params []*ObjectType) { 295 obj := action.Params.Type.ToObject() 296 for _, name := range names { 297 params = append(params, attToObject(name, action.Params, obj[name])) 298 } 299 return 300 } 301 302 func reservedNames(params, queryParams, headers []*ObjectType, payload, returnType *ObjectType) map[string]bool { 303 var names = make(map[string]bool) 304 for _, param := range params { 305 names[param.Name] = true 306 } 307 for _, param := range queryParams { 308 names[param.Name] = true 309 } 310 for _, header := range headers { 311 names[header.Name] = true 312 } 313 if payload != nil { 314 names[payload.Name] = true 315 } 316 if returnType != nil { 317 names[returnType.Name] = true 318 } 319 return names 320 } 321 322 func attToObject(name string, parent, att *design.AttributeDefinition) *ObjectType { 323 obj := &ObjectType{} 324 obj.Label = name 325 obj.Name = codegen.Goify(name, false) 326 obj.Type = codegen.GoTypeRef(att.Type, nil, 0, false) 327 if att.Type.IsPrimitive() && parent.IsPrimitivePointer(name) { 328 obj.Pointer = "*" 329 } 330 return obj 331 } 332 333 func goPathFormat(path string) string { 334 return design.WildcardRegex.ReplaceAllLiteralString(path, "/%v") 335 } 336 337 func suffixRoute(routes []*design.RouteDefinition, currIndex int) string { 338 if len(routes) > 1 && currIndex > 0 { 339 return strconv.Itoa(currIndex) 340 } 341 return "" 342 } 343 344 func isSlice(typeName string) bool { 345 return strings.HasPrefix(typeName, "[]") 346 } 347 348 var convertParamTmpl = `{{ if eq .Type "string" }} sliceVal := []string{ {{ if .Pointer }}*{{ end }}{{ .Name }}}{{/* 349 */}}{{ else if eq .Type "int" }} sliceVal := []string{strconv.Itoa({{ if .Pointer }}*{{ end }}{{ .Name }})}{{/* 350 */}}{{ else if eq .Type "[]string" }} sliceVal := {{ .Name }}{{/* 351 */}}{{ else if (isSlice .Type) }} sliceVal := make([]string, len({{ .Name }})) 352 for i, v := range {{ .Name }} { 353 sliceVal[i] = fmt.Sprintf("%v", v) 354 }{{/* 355 */}}{{ else if eq .Type "time.Time" }} sliceVal := []string{ {{ if .Pointer }}(*{{ end }}{{ .Name }}{{ if .Pointer }}){{ end }}.Format(time.RFC3339)}{{/* 356 */}}{{ else }} sliceVal := []string{fmt.Sprintf("%v", {{ if .Pointer }}*{{ end }}{{ .Name }})}{{ end }}` 357 358 var testTmpl = `{{ define "convertParam" }}` + convertParamTmpl + `{{ end }}` + ` 359 {{ range $test := . }} 360 // {{ $test.Name }} {{ $test.Comment }} 361 // If ctx is nil then context.Background() is used. 362 // If service is nil then a default service is created. 363 func {{ $test.Name }}(t goatest.TInterface, ctx context.Context, service *goa.Service, ctrl {{ $test.ControllerName}}{{/* 364 */}}{{ range $param := $test.Params }}, {{ $param.Name }} {{ $param.Pointer }}{{ $param.Type }}{{ end }}{{/* 365 */}}{{ range $param := $test.QueryParams }}, {{ $param.Name }} {{ $param.Pointer }}{{ $param.Type }}{{ end }}{{/* 366 */}}{{ range $header := $test.Headers }}, {{ $header.Name }} {{ $header.Pointer }}{{ $header.Type }}{{ end }}{{/* 367 */}}{{ if $test.Payload }}, {{ $test.Payload.Name }} {{ $test.Payload.Pointer }}{{ $test.Payload.Type }}{{ end }}){{/* 368 */}} (http.ResponseWriter{{ if $test.ReturnType }}, {{ $test.ReturnType.Pointer }}{{ $test.ReturnType.Type }}{{ end }}) { 369 // Setup service 370 var ( 371 {{ $logBuf := $test.Escape "logBuf" }}{{ $logBuf }} bytes.Buffer 372 {{ $resp := $test.Escape "resp" }}{{ if $test.ReturnType }}{{ $resp }} interface{}{{ end }} 373 374 {{ $respSetter := $test.Escape "respSetter" }}{{ $respSetter }} goatest.ResponseSetterFunc = func(r interface{}) { {{ if $test.ReturnType }}{{ $resp }} = r{{ end }} } 375 ) 376 if service == nil { 377 service = goatest.Service(&{{ $logBuf }}, {{ $respSetter }}) 378 } else { 379 {{ $logger := $test.Escape "logger" }}{{ $logger }} := log.New(&{{ $logBuf }}, "", log.Ltime) 380 service.WithLogger(goa.NewLogger({{ $logger }})) 381 {{ $newEncoder := $test.Escape "newEncoder" }}{{ $newEncoder }} := func(io.Writer) goa.Encoder { return {{ $respSetter }} } 382 service.Encoder = goa.NewHTTPEncoder() // Make sure the code ends up using this decoder 383 service.Encoder.Register({{ $newEncoder }}, "*/*") 384 } 385 {{ if $test.Payload }}{{ if $test.Payload.Validatable }} 386 // Validate payload 387 {{ $err := $test.Escape "err" }}{{ $err }} := {{ $test.Payload.Name }}.Validate() 388 if {{ $err }} != nil { 389 {{ $e := $test.Escape "e" }}{{ $e }}, {{ $ok := $test.Escape "ok" }}{{ $ok }} := {{ $err }}.(goa.ServiceError) 390 if !{{ $ok }} { 391 panic({{ $err }}) // bug 392 } 393 {{ if not $test.ReturnsErrorMedia }} t.Errorf("unexpected payload validation error: %+v", {{ $e }}) 394 {{ end }}{{ if $test.ReturnType }} return nil, {{ if $test.ReturnsErrorMedia }}{{ $e }}{{ else }}nil{{ end }}{{ else }}return nil{{ end }} 395 } 396 {{ end }}{{ end }} 397 // Setup request context 398 {{ $rw := $test.Escape "rw" }}{{ $rw }} := httptest.NewRecorder() 399 {{ $query := $test.Escape "query" }}{{ if $test.QueryParams}} {{ $query }} := url.Values{} 400 {{ range $param := $test.QueryParams }}{{ if $param.Pointer }} if {{ $param.Name }} != nil {{ end }}{ 401 {{ template "convertParam" $param }} 402 {{ $query }}[{{ printf "%q" $param.Label }}] = sliceVal 403 } 404 {{ end }}{{ end }} {{ $u := $test.Escape "u" }}{{ $u }}:= &url.URL{ 405 Path: fmt.Sprintf({{ printf "%q" $test.FullPath }}{{ range $param := $test.Params }}, {{ $param.Name }}{{ end }}), 406 {{ if $test.QueryParams }} RawQuery: {{ $query }}.Encode(), 407 {{ end }} } 408 {{ $req := $test.Escape "req" }}{{ $req }}, {{ $err := $test.Escape "err" }}{{ $err }}:= http.NewRequest("{{ $test.RouteVerb }}", {{ $u }}.String(), nil) 409 if {{ $err }} != nil { 410 panic("invalid test " + {{ $err }}.Error()) // bug 411 } 412 {{ range $header := $test.Headers }}{{ if $header.Pointer }} if {{ $header.Name }} != nil {{ end }}{ 413 {{ template "convertParam" $header }} 414 {{ $req }}.Header[{{ printf "%q" $header.Label }}] = sliceVal 415 } 416 {{ end }} {{ $prms := $test.Escape "prms" }}{{ $prms }} := url.Values{} 417 {{ range $param := $test.Params }} {{ $prms }}["{{ $param.Label }}"] = []string{fmt.Sprintf("%v",{{ $param.Name}})} 418 {{ end }}{{ range $param := $test.QueryParams }}{{ if $param.Pointer }} if {{ $param.Name }} != nil {{ end }} { 419 {{ template "convertParam" $param }} 420 {{ $prms }}[{{ printf "%q" $param.Label }}] = sliceVal 421 } 422 {{ end }} if ctx == nil { 423 ctx = context.Background() 424 } 425 {{ $goaCtx := $test.Escape "goaCtx" }}{{ $goaCtx }} := goa.NewContext(goa.WithAction(ctx, "{{ $test.ResourceName }}Test"), {{ $rw }}, {{ $req }}, {{ $prms }}) 426 {{ $test.ContextVarName }}, {{ $err := $test.Escape "err" }}{{ $err }} := {{ $test.ContextType }}({{ $goaCtx }}, {{ $req }}, service) 427 if {{ $err }} != nil { 428 {{ $e := $test.Escape "e" }}{{ $e }}, {{ $ok := $test.Escape "ok" }}{{ $ok }} := {{ $err }}.(goa.ServiceError) 429 if !{{ $ok }} { 430 panic("invalid test data " + {{ $err }}.Error()) // bug 431 } 432 {{ if not $test.ReturnsErrorMedia }} t.Errorf("unexpected parameter validation error: %+v", {{ $e }}) 433 {{ end }}{{ if $test.ReturnType }} return nil, {{ if $test.ReturnsErrorMedia }}{{ $e }}{{ else }}nil{{ end }}{{ else }}return nil{{ end }} 434 } 435 {{ if $test.Payload }}{{ $test.ContextVarName }}.Payload = {{ $test.Payload.Name }}{{ end }} 436 437 // Perform action 438 {{ $err }} = ctrl.{{ $test.ActionName}}({{ $test.ContextVarName }}) 439 440 // Validate response 441 if {{ $err }} != nil { 442 t.Fatalf("controller returned %+v, logs:\n%s", {{ $err }}, {{ $logBuf }}.String()) 443 } 444 if {{ $rw }}.Code != {{ $test.Status }} { 445 t.Errorf("invalid response status code: got %+v, expected {{ $test.Status }}", {{ $rw }}.Code) 446 } 447 {{ if $test.ReturnType }} var mt {{ $test.ReturnType.Pointer }}{{ $test.ReturnType.Type }} 448 if {{ $resp }} != nil { 449 var {{ $ok := $test.Escape "ok" }}{{ $ok }} bool 450 mt, {{ $ok }} = {{ $resp }}.({{ $test.ReturnType.Pointer }}{{ $test.ReturnType.Type }}) 451 if !{{ $ok }} { 452 t.Fatalf("invalid response media: got variable of type %T, value %+v, expected instance of {{ $test.ReturnType.Type }}", {{ $resp }}, {{ $resp }}) 453 } 454 {{ if $test.ReturnType.Validatable }} {{ $err }} = mt.Validate() 455 if {{ $err }} != nil { 456 t.Errorf("invalid response media type: %s", {{ $err }}) 457 } 458 {{ end }} } 459 {{ end }} 460 // Return results 461 return {{ $rw }}{{ if $test.ReturnType }}, mt{{ end }} 462 } 463 {{ end }}`