github.com/moitias/moq@v0.0.0-20240223074357-5eb0f0ba4054/internal/registry/registry.go (about)

     1  package registry
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"go/ast"
     7  	"go/types"
     8  	"log"
     9  	"path/filepath"
    10  	"sort"
    11  	"strings"
    12  
    13  	"golang.org/x/tools/go/packages"
    14  )
    15  
    16  // Registry encapsulates types information for the source and mock
    17  // destination package. For the mock package, it tracks the list of
    18  // imports and ensures there are no conflicts in the imported package
    19  // qualifiers.
    20  type Registry struct {
    21  	srcPkgName  string
    22  	srcPkgTypes *types.Package
    23  	moqPkgPath  string
    24  	aliases     map[string]string
    25  	imports     map[string]*Package
    26  }
    27  
    28  // New loads the source package info and returns a new instance of
    29  // Registry.
    30  func New(srcDir, moqPkg string) (*Registry, error) {
    31  	srcPkg, err := pkgInfoFromPath(
    32  		srcDir, packages.NeedName|packages.NeedSyntax|packages.NeedTypes,
    33  	)
    34  	fmt.Println("GOT SRCPKG", srcPkg, err)
    35  	if err != nil {
    36  		return nil, fmt.Errorf("couldn't load source package: %s", err)
    37  	}
    38  
    39  	return &Registry{
    40  		srcPkgName:  srcPkg.Name,
    41  		srcPkgTypes: srcPkg.Types,
    42  		moqPkgPath:  findPkgPath(moqPkg, srcPkg.PkgPath),
    43  		aliases:     parseImportsAliases(srcPkg.Syntax),
    44  		imports:     make(map[string]*Package),
    45  	}, nil
    46  }
    47  
    48  // SrcPkg returns the types info for the source package.
    49  func (r Registry) SrcPkg() *types.Package {
    50  	return r.srcPkgTypes
    51  }
    52  
    53  // SrcPkgName returns the name of the source package.
    54  func (r Registry) SrcPkgName() string {
    55  	return r.srcPkgName
    56  }
    57  
    58  // LookupInterface returns the underlying interface definition of the
    59  // given interface name.
    60  func (r Registry) LookupInterface(name string) (*types.Interface, *types.TypeParamList, error) {
    61  	obj := r.SrcPkg().Scope().Lookup(name)
    62  	if obj == nil {
    63  		return nil, nil, fmt.Errorf("interface not found: %s", name)
    64  	}
    65  
    66  	if !types.IsInterface(obj.Type()) {
    67  		return nil, nil, fmt.Errorf("%s (%s) is not an interface", name, obj.Type())
    68  	}
    69  
    70  	var tparams *types.TypeParamList
    71  	named, ok := obj.Type().(*types.Named)
    72  	if ok {
    73  		tparams = named.TypeParams()
    74  	}
    75  
    76  	return obj.Type().Underlying().(*types.Interface).Complete(), tparams, nil
    77  }
    78  
    79  // MethodScope returns a new MethodScope.
    80  func (r *Registry) MethodScope() *MethodScope {
    81  	return &MethodScope{
    82  		registry:   r,
    83  		moqPkgPath: r.moqPkgPath,
    84  		conflicted: map[string]bool{},
    85  	}
    86  }
    87  
    88  // AddImport adds the given package to the set of imports. It generates a
    89  // suitable alias if there are any conflicts with previously imported
    90  // packages.
    91  func (r *Registry) AddImport(pkg *types.Package) *Package {
    92  	path := stripVendorPath(pkg.Path())
    93  	if path == r.moqPkgPath {
    94  		return nil
    95  	}
    96  
    97  	if imprt, ok := r.imports[path]; ok {
    98  		return imprt
    99  	}
   100  
   101  	imprt := Package{pkg: pkg, Alias: r.aliases[path]}
   102  
   103  	if conflict, ok := r.searchImport(imprt.Qualifier()); ok {
   104  		r.resolveImportConflict(&imprt, conflict, 0)
   105  	}
   106  
   107  	r.imports[path] = &imprt
   108  	return &imprt
   109  }
   110  
   111  // Imports returns the list of imported packages. The list is sorted by
   112  // path.
   113  func (r Registry) Imports() []*Package {
   114  	imports := make([]*Package, 0, len(r.imports))
   115  	for _, imprt := range r.imports {
   116  		imports = append(imports, imprt)
   117  	}
   118  	sort.Slice(imports, func(i, j int) bool {
   119  		return imports[i].Path() < imports[j].Path()
   120  	})
   121  	return imports
   122  }
   123  
   124  func (r Registry) searchImport(name string) (*Package, bool) {
   125  	for _, imprt := range r.imports {
   126  		if imprt.Qualifier() == name {
   127  			return imprt, true
   128  		}
   129  	}
   130  
   131  	return nil, false
   132  }
   133  
   134  // resolveImportConflict generates and assigns a unique alias for
   135  // packages with conflicting qualifiers.
   136  func (r Registry) resolveImportConflict(a, b *Package, lvl int) {
   137  	if a.uniqueName(lvl) == b.uniqueName(lvl) {
   138  		r.resolveImportConflict(a, b, lvl+1)
   139  		return
   140  	}
   141  
   142  	for _, p := range []*Package{a, b} {
   143  		name := p.uniqueName(lvl)
   144  		// Even though the name is not conflicting with the other package we
   145  		// got, the new name we want to pick might already be taken. So check
   146  		// again for conflicts and resolve them as well. Since the name for
   147  		// this package would also get set in the recursive function call, skip
   148  		// setting the alias after it.
   149  		if conflict, ok := r.searchImport(name); ok && conflict != p {
   150  			r.resolveImportConflict(p, conflict, lvl+1)
   151  			continue
   152  		}
   153  
   154  		p.Alias = name
   155  	}
   156  }
   157  
   158  func pkgInfoFromPath(srcDir string, mode packages.LoadMode) (*packages.Package, error) {
   159  	log.Println("LOAD PACKAGES", mode, "dir", srcDir)
   160  	pkgs, err := packages.Load(&packages.Config{
   161  		Mode: mode,
   162  		Dir:  srcDir,
   163  	})
   164  	log.Println("GOT PACKAGES", pkgs, len(pkgs), err)
   165  	if err != nil {
   166  		return nil, err
   167  	}
   168  	if len(pkgs) == 0 {
   169  		return nil, errors.New("package not found")
   170  	}
   171  	if len(pkgs) > 1 {
   172  		return nil, errors.New("found more than one package")
   173  	}
   174  	if errs := pkgs[0].Errors; len(errs) != 0 {
   175  		log.Println("GOT 1 pKG WITH >0 ERRS", errs)
   176  		if len(errs) == 1 {
   177  			return nil, errs[0]
   178  		}
   179  		return nil, fmt.Errorf("%s (and %d more errors)", errs[0], len(errs)-1)
   180  	}
   181  	return pkgs[0], nil
   182  }
   183  
   184  func findPkgPath(pkgInputVal string, srcPkgPath string) string {
   185  	if pkgInputVal == "" {
   186  		return srcPkgPath
   187  	}
   188  	if pkgInDir(srcPkgPath, pkgInputVal) {
   189  		return srcPkgPath
   190  	}
   191  	subdirectoryPath := filepath.Join(srcPkgPath, pkgInputVal)
   192  	if pkgInDir(subdirectoryPath, pkgInputVal) {
   193  		return subdirectoryPath
   194  	}
   195  	return ""
   196  }
   197  
   198  func pkgInDir(pkgName, dir string) bool {
   199  	currentPkg, err := pkgInfoFromPath(dir, packages.NeedName)
   200  	if err != nil {
   201  		return false
   202  	}
   203  	return currentPkg.Name == pkgName || currentPkg.Name+"_test" == pkgName
   204  }
   205  
   206  func parseImportsAliases(syntaxTree []*ast.File) map[string]string {
   207  	aliases := make(map[string]string)
   208  	for _, syntax := range syntaxTree {
   209  		for _, imprt := range syntax.Imports {
   210  			if imprt.Name != nil && imprt.Name.Name != "." && imprt.Name.Name != "_" {
   211  				aliases[strings.Trim(imprt.Path.Value, `"`)] = imprt.Name.Name
   212  			}
   213  		}
   214  	}
   215  	return aliases
   216  }