github.com/Ali-iotechsys/sqlboiler/v4@v4.0.0-20221208124957-6aec9a5f1f71/boilingcore/templates.go (about)

     1  package boilingcore
     2  
     3  import (
     4  	"crypto/sha256"
     5  	"encoding"
     6  	"encoding/base64"
     7  	"fmt"
     8  	"io/fs"
     9  	"os"
    10  	"path/filepath"
    11  	"sort"
    12  	"strings"
    13  	"text/template"
    14  
    15  	"github.com/Masterminds/sprig/v3"
    16  	"github.com/friendsofgo/errors"
    17  	"github.com/volatiletech/sqlboiler/v4/drivers"
    18  	"github.com/volatiletech/strmangle"
    19  )
    20  
    21  // templateData for sqlboiler templates
    22  type templateData struct {
    23  	Tables  []drivers.Table
    24  	Table   drivers.Table
    25  	Aliases Aliases
    26  
    27  	// Controls what names are output
    28  	PkgName string
    29  	Schema  string
    30  
    31  	// Helps tune the output
    32  	DriverName string
    33  	Dialect    drivers.Dialect
    34  
    35  	// LQ and RQ contain a quoted quote that allows us to write
    36  	// the templates more easily.
    37  	LQ string
    38  	RQ string
    39  
    40  	// Control various generation features
    41  	AddGlobal         bool
    42  	AddPanic          bool
    43  	AddSoftDeletes    bool
    44  	AddEnumTypes      bool
    45  	EnumNullPrefix    string
    46  	NoContext         bool
    47  	NoHooks           bool
    48  	NoAutoTimestamps  bool
    49  	NoRowsAffected    bool
    50  	NoDriverTemplates bool
    51  	NoBackReferencing bool
    52  	AlwaysWrapErrors  bool
    53  
    54  	// Tags control which tags are added to the struct
    55  	Tags []string
    56  
    57  	// RelationTag controls the value of the tags for the Relationship struct
    58  	RelationTag string
    59  
    60  	// Generate struct tags as camelCase or snake_case
    61  	StructTagCasing string
    62  
    63  	// Contains field names that should have tags values set to '-'
    64  	TagIgnore map[string]struct{}
    65  
    66  	// OutputDirDepth is used to find sqlboiler config file
    67  	OutputDirDepth int
    68  
    69  	// Hacky state for where clauses to avoid having to do type-based imports
    70  	// for singletons
    71  	DBTypes once
    72  
    73  	// StringFuncs are usable in templates with stringMap
    74  	StringFuncs map[string]func(string) string
    75  
    76  	// AutoColumns set the name of the columns for auto timestamps and soft deletes
    77  	AutoColumns AutoColumns
    78  }
    79  
    80  func (t templateData) Quotes(s string) string {
    81  	return fmt.Sprintf("%s%s%s", t.LQ, s, t.RQ)
    82  }
    83  
    84  func (t templateData) QuoteMap(s []string) []string {
    85  	return strmangle.StringMap(t.Quotes, s)
    86  }
    87  
    88  func (t templateData) SchemaTable(table string) string {
    89  	return strmangle.SchemaTable(t.LQ, t.RQ, t.Dialect.UseSchema, t.Schema, table)
    90  }
    91  
    92  type templateList struct {
    93  	*template.Template
    94  }
    95  
    96  type templateNameList []string
    97  
    98  func (t templateNameList) Len() int {
    99  	return len(t)
   100  }
   101  
   102  func (t templateNameList) Swap(k, j int) {
   103  	t[k], t[j] = t[j], t[k]
   104  }
   105  
   106  func (t templateNameList) Less(k, j int) bool {
   107  	// Make sure "struct" goes to the front
   108  	if t[k] == "struct.tpl" {
   109  		return true
   110  	}
   111  
   112  	res := strings.Compare(t[k], t[j])
   113  	return res <= 0
   114  }
   115  
   116  // Templates returns the name of all the templates defined in the template list
   117  func (t templateList) Templates() []string {
   118  	tplList := t.Template.Templates()
   119  
   120  	if len(tplList) == 0 {
   121  		return nil
   122  	}
   123  
   124  	ret := make([]string, 0, len(tplList))
   125  	for _, tpl := range tplList {
   126  		if name := tpl.Name(); strings.HasSuffix(name, ".tpl") {
   127  			ret = append(ret, name)
   128  		}
   129  	}
   130  
   131  	sort.Sort(templateNameList(ret))
   132  
   133  	return ret
   134  }
   135  
   136  func loadTemplates(lazyTemplates []lazyTemplate, testTemplates bool, customFuncs template.FuncMap) (*templateList, error) {
   137  	tpl := template.New("")
   138  
   139  	for _, t := range lazyTemplates {
   140  		firstDir := strings.Split(t.Name, string(filepath.Separator))[0]
   141  		isTest := firstDir == "test" || strings.HasSuffix(firstDir, "_test")
   142  		if testTemplates && !isTest || !testTemplates && isTest {
   143  			continue
   144  		}
   145  
   146  		byt, err := t.Loader.Load()
   147  		if err != nil {
   148  			return nil, errors.Wrapf(err, "failed to load template: %s", t.Name)
   149  		}
   150  
   151  		_, err = tpl.New(t.Name).
   152  			Funcs(sprig.GenericFuncMap()).
   153  			Funcs(templateFunctions).
   154  			Funcs(customFuncs).
   155  			Parse(string(byt))
   156  		if err != nil {
   157  			return nil, errors.Wrapf(err, "failed to parse template: %s", t.Name)
   158  		}
   159  	}
   160  
   161  	return &templateList{Template: tpl}, nil
   162  }
   163  
   164  type lazyTemplate struct {
   165  	Name   string         `json:"name"`
   166  	Loader templateLoader `json:"loader"`
   167  }
   168  
   169  type templateLoader interface {
   170  	encoding.TextMarshaler
   171  	Load() ([]byte, error)
   172  }
   173  
   174  type fileLoader string
   175  
   176  func (f fileLoader) Load() ([]byte, error) {
   177  	fname := string(f)
   178  	b, err := os.ReadFile(fname)
   179  	if err != nil {
   180  		return nil, errors.Wrapf(err, "failed to load template: %s", fname)
   181  	}
   182  	return b, nil
   183  }
   184  
   185  func (f fileLoader) MarshalText() ([]byte, error) {
   186  	return []byte(f.String()), nil
   187  }
   188  
   189  func (f fileLoader) String() string {
   190  	return "file:" + string(f)
   191  }
   192  
   193  type base64Loader string
   194  
   195  func (b base64Loader) Load() ([]byte, error) {
   196  	byt, err := base64.StdEncoding.DecodeString(string(b))
   197  	if err != nil {
   198  		return nil, errors.Wrap(err, "failed to decode driver's template, should be base64)")
   199  	}
   200  	return byt, nil
   201  }
   202  
   203  func (b base64Loader) MarshalText() ([]byte, error) {
   204  	return []byte(b.String()), nil
   205  }
   206  
   207  func (b base64Loader) String() string {
   208  	byt, err := base64.StdEncoding.DecodeString(string(b))
   209  	if err != nil {
   210  		panic("trying to debug output base64 loader, but was not proper base64!")
   211  	}
   212  	sha := sha256.Sum256(byt)
   213  	return fmt.Sprintf("base64:(sha256 of content): %x", sha)
   214  }
   215  
   216  type assetLoader struct {
   217  	fs   fs.FS
   218  	name string
   219  }
   220  
   221  func (a assetLoader) Load() ([]byte, error) {
   222  	return fs.ReadFile(a.fs, string(a.name))
   223  }
   224  
   225  func (a assetLoader) MarshalText() ([]byte, error) {
   226  	return []byte(a.String()), nil
   227  }
   228  
   229  func (a assetLoader) String() string {
   230  	return "asset:" + string(a.name)
   231  }
   232  
   233  // set is to stop duplication from named enums, allowing a template loop
   234  // to keep some state
   235  type once map[string]struct{}
   236  
   237  func newOnce() once {
   238  	return make(once)
   239  }
   240  
   241  func (o once) Has(s string) bool {
   242  	_, ok := o[s]
   243  	return ok
   244  }
   245  
   246  func (o once) Put(s string) bool {
   247  	if _, ok := o[s]; ok {
   248  		return false
   249  	}
   250  
   251  	o[s] = struct{}{}
   252  	return true
   253  }
   254  
   255  // templateStringMappers are placed into the data to make it easy to use the
   256  // stringMap function.
   257  var templateStringMappers = map[string]func(string) string{
   258  	// String ops
   259  	"quoteWrap":       func(a string) string { return fmt.Sprintf(`%q`, a) },
   260  	"safeQuoteWrap":   func(a string) string { return fmt.Sprintf(`\"%s\"`, a) },
   261  	"replaceReserved": strmangle.ReplaceReservedWords,
   262  
   263  	// Casing
   264  	"titleCase": strmangle.TitleCase,
   265  	"camelCase": strmangle.CamelCase,
   266  }
   267  
   268  var goVarnameReplacer = strings.NewReplacer("[", "_", "]", "_", ".", "_")
   269  
   270  // templateFunctions is a map of some helper functions that get passed into the
   271  // templates. If you wish to pass a new function into your own template,
   272  // you can add that with Config.CustomTemplateFuncs
   273  var templateFunctions = template.FuncMap{
   274  	// String ops
   275  	"quoteWrap": func(s string) string { return fmt.Sprintf(`"%s"`, s) },
   276  	"id":        strmangle.Identifier,
   277  	"goVarname": goVarnameReplacer.Replace,
   278  
   279  	// Pluralization
   280  	"singular": strmangle.Singular,
   281  	"plural":   strmangle.Plural,
   282  
   283  	// Casing
   284  	"titleCase": strmangle.TitleCase,
   285  	"camelCase": strmangle.CamelCase,
   286  	"ignore":    strmangle.Ignore,
   287  
   288  	// String Slice ops
   289  	"join":               func(sep string, slice []string) string { return strings.Join(slice, sep) },
   290  	"joinSlices":         strmangle.JoinSlices,
   291  	"stringMap":          strmangle.StringMap,
   292  	"prefixStringSlice":  strmangle.PrefixStringSlice,
   293  	"containsAny":        strmangle.ContainsAny,
   294  	"generateTags":       strmangle.GenerateTags,
   295  	"generateIgnoreTags": strmangle.GenerateIgnoreTags,
   296  
   297  	// Enum ops
   298  	"parseEnumName": strmangle.ParseEnumName,
   299  	"parseEnumVals": strmangle.ParseEnumVals,
   300  	"onceNew":       newOnce,
   301  	"oncePut":       once.Put,
   302  	"onceHas":       once.Has,
   303  	"isEnumDBType":  drivers.IsEnumDBType,
   304  
   305  	// String Map ops
   306  	"makeStringMap": strmangle.MakeStringMap,
   307  
   308  	// Set operations
   309  	"setInclude": strmangle.SetInclude,
   310  
   311  	// Database related mangling
   312  	"whereClause": strmangle.WhereClause,
   313  
   314  	// Alias and text helping
   315  	"aliasCols":              func(ta TableAlias) func(string) string { return ta.Column },
   316  	"usesPrimitives":         usesPrimitives,
   317  	"isPrimitive":            isPrimitive,
   318  	"isNullPrimitive":        isNullPrimitive,
   319  	"convertNullToPrimitive": convertNullToPrimitive,
   320  	"splitLines": func(a string) []string {
   321  		if a == "" {
   322  			return nil
   323  		}
   324  		return strings.Split(strings.TrimSpace(a), "\n")
   325  	},
   326  
   327  	// dbdrivers ops
   328  	"filterColumnsByAuto":    drivers.FilterColumnsByAuto,
   329  	"filterColumnsByDefault": drivers.FilterColumnsByDefault,
   330  	"filterColumnsByEnum":    drivers.FilterColumnsByEnum,
   331  	"sqlColDefinitions":      drivers.SQLColDefinitions,
   332  	"columnNames":            drivers.ColumnNames,
   333  	"columnDBTypes":          drivers.ColumnDBTypes,
   334  	"getTable":               drivers.GetTable,
   335  }