github.com/spread-ai/gqlgen@v0.0.0-20221124102857-a6c8ef538a1d/codegen/templates/templates.go (about) 1 package templates 2 3 import ( 4 "bytes" 5 "fmt" 6 "go/types" 7 "io/fs" 8 "os" 9 "path/filepath" 10 "reflect" 11 "regexp" 12 "runtime" 13 "sort" 14 "strconv" 15 "strings" 16 "sync" 17 "text/template" 18 "unicode" 19 20 "github.com/spread-ai/gqlgen/internal/code" 21 22 "github.com/spread-ai/gqlgen/internal/imports" 23 ) 24 25 // CurrentImports keeps track of all the import declarations that are needed during the execution of a plugin. 26 // this is done with a global because subtemplates currently get called in functions. Lets aim to remove this eventually. 27 var CurrentImports *Imports 28 29 // Options specify various parameters to rendering a template. 30 type Options struct { 31 // PackageName is a helper that specifies the package header declaration. 32 // In other words, when you write the template you don't need to specify `package X` 33 // at the top of the file. By providing PackageName in the Options, the Render 34 // function will do that for you. 35 PackageName string 36 // Template is a string of the entire template that 37 // will be parsed and rendered. If it's empty, 38 // the plugin processor will look for .gotpl files 39 // in the same directory of where you wrote the plugin. 40 Template string 41 42 // Use the go:embed API to collect all the template files you want to pass into Render 43 // this is an alternative to passing the Template option 44 TemplateFS fs.FS 45 46 // Filename is the name of the file that will be 47 // written to the system disk once the template is rendered. 48 Filename string 49 RegionTags bool 50 GeneratedHeader bool 51 // PackageDoc is documentation written above the package line 52 PackageDoc string 53 // FileNotice is notice written below the package line 54 FileNotice string 55 // Data will be passed to the template execution. 56 Data interface{} 57 Funcs template.FuncMap 58 59 // Packages cache, you can find me on config.Config 60 Packages *code.Packages 61 } 62 63 var ( 64 modelNamesMu sync.Mutex 65 modelNames = make(map[string]string, 0) 66 goNameRe = regexp.MustCompile("[^a-zA-Z0-9_]") 67 ) 68 69 // Render renders a gql plugin template from the given Options. Render is an 70 // abstraction of the text/template package that makes it easier to write gqlgen 71 // plugins. If Options.Template is empty, the Render function will look for `.gotpl` 72 // files inside the directory where you wrote the plugin. 73 func Render(cfg Options) error { 74 if CurrentImports != nil { 75 panic(fmt.Errorf("recursive or concurrent call to RenderToFile detected")) 76 } 77 CurrentImports = &Imports{packages: cfg.Packages, destDir: filepath.Dir(cfg.Filename)} 78 79 funcs := Funcs() 80 for n, f := range cfg.Funcs { 81 funcs[n] = f 82 } 83 84 t := template.New("").Funcs(funcs) 85 t, err := parseTemplates(cfg, t) 86 if err != nil { 87 return err 88 } 89 90 roots := make([]string, 0, len(t.Templates())) 91 for _, template := range t.Templates() { 92 // templates that end with _.gotpl are special files we don't want to include 93 if strings.HasSuffix(template.Name(), "_.gotpl") || 94 // filter out templates added with {{ template xxx }} syntax inside the template file 95 !strings.HasSuffix(template.Name(), ".gotpl") { 96 continue 97 } 98 99 roots = append(roots, template.Name()) 100 } 101 102 // then execute all the important looking ones in order, adding them to the same file 103 sort.Slice(roots, func(i, j int) bool { 104 // important files go first 105 if strings.HasSuffix(roots[i], "!.gotpl") { 106 return true 107 } 108 if strings.HasSuffix(roots[j], "!.gotpl") { 109 return false 110 } 111 return roots[i] < roots[j] 112 }) 113 114 var buf bytes.Buffer 115 for _, root := range roots { 116 if cfg.RegionTags { 117 buf.WriteString("\n// region " + center(70, "*", " "+root+" ") + "\n") 118 } 119 err := t.Lookup(root).Execute(&buf, cfg.Data) 120 if err != nil { 121 return fmt.Errorf("%s: %w", root, err) 122 } 123 if cfg.RegionTags { 124 buf.WriteString("\n// endregion " + center(70, "*", " "+root+" ") + "\n") 125 } 126 } 127 128 var result bytes.Buffer 129 if cfg.GeneratedHeader { 130 result.WriteString("// Code generated by github.com/spread-ai/gqlgen, DO NOT EDIT.\n\n") 131 } 132 if cfg.PackageDoc != "" { 133 result.WriteString(cfg.PackageDoc + "\n") 134 } 135 result.WriteString("package ") 136 result.WriteString(cfg.PackageName) 137 result.WriteString("\n\n") 138 if cfg.FileNotice != "" { 139 result.WriteString(cfg.FileNotice) 140 result.WriteString("\n\n") 141 } 142 result.WriteString("import (\n") 143 result.WriteString(CurrentImports.String()) 144 result.WriteString(")\n") 145 _, err = buf.WriteTo(&result) 146 if err != nil { 147 return err 148 } 149 CurrentImports = nil 150 151 err = write(cfg.Filename, result.Bytes(), cfg.Packages) 152 if err != nil { 153 return err 154 } 155 156 cfg.Packages.Evict(code.ImportPathForDir(filepath.Dir(cfg.Filename))) 157 return nil 158 } 159 160 func parseTemplates(cfg Options, t *template.Template) (*template.Template, error) { 161 if cfg.Template != "" { 162 var err error 163 t, err = t.New("template.gotpl").Parse(cfg.Template) 164 if err != nil { 165 return nil, fmt.Errorf("error with provided template: %w", err) 166 } 167 return t, nil 168 } 169 170 var fileSystem fs.FS 171 if cfg.TemplateFS != nil { 172 fileSystem = cfg.TemplateFS 173 } else { 174 // load path relative to calling source file 175 _, callerFile, _, _ := runtime.Caller(1) 176 rootDir := filepath.Dir(callerFile) 177 fileSystem = os.DirFS(rootDir) 178 } 179 180 t, err := t.ParseFS(fileSystem, "*.gotpl") 181 if err != nil { 182 return nil, fmt.Errorf("locating templates: %w", err) 183 } 184 185 return t, nil 186 } 187 188 func center(width int, pad string, s string) string { 189 if len(s)+2 > width { 190 return s 191 } 192 lpad := (width - len(s)) / 2 193 rpad := width - (lpad + len(s)) 194 return strings.Repeat(pad, lpad) + s + strings.Repeat(pad, rpad) 195 } 196 197 func Funcs() template.FuncMap { 198 return template.FuncMap{ 199 "ucFirst": UcFirst, 200 "lcFirst": LcFirst, 201 "quote": strconv.Quote, 202 "rawQuote": rawQuote, 203 "dump": Dump, 204 "ref": ref, 205 "ts": TypeIdentifier, 206 "call": Call, 207 "prefixLines": prefixLines, 208 "notNil": notNil, 209 "reserveImport": CurrentImports.Reserve, 210 "lookupImport": CurrentImports.Lookup, 211 "go": ToGo, 212 "goPrivate": ToGoPrivate, 213 "goModelName": ToGoModelName, 214 "goPrivateModelName": ToGoPrivateModelName, 215 "add": func(a, b int) int { 216 return a + b 217 }, 218 "render": func(filename string, tpldata interface{}) (*bytes.Buffer, error) { 219 return render(resolveName(filename, 0), tpldata) 220 }, 221 } 222 } 223 224 func UcFirst(s string) string { 225 if s == "" { 226 return "" 227 } 228 r := []rune(s) 229 r[0] = unicode.ToUpper(r[0]) 230 return string(r) 231 } 232 233 func LcFirst(s string) string { 234 if s == "" { 235 return "" 236 } 237 238 r := []rune(s) 239 r[0] = unicode.ToLower(r[0]) 240 return string(r) 241 } 242 243 func isDelimiter(c rune) bool { 244 return c == '-' || c == '_' || unicode.IsSpace(c) 245 } 246 247 func ref(p types.Type) string { 248 return CurrentImports.LookupType(p) 249 } 250 251 var pkgReplacer = strings.NewReplacer( 252 "/", "ᚋ", 253 ".", "ᚗ", 254 "-", "ᚑ", 255 "~", "א", 256 ) 257 258 func TypeIdentifier(t types.Type) string { 259 res := "" 260 for { 261 switch it := t.(type) { 262 case *types.Pointer: 263 t.Underlying() 264 res += "ᚖ" 265 t = it.Elem() 266 case *types.Slice: 267 res += "ᚕ" 268 t = it.Elem() 269 case *types.Named: 270 res += pkgReplacer.Replace(it.Obj().Pkg().Path()) 271 res += "ᚐ" 272 res += it.Obj().Name() 273 return res 274 case *types.Basic: 275 res += it.Name() 276 return res 277 case *types.Map: 278 res += "map" 279 return res 280 case *types.Interface: 281 res += "interface" 282 return res 283 default: 284 panic(fmt.Errorf("unexpected type %T", it)) 285 } 286 } 287 } 288 289 func Call(p *types.Func) string { 290 pkg := CurrentImports.Lookup(p.Pkg().Path()) 291 292 if pkg != "" { 293 pkg += "." 294 } 295 296 if p.Type() != nil { 297 // make sure the returned type is listed in our imports. 298 ref(p.Type().(*types.Signature).Results().At(0).Type()) 299 } 300 301 return pkg + p.Name() 302 } 303 304 func resetModelNames() { 305 modelNamesMu.Lock() 306 defer modelNamesMu.Unlock() 307 modelNames = make(map[string]string, 0) 308 } 309 310 func buildGoModelNameKey(parts []string) string { 311 const sep = ":" 312 return strings.Join(parts, sep) 313 } 314 315 func goModelName(primaryToGoFunc func(string) string, parts []string) string { 316 modelNamesMu.Lock() 317 defer modelNamesMu.Unlock() 318 319 var ( 320 goNameKey string 321 partLen int 322 323 nameExists = func(n string) bool { 324 for _, v := range modelNames { 325 if n == v { 326 return true 327 } 328 } 329 return false 330 } 331 332 applyToGoFunc = func(parts []string) string { 333 var out string 334 switch len(parts) { 335 case 0: 336 return "" 337 case 1: 338 return primaryToGoFunc(parts[0]) 339 default: 340 out = primaryToGoFunc(parts[0]) 341 } 342 for _, p := range parts[1:] { 343 out = fmt.Sprintf("%s%s", out, ToGo(p)) 344 } 345 return out 346 } 347 348 applyValidGoName = func(parts []string) string { 349 var out string 350 for _, p := range parts { 351 out = fmt.Sprintf("%s%s", out, replaceInvalidCharacters(p)) 352 } 353 return out 354 } 355 ) 356 357 // build key for this entity 358 goNameKey = buildGoModelNameKey(parts) 359 360 // determine if we've seen this entity before, and reuse if so 361 if goName, ok := modelNames[goNameKey]; ok { 362 return goName 363 } 364 365 // attempt first pass 366 if goName := applyToGoFunc(parts); !nameExists(goName) { 367 modelNames[goNameKey] = goName 368 return goName 369 } 370 371 // determine number of parts 372 partLen = len(parts) 373 374 // if there is only 1 part, append incrementing number until no conflict 375 if partLen == 1 { 376 base := applyToGoFunc(parts) 377 for i := 0; ; i++ { 378 tmp := fmt.Sprintf("%s%d", base, i) 379 if !nameExists(tmp) { 380 modelNames[goNameKey] = tmp 381 return tmp 382 } 383 } 384 } 385 386 // best effort "pretty" name 387 for i := partLen - 1; i >= 1; i-- { 388 tmp := fmt.Sprintf("%s%s", applyToGoFunc(parts[0:i]), applyValidGoName(parts[i:])) 389 if !nameExists(tmp) { 390 modelNames[goNameKey] = tmp 391 return tmp 392 } 393 } 394 395 // finally, fallback to just adding an incrementing number 396 base := applyToGoFunc(parts) 397 for i := 0; ; i++ { 398 tmp := fmt.Sprintf("%s%d", base, i) 399 if !nameExists(tmp) { 400 modelNames[goNameKey] = tmp 401 return tmp 402 } 403 } 404 } 405 406 func ToGoModelName(parts ...string) string { 407 return goModelName(ToGo, parts) 408 } 409 410 func ToGoPrivateModelName(parts ...string) string { 411 return goModelName(ToGoPrivate, parts) 412 } 413 414 func replaceInvalidCharacters(in string) string { 415 return goNameRe.ReplaceAllLiteralString(in, "_") 416 } 417 418 func wordWalkerFunc(private bool, nameRunes *[]rune) func(*wordInfo) { 419 return func(info *wordInfo) { 420 word := info.Word 421 422 switch { 423 case private && info.WordOffset == 0: 424 if strings.ToUpper(word) == word || strings.ToLower(word) == word { 425 // ID → id, CAMEL → camel 426 word = strings.ToLower(info.Word) 427 } else { 428 // ITicket → iTicket 429 word = LcFirst(info.Word) 430 } 431 432 case info.MatchCommonInitial: 433 word = strings.ToUpper(word) 434 435 case !info.HasCommonInitial && (strings.ToUpper(word) == word || strings.ToLower(word) == word): 436 // FOO or foo → Foo 437 // FOo → FOo 438 word = UcFirst(strings.ToLower(word)) 439 } 440 441 *nameRunes = append(*nameRunes, []rune(word)...) 442 } 443 } 444 445 func ToGo(name string) string { 446 if name == "_" { 447 return "_" 448 } 449 runes := make([]rune, 0, len(name)) 450 451 wordWalker(name, wordWalkerFunc(false, &runes)) 452 453 return string(runes) 454 } 455 456 func ToGoPrivate(name string) string { 457 if name == "_" { 458 return "_" 459 } 460 runes := make([]rune, 0, len(name)) 461 462 wordWalker(name, wordWalkerFunc(true, &runes)) 463 464 return sanitizeKeywords(string(runes)) 465 } 466 467 type wordInfo struct { 468 WordOffset int 469 Word string 470 MatchCommonInitial bool 471 HasCommonInitial bool 472 } 473 474 // This function is based on the following code. 475 // https://github.com/golang/lint/blob/06c8688daad7faa9da5a0c2f163a3d14aac986ca/lint.go#L679 476 func wordWalker(str string, f func(*wordInfo)) { 477 runes := []rune(strings.TrimFunc(str, isDelimiter)) 478 w, i, wo := 0, 0, 0 // index of start of word, scan, word offset 479 hasCommonInitial := false 480 for i+1 <= len(runes) { 481 eow := false // whether we hit the end of a word 482 switch { 483 case i+1 == len(runes): 484 eow = true 485 case isDelimiter(runes[i+1]): 486 // underscore; shift the remainder forward over any run of underscores 487 eow = true 488 n := 1 489 for i+n+1 < len(runes) && isDelimiter(runes[i+n+1]) { 490 n++ 491 } 492 493 // Leave at most one underscore if the underscore is between two digits 494 if i+n+1 < len(runes) && unicode.IsDigit(runes[i]) && unicode.IsDigit(runes[i+n+1]) { 495 n-- 496 } 497 498 copy(runes[i+1:], runes[i+n+1:]) 499 runes = runes[:len(runes)-n] 500 case unicode.IsLower(runes[i]) && !unicode.IsLower(runes[i+1]): 501 // lower->non-lower 502 eow = true 503 } 504 i++ 505 506 // [w,i) is a word. 507 word := string(runes[w:i]) 508 if !eow && commonInitialisms[word] && !unicode.IsLower(runes[i]) { 509 // through 510 // split IDFoo → ID, Foo 511 // but URLs → URLs 512 } else if !eow { 513 if commonInitialisms[word] { 514 hasCommonInitial = true 515 } 516 continue 517 } 518 519 matchCommonInitial := false 520 if commonInitialisms[strings.ToUpper(word)] { 521 hasCommonInitial = true 522 matchCommonInitial = true 523 } 524 525 f(&wordInfo{ 526 WordOffset: wo, 527 Word: word, 528 MatchCommonInitial: matchCommonInitial, 529 HasCommonInitial: hasCommonInitial, 530 }) 531 hasCommonInitial = false 532 w = i 533 wo++ 534 } 535 } 536 537 var keywords = []string{ 538 "break", 539 "default", 540 "func", 541 "interface", 542 "select", 543 "case", 544 "defer", 545 "go", 546 "map", 547 "struct", 548 "chan", 549 "else", 550 "goto", 551 "package", 552 "switch", 553 "const", 554 "fallthrough", 555 "if", 556 "range", 557 "type", 558 "continue", 559 "for", 560 "import", 561 "return", 562 "var", 563 "_", 564 } 565 566 // sanitizeKeywords prevents collisions with go keywords for arguments to resolver functions 567 func sanitizeKeywords(name string) string { 568 for _, k := range keywords { 569 if name == k { 570 return name + "Arg" 571 } 572 } 573 return name 574 } 575 576 // commonInitialisms is a set of common initialisms. 577 // Only add entries that are highly unlikely to be non-initialisms. 578 // For instance, "ID" is fine (Freudian code is rare), but "AND" is not. 579 var commonInitialisms = map[string]bool{ 580 "ACL": true, 581 "API": true, 582 "ASCII": true, 583 "CPU": true, 584 "CSS": true, 585 "CSV": true, 586 "DNS": true, 587 "EOF": true, 588 "GUID": true, 589 "HTML": true, 590 "HTTP": true, 591 "HTTPS": true, 592 "ICMP": true, 593 "ID": true, 594 "IP": true, 595 "JSON": true, 596 "KVK": true, 597 "LHS": true, 598 "PDF": true, 599 "PGP": true, 600 "QPS": true, 601 "QR": true, 602 "RAM": true, 603 "RHS": true, 604 "RPC": true, 605 "SLA": true, 606 "SMTP": true, 607 "SQL": true, 608 "SSH": true, 609 "SVG": true, 610 "TCP": true, 611 "TLS": true, 612 "TTL": true, 613 "UDP": true, 614 "UI": true, 615 "UID": true, 616 "URI": true, 617 "URL": true, 618 "UTF8": true, 619 "UUID": true, 620 "VM": true, 621 "XML": true, 622 "XMPP": true, 623 "XSRF": true, 624 "XSS": true, 625 } 626 627 func rawQuote(s string) string { 628 return "`" + strings.ReplaceAll(s, "`", "`+\"`\"+`") + "`" 629 } 630 631 func notNil(field string, data interface{}) bool { 632 v := reflect.ValueOf(data) 633 634 if v.Kind() == reflect.Ptr { 635 v = v.Elem() 636 } 637 if v.Kind() != reflect.Struct { 638 return false 639 } 640 val := v.FieldByName(field) 641 642 return val.IsValid() && !val.IsNil() 643 } 644 645 func Dump(val interface{}) string { 646 switch val := val.(type) { 647 case int: 648 return strconv.Itoa(val) 649 case int64: 650 return fmt.Sprintf("%d", val) 651 case float64: 652 return fmt.Sprintf("%f", val) 653 case string: 654 return strconv.Quote(val) 655 case bool: 656 return strconv.FormatBool(val) 657 case nil: 658 return "nil" 659 case []interface{}: 660 var parts []string 661 for _, part := range val { 662 parts = append(parts, Dump(part)) 663 } 664 return "[]interface{}{" + strings.Join(parts, ",") + "}" 665 case map[string]interface{}: 666 buf := bytes.Buffer{} 667 buf.WriteString("map[string]interface{}{") 668 var keys []string 669 for key := range val { 670 keys = append(keys, key) 671 } 672 sort.Strings(keys) 673 674 for _, key := range keys { 675 data := val[key] 676 677 buf.WriteString(strconv.Quote(key)) 678 buf.WriteString(":") 679 buf.WriteString(Dump(data)) 680 buf.WriteString(",") 681 } 682 buf.WriteString("}") 683 return buf.String() 684 default: 685 panic(fmt.Errorf("unsupported type %T", val)) 686 } 687 } 688 689 func prefixLines(prefix, s string) string { 690 return prefix + strings.ReplaceAll(s, "\n", "\n"+prefix) 691 } 692 693 func resolveName(name string, skip int) string { 694 if name[0] == '.' { 695 // load path relative to calling source file 696 _, callerFile, _, _ := runtime.Caller(skip + 1) 697 return filepath.Join(filepath.Dir(callerFile), name[1:]) 698 } 699 700 // load path relative to this directory 701 _, callerFile, _, _ := runtime.Caller(0) 702 return filepath.Join(filepath.Dir(callerFile), name) 703 } 704 705 func render(filename string, tpldata interface{}) (*bytes.Buffer, error) { 706 t := template.New("").Funcs(Funcs()) 707 708 b, err := os.ReadFile(filename) 709 if err != nil { 710 return nil, err 711 } 712 713 t, err = t.New(filepath.Base(filename)).Parse(string(b)) 714 if err != nil { 715 panic(err) 716 } 717 718 buf := &bytes.Buffer{} 719 return buf, t.Execute(buf, tpldata) 720 } 721 722 func write(filename string, b []byte, packages *code.Packages) error { 723 err := os.MkdirAll(filepath.Dir(filename), 0o755) 724 if err != nil { 725 return fmt.Errorf("failed to create directory: %w", err) 726 } 727 728 formatted, err := imports.Prune(filename, b, packages) 729 if err != nil { 730 fmt.Fprintf(os.Stderr, "gofmt failed on %s: %s\n", filepath.Base(filename), err.Error()) 731 formatted = b 732 } 733 734 err = os.WriteFile(filename, formatted, 0o644) 735 if err != nil { 736 return fmt.Errorf("failed to write %s: %w", filename, err) 737 } 738 739 return nil 740 }