github.com/golang/mock@v1.6.0/mockgen/mockgen.go (about) 1 // Copyright 2010 Google Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // MockGen generates mock implementations of Go interfaces. 16 package main 17 18 // TODO: This does not support recursive embedded interfaces. 19 // TODO: This does not support embedding package-local interfaces in a separate file. 20 21 import ( 22 "bytes" 23 "encoding/json" 24 "flag" 25 "fmt" 26 "go/token" 27 "io" 28 "io/ioutil" 29 "log" 30 "os" 31 "os/exec" 32 "path" 33 "path/filepath" 34 "sort" 35 "strconv" 36 "strings" 37 "unicode" 38 39 "github.com/golang/mock/mockgen/model" 40 41 "golang.org/x/mod/modfile" 42 toolsimports "golang.org/x/tools/imports" 43 ) 44 45 const ( 46 gomockImportPath = "github.com/golang/mock/gomock" 47 ) 48 49 var ( 50 version = "" 51 commit = "none" 52 date = "unknown" 53 ) 54 55 var ( 56 source = flag.String("source", "", "(source mode) Input Go source file; enables source mode.") 57 destination = flag.String("destination", "", "Output file; defaults to stdout.") 58 mockNames = flag.String("mock_names", "", "Comma-separated interfaceName=mockName pairs of explicit mock names to use. Mock names default to 'Mock'+ interfaceName suffix.") 59 packageOut = flag.String("package", "", "Package of the generated code; defaults to the package of the input with a 'mock_' prefix.") 60 selfPackage = flag.String("self_package", "", "The full package import path for the generated code. The purpose of this flag is to prevent import cycles in the generated code by trying to include its own package. This can happen if the mock's package is set to one of its inputs (usually the main one) and the output is stdio so mockgen cannot detect the final output package. Setting this flag will then tell mockgen which import to exclude.") 61 writePkgComment = flag.Bool("write_package_comment", true, "Writes package documentation comment (godoc) if true.") 62 copyrightFile = flag.String("copyright_file", "", "Copyright file used to add copyright header") 63 64 debugParser = flag.Bool("debug_parser", false, "Print out parser results only.") 65 showVersion = flag.Bool("version", false, "Print version.") 66 ) 67 68 func main() { 69 flag.Usage = usage 70 flag.Parse() 71 72 if *showVersion { 73 printVersion() 74 return 75 } 76 77 var pkg *model.Package 78 var err error 79 var packageName string 80 if *source != "" { 81 pkg, err = sourceMode(*source) 82 } else { 83 if flag.NArg() != 2 { 84 usage() 85 log.Fatal("Expected exactly two arguments") 86 } 87 packageName = flag.Arg(0) 88 interfaces := strings.Split(flag.Arg(1), ",") 89 if packageName == "." { 90 dir, err := os.Getwd() 91 if err != nil { 92 log.Fatalf("Get current directory failed: %v", err) 93 } 94 packageName, err = packageNameOfDir(dir) 95 if err != nil { 96 log.Fatalf("Parse package name failed: %v", err) 97 } 98 } 99 pkg, err = reflectMode(packageName, interfaces) 100 } 101 if err != nil { 102 log.Fatalf("Loading input failed: %v", err) 103 } 104 105 if *debugParser { 106 pkg.Print(os.Stdout) 107 return 108 } 109 110 dst := os.Stdout 111 if len(*destination) > 0 { 112 if err := os.MkdirAll(filepath.Dir(*destination), os.ModePerm); err != nil { 113 log.Fatalf("Unable to create directory: %v", err) 114 } 115 f, err := os.Create(*destination) 116 if err != nil { 117 log.Fatalf("Failed opening destination file: %v", err) 118 } 119 defer f.Close() 120 dst = f 121 } 122 123 outputPackageName := *packageOut 124 if outputPackageName == "" { 125 // pkg.Name in reflect mode is the base name of the import path, 126 // which might have characters that are illegal to have in package names. 127 outputPackageName = "mock_" + sanitize(pkg.Name) 128 } 129 130 // outputPackagePath represents the fully qualified name of the package of 131 // the generated code. Its purposes are to prevent the module from importing 132 // itself and to prevent qualifying type names that come from its own 133 // package (i.e. if there is a type called X then we want to print "X" not 134 // "package.X" since "package" is this package). This can happen if the mock 135 // is output into an already existing package. 136 outputPackagePath := *selfPackage 137 if outputPackagePath == "" && *destination != "" { 138 dstPath, err := filepath.Abs(filepath.Dir(*destination)) 139 if err == nil { 140 pkgPath, err := parsePackageImport(dstPath) 141 if err == nil { 142 outputPackagePath = pkgPath 143 } else { 144 log.Println("Unable to infer -self_package from destination file path:", err) 145 } 146 } else { 147 log.Println("Unable to determine destination file path:", err) 148 } 149 } 150 151 g := new(generator) 152 if *source != "" { 153 g.filename = *source 154 } else { 155 g.srcPackage = packageName 156 g.srcInterfaces = flag.Arg(1) 157 } 158 g.destination = *destination 159 160 if *mockNames != "" { 161 g.mockNames = parseMockNames(*mockNames) 162 } 163 if *copyrightFile != "" { 164 header, err := ioutil.ReadFile(*copyrightFile) 165 if err != nil { 166 log.Fatalf("Failed reading copyright file: %v", err) 167 } 168 169 g.copyrightHeader = string(header) 170 } 171 if err := g.Generate(pkg, outputPackageName, outputPackagePath); err != nil { 172 log.Fatalf("Failed generating mock: %v", err) 173 } 174 if _, err := dst.Write(g.Output()); err != nil { 175 log.Fatalf("Failed writing to destination: %v", err) 176 } 177 } 178 179 func parseMockNames(names string) map[string]string { 180 mocksMap := make(map[string]string) 181 for _, kv := range strings.Split(names, ",") { 182 parts := strings.SplitN(kv, "=", 2) 183 if len(parts) != 2 || parts[1] == "" { 184 log.Fatalf("bad mock names spec: %v", kv) 185 } 186 mocksMap[parts[0]] = parts[1] 187 } 188 return mocksMap 189 } 190 191 func usage() { 192 _, _ = io.WriteString(os.Stderr, usageText) 193 flag.PrintDefaults() 194 } 195 196 const usageText = `mockgen has two modes of operation: source and reflect. 197 198 Source mode generates mock interfaces from a source file. 199 It is enabled by using the -source flag. Other flags that 200 may be useful in this mode are -imports and -aux_files. 201 Example: 202 mockgen -source=foo.go [other options] 203 204 Reflect mode generates mock interfaces by building a program 205 that uses reflection to understand interfaces. It is enabled 206 by passing two non-flag arguments: an import path, and a 207 comma-separated list of symbols. 208 Example: 209 mockgen database/sql/driver Conn,Driver 210 211 ` 212 213 type generator struct { 214 buf bytes.Buffer 215 indent string 216 mockNames map[string]string // may be empty 217 filename string // may be empty 218 destination string // may be empty 219 srcPackage, srcInterfaces string // may be empty 220 copyrightHeader string 221 222 packageMap map[string]string // map from import path to package name 223 } 224 225 func (g *generator) p(format string, args ...interface{}) { 226 fmt.Fprintf(&g.buf, g.indent+format+"\n", args...) 227 } 228 229 func (g *generator) in() { 230 g.indent += "\t" 231 } 232 233 func (g *generator) out() { 234 if len(g.indent) > 0 { 235 g.indent = g.indent[0 : len(g.indent)-1] 236 } 237 } 238 239 // sanitize cleans up a string to make a suitable package name. 240 func sanitize(s string) string { 241 t := "" 242 for _, r := range s { 243 if t == "" { 244 if unicode.IsLetter(r) || r == '_' { 245 t += string(r) 246 continue 247 } 248 } else { 249 if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' { 250 t += string(r) 251 continue 252 } 253 } 254 t += "_" 255 } 256 if t == "_" { 257 t = "x" 258 } 259 return t 260 } 261 262 func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPackagePath string) error { 263 if outputPkgName != pkg.Name && *selfPackage == "" { 264 // reset outputPackagePath if it's not passed in through -self_package 265 outputPackagePath = "" 266 } 267 268 if g.copyrightHeader != "" { 269 lines := strings.Split(g.copyrightHeader, "\n") 270 for _, line := range lines { 271 g.p("// %s", line) 272 } 273 g.p("") 274 } 275 276 g.p("// Code generated by MockGen. DO NOT EDIT.") 277 if g.filename != "" { 278 g.p("// Source: %v", g.filename) 279 } else { 280 g.p("// Source: %v (interfaces: %v)", g.srcPackage, g.srcInterfaces) 281 } 282 g.p("") 283 284 // Get all required imports, and generate unique names for them all. 285 im := pkg.Imports() 286 im[gomockImportPath] = true 287 288 // Only import reflect if it's used. We only use reflect in mocked methods 289 // so only import if any of the mocked interfaces have methods. 290 for _, intf := range pkg.Interfaces { 291 if len(intf.Methods) > 0 { 292 im["reflect"] = true 293 break 294 } 295 } 296 297 // Sort keys to make import alias generation predictable 298 sortedPaths := make([]string, len(im)) 299 x := 0 300 for pth := range im { 301 sortedPaths[x] = pth 302 x++ 303 } 304 sort.Strings(sortedPaths) 305 306 packagesName := createPackageMap(sortedPaths) 307 308 g.packageMap = make(map[string]string, len(im)) 309 localNames := make(map[string]bool, len(im)) 310 for _, pth := range sortedPaths { 311 base, ok := packagesName[pth] 312 if !ok { 313 base = sanitize(path.Base(pth)) 314 } 315 316 // Local names for an imported package can usually be the basename of the import path. 317 // A couple of situations don't permit that, such as duplicate local names 318 // (e.g. importing "html/template" and "text/template"), or where the basename is 319 // a keyword (e.g. "foo/case"). 320 // try base0, base1, ... 321 pkgName := base 322 i := 0 323 for localNames[pkgName] || token.Lookup(pkgName).IsKeyword() { 324 pkgName = base + strconv.Itoa(i) 325 i++ 326 } 327 328 // Avoid importing package if source pkg == output pkg 329 if pth == pkg.PkgPath && outputPackagePath == pkg.PkgPath { 330 continue 331 } 332 333 g.packageMap[pth] = pkgName 334 localNames[pkgName] = true 335 } 336 337 if *writePkgComment { 338 g.p("// Package %v is a generated GoMock package.", outputPkgName) 339 } 340 g.p("package %v", outputPkgName) 341 g.p("") 342 g.p("import (") 343 g.in() 344 for pkgPath, pkgName := range g.packageMap { 345 if pkgPath == outputPackagePath { 346 continue 347 } 348 g.p("%v %q", pkgName, pkgPath) 349 } 350 for _, pkgPath := range pkg.DotImports { 351 g.p(". %q", pkgPath) 352 } 353 g.out() 354 g.p(")") 355 356 for _, intf := range pkg.Interfaces { 357 if err := g.GenerateMockInterface(intf, outputPackagePath); err != nil { 358 return err 359 } 360 } 361 362 return nil 363 } 364 365 // The name of the mock type to use for the given interface identifier. 366 func (g *generator) mockName(typeName string) string { 367 if mockName, ok := g.mockNames[typeName]; ok { 368 return mockName 369 } 370 371 return "Mock" + typeName 372 } 373 374 func (g *generator) GenerateMockInterface(intf *model.Interface, outputPackagePath string) error { 375 mockType := g.mockName(intf.Name) 376 377 g.p("") 378 g.p("// %v is a mock of %v interface.", mockType, intf.Name) 379 g.p("type %v struct {", mockType) 380 g.in() 381 g.p("ctrl *gomock.Controller") 382 g.p("recorder *%vMockRecorder", mockType) 383 g.out() 384 g.p("}") 385 g.p("") 386 387 g.p("// %vMockRecorder is the mock recorder for %v.", mockType, mockType) 388 g.p("type %vMockRecorder struct {", mockType) 389 g.in() 390 g.p("mock *%v", mockType) 391 g.out() 392 g.p("}") 393 g.p("") 394 395 g.p("// New%v creates a new mock instance.", mockType) 396 g.p("func New%v(ctrl *gomock.Controller) *%v {", mockType, mockType) 397 g.in() 398 g.p("mock := &%v{ctrl: ctrl}", mockType) 399 g.p("mock.recorder = &%vMockRecorder{mock}", mockType) 400 g.p("return mock") 401 g.out() 402 g.p("}") 403 g.p("") 404 405 // XXX: possible name collision here if someone has EXPECT in their interface. 406 g.p("// EXPECT returns an object that allows the caller to indicate expected use.") 407 g.p("func (m *%v) EXPECT() *%vMockRecorder {", mockType, mockType) 408 g.in() 409 g.p("return m.recorder") 410 g.out() 411 g.p("}") 412 413 g.GenerateMockMethods(mockType, intf, outputPackagePath) 414 415 return nil 416 } 417 418 type byMethodName []*model.Method 419 420 func (b byMethodName) Len() int { return len(b) } 421 func (b byMethodName) Swap(i, j int) { b[i], b[j] = b[j], b[i] } 422 func (b byMethodName) Less(i, j int) bool { return b[i].Name < b[j].Name } 423 424 func (g *generator) GenerateMockMethods(mockType string, intf *model.Interface, pkgOverride string) { 425 sort.Sort(byMethodName(intf.Methods)) 426 for _, m := range intf.Methods { 427 g.p("") 428 _ = g.GenerateMockMethod(mockType, m, pkgOverride) 429 g.p("") 430 _ = g.GenerateMockRecorderMethod(mockType, m) 431 } 432 } 433 434 func makeArgString(argNames, argTypes []string) string { 435 args := make([]string, len(argNames)) 436 for i, name := range argNames { 437 // specify the type only once for consecutive args of the same type 438 if i+1 < len(argTypes) && argTypes[i] == argTypes[i+1] { 439 args[i] = name 440 } else { 441 args[i] = name + " " + argTypes[i] 442 } 443 } 444 return strings.Join(args, ", ") 445 } 446 447 // GenerateMockMethod generates a mock method implementation. 448 // If non-empty, pkgOverride is the package in which unqualified types reside. 449 func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOverride string) error { 450 argNames := g.getArgNames(m) 451 argTypes := g.getArgTypes(m, pkgOverride) 452 argString := makeArgString(argNames, argTypes) 453 454 rets := make([]string, len(m.Out)) 455 for i, p := range m.Out { 456 rets[i] = p.Type.String(g.packageMap, pkgOverride) 457 } 458 retString := strings.Join(rets, ", ") 459 if len(rets) > 1 { 460 retString = "(" + retString + ")" 461 } 462 if retString != "" { 463 retString = " " + retString 464 } 465 466 ia := newIdentifierAllocator(argNames) 467 idRecv := ia.allocateIdentifier("m") 468 469 g.p("// %v mocks base method.", m.Name) 470 g.p("func (%v *%v) %v(%v)%v {", idRecv, mockType, m.Name, argString, retString) 471 g.in() 472 g.p("%s.ctrl.T.Helper()", idRecv) 473 474 var callArgs string 475 if m.Variadic == nil { 476 if len(argNames) > 0 { 477 callArgs = ", " + strings.Join(argNames, ", ") 478 } 479 } else { 480 // Non-trivial. The generated code must build a []interface{}, 481 // but the variadic argument may be any type. 482 idVarArgs := ia.allocateIdentifier("varargs") 483 idVArg := ia.allocateIdentifier("a") 484 g.p("%s := []interface{}{%s}", idVarArgs, strings.Join(argNames[:len(argNames)-1], ", ")) 485 g.p("for _, %s := range %s {", idVArg, argNames[len(argNames)-1]) 486 g.in() 487 g.p("%s = append(%s, %s)", idVarArgs, idVarArgs, idVArg) 488 g.out() 489 g.p("}") 490 callArgs = ", " + idVarArgs + "..." 491 } 492 if len(m.Out) == 0 { 493 g.p(`%v.ctrl.Call(%v, %q%v)`, idRecv, idRecv, m.Name, callArgs) 494 } else { 495 idRet := ia.allocateIdentifier("ret") 496 g.p(`%v := %v.ctrl.Call(%v, %q%v)`, idRet, idRecv, idRecv, m.Name, callArgs) 497 498 // Go does not allow "naked" type assertions on nil values, so we use the two-value form here. 499 // The value of that is either (x.(T), true) or (Z, false), where Z is the zero value for T. 500 // Happily, this coincides with the semantics we want here. 501 retNames := make([]string, len(rets)) 502 for i, t := range rets { 503 retNames[i] = ia.allocateIdentifier(fmt.Sprintf("ret%d", i)) 504 g.p("%s, _ := %s[%d].(%s)", retNames[i], idRet, i, t) 505 } 506 g.p("return " + strings.Join(retNames, ", ")) 507 } 508 509 g.out() 510 g.p("}") 511 return nil 512 } 513 514 func (g *generator) GenerateMockRecorderMethod(mockType string, m *model.Method) error { 515 argNames := g.getArgNames(m) 516 517 var argString string 518 if m.Variadic == nil { 519 argString = strings.Join(argNames, ", ") 520 } else { 521 argString = strings.Join(argNames[:len(argNames)-1], ", ") 522 } 523 if argString != "" { 524 argString += " interface{}" 525 } 526 527 if m.Variadic != nil { 528 if argString != "" { 529 argString += ", " 530 } 531 argString += fmt.Sprintf("%s ...interface{}", argNames[len(argNames)-1]) 532 } 533 534 ia := newIdentifierAllocator(argNames) 535 idRecv := ia.allocateIdentifier("mr") 536 537 g.p("// %v indicates an expected call of %v.", m.Name, m.Name) 538 g.p("func (%s *%vMockRecorder) %v(%v) *gomock.Call {", idRecv, mockType, m.Name, argString) 539 g.in() 540 g.p("%s.mock.ctrl.T.Helper()", idRecv) 541 542 var callArgs string 543 if m.Variadic == nil { 544 if len(argNames) > 0 { 545 callArgs = ", " + strings.Join(argNames, ", ") 546 } 547 } else { 548 if len(argNames) == 1 { 549 // Easy: just use ... to push the arguments through. 550 callArgs = ", " + argNames[0] + "..." 551 } else { 552 // Hard: create a temporary slice. 553 idVarArgs := ia.allocateIdentifier("varargs") 554 g.p("%s := append([]interface{}{%s}, %s...)", 555 idVarArgs, 556 strings.Join(argNames[:len(argNames)-1], ", "), 557 argNames[len(argNames)-1]) 558 callArgs = ", " + idVarArgs + "..." 559 } 560 } 561 g.p(`return %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, m.Name, callArgs) 562 563 g.out() 564 g.p("}") 565 return nil 566 } 567 568 func (g *generator) getArgNames(m *model.Method) []string { 569 argNames := make([]string, len(m.In)) 570 for i, p := range m.In { 571 name := p.Name 572 if name == "" || name == "_" { 573 name = fmt.Sprintf("arg%d", i) 574 } 575 argNames[i] = name 576 } 577 if m.Variadic != nil { 578 name := m.Variadic.Name 579 if name == "" { 580 name = fmt.Sprintf("arg%d", len(m.In)) 581 } 582 argNames = append(argNames, name) 583 } 584 return argNames 585 } 586 587 func (g *generator) getArgTypes(m *model.Method, pkgOverride string) []string { 588 argTypes := make([]string, len(m.In)) 589 for i, p := range m.In { 590 argTypes[i] = p.Type.String(g.packageMap, pkgOverride) 591 } 592 if m.Variadic != nil { 593 argTypes = append(argTypes, "..."+m.Variadic.Type.String(g.packageMap, pkgOverride)) 594 } 595 return argTypes 596 } 597 598 type identifierAllocator map[string]struct{} 599 600 func newIdentifierAllocator(taken []string) identifierAllocator { 601 a := make(identifierAllocator, len(taken)) 602 for _, s := range taken { 603 a[s] = struct{}{} 604 } 605 return a 606 } 607 608 func (o identifierAllocator) allocateIdentifier(want string) string { 609 id := want 610 for i := 2; ; i++ { 611 if _, ok := o[id]; !ok { 612 o[id] = struct{}{} 613 return id 614 } 615 id = want + "_" + strconv.Itoa(i) 616 } 617 } 618 619 // Output returns the generator's output, formatted in the standard Go style. 620 func (g *generator) Output() []byte { 621 src, err := toolsimports.Process(g.destination, g.buf.Bytes(), nil) 622 if err != nil { 623 log.Fatalf("Failed to format generated source code: %s\n%s", err, g.buf.String()) 624 } 625 return src 626 } 627 628 // createPackageMap returns a map of import path to package name 629 // for specified importPaths. 630 func createPackageMap(importPaths []string) map[string]string { 631 var pkg struct { 632 Name string 633 ImportPath string 634 } 635 pkgMap := make(map[string]string) 636 b := bytes.NewBuffer(nil) 637 args := []string{"list", "-json"} 638 args = append(args, importPaths...) 639 cmd := exec.Command("go", args...) 640 cmd.Stdout = b 641 cmd.Run() 642 dec := json.NewDecoder(b) 643 for dec.More() { 644 err := dec.Decode(&pkg) 645 if err != nil { 646 log.Printf("failed to decode 'go list' output: %v", err) 647 continue 648 } 649 pkgMap[pkg.ImportPath] = pkg.Name 650 } 651 return pkgMap 652 } 653 654 func printVersion() { 655 if version != "" { 656 fmt.Printf("v%s\nCommit: %s\nDate: %s\n", version, commit, date) 657 } else { 658 printModuleVersion() 659 } 660 } 661 662 // parseImportPackage get package import path via source file 663 // an alternative implementation is to use: 664 // cfg := &packages.Config{Mode: packages.NeedName, Tests: true, Dir: srcDir} 665 // pkgs, err := packages.Load(cfg, "file="+source) 666 // However, it will call "go list" and slow down the performance 667 func parsePackageImport(srcDir string) (string, error) { 668 moduleMode := os.Getenv("GO111MODULE") 669 // trying to find the module 670 if moduleMode != "off" { 671 currentDir := srcDir 672 for { 673 dat, err := ioutil.ReadFile(filepath.Join(currentDir, "go.mod")) 674 if os.IsNotExist(err) { 675 if currentDir == filepath.Dir(currentDir) { 676 // at the root 677 break 678 } 679 currentDir = filepath.Dir(currentDir) 680 continue 681 } else if err != nil { 682 return "", err 683 } 684 modulePath := modfile.ModulePath(dat) 685 return filepath.ToSlash(filepath.Join(modulePath, strings.TrimPrefix(srcDir, currentDir))), nil 686 } 687 } 688 // fall back to GOPATH mode 689 goPaths := os.Getenv("GOPATH") 690 if goPaths == "" { 691 return "", fmt.Errorf("GOPATH is not set") 692 } 693 goPathList := strings.Split(goPaths, string(os.PathListSeparator)) 694 for _, goPath := range goPathList { 695 sourceRoot := filepath.Join(goPath, "src") + string(os.PathSeparator) 696 if strings.HasPrefix(srcDir, sourceRoot) { 697 return filepath.ToSlash(strings.TrimPrefix(srcDir, sourceRoot)), nil 698 } 699 } 700 return "", errOutsideGoPath 701 }