github.com/emreu/go-swagger@v0.22.1/generator/template_repo.go (about)

     1  package generator
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"os"
     9  	"path"
    10  	"path/filepath"
    11  	"strings"
    12  	"text/template"
    13  	"text/template/parse"
    14  
    15  	"log"
    16  
    17  	"github.com/go-openapi/inflect"
    18  	"github.com/go-openapi/swag"
    19  	"github.com/kr/pretty"
    20  )
    21  
    22  var templates *Repository
    23  
    24  // FuncMap is a map with default functions for use n the templates.
    25  // These are available in every template
    26  var FuncMap template.FuncMap = map[string]interface{}{
    27  	"pascalize": pascalize,
    28  	"camelize":  swag.ToJSONName,
    29  	"varname":   golang.MangleVarName,
    30  	"humanize":  swag.ToHumanNameLower,
    31  	"snakize":   golang.MangleFileName,
    32  	"toPackagePath": func(name string) string {
    33  		return filepath.FromSlash(golang.ManglePackagePath(name, ""))
    34  	},
    35  	"toPackage": func(name string) string {
    36  		return golang.ManglePackagePath(name, "")
    37  	},
    38  	"toPackageName": func(name string) string {
    39  		return golang.ManglePackageName(name, "")
    40  	},
    41  	"dasherize": swag.ToCommandName,
    42  	"pluralizeFirstWord": func(arg string) string {
    43  		sentence := strings.Split(arg, " ")
    44  		if len(sentence) == 1 {
    45  			return inflect.Pluralize(arg)
    46  		}
    47  
    48  		return inflect.Pluralize(sentence[0]) + " " + strings.Join(sentence[1:], " ")
    49  	},
    50  	"json":       asJSON,
    51  	"prettyjson": asPrettyJSON,
    52  	"hasInsecure": func(arg []string) bool {
    53  		return swag.ContainsStringsCI(arg, "http") || swag.ContainsStringsCI(arg, "ws")
    54  	},
    55  	"hasSecure": func(arg []string) bool {
    56  		return swag.ContainsStringsCI(arg, "https") || swag.ContainsStringsCI(arg, "wss")
    57  	},
    58  	// TODO: simplify redundant functions
    59  	"stripPackage": func(str, pkg string) string {
    60  		parts := strings.Split(str, ".")
    61  		strlen := len(parts)
    62  		if strlen > 0 {
    63  			return parts[strlen-1]
    64  		}
    65  		return str
    66  	},
    67  	"dropPackage": func(str string) string {
    68  		parts := strings.Split(str, ".")
    69  		strlen := len(parts)
    70  		if strlen > 0 {
    71  			return parts[strlen-1]
    72  		}
    73  		return str
    74  	},
    75  	"upper": strings.ToUpper,
    76  	"contains": func(coll []string, arg string) bool {
    77  		for _, v := range coll {
    78  			if v == arg {
    79  				return true
    80  			}
    81  		}
    82  		return false
    83  	},
    84  	"padSurround": func(entry, padWith string, i, ln int) string {
    85  		var res []string
    86  		if i > 0 {
    87  			for j := 0; j < i; j++ {
    88  				res = append(res, padWith)
    89  			}
    90  		}
    91  		res = append(res, entry)
    92  		tot := ln - i - 1
    93  		for j := 0; j < tot; j++ {
    94  			res = append(res, padWith)
    95  		}
    96  		return strings.Join(res, ",")
    97  	},
    98  	"joinFilePath": filepath.Join,
    99  	"comment": func(str string) string {
   100  		lines := strings.Split(str, "\n")
   101  		return (strings.Join(lines, "\n// "))
   102  	},
   103  	"blockcomment": func(str string) string {
   104  		return strings.Replace(str, "*/", "[*]/", -1)
   105  	},
   106  	"inspect":   pretty.Sprint,
   107  	"cleanPath": path.Clean,
   108  	"mediaTypeName": func(orig string) string {
   109  		return strings.SplitN(orig, ";", 2)[0]
   110  	},
   111  	"goSliceInitializer": goSliceInitializer,
   112  	"hasPrefix":          strings.HasPrefix,
   113  	"stringContains":     strings.Contains,
   114  }
   115  
   116  func init() {
   117  	templates = NewRepository(FuncMap)
   118  }
   119  
   120  var assets = map[string][]byte{
   121  	"validation/primitive.gotmpl":           MustAsset("templates/validation/primitive.gotmpl"),
   122  	"validation/customformat.gotmpl":        MustAsset("templates/validation/customformat.gotmpl"),
   123  	"docstring.gotmpl":                      MustAsset("templates/docstring.gotmpl"),
   124  	"validation/structfield.gotmpl":         MustAsset("templates/validation/structfield.gotmpl"),
   125  	"modelvalidator.gotmpl":                 MustAsset("templates/modelvalidator.gotmpl"),
   126  	"structfield.gotmpl":                    MustAsset("templates/structfield.gotmpl"),
   127  	"tupleserializer.gotmpl":                MustAsset("templates/tupleserializer.gotmpl"),
   128  	"additionalpropertiesserializer.gotmpl": MustAsset("templates/additionalpropertiesserializer.gotmpl"),
   129  	"schematype.gotmpl":                     MustAsset("templates/schematype.gotmpl"),
   130  	"schemabody.gotmpl":                     MustAsset("templates/schemabody.gotmpl"),
   131  	"schema.gotmpl":                         MustAsset("templates/schema.gotmpl"),
   132  	"schemavalidator.gotmpl":                MustAsset("templates/schemavalidator.gotmpl"),
   133  	"model.gotmpl":                          MustAsset("templates/model.gotmpl"),
   134  	"header.gotmpl":                         MustAsset("templates/header.gotmpl"),
   135  	"swagger_json_embed.gotmpl":             MustAsset("templates/swagger_json_embed.gotmpl"),
   136  
   137  	"server/parameter.gotmpl":    MustAsset("templates/server/parameter.gotmpl"),
   138  	"server/urlbuilder.gotmpl":   MustAsset("templates/server/urlbuilder.gotmpl"),
   139  	"server/responses.gotmpl":    MustAsset("templates/server/responses.gotmpl"),
   140  	"server/operation.gotmpl":    MustAsset("templates/server/operation.gotmpl"),
   141  	"server/builder.gotmpl":      MustAsset("templates/server/builder.gotmpl"),
   142  	"server/server.gotmpl":       MustAsset("templates/server/server.gotmpl"),
   143  	"server/configureapi.gotmpl": MustAsset("templates/server/configureapi.gotmpl"),
   144  	"server/main.gotmpl":         MustAsset("templates/server/main.gotmpl"),
   145  	"server/doc.gotmpl":          MustAsset("templates/server/doc.gotmpl"),
   146  
   147  	"client/parameter.gotmpl": MustAsset("templates/client/parameter.gotmpl"),
   148  	"client/response.gotmpl":  MustAsset("templates/client/response.gotmpl"),
   149  	"client/client.gotmpl":    MustAsset("templates/client/client.gotmpl"),
   150  	"client/facade.gotmpl":    MustAsset("templates/client/facade.gotmpl"),
   151  }
   152  
   153  var protectedTemplates = map[string]bool{
   154  	"schemabody":                     true,
   155  	"privtuplefield":                 true,
   156  	"withoutBaseTypeBody":            true,
   157  	"swaggerJsonEmbed":               true,
   158  	"validationCustomformat":         true,
   159  	"tuplefield":                     true,
   160  	"header":                         true,
   161  	"withBaseTypeBody":               true,
   162  	"primitivefieldvalidator":        true,
   163  	"mapvalidator":                   true,
   164  	"propertyValidationDocString":    true,
   165  	"typeSchemaType":                 true,
   166  	"docstring":                      true,
   167  	"dereffedSchemaType":             true,
   168  	"model":                          true,
   169  	"modelvalidator":                 true,
   170  	"privstructfield":                true,
   171  	"schemavalidator":                true,
   172  	"tuplefieldIface":                true,
   173  	"tupleSerializer":                true,
   174  	"tupleserializer":                true,
   175  	"schemaSerializer":               true,
   176  	"propertyvalidator":              true,
   177  	"structfieldIface":               true,
   178  	"schemaBody":                     true,
   179  	"objectvalidator":                true,
   180  	"schematype":                     true,
   181  	"additionalpropertiesserializer": true,
   182  	"slicevalidator":                 true,
   183  	"validationStructfield":          true,
   184  	"validationPrimitive":            true,
   185  	"schemaType":                     true,
   186  	"subTypeBody":                    true,
   187  	"schema":                         true,
   188  	"additionalPropertiesSerializer": true,
   189  	"serverDoc":                      true,
   190  	"structfield":                    true,
   191  	"hasDiscriminatedSerializer":     true,
   192  	"discriminatedSerializer":        true,
   193  }
   194  
   195  // AddFile adds a file to the default repository. It will create a new template based on the filename.
   196  // It trims the .gotmpl from the end and converts the name using swag.ToJSONName. This will strip
   197  // directory separators and Camelcase the next letter.
   198  // e.g validation/primitive.gotmpl will become validationPrimitive
   199  //
   200  // If the file contains a definition for a template that is protected the whole file will not be added
   201  func AddFile(name, data string) error {
   202  	return templates.addFile(name, data, false)
   203  }
   204  
   205  func asJSON(data interface{}) (string, error) {
   206  	b, err := json.Marshal(data)
   207  	if err != nil {
   208  		return "", err
   209  	}
   210  	return string(b), nil
   211  }
   212  
   213  func asPrettyJSON(data interface{}) (string, error) {
   214  	b, err := json.MarshalIndent(data, "", "  ")
   215  	if err != nil {
   216  		return "", err
   217  	}
   218  	return string(b), nil
   219  }
   220  
   221  func goSliceInitializer(data interface{}) (string, error) {
   222  	// goSliceInitializer constructs a Go literal initializer from interface{} literals.
   223  	// e.g. []interface{}{"a", "b"} is transformed in {"a","b",}
   224  	// e.g. map[string]interface{}{ "a": "x", "b": "y"} is transformed in {"a":"x","b":"y",}.
   225  	//
   226  	// NOTE: this is currently used to construct simple slice intializers for default values.
   227  	// This allows for nicer slice initializers for slices of primitive types and avoid systematic use for json.Unmarshal().
   228  	b, err := json.Marshal(data)
   229  	if err != nil {
   230  		return "", err
   231  	}
   232  	return strings.Replace(strings.Replace(strings.Replace(string(b), "}", ",}", -1), "[", "{", -1), "]", ",}", -1), nil
   233  }
   234  
   235  // NewRepository creates a new template repository with the provided functions defined
   236  func NewRepository(funcs template.FuncMap) *Repository {
   237  	repo := Repository{
   238  		files:     make(map[string]string),
   239  		templates: make(map[string]*template.Template),
   240  		funcs:     funcs,
   241  	}
   242  
   243  	if repo.funcs == nil {
   244  		repo.funcs = make(template.FuncMap)
   245  	}
   246  
   247  	return &repo
   248  }
   249  
   250  // Repository is the repository for the generator templates
   251  type Repository struct {
   252  	files         map[string]string
   253  	templates     map[string]*template.Template
   254  	funcs         template.FuncMap
   255  	allowOverride bool
   256  }
   257  
   258  // LoadDefaults will load the embedded templates
   259  func (t *Repository) LoadDefaults() {
   260  
   261  	for name, asset := range assets {
   262  		if err := t.addFile(name, string(asset), true); err != nil {
   263  			log.Fatal(err)
   264  		}
   265  	}
   266  }
   267  
   268  // LoadDir will walk the specified path and add each .gotmpl file it finds to the repository
   269  func (t *Repository) LoadDir(templatePath string) error {
   270  	err := filepath.Walk(templatePath, func(path string, info os.FileInfo, err error) error {
   271  
   272  		if strings.HasSuffix(path, ".gotmpl") {
   273  			if assetName, e := filepath.Rel(templatePath, path); e == nil {
   274  				if data, e := ioutil.ReadFile(path); e == nil {
   275  					if ee := t.AddFile(assetName, string(data)); ee != nil {
   276  						// Fatality is decided by caller
   277  						// log.Fatal(ee)
   278  						return fmt.Errorf("could not add template: %v", ee)
   279  					}
   280  				}
   281  				// Non-readable files are skipped
   282  			}
   283  		}
   284  		if err != nil {
   285  			return err
   286  		}
   287  		// Non-template files are skipped
   288  		return nil
   289  	})
   290  	if err != nil {
   291  		return fmt.Errorf("could not complete template processing in directory \"%s\": %v", templatePath, err)
   292  	}
   293  	return nil
   294  }
   295  
   296  // LoadContrib loads template from contrib directory
   297  func (t *Repository) LoadContrib(name string) error {
   298  	log.Printf("loading contrib %s", name)
   299  	const pathPrefix = "templates/contrib/"
   300  	basePath := pathPrefix + name
   301  	filesAdded := 0
   302  	for _, aname := range AssetNames() {
   303  		if !strings.HasSuffix(aname, ".gotmpl") {
   304  			continue
   305  		}
   306  		if strings.HasPrefix(aname, basePath) {
   307  			target := aname[len(basePath)+1:]
   308  			err := t.addFile(target, string(MustAsset(aname)), true)
   309  			if err != nil {
   310  				return err
   311  			}
   312  			log.Printf("added contributed template %s from %s", target, aname)
   313  			filesAdded++
   314  		}
   315  	}
   316  	if filesAdded == 0 {
   317  		return fmt.Errorf("no files added from template: %s", name)
   318  	}
   319  	return nil
   320  }
   321  
   322  func (t *Repository) addFile(name, data string, allowOverride bool) error {
   323  	fileName := name
   324  	name = swag.ToJSONName(strings.TrimSuffix(name, ".gotmpl"))
   325  
   326  	templ, err := template.New(name).Funcs(t.funcs).Parse(data)
   327  
   328  	if err != nil {
   329  		return fmt.Errorf("failed to load template %s: %v", name, err)
   330  	}
   331  
   332  	// check if any protected templates are defined
   333  	if !allowOverride && !t.allowOverride {
   334  		for _, template := range templ.Templates() {
   335  			if protectedTemplates[template.Name()] {
   336  				return fmt.Errorf("cannot overwrite protected template %s", template.Name())
   337  			}
   338  		}
   339  	}
   340  
   341  	// Add each defined template into the cache
   342  	for _, template := range templ.Templates() {
   343  
   344  		t.files[template.Name()] = fileName
   345  		t.templates[template.Name()] = template.Lookup(template.Name())
   346  	}
   347  
   348  	return nil
   349  }
   350  
   351  // MustGet a template by name, panics when fails
   352  func (t *Repository) MustGet(name string) *template.Template {
   353  	tpl, err := t.Get(name)
   354  	if err != nil {
   355  		panic(err)
   356  	}
   357  	return tpl
   358  }
   359  
   360  // AddFile adds a file to the repository. It will create a new template based on the filename.
   361  // It trims the .gotmpl from the end and converts the name using swag.ToJSONName. This will strip
   362  // directory separators and Camelcase the next letter.
   363  // e.g validation/primitive.gotmpl will become validationPrimitive
   364  //
   365  // If the file contains a definition for a template that is protected the whole file will not be added
   366  func (t *Repository) AddFile(name, data string) error {
   367  	return t.addFile(name, data, false)
   368  }
   369  
   370  // SetAllowOverride allows setting allowOverride after the Repository was initialized
   371  func (t *Repository) SetAllowOverride(value bool) {
   372  	t.allowOverride = value
   373  }
   374  
   375  func findDependencies(n parse.Node) []string {
   376  
   377  	var deps []string
   378  	depMap := make(map[string]bool)
   379  
   380  	if n == nil {
   381  		return deps
   382  	}
   383  
   384  	switch node := n.(type) {
   385  	case *parse.ListNode:
   386  		if node != nil && node.Nodes != nil {
   387  			for _, nn := range node.Nodes {
   388  				for _, dep := range findDependencies(nn) {
   389  					depMap[dep] = true
   390  				}
   391  			}
   392  		}
   393  	case *parse.IfNode:
   394  		for _, dep := range findDependencies(node.BranchNode.List) {
   395  			depMap[dep] = true
   396  		}
   397  		for _, dep := range findDependencies(node.BranchNode.ElseList) {
   398  			depMap[dep] = true
   399  		}
   400  
   401  	case *parse.RangeNode:
   402  		for _, dep := range findDependencies(node.BranchNode.List) {
   403  			depMap[dep] = true
   404  		}
   405  		for _, dep := range findDependencies(node.BranchNode.ElseList) {
   406  			depMap[dep] = true
   407  		}
   408  
   409  	case *parse.WithNode:
   410  		for _, dep := range findDependencies(node.BranchNode.List) {
   411  			depMap[dep] = true
   412  		}
   413  		for _, dep := range findDependencies(node.BranchNode.ElseList) {
   414  			depMap[dep] = true
   415  		}
   416  
   417  	case *parse.TemplateNode:
   418  		depMap[node.Name] = true
   419  	}
   420  
   421  	for dep := range depMap {
   422  		deps = append(deps, dep)
   423  	}
   424  
   425  	return deps
   426  
   427  }
   428  
   429  func (t *Repository) flattenDependencies(templ *template.Template, dependencies map[string]bool) map[string]bool {
   430  	if dependencies == nil {
   431  		dependencies = make(map[string]bool)
   432  	}
   433  
   434  	deps := findDependencies(templ.Tree.Root)
   435  
   436  	for _, d := range deps {
   437  		if _, found := dependencies[d]; !found {
   438  
   439  			dependencies[d] = true
   440  
   441  			if tt := t.templates[d]; tt != nil {
   442  				dependencies = t.flattenDependencies(tt, dependencies)
   443  			}
   444  		}
   445  
   446  		dependencies[d] = true
   447  
   448  	}
   449  
   450  	return dependencies
   451  
   452  }
   453  
   454  func (t *Repository) addDependencies(templ *template.Template) (*template.Template, error) {
   455  
   456  	name := templ.Name()
   457  
   458  	deps := t.flattenDependencies(templ, nil)
   459  
   460  	for dep := range deps {
   461  
   462  		if dep == "" {
   463  			continue
   464  		}
   465  
   466  		tt := templ.Lookup(dep)
   467  
   468  		// Check if we have it
   469  		if tt == nil {
   470  			tt = t.templates[dep]
   471  
   472  			// Still don't have it, return an error
   473  			if tt == nil {
   474  				return templ, fmt.Errorf("could not find template %s", dep)
   475  			}
   476  			var err error
   477  
   478  			// Add it to the parse tree
   479  			templ, err = templ.AddParseTree(dep, tt.Tree)
   480  
   481  			if err != nil {
   482  				return templ, fmt.Errorf("dependency error: %v", err)
   483  			}
   484  
   485  		}
   486  	}
   487  	return templ.Lookup(name), nil
   488  }
   489  
   490  // Get will return the named template from the repository, ensuring that all dependent templates are loaded.
   491  // It will return an error if a dependent template is not defined in the repository.
   492  func (t *Repository) Get(name string) (*template.Template, error) {
   493  	templ, found := t.templates[name]
   494  
   495  	if !found {
   496  		return templ, fmt.Errorf("template doesn't exist %s", name)
   497  	}
   498  
   499  	return t.addDependencies(templ)
   500  }
   501  
   502  // DumpTemplates prints out a dump of all the defined templates, where they are defined and what their dependencies are.
   503  func (t *Repository) DumpTemplates() {
   504  	buf := bytes.NewBuffer(nil)
   505  	fmt.Fprintln(buf, "\n# Templates")
   506  	for name, templ := range t.templates {
   507  		fmt.Fprintf(buf, "## %s\n", name)
   508  		fmt.Fprintf(buf, "Defined in `%s`\n", t.files[name])
   509  
   510  		if deps := findDependencies(templ.Tree.Root); len(deps) > 0 {
   511  
   512  			fmt.Fprintf(buf, "####requires \n - %v\n\n\n", strings.Join(deps, "\n - "))
   513  		}
   514  		fmt.Fprintln(buf, "\n---")
   515  	}
   516  	log.Println(buf.String())
   517  }