github.com/relnod/pegomock@v2.0.1+incompatible/mockgen/mockgen.go (about) 1 // Copyright 2015 Peter Goetz 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Based on the work done in 16 // https://github.com/golang/mock/blob/d581abfc04272f381d7a05e4b80163ea4e2b9447/mockgen/mockgen.go 17 18 // MockGen generates mock implementations of Go interfaces. 19 package mockgen 20 21 // TODO: This does not support recursive embedded interfaces. 22 // TODO: This does not support embedding package-local interfaces in a separate file. 23 24 import ( 25 "bytes" 26 "fmt" 27 "go/format" 28 "go/token" 29 "path" 30 "strconv" 31 "strings" 32 "unicode" 33 34 "github.com/petergtz/pegomock/model" 35 ) 36 37 const mockFrameworkImportPath = "github.com/petergtz/pegomock" 38 39 func GenerateOutput(ast *model.Package, source, packageOut, selfPackage string) ([]byte, map[string]string) { 40 g := generator{typesSet: make(map[string]string)} 41 g.generateCode(source, ast, packageOut, selfPackage) 42 return g.formattedOutput(), g.typesSet 43 } 44 45 type generator struct { 46 buf bytes.Buffer 47 packageMap map[string]string // map from import path to package name 48 typesSet map[string]string 49 } 50 51 func (g *generator) generateCode(source string, pkg *model.Package, pkgName, selfPackage string) { 52 g.p("// Code generated by pegomock. DO NOT EDIT.") 53 g.p("// Source: %v", source) 54 g.emptyLine() 55 56 importPaths := pkg.Imports() 57 importPaths[mockFrameworkImportPath] = true 58 packageMap, nonVendorPackageMap := generateUniquePackageNamesFor(importPaths) 59 g.packageMap = packageMap 60 61 g.p("package %v", pkgName) 62 g.emptyLine() 63 g.p("import (") 64 g.p("\"reflect\"") 65 g.p("\"time\"") 66 for packagePath, packageName := range nonVendorPackageMap { 67 if packagePath != selfPackage && packagePath != "time" && packagePath != "reflect" { 68 g.p("%v %q", packageName, packagePath) 69 } 70 } 71 for _, packagePath := range pkg.DotImports { 72 g.p(". %q", packagePath) 73 } 74 g.p(")") 75 76 for _, iface := range pkg.Interfaces { 77 g.generateMockFor(iface, selfPackage) 78 } 79 } 80 81 func generateUniquePackageNamesFor(importPaths map[string]bool) (packageMap, nonVendorPackageMap map[string]string) { 82 packageMap = make(map[string]string, len(importPaths)) 83 nonVendorPackageMap = make(map[string]string, len(importPaths)) 84 packageNamesAlreadyUsed := make(map[string]bool, len(importPaths)) 85 for importPath := range importPaths { 86 sanitizedPackagePathBaseName := sanitize(path.Base(importPath)) 87 88 // Local names for an imported package can usually be the basename of the import path. 89 // A couple of situations don't permit that, such as duplicate local names 90 // (e.g. importing "html/template" and "text/template"), or where the basename is 91 // a keyword (e.g. "foo/case"). 92 // try base0, base1, ... 93 packageName := sanitizedPackagePathBaseName 94 for i := 0; packageNamesAlreadyUsed[packageName] || token.Lookup(packageName).IsKeyword(); i++ { 95 packageName = sanitizedPackagePathBaseName + strconv.Itoa(i) 96 } 97 98 packageMap[importPath] = packageName 99 packageNamesAlreadyUsed[packageName] = true 100 101 nonVendorPackageMap[vendorCleaned(importPath)] = packageName 102 } 103 return 104 } 105 106 func vendorCleaned(importPath string) string { 107 if split := strings.Split(importPath, "/vendor/"); len(split) > 1 { 108 return split[1] 109 } 110 return importPath 111 } 112 113 // sanitize cleans up a string to make a suitable package name. 114 // pkgName in reflect mode is the base name of the import path, 115 // which might have characters that are illegal to have in package names. 116 func sanitize(s string) string { 117 t := "" 118 for _, r := range s { 119 if t == "" { 120 if unicode.IsLetter(r) || r == '_' { 121 t += string(r) 122 continue 123 } 124 } else { 125 if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' { 126 t += string(r) 127 continue 128 } 129 } 130 t += "_" 131 } 132 if t == "_" { 133 t = "x" 134 } 135 return t 136 } 137 138 func (g *generator) generateMockFor(iface *model.Interface, selfPackage string) { 139 mockTypeName := "Mock" + iface.Name 140 g.generateMockType(mockTypeName) 141 for _, method := range iface.Methods { 142 g.generateMockMethod(mockTypeName, method, selfPackage) 143 g.emptyLine() 144 145 addTypesFromMethodParamsTo(g.typesSet, method.In, g.packageMap) 146 addTypesFromMethodParamsTo(g.typesSet, method.Out, g.packageMap) 147 } 148 g.generateMockVerifyMethods(iface.Name) 149 g.generateVerifierType(iface.Name) 150 for _, method := range iface.Methods { 151 ongoingVerificationTypeName := fmt.Sprintf("%v_%v_OngoingVerification", iface.Name, method.Name) 152 args, argNames, argTypes, _ := argDataFor(method, g.packageMap, selfPackage) 153 g.generateVerifierMethod(iface.Name, method, selfPackage, ongoingVerificationTypeName, args, argNames) 154 g.generateOngoingVerificationType(iface.Name, ongoingVerificationTypeName) 155 g.generateOngoingVerificationGetCapturedArguments(ongoingVerificationTypeName, argNames, argTypes) 156 g.generateOngoingVerificationGetAllCapturedArguments(ongoingVerificationTypeName, argTypes, method.Variadic != nil) 157 } 158 } 159 160 func (g *generator) generateMockType(mockTypeName string) { 161 g. 162 emptyLine(). 163 p("type %v struct {", mockTypeName). 164 p(" fail func(message string, callerSkip ...int)"). 165 p("}"). 166 emptyLine(). 167 p("func New%v(options ...pegomock.Option) *%v {", mockTypeName, mockTypeName). 168 p(" mock := &%v{}", mockTypeName). 169 p(" for _, option := range options {"). 170 p(" option.Apply(mock)"). 171 p(" }"). 172 p(" return mock"). 173 p("}"). 174 emptyLine(). 175 p("func (mock *%v) SetFailHandler(fh pegomock.FailHandler) { mock.fail = fh }", mockTypeName). 176 p("func (mock *%v) FailHandler() pegomock.FailHandler { return mock.fail }", mockTypeName). 177 emptyLine() 178 } 179 180 // If non-empty, pkgOverride is the package in which unqualified types reside. 181 func (g *generator) generateMockMethod(mockType string, method *model.Method, pkgOverride string) *generator { 182 args, argNames, _, returnTypes := argDataFor(method, g.packageMap, pkgOverride) 183 g.p("func (mock *%v) %v(%v) (%v) {", mockType, method.Name, join(args), join(returnTypes)) 184 g.p("if mock == nil {"). 185 p(" panic(\"mock must not be nil. Use myMock := New%v().\")", mockType). 186 p("}") 187 g.GenerateParamsDeclaration(argNames, method.Variadic != nil) 188 reflectReturnTypes := make([]string, len(returnTypes)) 189 for i, returnType := range returnTypes { 190 reflectReturnTypes[i] = fmt.Sprintf("reflect.TypeOf((*%v)(nil)).Elem()", returnType) 191 } 192 resultAssignment := "" 193 if len(method.Out) > 0 { 194 resultAssignment = "result :=" 195 } 196 g.p("%v pegomock.GetGenericMockFrom(mock).Invoke(\"%v\", params, []reflect.Type{%v})", 197 resultAssignment, method.Name, strings.Join(reflectReturnTypes, ", ")) 198 if len(method.Out) > 0 { 199 // TODO: translate LastInvocation into a Matcher so it can be used as key for Stubbings 200 for i, returnType := range returnTypes { 201 g.p("var ret%v %v", i, returnType) 202 } 203 g.p("if len(result) != 0 {") 204 returnValues := make([]string, len(returnTypes)) 205 for i, returnType := range returnTypes { 206 g.p("if result[%v] != nil {", i) 207 g.p("ret%v = result[%v].(%v)", i, i, returnType) 208 g.p("}") 209 returnValues[i] = fmt.Sprintf("ret%v", i) 210 } 211 g.p("}") 212 g.p("return %v", strings.Join(returnValues, ", ")) 213 } 214 g.p("}") 215 return g 216 } 217 218 func (g *generator) generateVerifierType(interfaceName string) *generator { 219 return g. 220 p("type Verifier%v struct {", interfaceName). 221 p(" mock *Mock%v", interfaceName). 222 p(" invocationCountMatcher pegomock.Matcher"). 223 p(" inOrderContext *pegomock.InOrderContext"). 224 p(" timeout time.Duration"). 225 p("}"). 226 emptyLine() 227 } 228 229 func (g *generator) generateMockVerifyMethods(interfaceName string) { 230 g. 231 p("func (mock *Mock%v) VerifyWasCalledOnce() *Verifier%v {", interfaceName, interfaceName). 232 p(" return &Verifier%v{", interfaceName). 233 p(" mock: mock,"). 234 p(" invocationCountMatcher: pegomock.Times(1),"). 235 p(" }"). 236 p("}"). 237 emptyLine(). 238 p("func (mock *Mock%v) VerifyWasCalled(invocationCountMatcher pegomock.Matcher) *Verifier%v {", interfaceName, interfaceName). 239 p(" return &Verifier%v{", interfaceName). 240 p(" mock: mock,"). 241 p(" invocationCountMatcher: invocationCountMatcher,"). 242 p(" }"). 243 p("}"). 244 emptyLine(). 245 p("func (mock *Mock%v) VerifyWasCalledInOrder(invocationCountMatcher pegomock.Matcher, inOrderContext *pegomock.InOrderContext) *Verifier%v {", interfaceName, interfaceName). 246 p(" return &Verifier%v{", interfaceName). 247 p(" mock: mock,"). 248 p(" invocationCountMatcher: invocationCountMatcher,"). 249 p(" inOrderContext: inOrderContext,"). 250 p(" }"). 251 p("}"). 252 emptyLine(). 253 p("func (mock *Mock%v) VerifyWasCalledEventually(invocationCountMatcher pegomock.Matcher, timeout time.Duration) *Verifier%v {", interfaceName, interfaceName). 254 p(" return &Verifier%v{", interfaceName). 255 p(" mock: mock,"). 256 p(" invocationCountMatcher: invocationCountMatcher,"). 257 p(" timeout: timeout,"). 258 p(" }"). 259 p("}"). 260 emptyLine() 261 } 262 263 func (g *generator) generateVerifierMethod(interfaceName string, method *model.Method, pkgOverride string, returnTypeString string, args []string, argNames []string) *generator { 264 return g. 265 p("func (verifier *Verifier%v) %v(%v) *%v {", interfaceName, method.Name, join(args), returnTypeString). 266 GenerateParamsDeclaration(argNames, method.Variadic != nil). 267 p("methodInvocations := pegomock.GetGenericMockFrom(verifier.mock).Verify(verifier.inOrderContext, verifier.invocationCountMatcher, \"%v\", params, verifier.timeout)", method.Name). 268 p("return &%v{mock: verifier.mock, methodInvocations: methodInvocations}", returnTypeString). 269 p("}") 270 } 271 272 func (g *generator) GenerateParamsDeclaration(argNames []string, isVariadic bool) *generator { 273 if isVariadic { 274 return g. 275 p("params := []pegomock.Param{%v}", strings.Join(argNames[0:len(argNames)-1], ", ")). 276 p("for _, param := range %v {", argNames[len(argNames)-1]). 277 p("params = append(params, param)"). 278 p("}") 279 } else { 280 return g.p("params := []pegomock.Param{%v}", join(argNames)) 281 } 282 } 283 284 func (g *generator) generateOngoingVerificationType(interfaceName string, ongoingVerificationStructName string) *generator { 285 return g. 286 p("type %v struct {", ongoingVerificationStructName). 287 p("mock *Mock%v", interfaceName). 288 p(" methodInvocations []pegomock.MethodInvocation"). 289 p("}"). 290 emptyLine() 291 } 292 293 func (g *generator) generateOngoingVerificationGetCapturedArguments(ongoingVerificationStructName string, argNames []string, argTypes []string) *generator { 294 g.p("func (c *%v) GetCapturedArguments() (%v) {", ongoingVerificationStructName, join(argTypes)) 295 if len(argNames) > 0 { 296 indexedArgNames := make([]string, len(argNames)) 297 for i, argName := range argNames { 298 indexedArgNames[i] = argName + "[len(" + argName + ")-1]" 299 } 300 g.p("%v := c.GetAllCapturedArguments()", join(argNames)) 301 g.p("return %v", strings.Join(indexedArgNames, ", ")) 302 } 303 g.p("}") 304 g.emptyLine() 305 return g 306 } 307 308 func (g *generator) generateOngoingVerificationGetAllCapturedArguments(ongoingVerificationStructName string, argTypes []string, isVariadic bool) *generator { 309 argsAsArray := make([]string, len(argTypes)) 310 for i, argType := range argTypes { 311 argsAsArray[i] = fmt.Sprintf("_param%v []%v", i, argType) 312 } 313 g.p("func (c *%v) GetAllCapturedArguments() (%v) {", ongoingVerificationStructName, strings.Join(argsAsArray, ", ")) 314 if len(argTypes) > 0 { 315 g.p("params := pegomock.GetGenericMockFrom(c.mock).GetInvocationParams(c.methodInvocations)") 316 g.p("if len(params) > 0 {") 317 for i, argType := range argTypes { 318 if isVariadic && i == len(argTypes)-1 { 319 variadicBasicType := strings.Replace(argType, "[]", "", 1) 320 g. 321 p("_param%v = make([]%v, len(params[%v]))", i, argType, i). 322 p("for u := range params[0] {"). // the number of invocations and hence len(params[x]) is equal for all x 323 p("_param%v[u] = make([]%v, len(params)-%v)", i, variadicBasicType, i). 324 p("for x := %v; x < len(params); x++ {", i). 325 p("if params[x][u] != nil {"). 326 p("_param%v[u][x-%v] = params[x][u].(%v)", i, i, variadicBasicType). 327 p("}"). 328 p("}"). 329 p("}") 330 break 331 } else { 332 g.p("_param%v = make([]%v, len(params[%v]))", i, argType, i) 333 g.p("for u, param := range params[%v] {", i) 334 g.p("_param%v[u]=param.(%v)", i, argType) 335 g.p("}") 336 } 337 } 338 g.p("}") 339 g.p("return") 340 } 341 g.p("}") 342 g.emptyLine() 343 return g 344 } 345 346 func argDataFor(method *model.Method, packageMap map[string]string, pkgOverride string) ( 347 args []string, 348 argNames []string, 349 argTypes []string, 350 returnTypes []string, 351 ) { 352 args = make([]string, len(method.In)) 353 argNames = make([]string, len(method.In)) 354 argTypes = make([]string, len(args)) 355 for i, arg := range method.In { 356 argName := arg.Name 357 if argName == "" { 358 argName = fmt.Sprintf("_param%d", i) 359 } 360 argType := arg.Type.String(packageMap, pkgOverride) 361 args[i] = argName + " " + argType 362 argNames[i] = argName 363 argTypes[i] = argType 364 } 365 if method.Variadic != nil { 366 argName := method.Variadic.Name 367 if argName == "" { 368 argName = fmt.Sprintf("_param%d", len(method.In)) 369 } 370 argType := method.Variadic.Type.String(packageMap, pkgOverride) 371 args = append(args, argName+" ..."+argType) 372 argNames = append(argNames, argName) 373 argTypes = append(argTypes, "[]"+argType) 374 } 375 returnTypes = make([]string, len(method.Out)) 376 for i, ret := range method.Out { 377 returnTypes[i] = ret.Type.String(packageMap, pkgOverride) 378 } 379 return 380 } 381 382 func addTypesFromMethodParamsTo(typesSet map[string]string, params []*model.Parameter, packageMap map[string]string) { 383 for _, param := range params { 384 switch typedType := param.Type.(type) { 385 case *model.NamedType, *model.PointerType, *model.ArrayType, *model.MapType, *model.ChanType: 386 if _, exists := typesSet[underscoreNameFor(typedType, packageMap)]; !exists { 387 typesSet[underscoreNameFor(typedType, packageMap)] = generateMatcherSourceCode(typedType, packageMap) 388 } 389 case *model.FuncType: 390 // matcher generation for funcs not supported yet 391 // TODO implement 392 case model.PredeclaredType: 393 // skip. These come as part of pegomock. 394 default: 395 panic("Should not get here") 396 } 397 } 398 } 399 400 func generateMatcherSourceCode(t model.Type, packageMap map[string]string) string { 401 return fmt.Sprintf(`// Code generated by pegomock. DO NOT EDIT. 402 package matchers 403 404 import ( 405 "reflect" 406 "github.com/petergtz/pegomock" 407 %v 408 ) 409 410 func Any%v() %v { 411 pegomock.RegisterMatcher(pegomock.NewAnyMatcher(reflect.TypeOf((*(%v))(nil)).Elem())) 412 var nullValue %v 413 return nullValue 414 } 415 416 func Eq%v(value %v) %v { 417 pegomock.RegisterMatcher(&pegomock.EqMatcher{Value: value}) 418 var nullValue %v 419 return nullValue 420 } 421 `, 422 optionalPackageOf(t, packageMap), 423 camelcaseNameFor(t, packageMap), 424 t.String(packageMap, ""), 425 t.String(packageMap, ""), 426 t.String(packageMap, ""), 427 428 camelcaseNameFor(t, packageMap), 429 t.String(packageMap, ""), 430 t.String(packageMap, ""), 431 t.String(packageMap, ""), 432 ) 433 } 434 435 func optionalPackageOf(t model.Type, packageMap map[string]string) string { 436 switch typedType := t.(type) { 437 case model.PredeclaredType: 438 return "" 439 case *model.NamedType: 440 return fmt.Sprintf("%v \"%v\"", packageMap[typedType.Package], vendorCleaned(typedType.Package)) 441 case *model.PointerType: 442 return optionalPackageOf(typedType.Type, packageMap) 443 case *model.ArrayType: 444 return optionalPackageOf(typedType.Type, packageMap) 445 case *model.MapType: 446 return optionalPackageOf(typedType.Key, packageMap) + "\n" + optionalPackageOf(typedType.Value, packageMap) 447 case *model.ChanType: 448 return optionalPackageOf(typedType.Type, packageMap) 449 // TODO: 450 // case *model.FuncType: 451 default: 452 panic(fmt.Sprintf("TODO implement optionalPackageOf for: %v\nis type of %T\n", typedType, typedType)) 453 } 454 } 455 456 func spaceSeparatedNameFor(t model.Type, packageMap map[string]string) string { 457 switch typedType := t.(type) { 458 case model.PredeclaredType: 459 tt := typedType.String(packageMap, "") 460 if tt == "interface{}" { 461 // if a predeclared type is interface 462 // return a string type without curly brackets 463 return "interface" 464 } 465 return tt 466 case *model.NamedType: 467 return strings.Replace((typedType.String(packageMap, "")), ".", " ", -1) 468 case *model.PointerType: 469 return "ptr to " + spaceSeparatedNameFor(typedType.Type, packageMap) 470 case *model.ArrayType: 471 if typedType.Len == -1 { 472 return "slice of " + spaceSeparatedNameFor(typedType.Type, packageMap) 473 } else { 474 return "array of " + spaceSeparatedNameFor(typedType.Type, packageMap) 475 } 476 case *model.MapType: 477 return "map of " + spaceSeparatedNameFor(typedType.Key, packageMap) + " to " + spaceSeparatedNameFor(typedType.Value, packageMap) 478 case *model.ChanType: 479 return "chan of " + spaceSeparatedNameFor(typedType.Type, packageMap) 480 // TODO: 481 // case *model.FuncType: 482 default: 483 return fmt.Sprintf("TODO implement matcher for: %v\nis type of %T\n", typedType, typedType) 484 } 485 } 486 487 func camelcaseNameFor(t model.Type, packageMap map[string]string) string { 488 return strings.Replace(strings.Title(strings.Replace(spaceSeparatedNameFor(t, packageMap), "_", " ", -1)), " ", "", -1) 489 } 490 491 func underscoreNameFor(t model.Type, packageMap map[string]string) string { 492 return strings.ToLower(strings.Replace(spaceSeparatedNameFor(t, packageMap), " ", "_", -1)) 493 } 494 495 func (g *generator) p(format string, args ...interface{}) *generator { 496 fmt.Fprintf(&g.buf, format+"\n", args...) 497 return g 498 } 499 500 func (g *generator) emptyLine() *generator { return g.p("") } 501 502 func (g *generator) formattedOutput() []byte { 503 src, err := format.Source(g.buf.Bytes()) 504 if err != nil { 505 panic(fmt.Errorf("Failed to format generated source code: %s\n%s", err, g.buf.String())) 506 } 507 return src 508 } 509 510 func join(s []string) string { return strings.Join(s, ", ") }