github.com/Capventis/moq@v0.2.6-0.20220316100624-05dd47497214/internal/registry/registry.go (about) 1 package registry 2 3 import ( 4 "errors" 5 "fmt" 6 "go/types" 7 "path/filepath" 8 "sort" 9 "strings" 10 11 "golang.org/x/tools/go/packages" 12 ) 13 14 // Registry encapsulates types information for the source and mock 15 // destination package. For the mock package, it tracks the list of 16 // imports and ensures there are no conflicts in the imported package 17 // qualifiers. 18 type Registry struct { 19 srcPkg *packages.Package 20 moqPkgPath string 21 aliases map[string]string 22 imports map[string]*Package 23 } 24 25 // New loads the source package info and returns a new instance of 26 // Registry. 27 func New(srcDir, moqPkg string) (*Registry, error) { 28 srcPkg, err := pkgInfoFromPath( 29 srcDir, packages.NeedName|packages.NeedSyntax|packages.NeedTypes|packages.NeedTypesInfo, 30 ) 31 if err != nil { 32 return nil, fmt.Errorf("couldn't load source package: %s", err) 33 } 34 35 return &Registry{ 36 srcPkg: srcPkg, 37 moqPkgPath: findPkgPath(moqPkg, srcPkg), 38 aliases: parseImportsAliases(srcPkg), 39 imports: make(map[string]*Package), 40 }, nil 41 } 42 43 // SrcPkg returns the types info for the source package. 44 func (r Registry) SrcPkg() *types.Package { 45 return r.srcPkg.Types 46 } 47 48 // SrcPkgName returns the name of the source package. 49 func (r Registry) SrcPkgName() string { 50 return r.srcPkg.Name 51 } 52 53 // LookupInterface returns the underlying interface definition of the 54 // given interface name. 55 func (r Registry) LookupInterface(name string) (*types.Interface, error) { 56 obj := r.SrcPkg().Scope().Lookup(name) 57 if obj == nil { 58 return nil, fmt.Errorf("interface not found: %s", name) 59 } 60 61 if !types.IsInterface(obj.Type()) { 62 return nil, fmt.Errorf("%s (%s) is not an interface", name, obj.Type()) 63 } 64 65 return obj.Type().Underlying().(*types.Interface).Complete(), nil 66 } 67 68 // MethodScope returns a new MethodScope. 69 func (r *Registry) MethodScope() *MethodScope { 70 return &MethodScope{ 71 registry: r, 72 moqPkgPath: r.moqPkgPath, 73 conflicted: map[string]bool{}, 74 } 75 } 76 77 // AddImport adds the given package to the set of imports. It generates a 78 // suitable alias if there are any conflicts with previously imported 79 // packages. 80 func (r *Registry) AddImport(pkg *types.Package) *Package { 81 path := stripVendorPath(pkg.Path()) 82 if path == r.moqPkgPath { 83 return nil 84 } 85 86 if imprt, ok := r.imports[path]; ok { 87 return imprt 88 } 89 90 imprt := Package{pkg: pkg, Alias: r.aliases[path]} 91 92 if conflict, ok := r.searchImport(imprt.Qualifier()); ok { 93 resolveImportConflict(&imprt, conflict, 0) 94 } 95 96 r.imports[path] = &imprt 97 return &imprt 98 } 99 100 // Imports returns the list of imported packages. The list is sorted by 101 // path. 102 func (r Registry) Imports() []*Package { 103 imports := make([]*Package, 0, len(r.imports)) 104 for _, imprt := range r.imports { 105 imports = append(imports, imprt) 106 } 107 sort.Slice(imports, func(i, j int) bool { 108 return imports[i].Path() < imports[j].Path() 109 }) 110 return imports 111 } 112 113 func (r Registry) searchImport(name string) (*Package, bool) { 114 for _, imprt := range r.imports { 115 if imprt.Qualifier() == name { 116 return imprt, true 117 } 118 } 119 120 return nil, false 121 } 122 123 func pkgInfoFromPath(srcDir string, mode packages.LoadMode) (*packages.Package, error) { 124 pkgs, err := packages.Load(&packages.Config{ 125 Mode: mode, 126 Dir: srcDir, 127 }) 128 if err != nil { 129 return nil, err 130 } 131 if len(pkgs) == 0 { 132 return nil, errors.New("package not found") 133 } 134 if len(pkgs) > 1 { 135 return nil, errors.New("found more than one package") 136 } 137 if errs := pkgs[0].Errors; len(errs) != 0 { 138 if len(errs) == 1 { 139 return nil, errs[0] 140 } 141 return nil, fmt.Errorf("%s (and %d more errors)", errs[0], len(errs)-1) 142 } 143 return pkgs[0], nil 144 } 145 146 func findPkgPath(pkgInputVal string, srcPkg *packages.Package) string { 147 if pkgInputVal == "" { 148 return srcPkg.PkgPath 149 } 150 if pkgInDir(srcPkg.PkgPath, pkgInputVal) { 151 return srcPkg.PkgPath 152 } 153 subdirectoryPath := filepath.Join(srcPkg.PkgPath, pkgInputVal) 154 if pkgInDir(subdirectoryPath, pkgInputVal) { 155 return subdirectoryPath 156 } 157 return "" 158 } 159 160 func pkgInDir(pkgName, dir string) bool { 161 currentPkg, err := pkgInfoFromPath(dir, packages.NeedName) 162 if err != nil { 163 return false 164 } 165 return currentPkg.Name == pkgName || currentPkg.Name+"_test" == pkgName 166 } 167 168 func parseImportsAliases(pkg *packages.Package) map[string]string { 169 aliases := make(map[string]string) 170 for _, syntax := range pkg.Syntax { 171 for _, imprt := range syntax.Imports { 172 if imprt.Name != nil && imprt.Name.Name != "." && imprt.Name.Name != "_" { 173 aliases[strings.Trim(imprt.Path.Value, `"`)] = imprt.Name.Name 174 } 175 } 176 } 177 return aliases 178 } 179 180 // resolveImportConflict generates and assigns a unique alias for 181 // packages with conflicting qualifiers. 182 func resolveImportConflict(a, b *Package, lvl int) { 183 u1, u2 := a.uniqueName(lvl), b.uniqueName(lvl) 184 if u1 != u2 { 185 a.Alias, b.Alias = u1, u2 186 return 187 } 188 189 resolveImportConflict(a, b, lvl+1) 190 }