github.com/Ali-iotechsys/sqlboiler/v4@v4.0.0-20221208124957-6aec9a5f1f71/boilingcore/boilingcore.go (about) 1 // Package boilingcore has types and methods useful for generating code that 2 // acts as a fully dynamic ORM might. 3 package boilingcore 4 5 import ( 6 "encoding/json" 7 "fmt" 8 "io/fs" 9 "os" 10 "path/filepath" 11 "regexp" 12 "sort" 13 "strings" 14 15 "github.com/friendsofgo/errors" 16 "github.com/volatiletech/strmangle" 17 18 "github.com/volatiletech/sqlboiler/v4/drivers" 19 "github.com/volatiletech/sqlboiler/v4/importers" 20 boiltemplates "github.com/volatiletech/sqlboiler/v4/templates" 21 ) 22 23 var ( 24 // Tags must be in a format like: json, xml, etc. 25 rgxValidTag = regexp.MustCompile(`[a-zA-Z_\.]+`) 26 // Column names must be in format column_name or table_name.column_name 27 rgxValidTableColumn = regexp.MustCompile(`^[\w]+\.[\w]+$|^[\w]+$`) 28 ) 29 30 // State holds the global data needed by most pieces to run 31 type State struct { 32 Config *Config 33 34 Driver drivers.Interface 35 Schema string 36 Tables []drivers.Table 37 Dialect drivers.Dialect 38 39 Templates *templateList 40 TestTemplates *templateList 41 } 42 43 // New creates a new state based off of the config 44 func New(config *Config) (*State, error) { 45 s := &State{ 46 Config: config, 47 } 48 49 var templates []lazyTemplate 50 51 defer func() { 52 if s.Config.Debug { 53 debugOut := struct { 54 Config *Config `json:"config"` 55 DriverConfig drivers.Config `json:"driver_config"` 56 Schema string `json:"schema"` 57 Dialect drivers.Dialect `json:"dialect"` 58 Tables []drivers.Table `json:"tables"` 59 Templates []lazyTemplate `json:"templates"` 60 }{ 61 Config: s.Config, 62 DriverConfig: s.Config.DriverConfig, 63 Schema: s.Schema, 64 Dialect: s.Dialect, 65 Tables: s.Tables, 66 Templates: templates, 67 } 68 69 b, err := json.Marshal(debugOut) 70 if err != nil { 71 panic(err) 72 } 73 fmt.Printf("%s\n", b) 74 } 75 }() 76 77 if len(config.Version) > 0 { 78 noEditDisclaimer = []byte( 79 fmt.Sprintf(noEditDisclaimerFmt, " "+config.Version+" "), 80 ) 81 } 82 83 s.Driver = drivers.GetDriver(config.DriverName) 84 s.initInflections() 85 86 err := s.initDBInfo(config.DriverConfig) 87 if err != nil { 88 return nil, errors.Wrap(err, "unable to initialize tables") 89 } 90 91 if err := s.mergeDriverImports(); err != nil { 92 return nil, errors.Wrap(err, "unable to merge imports from driver") 93 } 94 95 if s.Config.AddEnumTypes { 96 s.mergeEnumImports() 97 } 98 99 if !s.Config.NoContext { 100 s.Config.Imports.All.Standard = append(s.Config.Imports.All.Standard, `"context"`) 101 s.Config.Imports.Test.Standard = append(s.Config.Imports.Test.Standard, `"context"`) 102 } 103 104 if err := s.processTypeReplacements(); err != nil { 105 return nil, err 106 } 107 108 templates, err = s.initTemplates() 109 if err != nil { 110 return nil, errors.Wrap(err, "unable to initialize templates") 111 } 112 113 err = s.initOutFolders(templates) 114 if err != nil { 115 return nil, errors.Wrap(err, "unable to initialize the output folders") 116 } 117 118 err = s.initTags(config.Tags) 119 if err != nil { 120 return nil, errors.Wrap(err, "unable to initialize struct tags") 121 } 122 123 err = s.initAliases(&config.Aliases) 124 if err != nil { 125 return nil, errors.Wrap(err, "unable to initialize aliases") 126 } 127 128 return s, nil 129 } 130 131 // Run executes the sqlboiler templates and outputs them to files based on the 132 // state given. 133 func (s *State) Run() error { 134 data := &templateData{ 135 Tables: s.Tables, 136 Aliases: s.Config.Aliases, 137 DriverName: s.Config.DriverName, 138 PkgName: s.Config.PkgName, 139 AddGlobal: s.Config.AddGlobal, 140 AddPanic: s.Config.AddPanic, 141 AddSoftDeletes: s.Config.AddSoftDeletes, 142 AddEnumTypes: s.Config.AddEnumTypes, 143 EnumNullPrefix: s.Config.EnumNullPrefix, 144 NoContext: s.Config.NoContext, 145 NoHooks: s.Config.NoHooks, 146 NoAutoTimestamps: s.Config.NoAutoTimestamps, 147 NoRowsAffected: s.Config.NoRowsAffected, 148 NoDriverTemplates: s.Config.NoDriverTemplates, 149 NoBackReferencing: s.Config.NoBackReferencing, 150 AlwaysWrapErrors: s.Config.AlwaysWrapErrors, 151 StructTagCasing: s.Config.StructTagCasing, 152 TagIgnore: make(map[string]struct{}), 153 Tags: s.Config.Tags, 154 RelationTag: s.Config.RelationTag, 155 Dialect: s.Dialect, 156 Schema: s.Schema, 157 LQ: strmangle.QuoteCharacter(s.Dialect.LQ), 158 RQ: strmangle.QuoteCharacter(s.Dialect.RQ), 159 OutputDirDepth: s.Config.OutputDirDepth(), 160 161 DBTypes: make(once), 162 StringFuncs: templateStringMappers, 163 AutoColumns: s.Config.AutoColumns, 164 } 165 166 for _, v := range s.Config.TagIgnore { 167 if !rgxValidTableColumn.MatchString(v) { 168 return errors.New("Invalid column name %q supplied, only specify column name or table.column, eg: created_at, user.password") 169 } 170 data.TagIgnore[v] = struct{}{} 171 } 172 173 if err := generateSingletonOutput(s, data); err != nil { 174 return errors.Wrap(err, "singleton template output") 175 } 176 177 if !s.Config.NoTests { 178 if err := generateSingletonTestOutput(s, data); err != nil { 179 return errors.Wrap(err, "unable to generate singleton test template output") 180 } 181 } 182 183 var regularDirExtMap, testDirExtMap dirExtMap 184 regularDirExtMap = groupTemplates(s.Templates) 185 if !s.Config.NoTests { 186 testDirExtMap = groupTemplates(s.TestTemplates) 187 } 188 189 for _, table := range s.Tables { 190 if table.IsJoinTable { 191 continue 192 } 193 194 data.Table = table 195 196 // Generate the regular templates 197 if err := generateOutput(s, regularDirExtMap, data); err != nil { 198 return errors.Wrap(err, "unable to generate output") 199 } 200 201 // Generate the test templates 202 if !s.Config.NoTests && !table.IsView { 203 if err := generateTestOutput(s, testDirExtMap, data); err != nil { 204 return errors.Wrap(err, "unable to generate test output") 205 } 206 } 207 } 208 209 return nil 210 } 211 212 // Cleanup closes any resources that must be closed 213 func (s *State) Cleanup() error { 214 // Nothing here atm, used to close the driver 215 return nil 216 } 217 218 // initTemplates loads all template folders into the state object. 219 // 220 // If TemplateDirs is set it uses those, else it pulls from assets. 221 // Then it allows drivers to override, followed by replacements. Any 222 // user functions passed in by library users will be merged into the 223 // template.FuncMap. 224 // 225 // Because there's the chance for windows paths to jumped in 226 // all paths are converted to the native OS's slash style. 227 // 228 // Later, in order to properly look up imports the paths will 229 // be forced back to linux style paths. 230 func (s *State) initTemplates() ([]lazyTemplate, error) { 231 var err error 232 233 templates := make(map[string]templateLoader) 234 if len(s.Config.TemplateDirs) != 0 { 235 for _, dir := range s.Config.TemplateDirs { 236 abs, err := filepath.Abs(dir) 237 if err != nil { 238 return nil, errors.Wrap(err, "could not find abs dir of templates directory") 239 } 240 241 base := filepath.Base(abs) 242 root := filepath.Dir(abs) 243 tpls, err := findTemplates(root, base) 244 if err != nil { 245 return nil, err 246 } 247 248 mergeTemplates(templates, tpls) 249 } 250 } else { 251 defaultTemplates := s.Config.DefaultTemplates 252 if defaultTemplates == nil { 253 defaultTemplates = boiltemplates.Builtin 254 } 255 256 err := fs.WalkDir(defaultTemplates, ".", func(path string, entry fs.DirEntry, err error) error { 257 if err != nil { 258 return err 259 } 260 261 if entry.IsDir() { 262 return nil 263 } 264 265 name := entry.Name() 266 if filepath.Ext(name) == ".tpl" { 267 templates[normalizeSlashes(path)] = assetLoader{fs: defaultTemplates, name: path} 268 } 269 270 return nil 271 }) 272 if err != nil { 273 return nil, err 274 } 275 } 276 277 if !s.Config.NoDriverTemplates { 278 driverTemplates, err := s.Driver.Templates() 279 if err != nil { 280 return nil, err 281 } 282 283 for template, contents := range driverTemplates { 284 templates[normalizeSlashes(template)] = base64Loader(contents) 285 } 286 } 287 288 for _, replace := range s.Config.Replacements { 289 splits := strings.Split(replace, ";") 290 if len(splits) != 2 { 291 return nil, errors.Errorf("replace parameters must have 2 arguments, given: %s", replace) 292 } 293 294 original, replacement := normalizeSlashes(splits[0]), splits[1] 295 296 _, ok := templates[original] 297 if !ok { 298 return nil, errors.Errorf("replace can only replace existing templates, %s does not exist", original) 299 } 300 301 templates[original] = fileLoader(replacement) 302 } 303 304 // For stability, sort keys to traverse the map and turn it into a slice 305 keys := make([]string, 0, len(templates)) 306 for k := range templates { 307 keys = append(keys, k) 308 } 309 sort.Strings(keys) 310 311 lazyTemplates := make([]lazyTemplate, 0, len(templates)) 312 for _, k := range keys { 313 lazyTemplates = append(lazyTemplates, lazyTemplate{ 314 Name: k, 315 Loader: templates[k], 316 }) 317 } 318 319 s.Templates, err = loadTemplates(lazyTemplates, false, s.Config.CustomTemplateFuncs) 320 if err != nil { 321 return nil, err 322 } 323 324 if !s.Config.NoTests { 325 s.TestTemplates, err = loadTemplates(lazyTemplates, true, s.Config.CustomTemplateFuncs) 326 if err != nil { 327 return nil, err 328 } 329 } 330 331 return lazyTemplates, nil 332 } 333 334 type dirExtMap map[string]map[string][]string 335 336 // groupTemplates takes templates and groups them according to their output directory 337 // and file extension. 338 func groupTemplates(templates *templateList) dirExtMap { 339 tplNames := templates.Templates() 340 dirs := make(map[string]map[string][]string) 341 for _, tplName := range tplNames { 342 normalized, isSingleton, _, _ := outputFilenameParts(tplName) 343 if isSingleton { 344 continue 345 } 346 347 dir := filepath.Dir(normalized) 348 if dir == "." { 349 dir = "" 350 } 351 352 extensions, ok := dirs[dir] 353 if !ok { 354 extensions = make(map[string][]string) 355 dirs[dir] = extensions 356 } 357 358 ext := getLongExt(tplName) 359 ext = strings.TrimSuffix(ext, ".tpl") 360 slice := extensions[ext] 361 extensions[ext] = append(slice, tplName) 362 } 363 364 return dirs 365 } 366 367 // findTemplates uses a root path: (/home/user/gopath/src/../sqlboiler/) 368 // and a base path: /templates 369 // to create a bunch of file loaders of the form: 370 // templates/00_struct.tpl -> /absolute/path/to/that/file 371 // Note the missing leading slash, this is important for the --replace argument 372 func findTemplates(root, base string) (map[string]templateLoader, error) { 373 templates := make(map[string]templateLoader) 374 rootBase := filepath.Join(root, base) 375 err := filepath.Walk(rootBase, func(path string, fi os.FileInfo, err error) error { 376 if err != nil { 377 return err 378 } 379 380 if fi.IsDir() { 381 return nil 382 } 383 384 ext := filepath.Ext(path) 385 if ext != ".tpl" { 386 return nil 387 } 388 389 relative, err := filepath.Rel(root, path) 390 if err != nil { 391 return errors.Wrapf(err, "could not find relative path to base root: %s", rootBase) 392 } 393 394 relative = strings.TrimLeft(relative, string(os.PathSeparator)) 395 templates[relative] = fileLoader(path) 396 return nil 397 }) 398 399 if err != nil { 400 return nil, err 401 } 402 403 return templates, nil 404 } 405 406 // initDBInfo retrieves information about the database 407 func (s *State) initDBInfo(config map[string]interface{}) error { 408 dbInfo, err := s.Driver.Assemble(config) 409 if err != nil { 410 return errors.Wrap(err, "unable to fetch table data") 411 } 412 413 if len(dbInfo.Tables) == 0 { 414 return errors.New("no tables found in database") 415 } 416 417 if err := checkPKeys(dbInfo.Tables); err != nil { 418 return err 419 } 420 421 s.Schema = dbInfo.Schema 422 s.Tables = dbInfo.Tables 423 s.Dialect = dbInfo.Dialect 424 425 return nil 426 } 427 428 // mergeDriverImports calls the driver and asks for its set 429 // of imports, then merges it into the current configuration's 430 // imports. 431 func (s *State) mergeDriverImports() error { 432 drivers, err := s.Driver.Imports() 433 if err != nil { 434 return errors.Wrap(err, "failed to fetch driver's imports") 435 } 436 437 s.Config.Imports = importers.Merge(s.Config.Imports, drivers) 438 return nil 439 } 440 441 // mergeEnumImports merges imports for nullable enum types 442 // into the current configuration's imports if tables returned 443 // from the driver have nullable enum columns. 444 func (s *State) mergeEnumImports() { 445 if drivers.TablesHaveNullableEnums(s.Tables) { 446 s.Config.Imports = importers.Merge(s.Config.Imports, importers.NullableEnumImports()) 447 } 448 } 449 450 // processTypeReplacements checks the config for type replacements 451 // and performs them. 452 func (s *State) processTypeReplacements() error { 453 for _, r := range s.Config.TypeReplaces { 454 455 for i := range s.Tables { 456 t := s.Tables[i] 457 458 if !shouldReplaceInTable(t, r) { 459 continue 460 } 461 462 for j := range t.Columns { 463 c := t.Columns[j] 464 if matchColumn(c, r.Match) { 465 t.Columns[j] = columnMerge(c, r.Replace) 466 467 if len(r.Imports.Standard) != 0 || len(r.Imports.ThirdParty) != 0 { 468 s.Config.Imports.BasedOnType[t.Columns[j].Type] = importers.Set{ 469 Standard: r.Imports.Standard, 470 ThirdParty: r.Imports.ThirdParty, 471 } 472 } 473 } 474 } 475 } 476 } 477 478 return nil 479 } 480 481 // matchColumn checks if a column 'c' matches specifiers in 'm'. 482 // Anything defined in m is checked against a's values, the 483 // match is a done using logical and (all specifiers must match). 484 // Bool fields are only checked if a string type field matched first 485 // and if a string field matched they are always checked (must be defined). 486 // 487 // Doesn't care about Unique columns since those can vary independent of type. 488 func matchColumn(c, m drivers.Column) bool { 489 matchedSomething := false 490 491 // return true if we matched, or we don't have to match 492 // if we actually matched against something, then additionally set 493 // matchedSomething so we can check boolean values too. 494 matches := func(matcher, value string) bool { 495 if len(matcher) != 0 && matcher != value { 496 return false 497 } 498 matchedSomething = true 499 return true 500 } 501 502 if !matches(m.Name, c.Name) { 503 return false 504 } 505 if !matches(m.Type, c.Type) { 506 return false 507 } 508 if !matches(m.DBType, c.DBType) { 509 return false 510 } 511 if !matches(m.UDTName, c.UDTName) { 512 return false 513 } 514 if !matches(m.FullDBType, c.FullDBType) { 515 return false 516 } 517 if m.ArrType != nil && (c.ArrType == nil || !matches(*m.ArrType, *c.ArrType)) { 518 return false 519 } 520 if m.DomainName != nil && (c.DomainName == nil || !matches(*m.DomainName, *c.DomainName)) { 521 return false 522 } 523 524 if !matchedSomething { 525 return false 526 } 527 528 if m.AutoGenerated != c.AutoGenerated { 529 return false 530 } 531 if m.Nullable != c.Nullable { 532 return false 533 } 534 535 return true 536 } 537 538 // columnMerge merges values from src into dst. Bools are copied regardless 539 // strings are copied if they have values. Name is excluded because it doesn't make 540 // sense to non-programatically replace a name. 541 func columnMerge(dst, src drivers.Column) drivers.Column { 542 ret := dst 543 if len(src.Type) != 0 { 544 ret.Type = src.Type 545 } 546 if len(src.DBType) != 0 { 547 ret.DBType = src.DBType 548 } 549 if len(src.UDTName) != 0 { 550 ret.UDTName = src.UDTName 551 } 552 if len(src.FullDBType) != 0 { 553 ret.FullDBType = src.FullDBType 554 } 555 if src.ArrType != nil && len(*src.ArrType) != 0 { 556 ret.ArrType = new(string) 557 *ret.ArrType = *src.ArrType 558 } 559 560 return ret 561 } 562 563 // shouldReplaceInTable checks if tables were specified in types.match in the config. 564 // If tables were set, it checks if the given table is among the specified tables. 565 func shouldReplaceInTable(t drivers.Table, r TypeReplace) bool { 566 if len(r.Tables) == 0 { 567 return true 568 } 569 570 for _, replaceInTable := range r.Tables { 571 if replaceInTable == t.Name { 572 return true 573 } 574 } 575 576 return false 577 } 578 579 // initOutFolders creates the folders that will hold the generated output. 580 func (s *State) initOutFolders(lazyTemplates []lazyTemplate) error { 581 if s.Config.Wipe { 582 if err := os.RemoveAll(s.Config.OutFolder); err != nil { 583 return err 584 } 585 } 586 587 newDirs := make(map[string]struct{}) 588 for _, t := range lazyTemplates { 589 // templates/js/00_struct.js.tpl 590 // templates/js/singleton/00_struct.js.tpl 591 // we want the js part only 592 fragments := strings.Split(t.Name, string(os.PathSeparator)) 593 594 // Throw away the root dir and filename 595 fragments = fragments[1 : len(fragments)-1] 596 if len(fragments) != 0 && fragments[len(fragments)-1] == "singleton" { 597 fragments = fragments[:len(fragments)-1] 598 } 599 600 if len(fragments) == 0 { 601 continue 602 } 603 604 newDirs[strings.Join(fragments, string(os.PathSeparator))] = struct{}{} 605 } 606 607 if err := os.MkdirAll(s.Config.OutFolder, os.ModePerm); err != nil { 608 return err 609 } 610 611 for d := range newDirs { 612 if err := os.MkdirAll(filepath.Join(s.Config.OutFolder, d), os.ModePerm); err != nil { 613 return err 614 } 615 } 616 617 return nil 618 } 619 620 // initInflections adds custom inflections to strmangle's ruleset 621 func (s *State) initInflections() { 622 ruleset := strmangle.GetBoilRuleset() 623 624 for k, v := range s.Config.Inflections.Plural { 625 ruleset.AddPlural(k, v) 626 } 627 for k, v := range s.Config.Inflections.PluralExact { 628 ruleset.AddPluralExact(k, v, true) 629 } 630 631 for k, v := range s.Config.Inflections.Singular { 632 ruleset.AddSingular(k, v) 633 } 634 for k, v := range s.Config.Inflections.SingularExact { 635 ruleset.AddSingularExact(k, v, true) 636 } 637 638 for k, v := range s.Config.Inflections.Irregular { 639 ruleset.AddIrregular(k, v) 640 } 641 } 642 643 // initTags removes duplicate tags and validates the format 644 // of all user tags are simple strings without quotes: [a-zA-Z_\.]+ 645 func (s *State) initTags(tags []string) error { 646 s.Config.Tags = strmangle.RemoveDuplicates(s.Config.Tags) 647 for _, v := range s.Config.Tags { 648 if !rgxValidTag.MatchString(v) { 649 return errors.New("Invalid tag format %q supplied, only specify name, eg: xml") 650 } 651 } 652 653 return nil 654 } 655 656 func (s *State) initAliases(a *Aliases) error { 657 FillAliases(a, s.Tables) 658 return nil 659 } 660 661 // checkPKeys ensures every table has a primary key column 662 func checkPKeys(tables []drivers.Table) error { 663 var missingPkey []string 664 for _, t := range tables { 665 if !t.IsView && t.PKey == nil { 666 missingPkey = append(missingPkey, t.Name) 667 } 668 } 669 670 if len(missingPkey) != 0 { 671 return errors.Errorf("primary key missing in tables (%s)", strings.Join(missingPkey, ", ")) 672 } 673 674 return nil 675 } 676 677 func mergeTemplates(dst, src map[string]templateLoader) { 678 for k, v := range src { 679 dst[k] = v 680 } 681 } 682 683 // normalizeSlashes takes a path that was made on linux or windows and converts it 684 // to a native path. 685 func normalizeSlashes(path string) string { 686 path = strings.ReplaceAll(path, `/`, string(os.PathSeparator)) 687 path = strings.ReplaceAll(path, `\`, string(os.PathSeparator)) 688 return path 689 } 690 691 // denormalizeSlashes takes any backslashes and converts them to linux style slashes 692 func denormalizeSlashes(path string) string { 693 path = strings.ReplaceAll(path, `\`, `/`) 694 return path 695 }