github.com/6543-forks/go-swagger@v0.26.0/generator/template_repo.go (about) 1 package generator 2 3 import ( 4 "bytes" 5 "encoding/json" 6 "fmt" 7 "io/ioutil" 8 "os" 9 "path" 10 "path/filepath" 11 "strings" 12 "text/template" 13 "text/template/parse" 14 "unicode" 15 16 "log" 17 18 "github.com/go-openapi/inflect" 19 "github.com/go-openapi/swag" 20 "github.com/kr/pretty" 21 ) 22 23 var ( 24 assets map[string][]byte 25 protectedTemplates map[string]bool 26 27 // FuncMapFunc yields a map with all functions for templates 28 FuncMapFunc func(*LanguageOpts) template.FuncMap 29 30 templates *Repository 31 ) 32 33 func initTemplateRepo() { 34 FuncMapFunc = DefaultFuncMap 35 36 // this makes the ToGoName func behave with the special 37 // prefixing rule above 38 swag.GoNamePrefixFunc = prefixForName 39 40 assets = defaultAssets() 41 protectedTemplates = defaultProtectedTemplates() 42 templates = NewRepository(FuncMapFunc(DefaultLanguageFunc())) 43 } 44 45 // DefaultFuncMap yields a map with default functions for use n the templates. 46 // These are available in every template 47 func DefaultFuncMap(lang *LanguageOpts) template.FuncMap { 48 return template.FuncMap(map[string]interface{}{ 49 "pascalize": pascalize, 50 "camelize": swag.ToJSONName, 51 "varname": lang.MangleVarName, 52 "humanize": swag.ToHumanNameLower, 53 "snakize": lang.MangleFileName, 54 "toPackagePath": func(name string) string { 55 return filepath.FromSlash(lang.ManglePackagePath(name, "")) 56 }, 57 "toPackage": func(name string) string { 58 return lang.ManglePackagePath(name, "") 59 }, 60 "toPackageName": func(name string) string { 61 return lang.ManglePackageName(name, "") 62 }, 63 "dasherize": swag.ToCommandName, 64 "pluralizeFirstWord": pluralizeFirstWord, 65 "json": asJSON, 66 "prettyjson": asPrettyJSON, 67 "hasInsecure": func(arg []string) bool { 68 return swag.ContainsStringsCI(arg, "http") || swag.ContainsStringsCI(arg, "ws") 69 }, 70 "hasSecure": func(arg []string) bool { 71 return swag.ContainsStringsCI(arg, "https") || swag.ContainsStringsCI(arg, "wss") 72 }, 73 "dropPackage": dropPackage, 74 "upper": strings.ToUpper, 75 "contains": swag.ContainsStrings, 76 "padSurround": padSurround, 77 "joinFilePath": filepath.Join, 78 "comment": padComment, 79 "blockcomment": blockComment, 80 "inspect": pretty.Sprint, 81 "cleanPath": path.Clean, 82 "mediaTypeName": mediaMime, 83 "arrayInitializer": lang.arrayInitializer, 84 "hasPrefix": strings.HasPrefix, 85 "stringContains": strings.Contains, 86 "imports": lang.imports, 87 "dict": dict, 88 }) 89 } 90 91 func defaultAssets() map[string][]byte { 92 return map[string][]byte{ 93 // schema validation templates 94 "validation/primitive.gotmpl": MustAsset("templates/validation/primitive.gotmpl"), 95 "validation/customformat.gotmpl": MustAsset("templates/validation/customformat.gotmpl"), 96 "validation/structfield.gotmpl": MustAsset("templates/validation/structfield.gotmpl"), 97 "structfield.gotmpl": MustAsset("templates/structfield.gotmpl"), 98 "schemavalidator.gotmpl": MustAsset("templates/schemavalidator.gotmpl"), 99 "schemapolymorphic.gotmpl": MustAsset("templates/schemapolymorphic.gotmpl"), 100 "schemaembedded.gotmpl": MustAsset("templates/schemaembedded.gotmpl"), 101 102 // schema serialization templates 103 "additionalpropertiesserializer.gotmpl": MustAsset("templates/serializers/additionalpropertiesserializer.gotmpl"), 104 "aliasedserializer.gotmpl": MustAsset("templates/serializers/aliasedserializer.gotmpl"), 105 "allofserializer.gotmpl": MustAsset("templates/serializers/allofserializer.gotmpl"), 106 "basetypeserializer.gotmpl": MustAsset("templates/serializers/basetypeserializer.gotmpl"), 107 "marshalbinaryserializer.gotmpl": MustAsset("templates/serializers/marshalbinaryserializer.gotmpl"), 108 "schemaserializer.gotmpl": MustAsset("templates/serializers/schemaserializer.gotmpl"), 109 "subtypeserializer.gotmpl": MustAsset("templates/serializers/subtypeserializer.gotmpl"), 110 "tupleserializer.gotmpl": MustAsset("templates/serializers/tupleserializer.gotmpl"), 111 112 // schema generation template 113 "docstring.gotmpl": MustAsset("templates/docstring.gotmpl"), 114 "schematype.gotmpl": MustAsset("templates/schematype.gotmpl"), 115 "schemabody.gotmpl": MustAsset("templates/schemabody.gotmpl"), 116 "schema.gotmpl": MustAsset("templates/schema.gotmpl"), 117 "model.gotmpl": MustAsset("templates/model.gotmpl"), 118 "header.gotmpl": MustAsset("templates/header.gotmpl"), 119 120 "swagger_json_embed.gotmpl": MustAsset("templates/swagger_json_embed.gotmpl"), 121 122 // server templates 123 "server/parameter.gotmpl": MustAsset("templates/server/parameter.gotmpl"), 124 "server/urlbuilder.gotmpl": MustAsset("templates/server/urlbuilder.gotmpl"), 125 "server/responses.gotmpl": MustAsset("templates/server/responses.gotmpl"), 126 "server/operation.gotmpl": MustAsset("templates/server/operation.gotmpl"), 127 "server/builder.gotmpl": MustAsset("templates/server/builder.gotmpl"), 128 "server/server.gotmpl": MustAsset("templates/server/server.gotmpl"), 129 "server/configureapi.gotmpl": MustAsset("templates/server/configureapi.gotmpl"), 130 "server/main.gotmpl": MustAsset("templates/server/main.gotmpl"), 131 "server/doc.gotmpl": MustAsset("templates/server/doc.gotmpl"), 132 133 // client templates 134 "client/parameter.gotmpl": MustAsset("templates/client/parameter.gotmpl"), 135 "client/response.gotmpl": MustAsset("templates/client/response.gotmpl"), 136 "client/client.gotmpl": MustAsset("templates/client/client.gotmpl"), 137 "client/facade.gotmpl": MustAsset("templates/client/facade.gotmpl"), 138 } 139 } 140 141 func defaultProtectedTemplates() map[string]bool { 142 return map[string]bool{ 143 "dereffedSchemaType": true, 144 "docstring": true, 145 "header": true, 146 "mapvalidator": true, 147 "model": true, 148 "modelvalidator": true, 149 "objectvalidator": true, 150 "primitivefieldvalidator": true, 151 "privstructfield": true, 152 "privtuplefield": true, 153 "propertyValidationDocString": true, 154 "propertyvalidator": true, 155 "schema": true, 156 "schemaBody": true, 157 "schemaType": true, 158 "schemabody": true, 159 "schematype": true, 160 "schemavalidator": true, 161 "serverDoc": true, 162 "slicevalidator": true, 163 "structfield": true, 164 "structfieldIface": true, 165 "subTypeBody": true, 166 "swaggerJsonEmbed": true, 167 "tuplefield": true, 168 "tuplefieldIface": true, 169 "typeSchemaType": true, 170 "validationCustomformat": true, 171 "validationPrimitive": true, 172 "validationStructfield": true, 173 "withBaseTypeBody": true, 174 "withoutBaseTypeBody": true, 175 176 // all serializers TODO(fred) 177 "additionalPropertiesSerializer": true, 178 "tupleSerializer": true, 179 "schemaSerializer": true, 180 "hasDiscriminatedSerializer": true, 181 "discriminatedSerializer": true, 182 } 183 } 184 185 // AddFile adds a file to the default repository. It will create a new template based on the filename. 186 // It trims the .gotmpl from the end and converts the name using swag.ToJSONName. This will strip 187 // directory separators and Camelcase the next letter. 188 // e.g validation/primitive.gotmpl will become validationPrimitive 189 // 190 // If the file contains a definition for a template that is protected the whole file will not be added 191 func AddFile(name, data string) error { 192 return templates.addFile(name, data, false) 193 } 194 195 // NewRepository creates a new template repository with the provided functions defined 196 func NewRepository(funcs template.FuncMap) *Repository { 197 repo := Repository{ 198 files: make(map[string]string), 199 templates: make(map[string]*template.Template), 200 funcs: funcs, 201 } 202 203 if repo.funcs == nil { 204 repo.funcs = make(template.FuncMap) 205 } 206 207 return &repo 208 } 209 210 // Repository is the repository for the generator templates 211 type Repository struct { 212 files map[string]string 213 templates map[string]*template.Template 214 funcs template.FuncMap 215 allowOverride bool 216 } 217 218 // LoadDefaults will load the embedded templates 219 func (t *Repository) LoadDefaults() { 220 221 for name, asset := range assets { 222 if err := t.addFile(name, string(asset), true); err != nil { 223 log.Fatal(err) 224 } 225 } 226 } 227 228 // LoadDir will walk the specified path and add each .gotmpl file it finds to the repository 229 func (t *Repository) LoadDir(templatePath string) error { 230 err := filepath.Walk(templatePath, func(path string, info os.FileInfo, err error) error { 231 232 if strings.HasSuffix(path, ".gotmpl") { 233 if assetName, e := filepath.Rel(templatePath, path); e == nil { 234 if data, e := ioutil.ReadFile(path); e == nil { 235 if ee := t.AddFile(assetName, string(data)); ee != nil { 236 return fmt.Errorf("could not add template: %v", ee) 237 } 238 } 239 // Non-readable files are skipped 240 } 241 } 242 if err != nil { 243 return err 244 } 245 // Non-template files are skipped 246 return nil 247 }) 248 if err != nil { 249 return fmt.Errorf("could not complete template processing in directory \"%s\": %v", templatePath, err) 250 } 251 return nil 252 } 253 254 // LoadContrib loads template from contrib directory 255 func (t *Repository) LoadContrib(name string) error { 256 log.Printf("loading contrib %s", name) 257 const pathPrefix = "templates/contrib/" 258 basePath := pathPrefix + name 259 filesAdded := 0 260 for _, aname := range AssetNames() { 261 if !strings.HasSuffix(aname, ".gotmpl") { 262 continue 263 } 264 if strings.HasPrefix(aname, basePath) { 265 target := aname[len(basePath)+1:] 266 err := t.addFile(target, string(MustAsset(aname)), true) 267 if err != nil { 268 return err 269 } 270 log.Printf("added contributed template %s from %s", target, aname) 271 filesAdded++ 272 } 273 } 274 if filesAdded == 0 { 275 return fmt.Errorf("no files added from template: %s", name) 276 } 277 return nil 278 } 279 280 func (t *Repository) addFile(name, data string, allowOverride bool) error { 281 fileName := name 282 name = swag.ToJSONName(strings.TrimSuffix(name, ".gotmpl")) 283 284 templ, err := template.New(name).Funcs(t.funcs).Parse(data) 285 286 if err != nil { 287 return fmt.Errorf("failed to load template %s: %v", name, err) 288 } 289 290 // check if any protected templates are defined 291 if !allowOverride && !t.allowOverride { 292 for _, template := range templ.Templates() { 293 if protectedTemplates[template.Name()] { 294 return fmt.Errorf("cannot overwrite protected template %s", template.Name()) 295 } 296 } 297 } 298 299 // Add each defined template into the cache 300 for _, template := range templ.Templates() { 301 302 t.files[template.Name()] = fileName 303 t.templates[template.Name()] = template.Lookup(template.Name()) 304 } 305 306 return nil 307 } 308 309 // MustGet a template by name, panics when fails 310 func (t *Repository) MustGet(name string) *template.Template { 311 tpl, err := t.Get(name) 312 if err != nil { 313 panic(err) 314 } 315 return tpl 316 } 317 318 // AddFile adds a file to the repository. It will create a new template based on the filename. 319 // It trims the .gotmpl from the end and converts the name using swag.ToJSONName. This will strip 320 // directory separators and Camelcase the next letter. 321 // e.g validation/primitive.gotmpl will become validationPrimitive 322 // 323 // If the file contains a definition for a template that is protected the whole file will not be added 324 func (t *Repository) AddFile(name, data string) error { 325 return t.addFile(name, data, false) 326 } 327 328 // SetAllowOverride allows setting allowOverride after the Repository was initialized 329 func (t *Repository) SetAllowOverride(value bool) { 330 t.allowOverride = value 331 } 332 333 func findDependencies(n parse.Node) []string { 334 335 var deps []string 336 depMap := make(map[string]bool) 337 338 if n == nil { 339 return deps 340 } 341 342 switch node := n.(type) { 343 case *parse.ListNode: 344 if node != nil && node.Nodes != nil { 345 for _, nn := range node.Nodes { 346 for _, dep := range findDependencies(nn) { 347 depMap[dep] = true 348 } 349 } 350 } 351 case *parse.IfNode: 352 for _, dep := range findDependencies(node.BranchNode.List) { 353 depMap[dep] = true 354 } 355 for _, dep := range findDependencies(node.BranchNode.ElseList) { 356 depMap[dep] = true 357 } 358 359 case *parse.RangeNode: 360 for _, dep := range findDependencies(node.BranchNode.List) { 361 depMap[dep] = true 362 } 363 for _, dep := range findDependencies(node.BranchNode.ElseList) { 364 depMap[dep] = true 365 } 366 367 case *parse.WithNode: 368 for _, dep := range findDependencies(node.BranchNode.List) { 369 depMap[dep] = true 370 } 371 for _, dep := range findDependencies(node.BranchNode.ElseList) { 372 depMap[dep] = true 373 } 374 375 case *parse.TemplateNode: 376 depMap[node.Name] = true 377 } 378 379 for dep := range depMap { 380 deps = append(deps, dep) 381 } 382 383 return deps 384 385 } 386 387 func (t *Repository) flattenDependencies(templ *template.Template, dependencies map[string]bool) map[string]bool { 388 if dependencies == nil { 389 dependencies = make(map[string]bool) 390 } 391 392 deps := findDependencies(templ.Tree.Root) 393 394 for _, d := range deps { 395 if _, found := dependencies[d]; !found { 396 397 dependencies[d] = true 398 399 if tt := t.templates[d]; tt != nil { 400 dependencies = t.flattenDependencies(tt, dependencies) 401 } 402 } 403 404 dependencies[d] = true 405 406 } 407 408 return dependencies 409 410 } 411 412 func (t *Repository) addDependencies(templ *template.Template) (*template.Template, error) { 413 414 name := templ.Name() 415 416 deps := t.flattenDependencies(templ, nil) 417 418 for dep := range deps { 419 420 if dep == "" { 421 continue 422 } 423 424 tt := templ.Lookup(dep) 425 426 // Check if we have it 427 if tt == nil { 428 tt = t.templates[dep] 429 430 // Still don't have it, return an error 431 if tt == nil { 432 return templ, fmt.Errorf("could not find template %s", dep) 433 } 434 var err error 435 436 // Add it to the parse tree 437 templ, err = templ.AddParseTree(dep, tt.Tree) 438 439 if err != nil { 440 return templ, fmt.Errorf("dependency error: %v", err) 441 } 442 443 } 444 } 445 return templ.Lookup(name), nil 446 } 447 448 // Get will return the named template from the repository, ensuring that all dependent templates are loaded. 449 // It will return an error if a dependent template is not defined in the repository. 450 func (t *Repository) Get(name string) (*template.Template, error) { 451 templ, found := t.templates[name] 452 453 if !found { 454 return templ, fmt.Errorf("template doesn't exist %s", name) 455 } 456 457 return t.addDependencies(templ) 458 } 459 460 // DumpTemplates prints out a dump of all the defined templates, where they are defined and what their dependencies are. 461 func (t *Repository) DumpTemplates() { 462 buf := bytes.NewBuffer(nil) 463 fmt.Fprintln(buf, "\n# Templates") 464 for name, templ := range t.templates { 465 fmt.Fprintf(buf, "## %s\n", name) 466 fmt.Fprintf(buf, "Defined in `%s`\n", t.files[name]) 467 468 if deps := findDependencies(templ.Tree.Root); len(deps) > 0 { 469 470 fmt.Fprintf(buf, "####requires \n - %v\n\n\n", strings.Join(deps, "\n - ")) 471 } 472 fmt.Fprintln(buf, "\n---") 473 } 474 log.Println(buf.String()) 475 } 476 477 // FuncMap functions 478 479 func asJSON(data interface{}) (string, error) { 480 b, err := json.Marshal(data) 481 if err != nil { 482 return "", err 483 } 484 return string(b), nil 485 } 486 487 func asPrettyJSON(data interface{}) (string, error) { 488 b, err := json.MarshalIndent(data, "", " ") 489 if err != nil { 490 return "", err 491 } 492 return string(b), nil 493 } 494 495 func pluralizeFirstWord(arg string) string { 496 sentence := strings.Split(arg, " ") 497 if len(sentence) == 1 { 498 return inflect.Pluralize(arg) 499 } 500 501 return inflect.Pluralize(sentence[0]) + " " + strings.Join(sentence[1:], " ") 502 } 503 504 func dropPackage(str string) string { 505 parts := strings.Split(str, ".") 506 return parts[len(parts)-1] 507 } 508 509 func padSurround(entry, padWith string, i, ln int) string { 510 var res []string 511 if i > 0 { 512 for j := 0; j < i; j++ { 513 res = append(res, padWith) 514 } 515 } 516 res = append(res, entry) 517 tot := ln - i - 1 518 for j := 0; j < tot; j++ { 519 res = append(res, padWith) 520 } 521 return strings.Join(res, ",") 522 } 523 524 func padComment(str string, pads ...string) string { 525 // pads specifes padding to indent multi line comments.Defaults to one space 526 pad := " " 527 lines := strings.Split(str, "\n") 528 if len(pads) > 0 { 529 pad = strings.Join(pads, "") 530 } 531 return (strings.Join(lines, "\n//"+pad)) 532 } 533 534 func blockComment(str string) string { 535 return strings.Replace(str, "*/", "[*]/", -1) 536 } 537 538 func pascalize(arg string) string { 539 runes := []rune(arg) 540 switch len(runes) { 541 case 0: 542 return "Empty" 543 case 1: // handle special case when we have a single rune that is not handled by swag.ToGoName 544 switch runes[0] { 545 case '+', '-', '#', '_': // those cases are handled differently than swag utility 546 return prefixForName(arg) 547 } 548 } 549 return swag.ToGoName(swag.ToGoName(arg)) // want to remove spaces 550 } 551 552 func prefixForName(arg string) string { 553 first := []rune(arg)[0] 554 if len(arg) == 0 || unicode.IsLetter(first) { 555 return "" 556 } 557 switch first { 558 case '+': 559 return "Plus" 560 case '-': 561 return "Minus" 562 case '#': 563 return "HashTag" 564 // other cases ($,@ etc..) handled by swag.ToGoName 565 } 566 return "Nr" 567 } 568 569 func dict(values ...interface{}) (map[string]interface{}, error) { 570 if len(values)%2 != 0 { 571 return nil, fmt.Errorf("expected even number of arguments, got %d", len(values)) 572 } 573 dict := make(map[string]interface{}, len(values)/2) 574 for i := 0; i < len(values); i += 2 { 575 key, ok := values[i].(string) 576 if !ok { 577 return nil, fmt.Errorf("expected string key, got %+v", values[i]) 578 } 579 dict[key] = values[i+1] 580 } 581 return dict, nil 582 }