github.com/99designs/gqlgen@v0.17.45/codegen/templates/templates.go (about) 1 package templates 2 3 import ( 4 "bytes" 5 "fmt" 6 "go/types" 7 "io/fs" 8 "os" 9 "path/filepath" 10 "reflect" 11 "regexp" 12 "runtime" 13 "sort" 14 "strconv" 15 "strings" 16 "sync" 17 "text/template" 18 "unicode" 19 20 "github.com/99designs/gqlgen/internal/code" 21 "github.com/99designs/gqlgen/internal/imports" 22 ) 23 24 // CurrentImports keeps track of all the import declarations that are needed during the execution of a plugin. 25 // this is done with a global because subtemplates currently get called in functions. Lets aim to remove this eventually. 26 var CurrentImports *Imports 27 28 // Options specify various parameters to rendering a template. 29 type Options struct { 30 // PackageName is a helper that specifies the package header declaration. 31 // In other words, when you write the template you don't need to specify `package X` 32 // at the top of the file. By providing PackageName in the Options, the Render 33 // function will do that for you. 34 PackageName string 35 // Template is a string of the entire template that 36 // will be parsed and rendered. If it's empty, 37 // the plugin processor will look for .gotpl files 38 // in the same directory of where you wrote the plugin. 39 Template string 40 41 // Use the go:embed API to collect all the template files you want to pass into Render 42 // this is an alternative to passing the Template option 43 TemplateFS fs.FS 44 45 // Filename is the name of the file that will be 46 // written to the system disk once the template is rendered. 47 Filename string 48 RegionTags bool 49 GeneratedHeader bool 50 // PackageDoc is documentation written above the package line 51 PackageDoc string 52 // FileNotice is notice written below the package line 53 FileNotice string 54 // Data will be passed to the template execution. 55 Data interface{} 56 Funcs template.FuncMap 57 58 // Packages cache, you can find me on config.Config 59 Packages *code.Packages 60 } 61 62 var ( 63 modelNamesMu sync.Mutex 64 modelNames = make(map[string]string, 0) 65 goNameRe = regexp.MustCompile("[^a-zA-Z0-9_]") 66 ) 67 68 // Render renders a gql plugin template from the given Options. Render is an 69 // abstraction of the text/template package that makes it easier to write gqlgen 70 // plugins. If Options.Template is empty, the Render function will look for `.gotpl` 71 // files inside the directory where you wrote the plugin. 72 func Render(cfg Options) error { 73 if CurrentImports != nil { 74 panic(fmt.Errorf("recursive or concurrent call to RenderToFile detected")) 75 } 76 CurrentImports = &Imports{packages: cfg.Packages, destDir: filepath.Dir(cfg.Filename)} 77 78 funcs := Funcs() 79 for n, f := range cfg.Funcs { 80 funcs[n] = f 81 } 82 83 t := template.New("").Funcs(funcs) 84 t, err := parseTemplates(cfg, t) 85 if err != nil { 86 return err 87 } 88 89 roots := make([]string, 0, len(t.Templates())) 90 for _, template := range t.Templates() { 91 // templates that end with _.gotpl are special files we don't want to include 92 if strings.HasSuffix(template.Name(), "_.gotpl") || 93 // filter out templates added with {{ template xxx }} syntax inside the template file 94 !strings.HasSuffix(template.Name(), ".gotpl") { 95 continue 96 } 97 98 roots = append(roots, template.Name()) 99 } 100 101 // then execute all the important looking ones in order, adding them to the same file 102 sort.Slice(roots, func(i, j int) bool { 103 // important files go first 104 if strings.HasSuffix(roots[i], "!.gotpl") { 105 return true 106 } 107 if strings.HasSuffix(roots[j], "!.gotpl") { 108 return false 109 } 110 return roots[i] < roots[j] 111 }) 112 113 var buf bytes.Buffer 114 for _, root := range roots { 115 if cfg.RegionTags { 116 buf.WriteString("\n// region " + center(70, "*", " "+root+" ") + "\n") 117 } 118 err := t.Lookup(root).Execute(&buf, cfg.Data) 119 if err != nil { 120 return fmt.Errorf("%s: %w", root, err) 121 } 122 if cfg.RegionTags { 123 buf.WriteString("\n// endregion " + center(70, "*", " "+root+" ") + "\n") 124 } 125 } 126 127 var result bytes.Buffer 128 if cfg.GeneratedHeader { 129 result.WriteString("// Code generated by github.com/99designs/gqlgen, DO NOT EDIT.\n\n") 130 } 131 if cfg.PackageDoc != "" { 132 result.WriteString(cfg.PackageDoc + "\n") 133 } 134 result.WriteString("package ") 135 result.WriteString(cfg.PackageName) 136 result.WriteString("\n\n") 137 if cfg.FileNotice != "" { 138 result.WriteString(cfg.FileNotice) 139 result.WriteString("\n\n") 140 } 141 result.WriteString("import (\n") 142 result.WriteString(CurrentImports.String()) 143 result.WriteString(")\n") 144 _, err = buf.WriteTo(&result) 145 if err != nil { 146 return err 147 } 148 CurrentImports = nil 149 150 err = write(cfg.Filename, result.Bytes(), cfg.Packages) 151 if err != nil { 152 return err 153 } 154 155 cfg.Packages.Evict(code.ImportPathForDir(filepath.Dir(cfg.Filename))) 156 return nil 157 } 158 159 func parseTemplates(cfg Options, t *template.Template) (*template.Template, error) { 160 if cfg.Template != "" { 161 var err error 162 t, err = t.New("template.gotpl").Parse(cfg.Template) 163 if err != nil { 164 return nil, fmt.Errorf("error with provided template: %w", err) 165 } 166 return t, nil 167 } 168 169 var fileSystem fs.FS 170 if cfg.TemplateFS != nil { 171 fileSystem = cfg.TemplateFS 172 } else { 173 // load path relative to calling source file 174 _, callerFile, _, _ := runtime.Caller(2) 175 rootDir := filepath.Dir(callerFile) 176 fileSystem = os.DirFS(rootDir) 177 } 178 179 t, err := t.ParseFS(fileSystem, "*.gotpl") 180 if err != nil { 181 return nil, fmt.Errorf("locating templates: %w", err) 182 } 183 184 return t, nil 185 } 186 187 func center(width int, pad string, s string) string { 188 if len(s)+2 > width { 189 return s 190 } 191 lpad := (width - len(s)) / 2 192 rpad := width - (lpad + len(s)) 193 return strings.Repeat(pad, lpad) + s + strings.Repeat(pad, rpad) 194 } 195 196 func Funcs() template.FuncMap { 197 return template.FuncMap{ 198 "ucFirst": UcFirst, 199 "lcFirst": LcFirst, 200 "quote": strconv.Quote, 201 "rawQuote": rawQuote, 202 "dump": Dump, 203 "ref": ref, 204 "ts": TypeIdentifier, 205 "call": Call, 206 "prefixLines": prefixLines, 207 "notNil": notNil, 208 "reserveImport": CurrentImports.Reserve, 209 "lookupImport": CurrentImports.Lookup, 210 "go": ToGo, 211 "goPrivate": ToGoPrivate, 212 "goModelName": ToGoModelName, 213 "goPrivateModelName": ToGoPrivateModelName, 214 "add": func(a, b int) int { 215 return a + b 216 }, 217 "render": func(filename string, tpldata interface{}) (*bytes.Buffer, error) { 218 return render(resolveName(filename, 0), tpldata) 219 }, 220 } 221 } 222 223 func UcFirst(s string) string { 224 if s == "" { 225 return "" 226 } 227 r := []rune(s) 228 r[0] = unicode.ToUpper(r[0]) 229 return string(r) 230 } 231 232 func LcFirst(s string) string { 233 if s == "" { 234 return "" 235 } 236 237 r := []rune(s) 238 r[0] = unicode.ToLower(r[0]) 239 return string(r) 240 } 241 242 func isDelimiter(c rune) bool { 243 return c == '-' || c == '_' || unicode.IsSpace(c) 244 } 245 246 func ref(p types.Type) string { 247 return CurrentImports.LookupType(p) 248 } 249 250 func Call(p *types.Func) string { 251 pkg := CurrentImports.Lookup(p.Pkg().Path()) 252 253 if pkg != "" { 254 pkg += "." 255 } 256 257 if p.Type() != nil { 258 // make sure the returned type is listed in our imports. 259 ref(p.Type().(*types.Signature).Results().At(0).Type()) 260 } 261 262 return pkg + p.Name() 263 } 264 265 func resetModelNames() { 266 modelNamesMu.Lock() 267 defer modelNamesMu.Unlock() 268 modelNames = make(map[string]string, 0) 269 } 270 271 func buildGoModelNameKey(parts []string) string { 272 const sep = ":" 273 return strings.Join(parts, sep) 274 } 275 276 func goModelName(primaryToGoFunc func(string) string, parts []string) string { 277 modelNamesMu.Lock() 278 defer modelNamesMu.Unlock() 279 280 var ( 281 goNameKey string 282 partLen int 283 284 nameExists = func(n string) bool { 285 for _, v := range modelNames { 286 if n == v { 287 return true 288 } 289 } 290 return false 291 } 292 293 applyToGoFunc = func(parts []string) string { 294 var out string 295 switch len(parts) { 296 case 0: 297 return "" 298 case 1: 299 return primaryToGoFunc(parts[0]) 300 default: 301 out = primaryToGoFunc(parts[0]) 302 } 303 for _, p := range parts[1:] { 304 out = fmt.Sprintf("%s%s", out, ToGo(p)) 305 } 306 return out 307 } 308 309 applyValidGoName = func(parts []string) string { 310 var out string 311 for _, p := range parts { 312 out = fmt.Sprintf("%s%s", out, replaceInvalidCharacters(p)) 313 } 314 return out 315 } 316 ) 317 318 // build key for this entity 319 goNameKey = buildGoModelNameKey(parts) 320 321 // determine if we've seen this entity before, and reuse if so 322 if goName, ok := modelNames[goNameKey]; ok { 323 return goName 324 } 325 326 // attempt first pass 327 if goName := applyToGoFunc(parts); !nameExists(goName) { 328 modelNames[goNameKey] = goName 329 return goName 330 } 331 332 // determine number of parts 333 partLen = len(parts) 334 335 // if there is only 1 part, append incrementing number until no conflict 336 if partLen == 1 { 337 base := applyToGoFunc(parts) 338 for i := 0; ; i++ { 339 tmp := fmt.Sprintf("%s%d", base, i) 340 if !nameExists(tmp) { 341 modelNames[goNameKey] = tmp 342 return tmp 343 } 344 } 345 } 346 347 // best effort "pretty" name 348 for i := partLen - 1; i >= 1; i-- { 349 tmp := fmt.Sprintf("%s%s", applyToGoFunc(parts[0:i]), applyValidGoName(parts[i:])) 350 if !nameExists(tmp) { 351 modelNames[goNameKey] = tmp 352 return tmp 353 } 354 } 355 356 // finally, fallback to just adding an incrementing number 357 base := applyToGoFunc(parts) 358 for i := 0; ; i++ { 359 tmp := fmt.Sprintf("%s%d", base, i) 360 if !nameExists(tmp) { 361 modelNames[goNameKey] = tmp 362 return tmp 363 } 364 } 365 } 366 367 func ToGoModelName(parts ...string) string { 368 return goModelName(ToGo, parts) 369 } 370 371 func ToGoPrivateModelName(parts ...string) string { 372 return goModelName(ToGoPrivate, parts) 373 } 374 375 func replaceInvalidCharacters(in string) string { 376 return goNameRe.ReplaceAllLiteralString(in, "_") 377 } 378 379 func wordWalkerFunc(private bool, nameRunes *[]rune) func(*wordInfo) { 380 return func(info *wordInfo) { 381 word := info.Word 382 383 switch { 384 case private && info.WordOffset == 0: 385 if strings.ToUpper(word) == word || strings.ToLower(word) == word { 386 // ID → id, CAMEL → camel 387 word = strings.ToLower(info.Word) 388 } else { 389 // ITicket → iTicket 390 word = LcFirst(info.Word) 391 } 392 393 case info.MatchCommonInitial: 394 word = strings.ToUpper(word) 395 396 case !info.HasCommonInitial && (strings.ToUpper(word) == word || strings.ToLower(word) == word): 397 // FOO or foo → Foo 398 // FOo → FOo 399 word = UcFirst(strings.ToLower(word)) 400 } 401 402 *nameRunes = append(*nameRunes, []rune(word)...) 403 } 404 } 405 406 func ToGo(name string) string { 407 if name == "_" { 408 return "_" 409 } 410 runes := make([]rune, 0, len(name)) 411 412 wordWalker(name, wordWalkerFunc(false, &runes)) 413 414 return string(runes) 415 } 416 417 func ToGoPrivate(name string) string { 418 if name == "_" { 419 return "_" 420 } 421 runes := make([]rune, 0, len(name)) 422 423 wordWalker(name, wordWalkerFunc(true, &runes)) 424 425 return sanitizeKeywords(string(runes)) 426 } 427 428 type wordInfo struct { 429 WordOffset int 430 Word string 431 MatchCommonInitial bool 432 HasCommonInitial bool 433 } 434 435 // This function is based on the following code. 436 // https://github.com/golang/lint/blob/06c8688daad7faa9da5a0c2f163a3d14aac986ca/lint.go#L679 437 func wordWalker(str string, f func(*wordInfo)) { 438 runes := []rune(strings.TrimFunc(str, isDelimiter)) 439 w, i, wo := 0, 0, 0 // index of start of word, scan, word offset 440 hasCommonInitial := false 441 for i+1 <= len(runes) { 442 eow := false // whether we hit the end of a word 443 switch { 444 case i+1 == len(runes): 445 eow = true 446 case isDelimiter(runes[i+1]): 447 // underscore; shift the remainder forward over any run of underscores 448 eow = true 449 n := 1 450 for i+n+1 < len(runes) && isDelimiter(runes[i+n+1]) { 451 n++ 452 } 453 454 // Leave at most one underscore if the underscore is between two digits 455 if i+n+1 < len(runes) && unicode.IsDigit(runes[i]) && unicode.IsDigit(runes[i+n+1]) { 456 n-- 457 } 458 459 copy(runes[i+1:], runes[i+n+1:]) 460 runes = runes[:len(runes)-n] 461 case unicode.IsLower(runes[i]) && !unicode.IsLower(runes[i+1]): 462 // lower->non-lower 463 eow = true 464 } 465 i++ 466 467 initialisms := GetInitialisms() 468 // [w,i) is a word. 469 word := string(runes[w:i]) 470 if !eow && initialisms[word] && !unicode.IsLower(runes[i]) { 471 // through 472 // split IDFoo → ID, Foo 473 // but URLs → URLs 474 } else if !eow { 475 if initialisms[word] { 476 hasCommonInitial = true 477 } 478 continue 479 } 480 481 matchCommonInitial := false 482 upperWord := strings.ToUpper(word) 483 if initialisms[upperWord] { 484 // If the uppercase word (string(runes[w:i]) is "ID" or "IP" 485 // AND 486 // the word is the first two characters of the str 487 // AND 488 // that is not the end of the word 489 // AND 490 // the length of the string is greater than 3 491 // AND 492 // the third rune is an uppercase one 493 // THEN 494 // do NOT count this as an initialism. 495 switch upperWord { 496 case "ID", "IP": 497 if word == str[:2] && !eow && len(str) > 3 && unicode.IsUpper(runes[3]) { 498 continue 499 } 500 } 501 hasCommonInitial = true 502 matchCommonInitial = true 503 } 504 505 f(&wordInfo{ 506 WordOffset: wo, 507 Word: word, 508 MatchCommonInitial: matchCommonInitial, 509 HasCommonInitial: hasCommonInitial, 510 }) 511 hasCommonInitial = false 512 w = i 513 wo++ 514 } 515 } 516 517 var keywords = []string{ 518 "break", 519 "default", 520 "func", 521 "interface", 522 "select", 523 "case", 524 "defer", 525 "go", 526 "map", 527 "struct", 528 "chan", 529 "else", 530 "goto", 531 "package", 532 "switch", 533 "const", 534 "fallthrough", 535 "if", 536 "range", 537 "type", 538 "continue", 539 "for", 540 "import", 541 "return", 542 "var", 543 "_", 544 } 545 546 // sanitizeKeywords prevents collisions with go keywords for arguments to resolver functions 547 func sanitizeKeywords(name string) string { 548 for _, k := range keywords { 549 if name == k { 550 return name + "Arg" 551 } 552 } 553 return name 554 } 555 556 func rawQuote(s string) string { 557 return "`" + strings.ReplaceAll(s, "`", "`+\"`\"+`") + "`" 558 } 559 560 func notNil(field string, data interface{}) bool { 561 v := reflect.ValueOf(data) 562 563 if v.Kind() == reflect.Ptr { 564 v = v.Elem() 565 } 566 if v.Kind() != reflect.Struct { 567 return false 568 } 569 val := v.FieldByName(field) 570 571 return val.IsValid() && !val.IsNil() 572 } 573 574 func Dump(val interface{}) string { 575 switch val := val.(type) { 576 case int: 577 return strconv.Itoa(val) 578 case int64: 579 return fmt.Sprintf("%d", val) 580 case float64: 581 return fmt.Sprintf("%f", val) 582 case string: 583 return strconv.Quote(val) 584 case bool: 585 return strconv.FormatBool(val) 586 case nil: 587 return "nil" 588 case []interface{}: 589 var parts []string 590 for _, part := range val { 591 parts = append(parts, Dump(part)) 592 } 593 return "[]interface{}{" + strings.Join(parts, ",") + "}" 594 case map[string]interface{}: 595 buf := bytes.Buffer{} 596 buf.WriteString("map[string]interface{}{") 597 var keys []string 598 for key := range val { 599 keys = append(keys, key) 600 } 601 sort.Strings(keys) 602 603 for _, key := range keys { 604 data := val[key] 605 606 buf.WriteString(strconv.Quote(key)) 607 buf.WriteString(":") 608 buf.WriteString(Dump(data)) 609 buf.WriteString(",") 610 } 611 buf.WriteString("}") 612 return buf.String() 613 default: 614 panic(fmt.Errorf("unsupported type %T", val)) 615 } 616 } 617 618 func prefixLines(prefix, s string) string { 619 return prefix + strings.ReplaceAll(s, "\n", "\n"+prefix) 620 } 621 622 func resolveName(name string, skip int) string { 623 if name[0] == '.' { 624 // load path relative to calling source file 625 _, callerFile, _, _ := runtime.Caller(skip + 1) 626 return filepath.Join(filepath.Dir(callerFile), name[1:]) 627 } 628 629 // load path relative to this directory 630 _, callerFile, _, _ := runtime.Caller(0) 631 return filepath.Join(filepath.Dir(callerFile), name) 632 } 633 634 func render(filename string, tpldata interface{}) (*bytes.Buffer, error) { 635 t := template.New("").Funcs(Funcs()) 636 637 b, err := os.ReadFile(filename) 638 if err != nil { 639 return nil, err 640 } 641 642 t, err = t.New(filepath.Base(filename)).Parse(string(b)) 643 if err != nil { 644 panic(err) 645 } 646 647 buf := &bytes.Buffer{} 648 return buf, t.Execute(buf, tpldata) 649 } 650 651 func write(filename string, b []byte, packages *code.Packages) error { 652 err := os.MkdirAll(filepath.Dir(filename), 0o755) 653 if err != nil { 654 return fmt.Errorf("failed to create directory: %w", err) 655 } 656 657 formatted, err := imports.Prune(filename, b, packages) 658 if err != nil { 659 fmt.Fprintf(os.Stderr, "gofmt failed on %s: %s\n", filepath.Base(filename), err.Error()) 660 formatted = b 661 } 662 663 err = os.WriteFile(filename, formatted, 0o644) 664 if err != nil { 665 return fmt.Errorf("failed to write %s: %w", filename, err) 666 } 667 668 return nil 669 } 670 671 var pkgReplacer = strings.NewReplacer( 672 "/", "ᚋ", 673 ".", "ᚗ", 674 "-", "ᚑ", 675 "~", "א", 676 ) 677 678 func TypeIdentifier(t types.Type) string { 679 res := "" 680 for { 681 switch it := t.(type) { 682 case *types.Pointer: 683 t.Underlying() 684 res += "ᚖ" 685 t = it.Elem() 686 case *types.Slice: 687 res += "ᚕ" 688 t = it.Elem() 689 case *types.Named: 690 res += pkgReplacer.Replace(it.Obj().Pkg().Path()) 691 res += "ᚐ" 692 res += it.Obj().Name() 693 return res 694 case *types.Basic: 695 res += it.Name() 696 return res 697 case *types.Map: 698 res += "map" 699 return res 700 case *types.Interface: 701 res += "interface" 702 return res 703 default: 704 panic(fmt.Errorf("unexpected type %T", it)) 705 } 706 } 707 } 708 709 // CommonInitialisms is a set of common initialisms. 710 // Only add entries that are highly unlikely to be non-initialisms. 711 // For instance, "ID" is fine (Freudian code is rare), but "AND" is not. 712 var CommonInitialisms = map[string]bool{ 713 "ACL": true, 714 "API": true, 715 "ASCII": true, 716 "CPU": true, 717 "CSS": true, 718 "CSV": true, 719 "DNS": true, 720 "EOF": true, 721 "GUID": true, 722 "HTML": true, 723 "HTTP": true, 724 "HTTPS": true, 725 "ICMP": true, 726 "ID": true, 727 "IP": true, 728 "JSON": true, 729 "KVK": true, 730 "LHS": true, 731 "PDF": true, 732 "PGP": true, 733 "QPS": true, 734 "QR": true, 735 "RAM": true, 736 "RHS": true, 737 "RPC": true, 738 "SLA": true, 739 "SMTP": true, 740 "SQL": true, 741 "SSH": true, 742 "SVG": true, 743 "TCP": true, 744 "TLS": true, 745 "TTL": true, 746 "UDP": true, 747 "UI": true, 748 "UID": true, 749 "URI": true, 750 "URL": true, 751 "UTF8": true, 752 "UUID": true, 753 "VM": true, 754 "XML": true, 755 "XMPP": true, 756 "XSRF": true, 757 "XSS": true, 758 } 759 760 // GetInitialisms returns the initialisms to capitalize in Go names. If unchanged, default initialisms will be returned 761 var GetInitialisms = func() map[string]bool { 762 return CommonInitialisms 763 }