github.com/moitias/moq@v0.0.0-20240223074357-5eb0f0ba4054/internal/registry/registry.go (about) 1 package registry 2 3 import ( 4 "errors" 5 "fmt" 6 "go/ast" 7 "go/types" 8 "log" 9 "path/filepath" 10 "sort" 11 "strings" 12 13 "golang.org/x/tools/go/packages" 14 ) 15 16 // Registry encapsulates types information for the source and mock 17 // destination package. For the mock package, it tracks the list of 18 // imports and ensures there are no conflicts in the imported package 19 // qualifiers. 20 type Registry struct { 21 srcPkgName string 22 srcPkgTypes *types.Package 23 moqPkgPath string 24 aliases map[string]string 25 imports map[string]*Package 26 } 27 28 // New loads the source package info and returns a new instance of 29 // Registry. 30 func New(srcDir, moqPkg string) (*Registry, error) { 31 srcPkg, err := pkgInfoFromPath( 32 srcDir, packages.NeedName|packages.NeedSyntax|packages.NeedTypes, 33 ) 34 fmt.Println("GOT SRCPKG", srcPkg, err) 35 if err != nil { 36 return nil, fmt.Errorf("couldn't load source package: %s", err) 37 } 38 39 return &Registry{ 40 srcPkgName: srcPkg.Name, 41 srcPkgTypes: srcPkg.Types, 42 moqPkgPath: findPkgPath(moqPkg, srcPkg.PkgPath), 43 aliases: parseImportsAliases(srcPkg.Syntax), 44 imports: make(map[string]*Package), 45 }, nil 46 } 47 48 // SrcPkg returns the types info for the source package. 49 func (r Registry) SrcPkg() *types.Package { 50 return r.srcPkgTypes 51 } 52 53 // SrcPkgName returns the name of the source package. 54 func (r Registry) SrcPkgName() string { 55 return r.srcPkgName 56 } 57 58 // LookupInterface returns the underlying interface definition of the 59 // given interface name. 60 func (r Registry) LookupInterface(name string) (*types.Interface, *types.TypeParamList, error) { 61 obj := r.SrcPkg().Scope().Lookup(name) 62 if obj == nil { 63 return nil, nil, fmt.Errorf("interface not found: %s", name) 64 } 65 66 if !types.IsInterface(obj.Type()) { 67 return nil, nil, fmt.Errorf("%s (%s) is not an interface", name, obj.Type()) 68 } 69 70 var tparams *types.TypeParamList 71 named, ok := obj.Type().(*types.Named) 72 if ok { 73 tparams = named.TypeParams() 74 } 75 76 return obj.Type().Underlying().(*types.Interface).Complete(), tparams, nil 77 } 78 79 // MethodScope returns a new MethodScope. 80 func (r *Registry) MethodScope() *MethodScope { 81 return &MethodScope{ 82 registry: r, 83 moqPkgPath: r.moqPkgPath, 84 conflicted: map[string]bool{}, 85 } 86 } 87 88 // AddImport adds the given package to the set of imports. It generates a 89 // suitable alias if there are any conflicts with previously imported 90 // packages. 91 func (r *Registry) AddImport(pkg *types.Package) *Package { 92 path := stripVendorPath(pkg.Path()) 93 if path == r.moqPkgPath { 94 return nil 95 } 96 97 if imprt, ok := r.imports[path]; ok { 98 return imprt 99 } 100 101 imprt := Package{pkg: pkg, Alias: r.aliases[path]} 102 103 if conflict, ok := r.searchImport(imprt.Qualifier()); ok { 104 r.resolveImportConflict(&imprt, conflict, 0) 105 } 106 107 r.imports[path] = &imprt 108 return &imprt 109 } 110 111 // Imports returns the list of imported packages. The list is sorted by 112 // path. 113 func (r Registry) Imports() []*Package { 114 imports := make([]*Package, 0, len(r.imports)) 115 for _, imprt := range r.imports { 116 imports = append(imports, imprt) 117 } 118 sort.Slice(imports, func(i, j int) bool { 119 return imports[i].Path() < imports[j].Path() 120 }) 121 return imports 122 } 123 124 func (r Registry) searchImport(name string) (*Package, bool) { 125 for _, imprt := range r.imports { 126 if imprt.Qualifier() == name { 127 return imprt, true 128 } 129 } 130 131 return nil, false 132 } 133 134 // resolveImportConflict generates and assigns a unique alias for 135 // packages with conflicting qualifiers. 136 func (r Registry) resolveImportConflict(a, b *Package, lvl int) { 137 if a.uniqueName(lvl) == b.uniqueName(lvl) { 138 r.resolveImportConflict(a, b, lvl+1) 139 return 140 } 141 142 for _, p := range []*Package{a, b} { 143 name := p.uniqueName(lvl) 144 // Even though the name is not conflicting with the other package we 145 // got, the new name we want to pick might already be taken. So check 146 // again for conflicts and resolve them as well. Since the name for 147 // this package would also get set in the recursive function call, skip 148 // setting the alias after it. 149 if conflict, ok := r.searchImport(name); ok && conflict != p { 150 r.resolveImportConflict(p, conflict, lvl+1) 151 continue 152 } 153 154 p.Alias = name 155 } 156 } 157 158 func pkgInfoFromPath(srcDir string, mode packages.LoadMode) (*packages.Package, error) { 159 log.Println("LOAD PACKAGES", mode, "dir", srcDir) 160 pkgs, err := packages.Load(&packages.Config{ 161 Mode: mode, 162 Dir: srcDir, 163 }) 164 log.Println("GOT PACKAGES", pkgs, len(pkgs), err) 165 if err != nil { 166 return nil, err 167 } 168 if len(pkgs) == 0 { 169 return nil, errors.New("package not found") 170 } 171 if len(pkgs) > 1 { 172 return nil, errors.New("found more than one package") 173 } 174 if errs := pkgs[0].Errors; len(errs) != 0 { 175 log.Println("GOT 1 pKG WITH >0 ERRS", errs) 176 if len(errs) == 1 { 177 return nil, errs[0] 178 } 179 return nil, fmt.Errorf("%s (and %d more errors)", errs[0], len(errs)-1) 180 } 181 return pkgs[0], nil 182 } 183 184 func findPkgPath(pkgInputVal string, srcPkgPath string) string { 185 if pkgInputVal == "" { 186 return srcPkgPath 187 } 188 if pkgInDir(srcPkgPath, pkgInputVal) { 189 return srcPkgPath 190 } 191 subdirectoryPath := filepath.Join(srcPkgPath, pkgInputVal) 192 if pkgInDir(subdirectoryPath, pkgInputVal) { 193 return subdirectoryPath 194 } 195 return "" 196 } 197 198 func pkgInDir(pkgName, dir string) bool { 199 currentPkg, err := pkgInfoFromPath(dir, packages.NeedName) 200 if err != nil { 201 return false 202 } 203 return currentPkg.Name == pkgName || currentPkg.Name+"_test" == pkgName 204 } 205 206 func parseImportsAliases(syntaxTree []*ast.File) map[string]string { 207 aliases := make(map[string]string) 208 for _, syntax := range syntaxTree { 209 for _, imprt := range syntax.Imports { 210 if imprt.Name != nil && imprt.Name.Name != "." && imprt.Name.Name != "_" { 211 aliases[strings.Trim(imprt.Path.Value, `"`)] = imprt.Name.Name 212 } 213 } 214 } 215 return aliases 216 }