github.com/emreu/go-swagger@v0.22.1/generator/template_repo.go (about) 1 package generator 2 3 import ( 4 "bytes" 5 "encoding/json" 6 "fmt" 7 "io/ioutil" 8 "os" 9 "path" 10 "path/filepath" 11 "strings" 12 "text/template" 13 "text/template/parse" 14 15 "log" 16 17 "github.com/go-openapi/inflect" 18 "github.com/go-openapi/swag" 19 "github.com/kr/pretty" 20 ) 21 22 var templates *Repository 23 24 // FuncMap is a map with default functions for use n the templates. 25 // These are available in every template 26 var FuncMap template.FuncMap = map[string]interface{}{ 27 "pascalize": pascalize, 28 "camelize": swag.ToJSONName, 29 "varname": golang.MangleVarName, 30 "humanize": swag.ToHumanNameLower, 31 "snakize": golang.MangleFileName, 32 "toPackagePath": func(name string) string { 33 return filepath.FromSlash(golang.ManglePackagePath(name, "")) 34 }, 35 "toPackage": func(name string) string { 36 return golang.ManglePackagePath(name, "") 37 }, 38 "toPackageName": func(name string) string { 39 return golang.ManglePackageName(name, "") 40 }, 41 "dasherize": swag.ToCommandName, 42 "pluralizeFirstWord": func(arg string) string { 43 sentence := strings.Split(arg, " ") 44 if len(sentence) == 1 { 45 return inflect.Pluralize(arg) 46 } 47 48 return inflect.Pluralize(sentence[0]) + " " + strings.Join(sentence[1:], " ") 49 }, 50 "json": asJSON, 51 "prettyjson": asPrettyJSON, 52 "hasInsecure": func(arg []string) bool { 53 return swag.ContainsStringsCI(arg, "http") || swag.ContainsStringsCI(arg, "ws") 54 }, 55 "hasSecure": func(arg []string) bool { 56 return swag.ContainsStringsCI(arg, "https") || swag.ContainsStringsCI(arg, "wss") 57 }, 58 // TODO: simplify redundant functions 59 "stripPackage": func(str, pkg string) string { 60 parts := strings.Split(str, ".") 61 strlen := len(parts) 62 if strlen > 0 { 63 return parts[strlen-1] 64 } 65 return str 66 }, 67 "dropPackage": func(str string) string { 68 parts := strings.Split(str, ".") 69 strlen := len(parts) 70 if strlen > 0 { 71 return parts[strlen-1] 72 } 73 return str 74 }, 75 "upper": strings.ToUpper, 76 "contains": func(coll []string, arg string) bool { 77 for _, v := range coll { 78 if v == arg { 79 return true 80 } 81 } 82 return false 83 }, 84 "padSurround": func(entry, padWith string, i, ln int) string { 85 var res []string 86 if i > 0 { 87 for j := 0; j < i; j++ { 88 res = append(res, padWith) 89 } 90 } 91 res = append(res, entry) 92 tot := ln - i - 1 93 for j := 0; j < tot; j++ { 94 res = append(res, padWith) 95 } 96 return strings.Join(res, ",") 97 }, 98 "joinFilePath": filepath.Join, 99 "comment": func(str string) string { 100 lines := strings.Split(str, "\n") 101 return (strings.Join(lines, "\n// ")) 102 }, 103 "blockcomment": func(str string) string { 104 return strings.Replace(str, "*/", "[*]/", -1) 105 }, 106 "inspect": pretty.Sprint, 107 "cleanPath": path.Clean, 108 "mediaTypeName": func(orig string) string { 109 return strings.SplitN(orig, ";", 2)[0] 110 }, 111 "goSliceInitializer": goSliceInitializer, 112 "hasPrefix": strings.HasPrefix, 113 "stringContains": strings.Contains, 114 } 115 116 func init() { 117 templates = NewRepository(FuncMap) 118 } 119 120 var assets = map[string][]byte{ 121 "validation/primitive.gotmpl": MustAsset("templates/validation/primitive.gotmpl"), 122 "validation/customformat.gotmpl": MustAsset("templates/validation/customformat.gotmpl"), 123 "docstring.gotmpl": MustAsset("templates/docstring.gotmpl"), 124 "validation/structfield.gotmpl": MustAsset("templates/validation/structfield.gotmpl"), 125 "modelvalidator.gotmpl": MustAsset("templates/modelvalidator.gotmpl"), 126 "structfield.gotmpl": MustAsset("templates/structfield.gotmpl"), 127 "tupleserializer.gotmpl": MustAsset("templates/tupleserializer.gotmpl"), 128 "additionalpropertiesserializer.gotmpl": MustAsset("templates/additionalpropertiesserializer.gotmpl"), 129 "schematype.gotmpl": MustAsset("templates/schematype.gotmpl"), 130 "schemabody.gotmpl": MustAsset("templates/schemabody.gotmpl"), 131 "schema.gotmpl": MustAsset("templates/schema.gotmpl"), 132 "schemavalidator.gotmpl": MustAsset("templates/schemavalidator.gotmpl"), 133 "model.gotmpl": MustAsset("templates/model.gotmpl"), 134 "header.gotmpl": MustAsset("templates/header.gotmpl"), 135 "swagger_json_embed.gotmpl": MustAsset("templates/swagger_json_embed.gotmpl"), 136 137 "server/parameter.gotmpl": MustAsset("templates/server/parameter.gotmpl"), 138 "server/urlbuilder.gotmpl": MustAsset("templates/server/urlbuilder.gotmpl"), 139 "server/responses.gotmpl": MustAsset("templates/server/responses.gotmpl"), 140 "server/operation.gotmpl": MustAsset("templates/server/operation.gotmpl"), 141 "server/builder.gotmpl": MustAsset("templates/server/builder.gotmpl"), 142 "server/server.gotmpl": MustAsset("templates/server/server.gotmpl"), 143 "server/configureapi.gotmpl": MustAsset("templates/server/configureapi.gotmpl"), 144 "server/main.gotmpl": MustAsset("templates/server/main.gotmpl"), 145 "server/doc.gotmpl": MustAsset("templates/server/doc.gotmpl"), 146 147 "client/parameter.gotmpl": MustAsset("templates/client/parameter.gotmpl"), 148 "client/response.gotmpl": MustAsset("templates/client/response.gotmpl"), 149 "client/client.gotmpl": MustAsset("templates/client/client.gotmpl"), 150 "client/facade.gotmpl": MustAsset("templates/client/facade.gotmpl"), 151 } 152 153 var protectedTemplates = map[string]bool{ 154 "schemabody": true, 155 "privtuplefield": true, 156 "withoutBaseTypeBody": true, 157 "swaggerJsonEmbed": true, 158 "validationCustomformat": true, 159 "tuplefield": true, 160 "header": true, 161 "withBaseTypeBody": true, 162 "primitivefieldvalidator": true, 163 "mapvalidator": true, 164 "propertyValidationDocString": true, 165 "typeSchemaType": true, 166 "docstring": true, 167 "dereffedSchemaType": true, 168 "model": true, 169 "modelvalidator": true, 170 "privstructfield": true, 171 "schemavalidator": true, 172 "tuplefieldIface": true, 173 "tupleSerializer": true, 174 "tupleserializer": true, 175 "schemaSerializer": true, 176 "propertyvalidator": true, 177 "structfieldIface": true, 178 "schemaBody": true, 179 "objectvalidator": true, 180 "schematype": true, 181 "additionalpropertiesserializer": true, 182 "slicevalidator": true, 183 "validationStructfield": true, 184 "validationPrimitive": true, 185 "schemaType": true, 186 "subTypeBody": true, 187 "schema": true, 188 "additionalPropertiesSerializer": true, 189 "serverDoc": true, 190 "structfield": true, 191 "hasDiscriminatedSerializer": true, 192 "discriminatedSerializer": true, 193 } 194 195 // AddFile adds a file to the default repository. It will create a new template based on the filename. 196 // It trims the .gotmpl from the end and converts the name using swag.ToJSONName. This will strip 197 // directory separators and Camelcase the next letter. 198 // e.g validation/primitive.gotmpl will become validationPrimitive 199 // 200 // If the file contains a definition for a template that is protected the whole file will not be added 201 func AddFile(name, data string) error { 202 return templates.addFile(name, data, false) 203 } 204 205 func asJSON(data interface{}) (string, error) { 206 b, err := json.Marshal(data) 207 if err != nil { 208 return "", err 209 } 210 return string(b), nil 211 } 212 213 func asPrettyJSON(data interface{}) (string, error) { 214 b, err := json.MarshalIndent(data, "", " ") 215 if err != nil { 216 return "", err 217 } 218 return string(b), nil 219 } 220 221 func goSliceInitializer(data interface{}) (string, error) { 222 // goSliceInitializer constructs a Go literal initializer from interface{} literals. 223 // e.g. []interface{}{"a", "b"} is transformed in {"a","b",} 224 // e.g. map[string]interface{}{ "a": "x", "b": "y"} is transformed in {"a":"x","b":"y",}. 225 // 226 // NOTE: this is currently used to construct simple slice intializers for default values. 227 // This allows for nicer slice initializers for slices of primitive types and avoid systematic use for json.Unmarshal(). 228 b, err := json.Marshal(data) 229 if err != nil { 230 return "", err 231 } 232 return strings.Replace(strings.Replace(strings.Replace(string(b), "}", ",}", -1), "[", "{", -1), "]", ",}", -1), nil 233 } 234 235 // NewRepository creates a new template repository with the provided functions defined 236 func NewRepository(funcs template.FuncMap) *Repository { 237 repo := Repository{ 238 files: make(map[string]string), 239 templates: make(map[string]*template.Template), 240 funcs: funcs, 241 } 242 243 if repo.funcs == nil { 244 repo.funcs = make(template.FuncMap) 245 } 246 247 return &repo 248 } 249 250 // Repository is the repository for the generator templates 251 type Repository struct { 252 files map[string]string 253 templates map[string]*template.Template 254 funcs template.FuncMap 255 allowOverride bool 256 } 257 258 // LoadDefaults will load the embedded templates 259 func (t *Repository) LoadDefaults() { 260 261 for name, asset := range assets { 262 if err := t.addFile(name, string(asset), true); err != nil { 263 log.Fatal(err) 264 } 265 } 266 } 267 268 // LoadDir will walk the specified path and add each .gotmpl file it finds to the repository 269 func (t *Repository) LoadDir(templatePath string) error { 270 err := filepath.Walk(templatePath, func(path string, info os.FileInfo, err error) error { 271 272 if strings.HasSuffix(path, ".gotmpl") { 273 if assetName, e := filepath.Rel(templatePath, path); e == nil { 274 if data, e := ioutil.ReadFile(path); e == nil { 275 if ee := t.AddFile(assetName, string(data)); ee != nil { 276 // Fatality is decided by caller 277 // log.Fatal(ee) 278 return fmt.Errorf("could not add template: %v", ee) 279 } 280 } 281 // Non-readable files are skipped 282 } 283 } 284 if err != nil { 285 return err 286 } 287 // Non-template files are skipped 288 return nil 289 }) 290 if err != nil { 291 return fmt.Errorf("could not complete template processing in directory \"%s\": %v", templatePath, err) 292 } 293 return nil 294 } 295 296 // LoadContrib loads template from contrib directory 297 func (t *Repository) LoadContrib(name string) error { 298 log.Printf("loading contrib %s", name) 299 const pathPrefix = "templates/contrib/" 300 basePath := pathPrefix + name 301 filesAdded := 0 302 for _, aname := range AssetNames() { 303 if !strings.HasSuffix(aname, ".gotmpl") { 304 continue 305 } 306 if strings.HasPrefix(aname, basePath) { 307 target := aname[len(basePath)+1:] 308 err := t.addFile(target, string(MustAsset(aname)), true) 309 if err != nil { 310 return err 311 } 312 log.Printf("added contributed template %s from %s", target, aname) 313 filesAdded++ 314 } 315 } 316 if filesAdded == 0 { 317 return fmt.Errorf("no files added from template: %s", name) 318 } 319 return nil 320 } 321 322 func (t *Repository) addFile(name, data string, allowOverride bool) error { 323 fileName := name 324 name = swag.ToJSONName(strings.TrimSuffix(name, ".gotmpl")) 325 326 templ, err := template.New(name).Funcs(t.funcs).Parse(data) 327 328 if err != nil { 329 return fmt.Errorf("failed to load template %s: %v", name, err) 330 } 331 332 // check if any protected templates are defined 333 if !allowOverride && !t.allowOverride { 334 for _, template := range templ.Templates() { 335 if protectedTemplates[template.Name()] { 336 return fmt.Errorf("cannot overwrite protected template %s", template.Name()) 337 } 338 } 339 } 340 341 // Add each defined template into the cache 342 for _, template := range templ.Templates() { 343 344 t.files[template.Name()] = fileName 345 t.templates[template.Name()] = template.Lookup(template.Name()) 346 } 347 348 return nil 349 } 350 351 // MustGet a template by name, panics when fails 352 func (t *Repository) MustGet(name string) *template.Template { 353 tpl, err := t.Get(name) 354 if err != nil { 355 panic(err) 356 } 357 return tpl 358 } 359 360 // AddFile adds a file to the repository. It will create a new template based on the filename. 361 // It trims the .gotmpl from the end and converts the name using swag.ToJSONName. This will strip 362 // directory separators and Camelcase the next letter. 363 // e.g validation/primitive.gotmpl will become validationPrimitive 364 // 365 // If the file contains a definition for a template that is protected the whole file will not be added 366 func (t *Repository) AddFile(name, data string) error { 367 return t.addFile(name, data, false) 368 } 369 370 // SetAllowOverride allows setting allowOverride after the Repository was initialized 371 func (t *Repository) SetAllowOverride(value bool) { 372 t.allowOverride = value 373 } 374 375 func findDependencies(n parse.Node) []string { 376 377 var deps []string 378 depMap := make(map[string]bool) 379 380 if n == nil { 381 return deps 382 } 383 384 switch node := n.(type) { 385 case *parse.ListNode: 386 if node != nil && node.Nodes != nil { 387 for _, nn := range node.Nodes { 388 for _, dep := range findDependencies(nn) { 389 depMap[dep] = true 390 } 391 } 392 } 393 case *parse.IfNode: 394 for _, dep := range findDependencies(node.BranchNode.List) { 395 depMap[dep] = true 396 } 397 for _, dep := range findDependencies(node.BranchNode.ElseList) { 398 depMap[dep] = true 399 } 400 401 case *parse.RangeNode: 402 for _, dep := range findDependencies(node.BranchNode.List) { 403 depMap[dep] = true 404 } 405 for _, dep := range findDependencies(node.BranchNode.ElseList) { 406 depMap[dep] = true 407 } 408 409 case *parse.WithNode: 410 for _, dep := range findDependencies(node.BranchNode.List) { 411 depMap[dep] = true 412 } 413 for _, dep := range findDependencies(node.BranchNode.ElseList) { 414 depMap[dep] = true 415 } 416 417 case *parse.TemplateNode: 418 depMap[node.Name] = true 419 } 420 421 for dep := range depMap { 422 deps = append(deps, dep) 423 } 424 425 return deps 426 427 } 428 429 func (t *Repository) flattenDependencies(templ *template.Template, dependencies map[string]bool) map[string]bool { 430 if dependencies == nil { 431 dependencies = make(map[string]bool) 432 } 433 434 deps := findDependencies(templ.Tree.Root) 435 436 for _, d := range deps { 437 if _, found := dependencies[d]; !found { 438 439 dependencies[d] = true 440 441 if tt := t.templates[d]; tt != nil { 442 dependencies = t.flattenDependencies(tt, dependencies) 443 } 444 } 445 446 dependencies[d] = true 447 448 } 449 450 return dependencies 451 452 } 453 454 func (t *Repository) addDependencies(templ *template.Template) (*template.Template, error) { 455 456 name := templ.Name() 457 458 deps := t.flattenDependencies(templ, nil) 459 460 for dep := range deps { 461 462 if dep == "" { 463 continue 464 } 465 466 tt := templ.Lookup(dep) 467 468 // Check if we have it 469 if tt == nil { 470 tt = t.templates[dep] 471 472 // Still don't have it, return an error 473 if tt == nil { 474 return templ, fmt.Errorf("could not find template %s", dep) 475 } 476 var err error 477 478 // Add it to the parse tree 479 templ, err = templ.AddParseTree(dep, tt.Tree) 480 481 if err != nil { 482 return templ, fmt.Errorf("dependency error: %v", err) 483 } 484 485 } 486 } 487 return templ.Lookup(name), nil 488 } 489 490 // Get will return the named template from the repository, ensuring that all dependent templates are loaded. 491 // It will return an error if a dependent template is not defined in the repository. 492 func (t *Repository) Get(name string) (*template.Template, error) { 493 templ, found := t.templates[name] 494 495 if !found { 496 return templ, fmt.Errorf("template doesn't exist %s", name) 497 } 498 499 return t.addDependencies(templ) 500 } 501 502 // DumpTemplates prints out a dump of all the defined templates, where they are defined and what their dependencies are. 503 func (t *Repository) DumpTemplates() { 504 buf := bytes.NewBuffer(nil) 505 fmt.Fprintln(buf, "\n# Templates") 506 for name, templ := range t.templates { 507 fmt.Fprintf(buf, "## %s\n", name) 508 fmt.Fprintf(buf, "Defined in `%s`\n", t.files[name]) 509 510 if deps := findDependencies(templ.Tree.Root); len(deps) > 0 { 511 512 fmt.Fprintf(buf, "####requires \n - %v\n\n\n", strings.Join(deps, "\n - ")) 513 } 514 fmt.Fprintln(buf, "\n---") 515 } 516 log.Println(buf.String()) 517 }