github.com/djui/moq@v0.3.3/internal/registry/registry.go (about) 1 package registry 2 3 import ( 4 "errors" 5 "fmt" 6 "go/types" 7 "path/filepath" 8 "sort" 9 "strings" 10 11 "golang.org/x/tools/go/packages" 12 ) 13 14 // Registry encapsulates types information for the source and mock 15 // destination package. For the mock package, it tracks the list of 16 // imports and ensures there are no conflicts in the imported package 17 // qualifiers. 18 type Registry struct { 19 srcPkg *packages.Package 20 moqPkgPath string 21 aliases map[string]string 22 imports map[string]*Package 23 } 24 25 // New loads the source package info and returns a new instance of 26 // Registry. 27 func New(srcDir, moqPkg string) (*Registry, error) { 28 srcPkg, err := pkgInfoFromPath( 29 srcDir, packages.NeedName|packages.NeedSyntax|packages.NeedTypes|packages.NeedTypesInfo|packages.NeedDeps, 30 ) 31 if err != nil { 32 return nil, fmt.Errorf("couldn't load source package: %s", err) 33 } 34 35 return &Registry{ 36 srcPkg: srcPkg, 37 moqPkgPath: findPkgPath(moqPkg, srcPkg), 38 aliases: parseImportsAliases(srcPkg), 39 imports: make(map[string]*Package), 40 }, nil 41 } 42 43 // SrcPkg returns the types info for the source package. 44 func (r Registry) SrcPkg() *types.Package { 45 return r.srcPkg.Types 46 } 47 48 // SrcPkgName returns the name of the source package. 49 func (r Registry) SrcPkgName() string { 50 return r.srcPkg.Name 51 } 52 53 // LookupInterface returns the underlying interface definition of the 54 // given interface name. 55 func (r Registry) LookupInterface(name string) (*types.Interface, *types.TypeParamList, error) { 56 obj := r.SrcPkg().Scope().Lookup(name) 57 if obj == nil { 58 return nil, nil, fmt.Errorf("interface not found: %s", name) 59 } 60 61 if !types.IsInterface(obj.Type()) { 62 return nil, nil, fmt.Errorf("%s (%s) is not an interface", name, obj.Type()) 63 } 64 65 var tparams *types.TypeParamList 66 named, ok := obj.Type().(*types.Named) 67 if ok { 68 tparams = named.TypeParams() 69 } 70 71 return obj.Type().Underlying().(*types.Interface).Complete(), tparams, nil 72 } 73 74 // MethodScope returns a new MethodScope. 75 func (r *Registry) MethodScope() *MethodScope { 76 return &MethodScope{ 77 registry: r, 78 moqPkgPath: r.moqPkgPath, 79 conflicted: map[string]bool{}, 80 } 81 } 82 83 // AddImport adds the given package to the set of imports. It generates a 84 // suitable alias if there are any conflicts with previously imported 85 // packages. 86 func (r *Registry) AddImport(pkg *types.Package) *Package { 87 path := stripVendorPath(pkg.Path()) 88 if path == r.moqPkgPath { 89 return nil 90 } 91 92 if imprt, ok := r.imports[path]; ok { 93 return imprt 94 } 95 96 imprt := Package{pkg: pkg, Alias: r.aliases[path]} 97 98 if conflict, ok := r.searchImport(imprt.Qualifier()); ok { 99 resolveImportConflict(&imprt, conflict, 0) 100 } 101 102 r.imports[path] = &imprt 103 return &imprt 104 } 105 106 // Imports returns the list of imported packages. The list is sorted by 107 // path. 108 func (r Registry) Imports() []*Package { 109 imports := make([]*Package, 0, len(r.imports)) 110 for _, imprt := range r.imports { 111 imports = append(imports, imprt) 112 } 113 sort.Slice(imports, func(i, j int) bool { 114 return imports[i].Path() < imports[j].Path() 115 }) 116 return imports 117 } 118 119 func (r Registry) searchImport(name string) (*Package, bool) { 120 for _, imprt := range r.imports { 121 if imprt.Qualifier() == name { 122 return imprt, true 123 } 124 } 125 126 return nil, false 127 } 128 129 func pkgInfoFromPath(srcDir string, mode packages.LoadMode) (*packages.Package, error) { 130 pkgs, err := packages.Load(&packages.Config{ 131 Mode: mode, 132 Dir: srcDir, 133 }) 134 if err != nil { 135 return nil, err 136 } 137 if len(pkgs) == 0 { 138 return nil, errors.New("package not found") 139 } 140 if len(pkgs) > 1 { 141 return nil, errors.New("found more than one package") 142 } 143 if errs := pkgs[0].Errors; len(errs) != 0 { 144 if len(errs) == 1 { 145 return nil, errs[0] 146 } 147 return nil, fmt.Errorf("%s (and %d more errors)", errs[0], len(errs)-1) 148 } 149 return pkgs[0], nil 150 } 151 152 func findPkgPath(pkgInputVal string, srcPkg *packages.Package) string { 153 if pkgInputVal == "" { 154 return srcPkg.PkgPath 155 } 156 if pkgInDir(srcPkg.PkgPath, pkgInputVal) { 157 return srcPkg.PkgPath 158 } 159 subdirectoryPath := filepath.Join(srcPkg.PkgPath, pkgInputVal) 160 if pkgInDir(subdirectoryPath, pkgInputVal) { 161 return subdirectoryPath 162 } 163 return "" 164 } 165 166 func pkgInDir(pkgName, dir string) bool { 167 currentPkg, err := pkgInfoFromPath(dir, packages.NeedName) 168 if err != nil { 169 return false 170 } 171 return currentPkg.Name == pkgName || currentPkg.Name+"_test" == pkgName 172 } 173 174 func parseImportsAliases(pkg *packages.Package) map[string]string { 175 aliases := make(map[string]string) 176 for _, syntax := range pkg.Syntax { 177 for _, imprt := range syntax.Imports { 178 if imprt.Name != nil && imprt.Name.Name != "." && imprt.Name.Name != "_" { 179 aliases[strings.Trim(imprt.Path.Value, `"`)] = imprt.Name.Name 180 } 181 } 182 } 183 return aliases 184 } 185 186 // resolveImportConflict generates and assigns a unique alias for 187 // packages with conflicting qualifiers. 188 func resolveImportConflict(a, b *Package, lvl int) { 189 u1, u2 := a.uniqueName(lvl), b.uniqueName(lvl) 190 if u1 != u2 { 191 a.Alias, b.Alias = u1, u2 192 return 193 } 194 195 resolveImportConflict(a, b, lvl+1) 196 }