github.com/rancher/moq@v0.0.0-20200712062324-13d1f37d2d77/pkg/moq/moq.go (about) 1 package moq 2 3 import ( 4 "bytes" 5 "errors" 6 "fmt" 7 "go/build" 8 "go/types" 9 "io" 10 "os" 11 "path" 12 "path/filepath" 13 "strconv" 14 "strings" 15 "text/template" 16 17 "golang.org/x/tools/go/packages" 18 ) 19 20 // Mocker can generate mock structs. 21 type Mocker struct { 22 srcPkg *packages.Package 23 tmpl *template.Template 24 pkgName string 25 pkgPath string 26 fmter func(src []byte) ([]byte, error) 27 28 importToAlias map[string]string 29 importAliases map[string]bool 30 importLines map[string]bool 31 } 32 33 // Config specifies details about how interfaces should be mocked. 34 // SrcDir is the only field which needs be specified. 35 type Config struct { 36 SrcDir string 37 PkgName string 38 Formatter string 39 } 40 41 // New makes a new Mocker for the specified package directory. 42 func New(conf Config) (*Mocker, error) { 43 srcPkg, err := pkgInfoFromPath(conf.SrcDir, packages.NeedName|packages.NeedTypes|packages.NeedTypesInfo|packages.NeedSyntax) 44 if err != nil { 45 return nil, fmt.Errorf("couldn't load source package: %s", err) 46 } 47 48 pkgName := conf.PkgName 49 if pkgName == "" { 50 pkgName = srcPkg.Name 51 } 52 53 pkgPath, err := findPkgPath(conf.PkgName, srcPkg) 54 if err != nil { 55 return nil, fmt.Errorf("couldn't load mock package: %s", err) 56 } 57 58 tmpl, err := template.New("moq").Funcs(templateFuncs).Parse(moqTemplate) 59 if err != nil { 60 return nil, err 61 } 62 63 fmter := gofmt 64 if conf.Formatter == "goimports" { 65 fmter = goimports 66 } 67 68 importAliases := make(map[string]bool) 69 importToAlias := make(map[string]string) 70 71 // Attempt to preserve original aliases for prettiness 72 for _, syntax := range srcPkg.Syntax { 73 for _, importSpec := range syntax.Imports { 74 if importSpec.Name != nil && importSpec.Path != nil { 75 importAliases[importSpec.Name.Name] = true 76 importToAlias[strings.Trim(importSpec.Path.Value, "\"")] = importSpec.Name.Name 77 } 78 } 79 } 80 81 return &Mocker{ 82 tmpl: tmpl, 83 srcPkg: srcPkg, 84 pkgName: pkgName, 85 pkgPath: pkgPath, 86 fmter: fmter, 87 importLines: make(map[string]bool), 88 importAliases: importAliases, 89 importToAlias: importToAlias, 90 }, nil 91 } 92 93 func findPkgPath(pkgInputVal string, srcPkg *packages.Package) (string, error) { 94 if pkgInputVal == "" { 95 return srcPkg.PkgPath, nil 96 } 97 if pkgInDir(".", pkgInputVal) { 98 return ".", nil 99 } 100 if pkgInDir(srcPkg.PkgPath, pkgInputVal) { 101 return srcPkg.PkgPath, nil 102 } 103 subdirectoryPath := filepath.Join(srcPkg.PkgPath, pkgInputVal) 104 if pkgInDir(subdirectoryPath, pkgInputVal) { 105 return subdirectoryPath, nil 106 } 107 return "", nil 108 } 109 110 func pkgInDir(pkgName, dir string) bool { 111 currentPkg, err := pkgInfoFromPath(dir, packages.NeedName) 112 if err != nil { 113 return false 114 } 115 return currentPkg.Name == pkgName || currentPkg.Name+"_test" == pkgName 116 } 117 118 // Mock generates a mock for the specified interface name. 119 func (m *Mocker) Mock(w io.Writer, names ...string) error { 120 if len(names) == 0 { 121 return errors.New("must specify one interface") 122 } 123 124 doc := doc{ 125 PackageName: m.pkgName, 126 } 127 128 mocksMethods := false 129 130 paramCache := make(map[string][]*param) 131 132 tpkg := m.srcPkg.Types 133 for _, name := range names { 134 n, mockName := parseInterfaceName(name) 135 iface := tpkg.Scope().Lookup(n) 136 if iface == nil { 137 return fmt.Errorf("cannot find interface %s", n) 138 } 139 if !types.IsInterface(iface.Type()) { 140 return fmt.Errorf("%s (%s) not an interface", n, iface.Type().String()) 141 } 142 iiface := iface.Type().Underlying().(*types.Interface).Complete() 143 obj := obj{ 144 InterfaceName: n, 145 MockName: mockName, 146 } 147 for i := 0; i < iiface.NumMethods(); i++ { 148 mocksMethods = true 149 meth := iiface.Method(i) 150 sig := meth.Type().(*types.Signature) 151 method := &method{ 152 Name: meth.Name(), 153 } 154 obj.Methods = append(obj.Methods, method) 155 method.Params, method.Returns = m.extractArgs(sig) 156 157 for _, param := range method.Params { 158 paramCache[param.Name] = append(paramCache[param.Name], param) 159 } 160 } 161 doc.Objects = append(doc.Objects, obj) 162 } 163 164 if mocksMethods { 165 _, importLine := m.qualifierAndImportLine("sync", "sync") 166 doc.Imports = append(doc.Imports, importLine) 167 } 168 169 for pkgToImport := range m.importLines { 170 doc.Imports = append(doc.Imports, pkgToImport) 171 } 172 173 if tpkg.Name() != m.pkgName { 174 qualifier, importLine := m.qualifierAndImportLine(tpkg.Path(), tpkg.Name()) 175 doc.SourcePackagePrefix = qualifier + "." 176 doc.Imports = append(doc.Imports, importLine) 177 } 178 179 for pkg := range m.importAliases { 180 if params, hasConflict := paramCache[pkg]; hasConflict { 181 for _, param := range params { 182 param.LocalName = fmt.Sprintf("%sMoqParam", param.LocalName) 183 } 184 } 185 } 186 187 var buf bytes.Buffer 188 err := m.tmpl.Execute(&buf, doc) 189 if err != nil { 190 return err 191 } 192 formatted, err := m.fmter(buf.Bytes()) 193 if err != nil { 194 return err 195 } 196 if _, err := w.Write(formatted); err != nil { 197 return err 198 } 199 return nil 200 } 201 202 func (m *Mocker) allocAlias(path string, pkgName string) string { 203 suffix := 0 204 attemptedName := pkgName 205 for { 206 if _, taken := m.importAliases[attemptedName]; taken { 207 suffix++ 208 attemptedName = fmt.Sprintf("%s%d", pkgName, suffix) 209 continue 210 } 211 212 m.importAliases[attemptedName] = true 213 m.importToAlias[path] = attemptedName 214 215 // Don't alias packages that don't require an alias 216 if attemptedName == pkgName { 217 m.importToAlias[path] = "" 218 return "" 219 } 220 221 return attemptedName 222 } 223 } 224 225 func (m *Mocker) getAlias(path string, pkgName string) string { 226 alias, aliasSet := m.importToAlias[path] 227 if !aliasSet { 228 alias = m.allocAlias(path, pkgName) 229 } 230 return alias 231 } 232 233 func (m *Mocker) qualifierAndImportLine(pkg, pkgName string) (string, string) { 234 pkg = stripVendorPath(pkg) 235 alias := m.getAlias(pkg, pkgName) 236 importLine := quoteImport(alias, pkg) 237 if alias == "" { 238 return pkgName, importLine 239 } 240 return alias, importLine 241 } 242 243 func quoteImport(alias, pkg string) string { 244 if alias == "" { 245 return fmt.Sprintf("\"%s\"", pkg) 246 } 247 return fmt.Sprintf("%s \"%s\"", alias, pkg) 248 } 249 250 func (m *Mocker) packageQualifier(pkg *types.Package) string { 251 if m.pkgPath != "" && m.pkgPath == pkg.Path() { 252 return "" 253 } 254 path := pkg.Path() 255 if pkg.Path() == "." { 256 wd, err := os.Getwd() 257 if err == nil { 258 path = stripGopath(wd) 259 } 260 } 261 262 qualifier, importLine := m.qualifierAndImportLine(path, pkg.Name()) 263 m.importLines[importLine] = true 264 return qualifier 265 } 266 267 func (m *Mocker) extractArgs(sig *types.Signature) (params, results []*param) { 268 pp := sig.Params() 269 for i := 0; i < pp.Len(); i++ { 270 p := m.buildParam(pp.At(i), "in"+strconv.Itoa(i+1)) 271 // check for final variadic argument 272 p.Variadic = sig.Variadic() && i == pp.Len()-1 && p.Type[0:2] == "[]" 273 params = append(params, p) 274 } 275 276 rr := sig.Results() 277 for i := 0; i < rr.Len(); i++ { 278 results = append(results, m.buildParam(rr.At(i), "out"+strconv.Itoa(i+1))) 279 } 280 281 return 282 } 283 284 func (m *Mocker) buildParam(v *types.Var, fallbackName string) *param { 285 name := v.Name() 286 if name == "" { 287 name = fallbackName 288 } 289 typ := types.TypeString(v.Type(), m.packageQualifier) 290 return ¶m{Name: name, LocalName: name, Type: typ} 291 } 292 293 func pkgInfoFromPath(srcDir string, mode packages.LoadMode) (*packages.Package, error) { 294 pkgs, err := packages.Load(&packages.Config{ 295 Mode: mode, 296 Dir: srcDir, 297 }) 298 if err != nil { 299 return nil, err 300 } 301 if len(pkgs) == 0 { 302 return nil, errors.New("No packages found") 303 } 304 if len(pkgs) > 1 { 305 return nil, errors.New("More than one package was found") 306 } 307 return pkgs[0], nil 308 } 309 310 func parseInterfaceName(name string) (ifaceName, mockName string) { 311 parts := strings.SplitN(name, ":", 2) 312 ifaceName = parts[0] 313 mockName = ifaceName + "Mock" 314 if len(parts) == 2 { 315 mockName = parts[1] 316 } 317 return 318 } 319 320 type doc struct { 321 PackageName string 322 SourcePackagePrefix string 323 Objects []obj 324 Imports []string 325 } 326 327 type obj struct { 328 InterfaceName string 329 MockName string 330 Methods []*method 331 } 332 type method struct { 333 Name string 334 Params []*param 335 Returns []*param 336 } 337 338 func (m *method) Arglist() string { 339 params := make([]string, len(m.Params)) 340 for i, p := range m.Params { 341 params[i] = p.String() 342 } 343 return strings.Join(params, ", ") 344 } 345 346 func (m *method) ArgCallList() string { 347 params := make([]string, len(m.Params)) 348 for i, p := range m.Params { 349 params[i] = p.CallName() 350 } 351 return strings.Join(params, ", ") 352 } 353 354 func (m *method) ReturnArglist() string { 355 params := make([]string, len(m.Returns)) 356 for i, p := range m.Returns { 357 params[i] = p.TypeString() 358 } 359 if len(m.Returns) > 1 { 360 return fmt.Sprintf("(%s)", strings.Join(params, ", ")) 361 } 362 return strings.Join(params, ", ") 363 } 364 365 type param struct { 366 Name string 367 LocalName string 368 Type string 369 Variadic bool 370 } 371 372 func (p param) String() string { 373 return fmt.Sprintf("%s %s", p.LocalName, p.TypeString()) 374 } 375 376 func (p param) CallName() string { 377 if p.Variadic { 378 return p.LocalName + "..." 379 } 380 return p.LocalName 381 } 382 383 func (p param) TypeString() string { 384 if p.Variadic { 385 return "..." + p.Type[2:] 386 } 387 return p.Type 388 } 389 390 var templateFuncs = template.FuncMap{ 391 "Exported": func(s string) string { 392 if s == "" { 393 return "" 394 } 395 for _, initialism := range golintInitialisms { 396 if strings.ToUpper(s) == initialism { 397 return initialism 398 } 399 } 400 return strings.ToUpper(s[0:1]) + s[1:] 401 }, 402 } 403 404 // stripVendorPath strips the vendor dir prefix from a package path. 405 // For example we might encounter an absolute path like 406 // github.com/foo/bar/vendor/github.com/pkg/errors which is resolved 407 // to github.com/pkg/errors. 408 func stripVendorPath(p string) string { 409 parts := strings.Split(p, "/vendor/") 410 if len(parts) == 1 { 411 return p 412 } 413 return strings.TrimLeft(path.Join(parts[1:]...), "/") 414 } 415 416 // stripGopath takes the directory to a package and removes the 417 // $GOPATH/src path to get the canonical package name. 418 func stripGopath(p string) string { 419 for _, srcDir := range build.Default.SrcDirs() { 420 rel, err := filepath.Rel(srcDir, p) 421 if err != nil || strings.HasPrefix(rel, "..") { 422 continue 423 } 424 return filepath.ToSlash(rel) 425 } 426 return p 427 }