github.com/livermorium/go-swagger@v0.19.0/generator/template_repo.go (about)

     1  package generator
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"os"
     9  	"path"
    10  	"path/filepath"
    11  	"strings"
    12  	"text/template"
    13  	"text/template/parse"
    14  
    15  	"log"
    16  
    17  	"github.com/go-openapi/inflect"
    18  	"github.com/go-openapi/swag"
    19  	"github.com/kr/pretty"
    20  )
    21  
    22  var templates *Repository
    23  
    24  // FuncMap is a map with default functions for use n the templates.
    25  // These are available in every template
    26  var FuncMap template.FuncMap = map[string]interface{}{
    27  	"pascalize": pascalize,
    28  	"camelize":  swag.ToJSONName,
    29  	"varname":   golang.MangleVarName,
    30  	"humanize":  swag.ToHumanNameLower,
    31  	"snakize":   golang.MangleFileName,
    32  	"toPackagePath": func(name string) string {
    33  		return filepath.FromSlash(golang.ManglePackagePath(name, ""))
    34  	},
    35  	"toPackage": func(name string) string {
    36  		return golang.ManglePackagePath(name, "")
    37  	},
    38  	"toPackageName": func(name string) string {
    39  		return golang.ManglePackageName(name, "")
    40  	},
    41  	"dasherize": swag.ToCommandName,
    42  	"pluralizeFirstWord": func(arg string) string {
    43  		sentence := strings.Split(arg, " ")
    44  		if len(sentence) == 1 {
    45  			return inflect.Pluralize(arg)
    46  		}
    47  
    48  		return inflect.Pluralize(sentence[0]) + " " + strings.Join(sentence[1:], " ")
    49  	},
    50  	"json":       asJSON,
    51  	"prettyjson": asPrettyJSON,
    52  	"hasInsecure": func(arg []string) bool {
    53  		return swag.ContainsStringsCI(arg, "http") || swag.ContainsStringsCI(arg, "ws")
    54  	},
    55  	"hasSecure": func(arg []string) bool {
    56  		return swag.ContainsStringsCI(arg, "https") || swag.ContainsStringsCI(arg, "wss")
    57  	},
    58  	// TODO: simplify redundant functions
    59  	"stripPackage": func(str, pkg string) string {
    60  		parts := strings.Split(str, ".")
    61  		strlen := len(parts)
    62  		if strlen > 0 {
    63  			return parts[strlen-1]
    64  		}
    65  		return str
    66  	},
    67  	"dropPackage": func(str string) string {
    68  		parts := strings.Split(str, ".")
    69  		strlen := len(parts)
    70  		if strlen > 0 {
    71  			return parts[strlen-1]
    72  		}
    73  		return str
    74  	},
    75  	"upper": func(str string) string {
    76  		return strings.ToUpper(str)
    77  	},
    78  	"contains": func(coll []string, arg string) bool {
    79  		for _, v := range coll {
    80  			if v == arg {
    81  				return true
    82  			}
    83  		}
    84  		return false
    85  	},
    86  	"padSurround": func(entry, padWith string, i, ln int) string {
    87  		var res []string
    88  		if i > 0 {
    89  			for j := 0; j < i; j++ {
    90  				res = append(res, padWith)
    91  			}
    92  		}
    93  		res = append(res, entry)
    94  		tot := ln - i - 1
    95  		for j := 0; j < tot; j++ {
    96  			res = append(res, padWith)
    97  		}
    98  		return strings.Join(res, ",")
    99  	},
   100  	"joinFilePath": filepath.Join,
   101  	"comment": func(str string) string {
   102  		lines := strings.Split(str, "\n")
   103  		return (strings.Join(lines, "\n// "))
   104  	},
   105  	"blockcomment": func(str string) string {
   106  		return strings.Replace(str, "*/", "[*]/", -1)
   107  	},
   108  	"inspect":   pretty.Sprint,
   109  	"cleanPath": path.Clean,
   110  	"mediaTypeName": func(orig string) string {
   111  		return strings.SplitN(orig, ";", 2)[0]
   112  	},
   113  	"goSliceInitializer": goSliceInitializer,
   114  	"hasPrefix":          strings.HasPrefix,
   115  	"stringContains":     strings.Contains,
   116  }
   117  
   118  func init() {
   119  	templates = NewRepository(FuncMap)
   120  }
   121  
   122  var assets = map[string][]byte{
   123  	"validation/primitive.gotmpl":           MustAsset("templates/validation/primitive.gotmpl"),
   124  	"validation/customformat.gotmpl":        MustAsset("templates/validation/customformat.gotmpl"),
   125  	"docstring.gotmpl":                      MustAsset("templates/docstring.gotmpl"),
   126  	"validation/structfield.gotmpl":         MustAsset("templates/validation/structfield.gotmpl"),
   127  	"modelvalidator.gotmpl":                 MustAsset("templates/modelvalidator.gotmpl"),
   128  	"structfield.gotmpl":                    MustAsset("templates/structfield.gotmpl"),
   129  	"tupleserializer.gotmpl":                MustAsset("templates/tupleserializer.gotmpl"),
   130  	"additionalpropertiesserializer.gotmpl": MustAsset("templates/additionalpropertiesserializer.gotmpl"),
   131  	"schematype.gotmpl":                     MustAsset("templates/schematype.gotmpl"),
   132  	"schemabody.gotmpl":                     MustAsset("templates/schemabody.gotmpl"),
   133  	"schema.gotmpl":                         MustAsset("templates/schema.gotmpl"),
   134  	"schemavalidator.gotmpl":                MustAsset("templates/schemavalidator.gotmpl"),
   135  	"model.gotmpl":                          MustAsset("templates/model.gotmpl"),
   136  	"header.gotmpl":                         MustAsset("templates/header.gotmpl"),
   137  	"swagger_json_embed.gotmpl":             MustAsset("templates/swagger_json_embed.gotmpl"),
   138  
   139  	"server/parameter.gotmpl":    MustAsset("templates/server/parameter.gotmpl"),
   140  	"server/urlbuilder.gotmpl":   MustAsset("templates/server/urlbuilder.gotmpl"),
   141  	"server/responses.gotmpl":    MustAsset("templates/server/responses.gotmpl"),
   142  	"server/operation.gotmpl":    MustAsset("templates/server/operation.gotmpl"),
   143  	"server/builder.gotmpl":      MustAsset("templates/server/builder.gotmpl"),
   144  	"server/server.gotmpl":       MustAsset("templates/server/server.gotmpl"),
   145  	"server/configureapi.gotmpl": MustAsset("templates/server/configureapi.gotmpl"),
   146  	"server/main.gotmpl":         MustAsset("templates/server/main.gotmpl"),
   147  	"server/doc.gotmpl":          MustAsset("templates/server/doc.gotmpl"),
   148  
   149  	"client/parameter.gotmpl": MustAsset("templates/client/parameter.gotmpl"),
   150  	"client/response.gotmpl":  MustAsset("templates/client/response.gotmpl"),
   151  	"client/client.gotmpl":    MustAsset("templates/client/client.gotmpl"),
   152  	"client/facade.gotmpl":    MustAsset("templates/client/facade.gotmpl"),
   153  }
   154  
   155  var protectedTemplates = map[string]bool{
   156  	"schemabody":                     true,
   157  	"privtuplefield":                 true,
   158  	"withoutBaseTypeBody":            true,
   159  	"swaggerJsonEmbed":               true,
   160  	"validationCustomformat":         true,
   161  	"tuplefield":                     true,
   162  	"header":                         true,
   163  	"withBaseTypeBody":               true,
   164  	"primitivefieldvalidator":        true,
   165  	"mapvalidator":                   true,
   166  	"propertyValidationDocString":    true,
   167  	"typeSchemaType":                 true,
   168  	"docstring":                      true,
   169  	"dereffedSchemaType":             true,
   170  	"model":                          true,
   171  	"modelvalidator":                 true,
   172  	"privstructfield":                true,
   173  	"schemavalidator":                true,
   174  	"tuplefieldIface":                true,
   175  	"tupleSerializer":                true,
   176  	"tupleserializer":                true,
   177  	"schemaSerializer":               true,
   178  	"propertyvalidator":              true,
   179  	"structfieldIface":               true,
   180  	"schemaBody":                     true,
   181  	"objectvalidator":                true,
   182  	"schematype":                     true,
   183  	"additionalpropertiesserializer": true,
   184  	"slicevalidator":                 true,
   185  	"validationStructfield":          true,
   186  	"validationPrimitive":            true,
   187  	"schemaType":                     true,
   188  	"subTypeBody":                    true,
   189  	"schema":                         true,
   190  	"additionalPropertiesSerializer": true,
   191  	"serverDoc":                      true,
   192  	"structfield":                    true,
   193  	"hasDiscriminatedSerializer":     true,
   194  	"discriminatedSerializer":        true,
   195  }
   196  
   197  // AddFile adds a file to the default repository. It will create a new template based on the filename.
   198  // It trims the .gotmpl from the end and converts the name using swag.ToJSONName. This will strip
   199  // directory separators and Camelcase the next letter.
   200  // e.g validation/primitive.gotmpl will become validationPrimitive
   201  //
   202  // If the file contains a definition for a template that is protected the whole file will not be added
   203  func AddFile(name, data string) error {
   204  	return templates.addFile(name, data, false)
   205  }
   206  
   207  func asJSON(data interface{}) (string, error) {
   208  	b, err := json.Marshal(data)
   209  	if err != nil {
   210  		return "", err
   211  	}
   212  	return string(b), nil
   213  }
   214  
   215  func asPrettyJSON(data interface{}) (string, error) {
   216  	b, err := json.MarshalIndent(data, "", "  ")
   217  	if err != nil {
   218  		return "", err
   219  	}
   220  	return string(b), nil
   221  }
   222  
   223  func goSliceInitializer(data interface{}) (string, error) {
   224  	// goSliceInitializer constructs a Go literal initializer from interface{} literals.
   225  	// e.g. []interface{}{"a", "b"} is transformed in {"a","b",}
   226  	// e.g. map[string]interface{}{ "a": "x", "b": "y"} is transformed in {"a":"x","b":"y",}.
   227  	//
   228  	// NOTE: this is currently used to construct simple slice intializers for default values.
   229  	// This allows for nicer slice initializers for slices of primitive types and avoid systematic use for json.Unmarshal().
   230  	b, err := json.Marshal(data)
   231  	if err != nil {
   232  		return "", err
   233  	}
   234  	return strings.Replace(strings.Replace(strings.Replace(string(b), "}", ",}", -1), "[", "{", -1), "]", ",}", -1), nil
   235  }
   236  
   237  // NewRepository creates a new template repository with the provided functions defined
   238  func NewRepository(funcs template.FuncMap) *Repository {
   239  	repo := Repository{
   240  		files:     make(map[string]string),
   241  		templates: make(map[string]*template.Template),
   242  		funcs:     funcs,
   243  	}
   244  
   245  	if repo.funcs == nil {
   246  		repo.funcs = make(template.FuncMap)
   247  	}
   248  
   249  	return &repo
   250  }
   251  
   252  // Repository is the repository for the generator templates
   253  type Repository struct {
   254  	files     map[string]string
   255  	templates map[string]*template.Template
   256  	funcs     template.FuncMap
   257  }
   258  
   259  // LoadDefaults will load the embedded templates
   260  func (t *Repository) LoadDefaults() {
   261  
   262  	for name, asset := range assets {
   263  		if err := t.addFile(name, string(asset), true); err != nil {
   264  			log.Fatal(err)
   265  		}
   266  	}
   267  }
   268  
   269  // LoadDir will walk the specified path and add each .gotmpl file it finds to the repository
   270  func (t *Repository) LoadDir(templatePath string) error {
   271  	err := filepath.Walk(templatePath, func(path string, info os.FileInfo, err error) error {
   272  
   273  		if strings.HasSuffix(path, ".gotmpl") {
   274  			if assetName, e := filepath.Rel(templatePath, path); e == nil {
   275  				if data, e := ioutil.ReadFile(path); e == nil {
   276  					if ee := t.AddFile(assetName, string(data)); ee != nil {
   277  						// Fatality is decided by caller
   278  						// log.Fatal(ee)
   279  						return fmt.Errorf("Could not add template: %v", ee)
   280  					}
   281  				}
   282  				// Non-readable files are skipped
   283  			}
   284  		}
   285  		if err != nil {
   286  			return err
   287  		}
   288  		// Non-template files are skipped
   289  		return nil
   290  	})
   291  	if err != nil {
   292  		return fmt.Errorf("Could not complete template processing in directory \"%s\": %v", templatePath, err)
   293  	}
   294  	return nil
   295  }
   296  
   297  // LoadContrib loads template from contrib directory
   298  func (t *Repository) LoadContrib(name string) error {
   299  	log.Printf("loading contrib %s", name)
   300  	const pathPrefix = "templates/contrib/"
   301  	basePath := pathPrefix + name
   302  	filesAdded := 0
   303  	for _, aname := range AssetNames() {
   304  		if !strings.HasSuffix(aname, ".gotmpl") {
   305  			continue
   306  		}
   307  		if strings.HasPrefix(aname, basePath) {
   308  			target := aname[len(basePath)+1:]
   309  			err := t.addFile(target, string(MustAsset(aname)), true)
   310  			if err != nil {
   311  				return err
   312  			}
   313  			log.Printf("added contributed template %s from %s", target, aname)
   314  			filesAdded++
   315  		}
   316  	}
   317  	if filesAdded == 0 {
   318  		return fmt.Errorf("no files added from template: %s", name)
   319  	}
   320  	return nil
   321  }
   322  
   323  func (t *Repository) addFile(name, data string, allowOverride bool) error {
   324  	fileName := name
   325  	name = swag.ToJSONName(strings.TrimSuffix(name, ".gotmpl"))
   326  
   327  	templ, err := template.New(name).Funcs(t.funcs).Parse(data)
   328  
   329  	if err != nil {
   330  		return fmt.Errorf("Failed to load template %s: %v", name, err)
   331  	}
   332  
   333  	// check if any protected templates are defined
   334  	if !allowOverride {
   335  		for _, template := range templ.Templates() {
   336  			if protectedTemplates[template.Name()] {
   337  				return fmt.Errorf("Cannot overwrite protected template %s", template.Name())
   338  			}
   339  		}
   340  	}
   341  
   342  	// Add each defined template into the cache
   343  	for _, template := range templ.Templates() {
   344  
   345  		t.files[template.Name()] = fileName
   346  		t.templates[template.Name()] = template.Lookup(template.Name())
   347  	}
   348  
   349  	return nil
   350  }
   351  
   352  // MustGet a template by name, panics when fails
   353  func (t *Repository) MustGet(name string) *template.Template {
   354  	tpl, err := t.Get(name)
   355  	if err != nil {
   356  		panic(err)
   357  	}
   358  	return tpl
   359  }
   360  
   361  // AddFile adds a file to the repository. It will create a new template based on the filename.
   362  // It trims the .gotmpl from the end and converts the name using swag.ToJSONName. This will strip
   363  // directory separators and Camelcase the next letter.
   364  // e.g validation/primitive.gotmpl will become validationPrimitive
   365  //
   366  // If the file contains a definition for a template that is protected the whole file will not be added
   367  func (t *Repository) AddFile(name, data string) error {
   368  	return t.addFile(name, data, false)
   369  }
   370  
   371  func findDependencies(n parse.Node) []string {
   372  
   373  	var deps []string
   374  	depMap := make(map[string]bool)
   375  
   376  	if n == nil {
   377  		return deps
   378  	}
   379  
   380  	switch node := n.(type) {
   381  	case *parse.ListNode:
   382  		if node != nil && node.Nodes != nil {
   383  			for _, nn := range node.Nodes {
   384  				for _, dep := range findDependencies(nn) {
   385  					depMap[dep] = true
   386  				}
   387  			}
   388  		}
   389  	case *parse.IfNode:
   390  		for _, dep := range findDependencies(node.BranchNode.List) {
   391  			depMap[dep] = true
   392  		}
   393  		for _, dep := range findDependencies(node.BranchNode.ElseList) {
   394  			depMap[dep] = true
   395  		}
   396  
   397  	case *parse.RangeNode:
   398  		for _, dep := range findDependencies(node.BranchNode.List) {
   399  			depMap[dep] = true
   400  		}
   401  		for _, dep := range findDependencies(node.BranchNode.ElseList) {
   402  			depMap[dep] = true
   403  		}
   404  
   405  	case *parse.WithNode:
   406  		for _, dep := range findDependencies(node.BranchNode.List) {
   407  			depMap[dep] = true
   408  		}
   409  		for _, dep := range findDependencies(node.BranchNode.ElseList) {
   410  			depMap[dep] = true
   411  		}
   412  
   413  	case *parse.TemplateNode:
   414  		depMap[node.Name] = true
   415  	}
   416  
   417  	for dep := range depMap {
   418  		deps = append(deps, dep)
   419  	}
   420  
   421  	return deps
   422  
   423  }
   424  
   425  func (t *Repository) flattenDependencies(templ *template.Template, dependencies map[string]bool) map[string]bool {
   426  	if dependencies == nil {
   427  		dependencies = make(map[string]bool)
   428  	}
   429  
   430  	deps := findDependencies(templ.Tree.Root)
   431  
   432  	for _, d := range deps {
   433  		if _, found := dependencies[d]; !found {
   434  
   435  			dependencies[d] = true
   436  
   437  			if tt := t.templates[d]; tt != nil {
   438  				dependencies = t.flattenDependencies(tt, dependencies)
   439  			}
   440  		}
   441  
   442  		dependencies[d] = true
   443  
   444  	}
   445  
   446  	return dependencies
   447  
   448  }
   449  
   450  func (t *Repository) addDependencies(templ *template.Template) (*template.Template, error) {
   451  
   452  	name := templ.Name()
   453  
   454  	deps := t.flattenDependencies(templ, nil)
   455  
   456  	for dep := range deps {
   457  
   458  		if dep == "" {
   459  			continue
   460  		}
   461  
   462  		tt := templ.Lookup(dep)
   463  
   464  		// Check if we have it
   465  		if tt == nil {
   466  			tt = t.templates[dep]
   467  
   468  			// Still don't have it, return an error
   469  			if tt == nil {
   470  				return templ, fmt.Errorf("Could not find template %s", dep)
   471  			}
   472  			var err error
   473  
   474  			// Add it to the parse tree
   475  			templ, err = templ.AddParseTree(dep, tt.Tree)
   476  
   477  			if err != nil {
   478  				return templ, fmt.Errorf("Dependency Error: %v", err)
   479  			}
   480  
   481  		}
   482  	}
   483  	return templ.Lookup(name), nil
   484  }
   485  
   486  // Get will return the named template from the repository, ensuring that all dependent templates are loaded.
   487  // It will return an error if a dependent template is not defined in the repository.
   488  func (t *Repository) Get(name string) (*template.Template, error) {
   489  	templ, found := t.templates[name]
   490  
   491  	if !found {
   492  		return templ, fmt.Errorf("Template doesn't exist %s", name)
   493  	}
   494  
   495  	return t.addDependencies(templ)
   496  }
   497  
   498  // DumpTemplates prints out a dump of all the defined templates, where they are defined and what their dependencies are.
   499  func (t *Repository) DumpTemplates() {
   500  	buf := bytes.NewBuffer(nil)
   501  	fmt.Fprintln(buf, "\n# Templates")
   502  	for name, templ := range t.templates {
   503  		fmt.Fprintf(buf, "## %s\n", name)
   504  		fmt.Fprintf(buf, "Defined in `%s`\n", t.files[name])
   505  
   506  		if deps := findDependencies(templ.Tree.Root); len(deps) > 0 {
   507  
   508  			fmt.Fprintf(buf, "####requires \n - %v\n\n\n", strings.Join(deps, "\n - "))
   509  		}
   510  		fmt.Fprintln(buf, "\n---")
   511  	}
   512  	log.Println(buf.String())
   513  }