github.com/apipluspower/gqlgen@v0.15.2/codegen/templates/templates.go (about) 1 package templates 2 3 import ( 4 "bytes" 5 "fmt" 6 "go/types" 7 "io/ioutil" 8 "os" 9 "path/filepath" 10 "reflect" 11 "runtime" 12 "sort" 13 "strconv" 14 "strings" 15 "text/template" 16 "unicode" 17 18 "github.com/apipluspower/gqlgen/internal/code" 19 20 "github.com/apipluspower/gqlgen/internal/imports" 21 ) 22 23 // CurrentImports keeps track of all the import declarations that are needed during the execution of a plugin. 24 // this is done with a global because subtemplates currently get called in functions. Lets aim to remove this eventually. 25 var CurrentImports *Imports 26 27 // Options specify various parameters to rendering a template. 28 type Options struct { 29 // PackageName is a helper that specifies the package header declaration. 30 // In other words, when you write the template you don't need to specify `package X` 31 // at the top of the file. By providing PackageName in the Options, the Render 32 // function will do that for you. 33 PackageName string 34 // Template is a string of the entire template that 35 // will be parsed and rendered. If it's empty, 36 // the plugin processor will look for .gotpl files 37 // in the same directory of where you wrote the plugin. 38 Template string 39 // Filename is the name of the file that will be 40 // written to the system disk once the template is rendered. 41 Filename string 42 RegionTags bool 43 GeneratedHeader bool 44 // PackageDoc is documentation written above the package line 45 PackageDoc string 46 // FileNotice is notice written below the package line 47 FileNotice string 48 // Data will be passed to the template execution. 49 Data interface{} 50 Funcs template.FuncMap 51 52 // Packages cache, you can find me on config.Config 53 Packages *code.Packages 54 } 55 56 // Render renders a gql plugin template from the given Options. Render is an 57 // abstraction of the text/template package that makes it easier to write gqlgen 58 // plugins. If Options.Template is empty, the Render function will look for `.gotpl` 59 // files inside the directory where you wrote the plugin. 60 func Render(cfg Options) error { 61 if CurrentImports != nil { 62 panic(fmt.Errorf("recursive or concurrent call to RenderToFile detected")) 63 } 64 CurrentImports = &Imports{packages: cfg.Packages, destDir: filepath.Dir(cfg.Filename)} 65 66 // load path relative to calling source file 67 _, callerFile, _, _ := runtime.Caller(1) 68 rootDir := filepath.Dir(callerFile) 69 70 funcs := Funcs() 71 for n, f := range cfg.Funcs { 72 funcs[n] = f 73 } 74 t := template.New("").Funcs(funcs) 75 76 var roots []string 77 if cfg.Template != "" { 78 var err error 79 t, err = t.New("template.gotpl").Parse(cfg.Template) 80 if err != nil { 81 return fmt.Errorf("error with provided template: %w", err) 82 } 83 roots = append(roots, "template.gotpl") 84 } else { 85 // load all the templates in the directory 86 err := filepath.Walk(rootDir, func(path string, info os.FileInfo, err error) error { 87 if err != nil { 88 return err 89 } 90 name := filepath.ToSlash(strings.TrimPrefix(path, rootDir+string(os.PathSeparator))) 91 if !strings.HasSuffix(info.Name(), ".gotpl") { 92 return nil 93 } 94 // omit any templates with "_" at the end of their name, which are meant for specific contexts only 95 if strings.HasSuffix(info.Name(), "_.gotpl") { 96 return nil 97 } 98 b, err := ioutil.ReadFile(path) 99 if err != nil { 100 return err 101 } 102 103 t, err = t.New(name).Parse(string(b)) 104 if err != nil { 105 return fmt.Errorf("%s: %w", cfg.Filename, err) 106 } 107 108 roots = append(roots, name) 109 110 return nil 111 }) 112 if err != nil { 113 return fmt.Errorf("locating templates: %w", err) 114 } 115 } 116 117 // then execute all the important looking ones in order, adding them to the same file 118 sort.Slice(roots, func(i, j int) bool { 119 // important files go first 120 if strings.HasSuffix(roots[i], "!.gotpl") { 121 return true 122 } 123 if strings.HasSuffix(roots[j], "!.gotpl") { 124 return false 125 } 126 return roots[i] < roots[j] 127 }) 128 var buf bytes.Buffer 129 for _, root := range roots { 130 if cfg.RegionTags { 131 buf.WriteString("\n// region " + center(70, "*", " "+root+" ") + "\n") 132 } 133 err := t.Lookup(root).Execute(&buf, cfg.Data) 134 if err != nil { 135 return fmt.Errorf("%s: %w", root, err) 136 } 137 if cfg.RegionTags { 138 buf.WriteString("\n// endregion " + center(70, "*", " "+root+" ") + "\n") 139 } 140 } 141 142 var result bytes.Buffer 143 if cfg.GeneratedHeader { 144 result.WriteString("// Code generated by github.com/99designs/gqlgen, DO NOT EDIT.\n\n") 145 } 146 if cfg.PackageDoc != "" { 147 result.WriteString(cfg.PackageDoc + "\n") 148 } 149 result.WriteString("package ") 150 result.WriteString(cfg.PackageName) 151 result.WriteString("\n\n") 152 if cfg.FileNotice != "" { 153 result.WriteString(cfg.FileNotice) 154 result.WriteString("\n\n") 155 } 156 result.WriteString("import (\n") 157 result.WriteString(CurrentImports.String()) 158 result.WriteString(")\n") 159 _, err := buf.WriteTo(&result) 160 if err != nil { 161 return err 162 } 163 CurrentImports = nil 164 165 err = write(cfg.Filename, result.Bytes(), cfg.Packages) 166 if err != nil { 167 return err 168 } 169 170 cfg.Packages.Evict(code.ImportPathForDir(filepath.Dir(cfg.Filename))) 171 return nil 172 } 173 174 func center(width int, pad string, s string) string { 175 if len(s)+2 > width { 176 return s 177 } 178 lpad := (width - len(s)) / 2 179 rpad := width - (lpad + len(s)) 180 return strings.Repeat(pad, lpad) + s + strings.Repeat(pad, rpad) 181 } 182 183 func Funcs() template.FuncMap { 184 return template.FuncMap{ 185 "ucFirst": UcFirst, 186 "lcFirst": LcFirst, 187 "quote": strconv.Quote, 188 "rawQuote": rawQuote, 189 "dump": Dump, 190 "ref": ref, 191 "ts": TypeIdentifier, 192 "call": Call, 193 "prefixLines": prefixLines, 194 "notNil": notNil, 195 "reserveImport": CurrentImports.Reserve, 196 "lookupImport": CurrentImports.Lookup, 197 "go": ToGo, 198 "goPrivate": ToGoPrivate, 199 "add": func(a, b int) int { 200 return a + b 201 }, 202 "render": func(filename string, tpldata interface{}) (*bytes.Buffer, error) { 203 return render(resolveName(filename, 0), tpldata) 204 }, 205 } 206 } 207 208 func UcFirst(s string) string { 209 if s == "" { 210 return "" 211 } 212 r := []rune(s) 213 r[0] = unicode.ToUpper(r[0]) 214 return string(r) 215 } 216 217 func LcFirst(s string) string { 218 if s == "" { 219 return "" 220 } 221 222 r := []rune(s) 223 r[0] = unicode.ToLower(r[0]) 224 return string(r) 225 } 226 227 func isDelimiter(c rune) bool { 228 return c == '-' || c == '_' || unicode.IsSpace(c) 229 } 230 231 func ref(p types.Type) string { 232 return CurrentImports.LookupType(p) 233 } 234 235 var pkgReplacer = strings.NewReplacer( 236 "/", "ᚋ", 237 ".", "ᚗ", 238 "-", "ᚑ", 239 "~", "א", 240 ) 241 242 func TypeIdentifier(t types.Type) string { 243 res := "" 244 for { 245 switch it := t.(type) { 246 case *types.Pointer: 247 t.Underlying() 248 res += "ᚖ" 249 t = it.Elem() 250 case *types.Slice: 251 res += "ᚕ" 252 t = it.Elem() 253 case *types.Named: 254 res += pkgReplacer.Replace(it.Obj().Pkg().Path()) 255 res += "ᚐ" 256 res += it.Obj().Name() 257 return res 258 case *types.Basic: 259 res += it.Name() 260 return res 261 case *types.Map: 262 res += "map" 263 return res 264 case *types.Interface: 265 res += "interface" 266 return res 267 default: 268 panic(fmt.Errorf("unexpected type %T", it)) 269 } 270 } 271 } 272 273 func Call(p *types.Func) string { 274 pkg := CurrentImports.Lookup(p.Pkg().Path()) 275 276 if pkg != "" { 277 pkg += "." 278 } 279 280 if p.Type() != nil { 281 // make sure the returned type is listed in our imports. 282 ref(p.Type().(*types.Signature).Results().At(0).Type()) 283 } 284 285 return pkg + p.Name() 286 } 287 288 func ToGo(name string) string { 289 if name == "_" { 290 return "_" 291 } 292 runes := make([]rune, 0, len(name)) 293 294 wordWalker(name, func(info *wordInfo) { 295 word := info.Word 296 if info.MatchCommonInitial { 297 word = strings.ToUpper(word) 298 } else if !info.HasCommonInitial { 299 if strings.ToUpper(word) == word || strings.ToLower(word) == word { 300 // FOO or foo → Foo 301 // FOo → FOo 302 word = UcFirst(strings.ToLower(word)) 303 } 304 } 305 runes = append(runes, []rune(word)...) 306 }) 307 308 return string(runes) 309 } 310 311 func ToGoPrivate(name string) string { 312 if name == "_" { 313 return "_" 314 } 315 runes := make([]rune, 0, len(name)) 316 317 first := true 318 wordWalker(name, func(info *wordInfo) { 319 word := info.Word 320 switch { 321 case first: 322 if strings.ToUpper(word) == word || strings.ToLower(word) == word { 323 // ID → id, CAMEL → camel 324 word = strings.ToLower(info.Word) 325 } else { 326 // ITicket → iTicket 327 word = LcFirst(info.Word) 328 } 329 first = false 330 case info.MatchCommonInitial: 331 word = strings.ToUpper(word) 332 case !info.HasCommonInitial: 333 word = UcFirst(strings.ToLower(word)) 334 } 335 runes = append(runes, []rune(word)...) 336 }) 337 338 return sanitizeKeywords(string(runes)) 339 } 340 341 type wordInfo struct { 342 Word string 343 MatchCommonInitial bool 344 HasCommonInitial bool 345 } 346 347 // This function is based on the following code. 348 // https://github.com/golang/lint/blob/06c8688daad7faa9da5a0c2f163a3d14aac986ca/lint.go#L679 349 func wordWalker(str string, f func(*wordInfo)) { 350 runes := []rune(strings.TrimFunc(str, isDelimiter)) 351 w, i := 0, 0 // index of start of word, scan 352 hasCommonInitial := false 353 for i+1 <= len(runes) { 354 eow := false // whether we hit the end of a word 355 switch { 356 case i+1 == len(runes): 357 eow = true 358 case isDelimiter(runes[i+1]): 359 // underscore; shift the remainder forward over any run of underscores 360 eow = true 361 n := 1 362 for i+n+1 < len(runes) && isDelimiter(runes[i+n+1]) { 363 n++ 364 } 365 366 // Leave at most one underscore if the underscore is between two digits 367 if i+n+1 < len(runes) && unicode.IsDigit(runes[i]) && unicode.IsDigit(runes[i+n+1]) { 368 n-- 369 } 370 371 copy(runes[i+1:], runes[i+n+1:]) 372 runes = runes[:len(runes)-n] 373 case unicode.IsLower(runes[i]) && !unicode.IsLower(runes[i+1]): 374 // lower->non-lower 375 eow = true 376 } 377 i++ 378 379 // [w,i) is a word. 380 word := string(runes[w:i]) 381 if !eow && commonInitialisms[word] && !unicode.IsLower(runes[i]) { 382 // through 383 // split IDFoo → ID, Foo 384 // but URLs → URLs 385 } else if !eow { 386 if commonInitialisms[word] { 387 hasCommonInitial = true 388 } 389 continue 390 } 391 392 matchCommonInitial := false 393 if commonInitialisms[strings.ToUpper(word)] { 394 hasCommonInitial = true 395 matchCommonInitial = true 396 } 397 398 f(&wordInfo{ 399 Word: word, 400 MatchCommonInitial: matchCommonInitial, 401 HasCommonInitial: hasCommonInitial, 402 }) 403 hasCommonInitial = false 404 w = i 405 } 406 } 407 408 var keywords = []string{ 409 "break", 410 "default", 411 "func", 412 "interface", 413 "select", 414 "case", 415 "defer", 416 "go", 417 "map", 418 "struct", 419 "chan", 420 "else", 421 "goto", 422 "package", 423 "switch", 424 "const", 425 "fallthrough", 426 "if", 427 "range", 428 "type", 429 "continue", 430 "for", 431 "import", 432 "return", 433 "var", 434 "_", 435 } 436 437 // sanitizeKeywords prevents collisions with go keywords for arguments to resolver functions 438 func sanitizeKeywords(name string) string { 439 for _, k := range keywords { 440 if name == k { 441 return name + "Arg" 442 } 443 } 444 return name 445 } 446 447 // commonInitialisms is a set of common initialisms. 448 // Only add entries that are highly unlikely to be non-initialisms. 449 // For instance, "ID" is fine (Freudian code is rare), but "AND" is not. 450 var commonInitialisms = map[string]bool{ 451 "ACL": true, 452 "API": true, 453 "ASCII": true, 454 "CPU": true, 455 "CSS": true, 456 "CSV": true, 457 "DNS": true, 458 "EOF": true, 459 "GUID": true, 460 "HTML": true, 461 "HTTP": true, 462 "HTTPS": true, 463 "ICMP": true, 464 "ID": true, 465 "IP": true, 466 "JSON": true, 467 "KVK": true, 468 "LHS": true, 469 "PDF": true, 470 "PGP": true, 471 "QPS": true, 472 "QR": true, 473 "RAM": true, 474 "RHS": true, 475 "RPC": true, 476 "SLA": true, 477 "SMTP": true, 478 "SQL": true, 479 "SSH": true, 480 "SVG": true, 481 "TCP": true, 482 "TLS": true, 483 "TTL": true, 484 "UDP": true, 485 "UI": true, 486 "UID": true, 487 "URI": true, 488 "URL": true, 489 "UTF8": true, 490 "UUID": true, 491 "VM": true, 492 "XML": true, 493 "XMPP": true, 494 "XSRF": true, 495 "XSS": true, 496 } 497 498 func rawQuote(s string) string { 499 return "`" + strings.ReplaceAll(s, "`", "`+\"`\"+`") + "`" 500 } 501 502 func notNil(field string, data interface{}) bool { 503 v := reflect.ValueOf(data) 504 505 if v.Kind() == reflect.Ptr { 506 v = v.Elem() 507 } 508 if v.Kind() != reflect.Struct { 509 return false 510 } 511 val := v.FieldByName(field) 512 513 return val.IsValid() && !val.IsNil() 514 } 515 516 func Dump(val interface{}) string { 517 switch val := val.(type) { 518 case int: 519 return strconv.Itoa(val) 520 case int64: 521 return fmt.Sprintf("%d", val) 522 case float64: 523 return fmt.Sprintf("%f", val) 524 case string: 525 return strconv.Quote(val) 526 case bool: 527 return strconv.FormatBool(val) 528 case nil: 529 return "nil" 530 case []interface{}: 531 var parts []string 532 for _, part := range val { 533 parts = append(parts, Dump(part)) 534 } 535 return "[]interface{}{" + strings.Join(parts, ",") + "}" 536 case map[string]interface{}: 537 buf := bytes.Buffer{} 538 buf.WriteString("map[string]interface{}{") 539 var keys []string 540 for key := range val { 541 keys = append(keys, key) 542 } 543 sort.Strings(keys) 544 545 for _, key := range keys { 546 data := val[key] 547 548 buf.WriteString(strconv.Quote(key)) 549 buf.WriteString(":") 550 buf.WriteString(Dump(data)) 551 buf.WriteString(",") 552 } 553 buf.WriteString("}") 554 return buf.String() 555 default: 556 panic(fmt.Errorf("unsupported type %T", val)) 557 } 558 } 559 560 func prefixLines(prefix, s string) string { 561 return prefix + strings.ReplaceAll(s, "\n", "\n"+prefix) 562 } 563 564 func resolveName(name string, skip int) string { 565 if name[0] == '.' { 566 // load path relative to calling source file 567 _, callerFile, _, _ := runtime.Caller(skip + 1) 568 return filepath.Join(filepath.Dir(callerFile), name[1:]) 569 } 570 571 // load path relative to this directory 572 _, callerFile, _, _ := runtime.Caller(0) 573 return filepath.Join(filepath.Dir(callerFile), name) 574 } 575 576 func render(filename string, tpldata interface{}) (*bytes.Buffer, error) { 577 t := template.New("").Funcs(Funcs()) 578 579 b, err := ioutil.ReadFile(filename) 580 if err != nil { 581 return nil, err 582 } 583 584 t, err = t.New(filepath.Base(filename)).Parse(string(b)) 585 if err != nil { 586 panic(err) 587 } 588 589 buf := &bytes.Buffer{} 590 return buf, t.Execute(buf, tpldata) 591 } 592 593 func write(filename string, b []byte, packages *code.Packages) error { 594 err := os.MkdirAll(filepath.Dir(filename), 0o755) 595 if err != nil { 596 return fmt.Errorf("failed to create directory: %w", err) 597 } 598 599 formatted, err := imports.Prune(filename, b, packages) 600 if err != nil { 601 fmt.Fprintf(os.Stderr, "gofmt failed on %s: %s\n", filepath.Base(filename), err.Error()) 602 formatted = b 603 } 604 605 err = ioutil.WriteFile(filename, formatted, 0o644) 606 if err != nil { 607 return fmt.Errorf("failed to write %s: %w", filename, err) 608 } 609 610 return nil 611 }