github.com/djui/moq@v0.3.3/internal/registry/registry.go (about)

     1  package registry
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"go/types"
     7  	"path/filepath"
     8  	"sort"
     9  	"strings"
    10  
    11  	"golang.org/x/tools/go/packages"
    12  )
    13  
    14  // Registry encapsulates types information for the source and mock
    15  // destination package. For the mock package, it tracks the list of
    16  // imports and ensures there are no conflicts in the imported package
    17  // qualifiers.
    18  type Registry struct {
    19  	srcPkg     *packages.Package
    20  	moqPkgPath string
    21  	aliases    map[string]string
    22  	imports    map[string]*Package
    23  }
    24  
    25  // New loads the source package info and returns a new instance of
    26  // Registry.
    27  func New(srcDir, moqPkg string) (*Registry, error) {
    28  	srcPkg, err := pkgInfoFromPath(
    29  		srcDir, packages.NeedName|packages.NeedSyntax|packages.NeedTypes|packages.NeedTypesInfo|packages.NeedDeps,
    30  	)
    31  	if err != nil {
    32  		return nil, fmt.Errorf("couldn't load source package: %s", err)
    33  	}
    34  
    35  	return &Registry{
    36  		srcPkg:     srcPkg,
    37  		moqPkgPath: findPkgPath(moqPkg, srcPkg),
    38  		aliases:    parseImportsAliases(srcPkg),
    39  		imports:    make(map[string]*Package),
    40  	}, nil
    41  }
    42  
    43  // SrcPkg returns the types info for the source package.
    44  func (r Registry) SrcPkg() *types.Package {
    45  	return r.srcPkg.Types
    46  }
    47  
    48  // SrcPkgName returns the name of the source package.
    49  func (r Registry) SrcPkgName() string {
    50  	return r.srcPkg.Name
    51  }
    52  
    53  // LookupInterface returns the underlying interface definition of the
    54  // given interface name.
    55  func (r Registry) LookupInterface(name string) (*types.Interface, *types.TypeParamList, error) {
    56  	obj := r.SrcPkg().Scope().Lookup(name)
    57  	if obj == nil {
    58  		return nil, nil, fmt.Errorf("interface not found: %s", name)
    59  	}
    60  
    61  	if !types.IsInterface(obj.Type()) {
    62  		return nil, nil, fmt.Errorf("%s (%s) is not an interface", name, obj.Type())
    63  	}
    64  
    65  	var tparams *types.TypeParamList
    66  	named, ok := obj.Type().(*types.Named)
    67  	if ok {
    68  		tparams = named.TypeParams()
    69  	}
    70  
    71  	return obj.Type().Underlying().(*types.Interface).Complete(), tparams, nil
    72  }
    73  
    74  // MethodScope returns a new MethodScope.
    75  func (r *Registry) MethodScope() *MethodScope {
    76  	return &MethodScope{
    77  		registry:   r,
    78  		moqPkgPath: r.moqPkgPath,
    79  		conflicted: map[string]bool{},
    80  	}
    81  }
    82  
    83  // AddImport adds the given package to the set of imports. It generates a
    84  // suitable alias if there are any conflicts with previously imported
    85  // packages.
    86  func (r *Registry) AddImport(pkg *types.Package) *Package {
    87  	path := stripVendorPath(pkg.Path())
    88  	if path == r.moqPkgPath {
    89  		return nil
    90  	}
    91  
    92  	if imprt, ok := r.imports[path]; ok {
    93  		return imprt
    94  	}
    95  
    96  	imprt := Package{pkg: pkg, Alias: r.aliases[path]}
    97  
    98  	if conflict, ok := r.searchImport(imprt.Qualifier()); ok {
    99  		resolveImportConflict(&imprt, conflict, 0)
   100  	}
   101  
   102  	r.imports[path] = &imprt
   103  	return &imprt
   104  }
   105  
   106  // Imports returns the list of imported packages. The list is sorted by
   107  // path.
   108  func (r Registry) Imports() []*Package {
   109  	imports := make([]*Package, 0, len(r.imports))
   110  	for _, imprt := range r.imports {
   111  		imports = append(imports, imprt)
   112  	}
   113  	sort.Slice(imports, func(i, j int) bool {
   114  		return imports[i].Path() < imports[j].Path()
   115  	})
   116  	return imports
   117  }
   118  
   119  func (r Registry) searchImport(name string) (*Package, bool) {
   120  	for _, imprt := range r.imports {
   121  		if imprt.Qualifier() == name {
   122  			return imprt, true
   123  		}
   124  	}
   125  
   126  	return nil, false
   127  }
   128  
   129  func pkgInfoFromPath(srcDir string, mode packages.LoadMode) (*packages.Package, error) {
   130  	pkgs, err := packages.Load(&packages.Config{
   131  		Mode: mode,
   132  		Dir:  srcDir,
   133  	})
   134  	if err != nil {
   135  		return nil, err
   136  	}
   137  	if len(pkgs) == 0 {
   138  		return nil, errors.New("package not found")
   139  	}
   140  	if len(pkgs) > 1 {
   141  		return nil, errors.New("found more than one package")
   142  	}
   143  	if errs := pkgs[0].Errors; len(errs) != 0 {
   144  		if len(errs) == 1 {
   145  			return nil, errs[0]
   146  		}
   147  		return nil, fmt.Errorf("%s (and %d more errors)", errs[0], len(errs)-1)
   148  	}
   149  	return pkgs[0], nil
   150  }
   151  
   152  func findPkgPath(pkgInputVal string, srcPkg *packages.Package) string {
   153  	if pkgInputVal == "" {
   154  		return srcPkg.PkgPath
   155  	}
   156  	if pkgInDir(srcPkg.PkgPath, pkgInputVal) {
   157  		return srcPkg.PkgPath
   158  	}
   159  	subdirectoryPath := filepath.Join(srcPkg.PkgPath, pkgInputVal)
   160  	if pkgInDir(subdirectoryPath, pkgInputVal) {
   161  		return subdirectoryPath
   162  	}
   163  	return ""
   164  }
   165  
   166  func pkgInDir(pkgName, dir string) bool {
   167  	currentPkg, err := pkgInfoFromPath(dir, packages.NeedName)
   168  	if err != nil {
   169  		return false
   170  	}
   171  	return currentPkg.Name == pkgName || currentPkg.Name+"_test" == pkgName
   172  }
   173  
   174  func parseImportsAliases(pkg *packages.Package) map[string]string {
   175  	aliases := make(map[string]string)
   176  	for _, syntax := range pkg.Syntax {
   177  		for _, imprt := range syntax.Imports {
   178  			if imprt.Name != nil && imprt.Name.Name != "." && imprt.Name.Name != "_" {
   179  				aliases[strings.Trim(imprt.Path.Value, `"`)] = imprt.Name.Name
   180  			}
   181  		}
   182  	}
   183  	return aliases
   184  }
   185  
   186  // resolveImportConflict generates and assigns a unique alias for
   187  // packages with conflicting qualifiers.
   188  func resolveImportConflict(a, b *Package, lvl int) {
   189  	u1, u2 := a.uniqueName(lvl), b.uniqueName(lvl)
   190  	if u1 != u2 {
   191  		a.Alias, b.Alias = u1, u2
   192  		return
   193  	}
   194  
   195  	resolveImportConflict(a, b, lvl+1)
   196  }