github.com/geneva/gqlgen@v0.17.7-0.20230801155730-7b9317164836/codegen/templates/templates.go (about) 1 package templates 2 3 import ( 4 "bytes" 5 "fmt" 6 "go/types" 7 "io/fs" 8 "os" 9 "path/filepath" 10 "reflect" 11 "regexp" 12 "runtime" 13 "sort" 14 "strconv" 15 "strings" 16 "sync" 17 "text/template" 18 "unicode" 19 20 "github.com/geneva/gqlgen/codegen/config" 21 "github.com/geneva/gqlgen/internal/code" 22 "github.com/geneva/gqlgen/internal/imports" 23 ) 24 25 // CurrentImports keeps track of all the import declarations that are needed during the execution of a plugin. 26 // this is done with a global because subtemplates currently get called in functions. Lets aim to remove this eventually. 27 var CurrentImports *Imports 28 29 // Options specify various parameters to rendering a template. 30 type Options struct { 31 // PackageName is a helper that specifies the package header declaration. 32 // In other words, when you write the template you don't need to specify `package X` 33 // at the top of the file. By providing PackageName in the Options, the Render 34 // function will do that for you. 35 PackageName string 36 // Template is a string of the entire template that 37 // will be parsed and rendered. If it's empty, 38 // the plugin processor will look for .gotpl files 39 // in the same directory of where you wrote the plugin. 40 Template string 41 42 // Use the go:embed API to collect all the template files you want to pass into Render 43 // this is an alternative to passing the Template option 44 TemplateFS fs.FS 45 46 // Filename is the name of the file that will be 47 // written to the system disk once the template is rendered. 48 Filename string 49 RegionTags bool 50 GeneratedHeader bool 51 // PackageDoc is documentation written above the package line 52 PackageDoc string 53 // FileNotice is notice written below the package line 54 FileNotice string 55 // Data will be passed to the template execution. 56 Data interface{} 57 Funcs template.FuncMap 58 59 // Packages cache, you can find me on config.Config 60 Packages *code.Packages 61 } 62 63 var ( 64 modelNamesMu sync.Mutex 65 modelNames = make(map[string]string, 0) 66 goNameRe = regexp.MustCompile("[^a-zA-Z0-9_]") 67 ) 68 69 // Render renders a gql plugin template from the given Options. Render is an 70 // abstraction of the text/template package that makes it easier to write gqlgen 71 // plugins. If Options.Template is empty, the Render function will look for `.gotpl` 72 // files inside the directory where you wrote the plugin. 73 func Render(cfg Options) error { 74 if CurrentImports != nil { 75 panic(fmt.Errorf("recursive or concurrent call to RenderToFile detected")) 76 } 77 CurrentImports = &Imports{packages: cfg.Packages, destDir: filepath.Dir(cfg.Filename)} 78 79 funcs := Funcs() 80 for n, f := range cfg.Funcs { 81 funcs[n] = f 82 } 83 84 t := template.New("").Funcs(funcs) 85 t, err := parseTemplates(cfg, t) 86 if err != nil { 87 return err 88 } 89 90 roots := make([]string, 0, len(t.Templates())) 91 for _, template := range t.Templates() { 92 // templates that end with _.gotpl are special files we don't want to include 93 if strings.HasSuffix(template.Name(), "_.gotpl") || 94 // filter out templates added with {{ template xxx }} syntax inside the template file 95 !strings.HasSuffix(template.Name(), ".gotpl") { 96 continue 97 } 98 99 roots = append(roots, template.Name()) 100 } 101 102 // then execute all the important looking ones in order, adding them to the same file 103 sort.Slice(roots, func(i, j int) bool { 104 // important files go first 105 if strings.HasSuffix(roots[i], "!.gotpl") { 106 return true 107 } 108 if strings.HasSuffix(roots[j], "!.gotpl") { 109 return false 110 } 111 return roots[i] < roots[j] 112 }) 113 114 var buf bytes.Buffer 115 for _, root := range roots { 116 if cfg.RegionTags { 117 buf.WriteString("\n// region " + center(70, "*", " "+root+" ") + "\n") 118 } 119 err := t.Lookup(root).Execute(&buf, cfg.Data) 120 if err != nil { 121 return fmt.Errorf("%s: %w", root, err) 122 } 123 if cfg.RegionTags { 124 buf.WriteString("\n// endregion " + center(70, "*", " "+root+" ") + "\n") 125 } 126 } 127 128 var result bytes.Buffer 129 if cfg.GeneratedHeader { 130 result.WriteString("// Code generated by github.com/geneva/gqlgen, DO NOT EDIT.\n\n") 131 } 132 if cfg.PackageDoc != "" { 133 result.WriteString(cfg.PackageDoc + "\n") 134 } 135 result.WriteString("package ") 136 result.WriteString(cfg.PackageName) 137 result.WriteString("\n\n") 138 if cfg.FileNotice != "" { 139 result.WriteString(cfg.FileNotice) 140 result.WriteString("\n\n") 141 } 142 result.WriteString("import (\n") 143 result.WriteString(CurrentImports.String()) 144 result.WriteString(")\n") 145 _, err = buf.WriteTo(&result) 146 if err != nil { 147 return err 148 } 149 CurrentImports = nil 150 151 err = write(cfg.Filename, result.Bytes(), cfg.Packages) 152 if err != nil { 153 return err 154 } 155 156 cfg.Packages.Evict(code.ImportPathForDir(filepath.Dir(cfg.Filename))) 157 return nil 158 } 159 160 func parseTemplates(cfg Options, t *template.Template) (*template.Template, error) { 161 if cfg.Template != "" { 162 var err error 163 t, err = t.New("template.gotpl").Parse(cfg.Template) 164 if err != nil { 165 return nil, fmt.Errorf("error with provided template: %w", err) 166 } 167 return t, nil 168 } 169 170 var fileSystem fs.FS 171 if cfg.TemplateFS != nil { 172 fileSystem = cfg.TemplateFS 173 } else { 174 // load path relative to calling source file 175 _, callerFile, _, _ := runtime.Caller(2) 176 rootDir := filepath.Dir(callerFile) 177 fileSystem = os.DirFS(rootDir) 178 } 179 180 t, err := t.ParseFS(fileSystem, "*.gotpl") 181 if err != nil { 182 return nil, fmt.Errorf("locating templates: %w", err) 183 } 184 185 return t, nil 186 } 187 188 func center(width int, pad string, s string) string { 189 if len(s)+2 > width { 190 return s 191 } 192 lpad := (width - len(s)) / 2 193 rpad := width - (lpad + len(s)) 194 return strings.Repeat(pad, lpad) + s + strings.Repeat(pad, rpad) 195 } 196 197 func Funcs() template.FuncMap { 198 return template.FuncMap{ 199 "ucFirst": UcFirst, 200 "lcFirst": LcFirst, 201 "quote": strconv.Quote, 202 "rawQuote": rawQuote, 203 "dump": Dump, 204 "ref": ref, 205 "ts": config.TypeIdentifier, 206 "call": Call, 207 "prefixLines": prefixLines, 208 "notNil": notNil, 209 "reserveImport": CurrentImports.Reserve, 210 "lookupImport": CurrentImports.Lookup, 211 "go": ToGo, 212 "goPrivate": ToGoPrivate, 213 "goModelName": ToGoModelName, 214 "goPrivateModelName": ToGoPrivateModelName, 215 "add": func(a, b int) int { 216 return a + b 217 }, 218 "render": func(filename string, tpldata interface{}) (*bytes.Buffer, error) { 219 return render(resolveName(filename, 0), tpldata) 220 }, 221 } 222 } 223 224 func UcFirst(s string) string { 225 if s == "" { 226 return "" 227 } 228 r := []rune(s) 229 r[0] = unicode.ToUpper(r[0]) 230 return string(r) 231 } 232 233 func LcFirst(s string) string { 234 if s == "" { 235 return "" 236 } 237 238 r := []rune(s) 239 r[0] = unicode.ToLower(r[0]) 240 return string(r) 241 } 242 243 func isDelimiter(c rune) bool { 244 return c == '-' || c == '_' || unicode.IsSpace(c) 245 } 246 247 func ref(p types.Type) string { 248 return CurrentImports.LookupType(p) 249 } 250 251 func Call(p *types.Func) string { 252 pkg := CurrentImports.Lookup(p.Pkg().Path()) 253 254 if pkg != "" { 255 pkg += "." 256 } 257 258 if p.Type() != nil { 259 // make sure the returned type is listed in our imports. 260 ref(p.Type().(*types.Signature).Results().At(0).Type()) 261 } 262 263 return pkg + p.Name() 264 } 265 266 func resetModelNames() { 267 modelNamesMu.Lock() 268 defer modelNamesMu.Unlock() 269 modelNames = make(map[string]string, 0) 270 } 271 272 func buildGoModelNameKey(parts []string) string { 273 const sep = ":" 274 return strings.Join(parts, sep) 275 } 276 277 func goModelName(primaryToGoFunc func(string) string, parts []string) string { 278 modelNamesMu.Lock() 279 defer modelNamesMu.Unlock() 280 281 var ( 282 goNameKey string 283 partLen int 284 285 nameExists = func(n string) bool { 286 for _, v := range modelNames { 287 if n == v { 288 return true 289 } 290 } 291 return false 292 } 293 294 applyToGoFunc = func(parts []string) string { 295 var out string 296 switch len(parts) { 297 case 0: 298 return "" 299 case 1: 300 return primaryToGoFunc(parts[0]) 301 default: 302 out = primaryToGoFunc(parts[0]) 303 } 304 for _, p := range parts[1:] { 305 out = fmt.Sprintf("%s%s", out, ToGo(p)) 306 } 307 return out 308 } 309 310 applyValidGoName = func(parts []string) string { 311 var out string 312 for _, p := range parts { 313 out = fmt.Sprintf("%s%s", out, replaceInvalidCharacters(p)) 314 } 315 return out 316 } 317 ) 318 319 // build key for this entity 320 goNameKey = buildGoModelNameKey(parts) 321 322 // determine if we've seen this entity before, and reuse if so 323 if goName, ok := modelNames[goNameKey]; ok { 324 return goName 325 } 326 327 // attempt first pass 328 if goName := applyToGoFunc(parts); !nameExists(goName) { 329 modelNames[goNameKey] = goName 330 return goName 331 } 332 333 // determine number of parts 334 partLen = len(parts) 335 336 // if there is only 1 part, append incrementing number until no conflict 337 if partLen == 1 { 338 base := applyToGoFunc(parts) 339 for i := 0; ; i++ { 340 tmp := fmt.Sprintf("%s%d", base, i) 341 if !nameExists(tmp) { 342 modelNames[goNameKey] = tmp 343 return tmp 344 } 345 } 346 } 347 348 // best effort "pretty" name 349 for i := partLen - 1; i >= 1; i-- { 350 tmp := fmt.Sprintf("%s%s", applyToGoFunc(parts[0:i]), applyValidGoName(parts[i:])) 351 if !nameExists(tmp) { 352 modelNames[goNameKey] = tmp 353 return tmp 354 } 355 } 356 357 // finally, fallback to just adding an incrementing number 358 base := applyToGoFunc(parts) 359 for i := 0; ; i++ { 360 tmp := fmt.Sprintf("%s%d", base, i) 361 if !nameExists(tmp) { 362 modelNames[goNameKey] = tmp 363 return tmp 364 } 365 } 366 } 367 368 func ToGoModelName(parts ...string) string { 369 return goModelName(ToGo, parts) 370 } 371 372 func ToGoPrivateModelName(parts ...string) string { 373 return goModelName(ToGoPrivate, parts) 374 } 375 376 func replaceInvalidCharacters(in string) string { 377 return goNameRe.ReplaceAllLiteralString(in, "_") 378 } 379 380 func wordWalkerFunc(private bool, nameRunes *[]rune) func(*wordInfo) { 381 return func(info *wordInfo) { 382 word := info.Word 383 384 switch { 385 case private && info.WordOffset == 0: 386 if strings.ToUpper(word) == word || strings.ToLower(word) == word { 387 // ID → id, CAMEL → camel 388 word = strings.ToLower(info.Word) 389 } else { 390 // ITicket → iTicket 391 word = LcFirst(info.Word) 392 } 393 394 case info.MatchCommonInitial: 395 word = strings.ToUpper(word) 396 397 case !info.HasCommonInitial && (strings.ToUpper(word) == word || strings.ToLower(word) == word): 398 // FOO or foo → Foo 399 // FOo → FOo 400 word = UcFirst(strings.ToLower(word)) 401 } 402 403 *nameRunes = append(*nameRunes, []rune(word)...) 404 } 405 } 406 407 func ToGo(name string) string { 408 if name == "_" { 409 return "_" 410 } 411 runes := make([]rune, 0, len(name)) 412 413 wordWalker(name, wordWalkerFunc(false, &runes)) 414 415 return string(runes) 416 } 417 418 func ToGoPrivate(name string) string { 419 if name == "_" { 420 return "_" 421 } 422 runes := make([]rune, 0, len(name)) 423 424 wordWalker(name, wordWalkerFunc(true, &runes)) 425 426 return sanitizeKeywords(string(runes)) 427 } 428 429 type wordInfo struct { 430 WordOffset int 431 Word string 432 MatchCommonInitial bool 433 HasCommonInitial bool 434 } 435 436 // This function is based on the following code. 437 // https://github.com/golang/lint/blob/06c8688daad7faa9da5a0c2f163a3d14aac986ca/lint.go#L679 438 func wordWalker(str string, f func(*wordInfo)) { 439 runes := []rune(strings.TrimFunc(str, isDelimiter)) 440 w, i, wo := 0, 0, 0 // index of start of word, scan, word offset 441 hasCommonInitial := false 442 for i+1 <= len(runes) { 443 eow := false // whether we hit the end of a word 444 switch { 445 case i+1 == len(runes): 446 eow = true 447 case isDelimiter(runes[i+1]): 448 // underscore; shift the remainder forward over any run of underscores 449 eow = true 450 n := 1 451 for i+n+1 < len(runes) && isDelimiter(runes[i+n+1]) { 452 n++ 453 } 454 455 // Leave at most one underscore if the underscore is between two digits 456 if i+n+1 < len(runes) && unicode.IsDigit(runes[i]) && unicode.IsDigit(runes[i+n+1]) { 457 n-- 458 } 459 460 copy(runes[i+1:], runes[i+n+1:]) 461 runes = runes[:len(runes)-n] 462 case unicode.IsLower(runes[i]) && !unicode.IsLower(runes[i+1]): 463 // lower->non-lower 464 eow = true 465 } 466 i++ 467 468 initialisms := config.GetInitialisms() 469 // [w,i) is a word. 470 word := string(runes[w:i]) 471 if !eow && initialisms[word] && !unicode.IsLower(runes[i]) { 472 // through 473 // split IDFoo → ID, Foo 474 // but URLs → URLs 475 } else if !eow { 476 if initialisms[word] { 477 hasCommonInitial = true 478 } 479 continue 480 } 481 482 matchCommonInitial := false 483 upperWord := strings.ToUpper(word) 484 if initialisms[upperWord] { 485 // If the uppercase word (string(runes[w:i]) is "ID" or "IP" 486 // AND 487 // the word is the first two characters of the str 488 // AND 489 // that is not the end of the word 490 // AND 491 // the length of the string is greater than 3 492 // AND 493 // the third rune is an uppercase one 494 // THEN 495 // do NOT count this as an initialism. 496 switch upperWord { 497 case "ID", "IP": 498 if word == str[:2] && !eow && len(str) > 3 && unicode.IsUpper(runes[3]) { 499 continue 500 } 501 } 502 hasCommonInitial = true 503 matchCommonInitial = true 504 } 505 506 f(&wordInfo{ 507 WordOffset: wo, 508 Word: word, 509 MatchCommonInitial: matchCommonInitial, 510 HasCommonInitial: hasCommonInitial, 511 }) 512 hasCommonInitial = false 513 w = i 514 wo++ 515 } 516 } 517 518 var keywords = []string{ 519 "break", 520 "default", 521 "func", 522 "interface", 523 "select", 524 "case", 525 "defer", 526 "go", 527 "map", 528 "struct", 529 "chan", 530 "else", 531 "goto", 532 "package", 533 "switch", 534 "const", 535 "fallthrough", 536 "if", 537 "range", 538 "type", 539 "continue", 540 "for", 541 "import", 542 "return", 543 "var", 544 "_", 545 } 546 547 // sanitizeKeywords prevents collisions with go keywords for arguments to resolver functions 548 func sanitizeKeywords(name string) string { 549 for _, k := range keywords { 550 if name == k { 551 return name + "Arg" 552 } 553 } 554 return name 555 } 556 557 func rawQuote(s string) string { 558 return "`" + strings.ReplaceAll(s, "`", "`+\"`\"+`") + "`" 559 } 560 561 func notNil(field string, data interface{}) bool { 562 v := reflect.ValueOf(data) 563 564 if v.Kind() == reflect.Ptr { 565 v = v.Elem() 566 } 567 if v.Kind() != reflect.Struct { 568 return false 569 } 570 val := v.FieldByName(field) 571 572 return val.IsValid() && !val.IsNil() 573 } 574 575 func Dump(val interface{}) string { 576 switch val := val.(type) { 577 case int: 578 return strconv.Itoa(val) 579 case int64: 580 return fmt.Sprintf("%d", val) 581 case float64: 582 return fmt.Sprintf("%f", val) 583 case string: 584 return strconv.Quote(val) 585 case bool: 586 return strconv.FormatBool(val) 587 case nil: 588 return "nil" 589 case []interface{}: 590 var parts []string 591 for _, part := range val { 592 parts = append(parts, Dump(part)) 593 } 594 return "[]interface{}{" + strings.Join(parts, ",") + "}" 595 case map[string]interface{}: 596 buf := bytes.Buffer{} 597 buf.WriteString("map[string]interface{}{") 598 var keys []string 599 for key := range val { 600 keys = append(keys, key) 601 } 602 sort.Strings(keys) 603 604 for _, key := range keys { 605 data := val[key] 606 607 buf.WriteString(strconv.Quote(key)) 608 buf.WriteString(":") 609 buf.WriteString(Dump(data)) 610 buf.WriteString(",") 611 } 612 buf.WriteString("}") 613 return buf.String() 614 default: 615 panic(fmt.Errorf("unsupported type %T", val)) 616 } 617 } 618 619 func prefixLines(prefix, s string) string { 620 return prefix + strings.ReplaceAll(s, "\n", "\n"+prefix) 621 } 622 623 func resolveName(name string, skip int) string { 624 if name[0] == '.' { 625 // load path relative to calling source file 626 _, callerFile, _, _ := runtime.Caller(skip + 1) 627 return filepath.Join(filepath.Dir(callerFile), name[1:]) 628 } 629 630 // load path relative to this directory 631 _, callerFile, _, _ := runtime.Caller(0) 632 return filepath.Join(filepath.Dir(callerFile), name) 633 } 634 635 func render(filename string, tpldata interface{}) (*bytes.Buffer, error) { 636 t := template.New("").Funcs(Funcs()) 637 638 b, err := os.ReadFile(filename) 639 if err != nil { 640 return nil, err 641 } 642 643 t, err = t.New(filepath.Base(filename)).Parse(string(b)) 644 if err != nil { 645 panic(err) 646 } 647 648 buf := &bytes.Buffer{} 649 return buf, t.Execute(buf, tpldata) 650 } 651 652 func write(filename string, b []byte, packages *code.Packages) error { 653 err := os.MkdirAll(filepath.Dir(filename), 0o755) 654 if err != nil { 655 return fmt.Errorf("failed to create directory: %w", err) 656 } 657 658 formatted, err := imports.Prune(filename, b, packages) 659 if err != nil { 660 fmt.Fprintf(os.Stderr, "gofmt failed on %s: %s\n", filepath.Base(filename), err.Error()) 661 formatted = b 662 } 663 664 err = os.WriteFile(filename, formatted, 0o644) 665 if err != nil { 666 return fmt.Errorf("failed to write %s: %w", filename, err) 667 } 668 669 return nil 670 }