github.com/6543-forks/go-swagger@v0.26.0/generator/template_repo.go (about)

     1  package generator
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"os"
     9  	"path"
    10  	"path/filepath"
    11  	"strings"
    12  	"text/template"
    13  	"text/template/parse"
    14  	"unicode"
    15  
    16  	"log"
    17  
    18  	"github.com/go-openapi/inflect"
    19  	"github.com/go-openapi/swag"
    20  	"github.com/kr/pretty"
    21  )
    22  
    23  var (
    24  	assets             map[string][]byte
    25  	protectedTemplates map[string]bool
    26  
    27  	// FuncMapFunc yields a map with all functions for templates
    28  	FuncMapFunc func(*LanguageOpts) template.FuncMap
    29  
    30  	templates *Repository
    31  )
    32  
    33  func initTemplateRepo() {
    34  	FuncMapFunc = DefaultFuncMap
    35  
    36  	// this makes the ToGoName func behave with the special
    37  	// prefixing rule above
    38  	swag.GoNamePrefixFunc = prefixForName
    39  
    40  	assets = defaultAssets()
    41  	protectedTemplates = defaultProtectedTemplates()
    42  	templates = NewRepository(FuncMapFunc(DefaultLanguageFunc()))
    43  }
    44  
    45  // DefaultFuncMap yields a map with default functions for use n the templates.
    46  // These are available in every template
    47  func DefaultFuncMap(lang *LanguageOpts) template.FuncMap {
    48  	return template.FuncMap(map[string]interface{}{
    49  		"pascalize": pascalize,
    50  		"camelize":  swag.ToJSONName,
    51  		"varname":   lang.MangleVarName,
    52  		"humanize":  swag.ToHumanNameLower,
    53  		"snakize":   lang.MangleFileName,
    54  		"toPackagePath": func(name string) string {
    55  			return filepath.FromSlash(lang.ManglePackagePath(name, ""))
    56  		},
    57  		"toPackage": func(name string) string {
    58  			return lang.ManglePackagePath(name, "")
    59  		},
    60  		"toPackageName": func(name string) string {
    61  			return lang.ManglePackageName(name, "")
    62  		},
    63  		"dasherize":          swag.ToCommandName,
    64  		"pluralizeFirstWord": pluralizeFirstWord,
    65  		"json":               asJSON,
    66  		"prettyjson":         asPrettyJSON,
    67  		"hasInsecure": func(arg []string) bool {
    68  			return swag.ContainsStringsCI(arg, "http") || swag.ContainsStringsCI(arg, "ws")
    69  		},
    70  		"hasSecure": func(arg []string) bool {
    71  			return swag.ContainsStringsCI(arg, "https") || swag.ContainsStringsCI(arg, "wss")
    72  		},
    73  		"dropPackage":      dropPackage,
    74  		"upper":            strings.ToUpper,
    75  		"contains":         swag.ContainsStrings,
    76  		"padSurround":      padSurround,
    77  		"joinFilePath":     filepath.Join,
    78  		"comment":          padComment,
    79  		"blockcomment":     blockComment,
    80  		"inspect":          pretty.Sprint,
    81  		"cleanPath":        path.Clean,
    82  		"mediaTypeName":    mediaMime,
    83  		"arrayInitializer": lang.arrayInitializer,
    84  		"hasPrefix":        strings.HasPrefix,
    85  		"stringContains":   strings.Contains,
    86  		"imports":          lang.imports,
    87  		"dict":             dict,
    88  	})
    89  }
    90  
    91  func defaultAssets() map[string][]byte {
    92  	return map[string][]byte{
    93  		// schema validation templates
    94  		"validation/primitive.gotmpl":    MustAsset("templates/validation/primitive.gotmpl"),
    95  		"validation/customformat.gotmpl": MustAsset("templates/validation/customformat.gotmpl"),
    96  		"validation/structfield.gotmpl":  MustAsset("templates/validation/structfield.gotmpl"),
    97  		"structfield.gotmpl":             MustAsset("templates/structfield.gotmpl"),
    98  		"schemavalidator.gotmpl":         MustAsset("templates/schemavalidator.gotmpl"),
    99  		"schemapolymorphic.gotmpl":       MustAsset("templates/schemapolymorphic.gotmpl"),
   100  		"schemaembedded.gotmpl":          MustAsset("templates/schemaembedded.gotmpl"),
   101  
   102  		// schema serialization templates
   103  		"additionalpropertiesserializer.gotmpl": MustAsset("templates/serializers/additionalpropertiesserializer.gotmpl"),
   104  		"aliasedserializer.gotmpl":              MustAsset("templates/serializers/aliasedserializer.gotmpl"),
   105  		"allofserializer.gotmpl":                MustAsset("templates/serializers/allofserializer.gotmpl"),
   106  		"basetypeserializer.gotmpl":             MustAsset("templates/serializers/basetypeserializer.gotmpl"),
   107  		"marshalbinaryserializer.gotmpl":        MustAsset("templates/serializers/marshalbinaryserializer.gotmpl"),
   108  		"schemaserializer.gotmpl":               MustAsset("templates/serializers/schemaserializer.gotmpl"),
   109  		"subtypeserializer.gotmpl":              MustAsset("templates/serializers/subtypeserializer.gotmpl"),
   110  		"tupleserializer.gotmpl":                MustAsset("templates/serializers/tupleserializer.gotmpl"),
   111  
   112  		// schema generation template
   113  		"docstring.gotmpl":  MustAsset("templates/docstring.gotmpl"),
   114  		"schematype.gotmpl": MustAsset("templates/schematype.gotmpl"),
   115  		"schemabody.gotmpl": MustAsset("templates/schemabody.gotmpl"),
   116  		"schema.gotmpl":     MustAsset("templates/schema.gotmpl"),
   117  		"model.gotmpl":      MustAsset("templates/model.gotmpl"),
   118  		"header.gotmpl":     MustAsset("templates/header.gotmpl"),
   119  
   120  		"swagger_json_embed.gotmpl": MustAsset("templates/swagger_json_embed.gotmpl"),
   121  
   122  		// server templates
   123  		"server/parameter.gotmpl":    MustAsset("templates/server/parameter.gotmpl"),
   124  		"server/urlbuilder.gotmpl":   MustAsset("templates/server/urlbuilder.gotmpl"),
   125  		"server/responses.gotmpl":    MustAsset("templates/server/responses.gotmpl"),
   126  		"server/operation.gotmpl":    MustAsset("templates/server/operation.gotmpl"),
   127  		"server/builder.gotmpl":      MustAsset("templates/server/builder.gotmpl"),
   128  		"server/server.gotmpl":       MustAsset("templates/server/server.gotmpl"),
   129  		"server/configureapi.gotmpl": MustAsset("templates/server/configureapi.gotmpl"),
   130  		"server/main.gotmpl":         MustAsset("templates/server/main.gotmpl"),
   131  		"server/doc.gotmpl":          MustAsset("templates/server/doc.gotmpl"),
   132  
   133  		// client templates
   134  		"client/parameter.gotmpl": MustAsset("templates/client/parameter.gotmpl"),
   135  		"client/response.gotmpl":  MustAsset("templates/client/response.gotmpl"),
   136  		"client/client.gotmpl":    MustAsset("templates/client/client.gotmpl"),
   137  		"client/facade.gotmpl":    MustAsset("templates/client/facade.gotmpl"),
   138  	}
   139  }
   140  
   141  func defaultProtectedTemplates() map[string]bool {
   142  	return map[string]bool{
   143  		"dereffedSchemaType":          true,
   144  		"docstring":                   true,
   145  		"header":                      true,
   146  		"mapvalidator":                true,
   147  		"model":                       true,
   148  		"modelvalidator":              true,
   149  		"objectvalidator":             true,
   150  		"primitivefieldvalidator":     true,
   151  		"privstructfield":             true,
   152  		"privtuplefield":              true,
   153  		"propertyValidationDocString": true,
   154  		"propertyvalidator":           true,
   155  		"schema":                      true,
   156  		"schemaBody":                  true,
   157  		"schemaType":                  true,
   158  		"schemabody":                  true,
   159  		"schematype":                  true,
   160  		"schemavalidator":             true,
   161  		"serverDoc":                   true,
   162  		"slicevalidator":              true,
   163  		"structfield":                 true,
   164  		"structfieldIface":            true,
   165  		"subTypeBody":                 true,
   166  		"swaggerJsonEmbed":            true,
   167  		"tuplefield":                  true,
   168  		"tuplefieldIface":             true,
   169  		"typeSchemaType":              true,
   170  		"validationCustomformat":      true,
   171  		"validationPrimitive":         true,
   172  		"validationStructfield":       true,
   173  		"withBaseTypeBody":            true,
   174  		"withoutBaseTypeBody":         true,
   175  
   176  		// all serializers TODO(fred)
   177  		"additionalPropertiesSerializer": true,
   178  		"tupleSerializer":                true,
   179  		"schemaSerializer":               true,
   180  		"hasDiscriminatedSerializer":     true,
   181  		"discriminatedSerializer":        true,
   182  	}
   183  }
   184  
   185  // AddFile adds a file to the default repository. It will create a new template based on the filename.
   186  // It trims the .gotmpl from the end and converts the name using swag.ToJSONName. This will strip
   187  // directory separators and Camelcase the next letter.
   188  // e.g validation/primitive.gotmpl will become validationPrimitive
   189  //
   190  // If the file contains a definition for a template that is protected the whole file will not be added
   191  func AddFile(name, data string) error {
   192  	return templates.addFile(name, data, false)
   193  }
   194  
   195  // NewRepository creates a new template repository with the provided functions defined
   196  func NewRepository(funcs template.FuncMap) *Repository {
   197  	repo := Repository{
   198  		files:     make(map[string]string),
   199  		templates: make(map[string]*template.Template),
   200  		funcs:     funcs,
   201  	}
   202  
   203  	if repo.funcs == nil {
   204  		repo.funcs = make(template.FuncMap)
   205  	}
   206  
   207  	return &repo
   208  }
   209  
   210  // Repository is the repository for the generator templates
   211  type Repository struct {
   212  	files         map[string]string
   213  	templates     map[string]*template.Template
   214  	funcs         template.FuncMap
   215  	allowOverride bool
   216  }
   217  
   218  // LoadDefaults will load the embedded templates
   219  func (t *Repository) LoadDefaults() {
   220  
   221  	for name, asset := range assets {
   222  		if err := t.addFile(name, string(asset), true); err != nil {
   223  			log.Fatal(err)
   224  		}
   225  	}
   226  }
   227  
   228  // LoadDir will walk the specified path and add each .gotmpl file it finds to the repository
   229  func (t *Repository) LoadDir(templatePath string) error {
   230  	err := filepath.Walk(templatePath, func(path string, info os.FileInfo, err error) error {
   231  
   232  		if strings.HasSuffix(path, ".gotmpl") {
   233  			if assetName, e := filepath.Rel(templatePath, path); e == nil {
   234  				if data, e := ioutil.ReadFile(path); e == nil {
   235  					if ee := t.AddFile(assetName, string(data)); ee != nil {
   236  						return fmt.Errorf("could not add template: %v", ee)
   237  					}
   238  				}
   239  				// Non-readable files are skipped
   240  			}
   241  		}
   242  		if err != nil {
   243  			return err
   244  		}
   245  		// Non-template files are skipped
   246  		return nil
   247  	})
   248  	if err != nil {
   249  		return fmt.Errorf("could not complete template processing in directory \"%s\": %v", templatePath, err)
   250  	}
   251  	return nil
   252  }
   253  
   254  // LoadContrib loads template from contrib directory
   255  func (t *Repository) LoadContrib(name string) error {
   256  	log.Printf("loading contrib %s", name)
   257  	const pathPrefix = "templates/contrib/"
   258  	basePath := pathPrefix + name
   259  	filesAdded := 0
   260  	for _, aname := range AssetNames() {
   261  		if !strings.HasSuffix(aname, ".gotmpl") {
   262  			continue
   263  		}
   264  		if strings.HasPrefix(aname, basePath) {
   265  			target := aname[len(basePath)+1:]
   266  			err := t.addFile(target, string(MustAsset(aname)), true)
   267  			if err != nil {
   268  				return err
   269  			}
   270  			log.Printf("added contributed template %s from %s", target, aname)
   271  			filesAdded++
   272  		}
   273  	}
   274  	if filesAdded == 0 {
   275  		return fmt.Errorf("no files added from template: %s", name)
   276  	}
   277  	return nil
   278  }
   279  
   280  func (t *Repository) addFile(name, data string, allowOverride bool) error {
   281  	fileName := name
   282  	name = swag.ToJSONName(strings.TrimSuffix(name, ".gotmpl"))
   283  
   284  	templ, err := template.New(name).Funcs(t.funcs).Parse(data)
   285  
   286  	if err != nil {
   287  		return fmt.Errorf("failed to load template %s: %v", name, err)
   288  	}
   289  
   290  	// check if any protected templates are defined
   291  	if !allowOverride && !t.allowOverride {
   292  		for _, template := range templ.Templates() {
   293  			if protectedTemplates[template.Name()] {
   294  				return fmt.Errorf("cannot overwrite protected template %s", template.Name())
   295  			}
   296  		}
   297  	}
   298  
   299  	// Add each defined template into the cache
   300  	for _, template := range templ.Templates() {
   301  
   302  		t.files[template.Name()] = fileName
   303  		t.templates[template.Name()] = template.Lookup(template.Name())
   304  	}
   305  
   306  	return nil
   307  }
   308  
   309  // MustGet a template by name, panics when fails
   310  func (t *Repository) MustGet(name string) *template.Template {
   311  	tpl, err := t.Get(name)
   312  	if err != nil {
   313  		panic(err)
   314  	}
   315  	return tpl
   316  }
   317  
   318  // AddFile adds a file to the repository. It will create a new template based on the filename.
   319  // It trims the .gotmpl from the end and converts the name using swag.ToJSONName. This will strip
   320  // directory separators and Camelcase the next letter.
   321  // e.g validation/primitive.gotmpl will become validationPrimitive
   322  //
   323  // If the file contains a definition for a template that is protected the whole file will not be added
   324  func (t *Repository) AddFile(name, data string) error {
   325  	return t.addFile(name, data, false)
   326  }
   327  
   328  // SetAllowOverride allows setting allowOverride after the Repository was initialized
   329  func (t *Repository) SetAllowOverride(value bool) {
   330  	t.allowOverride = value
   331  }
   332  
   333  func findDependencies(n parse.Node) []string {
   334  
   335  	var deps []string
   336  	depMap := make(map[string]bool)
   337  
   338  	if n == nil {
   339  		return deps
   340  	}
   341  
   342  	switch node := n.(type) {
   343  	case *parse.ListNode:
   344  		if node != nil && node.Nodes != nil {
   345  			for _, nn := range node.Nodes {
   346  				for _, dep := range findDependencies(nn) {
   347  					depMap[dep] = true
   348  				}
   349  			}
   350  		}
   351  	case *parse.IfNode:
   352  		for _, dep := range findDependencies(node.BranchNode.List) {
   353  			depMap[dep] = true
   354  		}
   355  		for _, dep := range findDependencies(node.BranchNode.ElseList) {
   356  			depMap[dep] = true
   357  		}
   358  
   359  	case *parse.RangeNode:
   360  		for _, dep := range findDependencies(node.BranchNode.List) {
   361  			depMap[dep] = true
   362  		}
   363  		for _, dep := range findDependencies(node.BranchNode.ElseList) {
   364  			depMap[dep] = true
   365  		}
   366  
   367  	case *parse.WithNode:
   368  		for _, dep := range findDependencies(node.BranchNode.List) {
   369  			depMap[dep] = true
   370  		}
   371  		for _, dep := range findDependencies(node.BranchNode.ElseList) {
   372  			depMap[dep] = true
   373  		}
   374  
   375  	case *parse.TemplateNode:
   376  		depMap[node.Name] = true
   377  	}
   378  
   379  	for dep := range depMap {
   380  		deps = append(deps, dep)
   381  	}
   382  
   383  	return deps
   384  
   385  }
   386  
   387  func (t *Repository) flattenDependencies(templ *template.Template, dependencies map[string]bool) map[string]bool {
   388  	if dependencies == nil {
   389  		dependencies = make(map[string]bool)
   390  	}
   391  
   392  	deps := findDependencies(templ.Tree.Root)
   393  
   394  	for _, d := range deps {
   395  		if _, found := dependencies[d]; !found {
   396  
   397  			dependencies[d] = true
   398  
   399  			if tt := t.templates[d]; tt != nil {
   400  				dependencies = t.flattenDependencies(tt, dependencies)
   401  			}
   402  		}
   403  
   404  		dependencies[d] = true
   405  
   406  	}
   407  
   408  	return dependencies
   409  
   410  }
   411  
   412  func (t *Repository) addDependencies(templ *template.Template) (*template.Template, error) {
   413  
   414  	name := templ.Name()
   415  
   416  	deps := t.flattenDependencies(templ, nil)
   417  
   418  	for dep := range deps {
   419  
   420  		if dep == "" {
   421  			continue
   422  		}
   423  
   424  		tt := templ.Lookup(dep)
   425  
   426  		// Check if we have it
   427  		if tt == nil {
   428  			tt = t.templates[dep]
   429  
   430  			// Still don't have it, return an error
   431  			if tt == nil {
   432  				return templ, fmt.Errorf("could not find template %s", dep)
   433  			}
   434  			var err error
   435  
   436  			// Add it to the parse tree
   437  			templ, err = templ.AddParseTree(dep, tt.Tree)
   438  
   439  			if err != nil {
   440  				return templ, fmt.Errorf("dependency error: %v", err)
   441  			}
   442  
   443  		}
   444  	}
   445  	return templ.Lookup(name), nil
   446  }
   447  
   448  // Get will return the named template from the repository, ensuring that all dependent templates are loaded.
   449  // It will return an error if a dependent template is not defined in the repository.
   450  func (t *Repository) Get(name string) (*template.Template, error) {
   451  	templ, found := t.templates[name]
   452  
   453  	if !found {
   454  		return templ, fmt.Errorf("template doesn't exist %s", name)
   455  	}
   456  
   457  	return t.addDependencies(templ)
   458  }
   459  
   460  // DumpTemplates prints out a dump of all the defined templates, where they are defined and what their dependencies are.
   461  func (t *Repository) DumpTemplates() {
   462  	buf := bytes.NewBuffer(nil)
   463  	fmt.Fprintln(buf, "\n# Templates")
   464  	for name, templ := range t.templates {
   465  		fmt.Fprintf(buf, "## %s\n", name)
   466  		fmt.Fprintf(buf, "Defined in `%s`\n", t.files[name])
   467  
   468  		if deps := findDependencies(templ.Tree.Root); len(deps) > 0 {
   469  
   470  			fmt.Fprintf(buf, "####requires \n - %v\n\n\n", strings.Join(deps, "\n - "))
   471  		}
   472  		fmt.Fprintln(buf, "\n---")
   473  	}
   474  	log.Println(buf.String())
   475  }
   476  
   477  // FuncMap functions
   478  
   479  func asJSON(data interface{}) (string, error) {
   480  	b, err := json.Marshal(data)
   481  	if err != nil {
   482  		return "", err
   483  	}
   484  	return string(b), nil
   485  }
   486  
   487  func asPrettyJSON(data interface{}) (string, error) {
   488  	b, err := json.MarshalIndent(data, "", "  ")
   489  	if err != nil {
   490  		return "", err
   491  	}
   492  	return string(b), nil
   493  }
   494  
   495  func pluralizeFirstWord(arg string) string {
   496  	sentence := strings.Split(arg, " ")
   497  	if len(sentence) == 1 {
   498  		return inflect.Pluralize(arg)
   499  	}
   500  
   501  	return inflect.Pluralize(sentence[0]) + " " + strings.Join(sentence[1:], " ")
   502  }
   503  
   504  func dropPackage(str string) string {
   505  	parts := strings.Split(str, ".")
   506  	return parts[len(parts)-1]
   507  }
   508  
   509  func padSurround(entry, padWith string, i, ln int) string {
   510  	var res []string
   511  	if i > 0 {
   512  		for j := 0; j < i; j++ {
   513  			res = append(res, padWith)
   514  		}
   515  	}
   516  	res = append(res, entry)
   517  	tot := ln - i - 1
   518  	for j := 0; j < tot; j++ {
   519  		res = append(res, padWith)
   520  	}
   521  	return strings.Join(res, ",")
   522  }
   523  
   524  func padComment(str string, pads ...string) string {
   525  	// pads specifes padding to indent multi line comments.Defaults to one space
   526  	pad := " "
   527  	lines := strings.Split(str, "\n")
   528  	if len(pads) > 0 {
   529  		pad = strings.Join(pads, "")
   530  	}
   531  	return (strings.Join(lines, "\n//"+pad))
   532  }
   533  
   534  func blockComment(str string) string {
   535  	return strings.Replace(str, "*/", "[*]/", -1)
   536  }
   537  
   538  func pascalize(arg string) string {
   539  	runes := []rune(arg)
   540  	switch len(runes) {
   541  	case 0:
   542  		return "Empty"
   543  	case 1: // handle special case when we have a single rune that is not handled by swag.ToGoName
   544  		switch runes[0] {
   545  		case '+', '-', '#', '_': // those cases are handled differently than swag utility
   546  			return prefixForName(arg)
   547  		}
   548  	}
   549  	return swag.ToGoName(swag.ToGoName(arg)) // want to remove spaces
   550  }
   551  
   552  func prefixForName(arg string) string {
   553  	first := []rune(arg)[0]
   554  	if len(arg) == 0 || unicode.IsLetter(first) {
   555  		return ""
   556  	}
   557  	switch first {
   558  	case '+':
   559  		return "Plus"
   560  	case '-':
   561  		return "Minus"
   562  	case '#':
   563  		return "HashTag"
   564  		// other cases ($,@ etc..) handled by swag.ToGoName
   565  	}
   566  	return "Nr"
   567  }
   568  
   569  func dict(values ...interface{}) (map[string]interface{}, error) {
   570  	if len(values)%2 != 0 {
   571  		return nil, fmt.Errorf("expected even number of arguments, got %d", len(values))
   572  	}
   573  	dict := make(map[string]interface{}, len(values)/2)
   574  	for i := 0; i < len(values); i += 2 {
   575  		key, ok := values[i].(string)
   576  		if !ok {
   577  			return nil, fmt.Errorf("expected string key, got %+v", values[i])
   578  		}
   579  		dict[key] = values[i+1]
   580  	}
   581  	return dict, nil
   582  }