github.com/april1989/origin-go-tools@v0.0.32/internal/lsp/source/implementation.go (about)

     1  // Copyright 2019 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package source
     6  
     7  import (
     8  	"context"
     9  	"errors"
    10  	"fmt"
    11  	"go/ast"
    12  	"go/token"
    13  	"go/types"
    14  	"sort"
    15  
    16  	"github.com/april1989/origin-go-tools/internal/event"
    17  	"github.com/april1989/origin-go-tools/internal/lsp/protocol"
    18  	"golang.org/x/xerrors"
    19  )
    20  
    21  func Implementation(ctx context.Context, snapshot Snapshot, f FileHandle, pp protocol.Position) ([]protocol.Location, error) {
    22  	ctx, done := event.Start(ctx, "source.Implementation")
    23  	defer done()
    24  
    25  	impls, err := implementations(ctx, snapshot, f, pp)
    26  	if err != nil {
    27  		return nil, err
    28  	}
    29  	var locations []protocol.Location
    30  	for _, impl := range impls {
    31  		if impl.pkg == nil || len(impl.pkg.CompiledGoFiles()) == 0 {
    32  			continue
    33  		}
    34  		rng, err := objToMappedRange(snapshot, impl.pkg, impl.obj)
    35  		if err != nil {
    36  			return nil, err
    37  		}
    38  		pr, err := rng.Range()
    39  		if err != nil {
    40  			return nil, err
    41  		}
    42  		locations = append(locations, protocol.Location{
    43  			URI:   protocol.URIFromSpanURI(rng.URI()),
    44  			Range: pr,
    45  		})
    46  	}
    47  	sort.Slice(locations, func(i, j int) bool {
    48  		li, lj := locations[i], locations[j]
    49  		if li.URI == lj.URI {
    50  			return protocol.CompareRange(li.Range, lj.Range) < 0
    51  		}
    52  		return li.URI < lj.URI
    53  	})
    54  	return locations, nil
    55  }
    56  
    57  var ErrNotAType = errors.New("not a type name or method")
    58  
    59  // implementations returns the concrete implementations of the specified
    60  // interface, or the interfaces implemented by the specified concrete type.
    61  func implementations(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position) ([]qualifiedObject, error) {
    62  	var (
    63  		impls []qualifiedObject
    64  		seen  = make(map[token.Position]bool)
    65  		fset  = s.FileSet()
    66  	)
    67  
    68  	qos, err := qualifiedObjsAtProtocolPos(ctx, s, f, pp)
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  	for _, qo := range qos {
    73  		var (
    74  			queryType   types.Type
    75  			queryMethod *types.Func
    76  		)
    77  
    78  		switch obj := qo.obj.(type) {
    79  		case *types.Func:
    80  			queryMethod = obj
    81  			if recv := obj.Type().(*types.Signature).Recv(); recv != nil {
    82  				queryType = ensurePointer(recv.Type())
    83  			}
    84  		case *types.TypeName:
    85  			queryType = ensurePointer(obj.Type())
    86  		}
    87  
    88  		if queryType == nil {
    89  			return nil, ErrNotAType
    90  		}
    91  
    92  		if types.NewMethodSet(queryType).Len() == 0 {
    93  			return nil, nil
    94  		}
    95  
    96  		// Find all named types, even local types (which can have methods
    97  		// due to promotion).
    98  		var (
    99  			allNamed []*types.Named
   100  			pkgs     = make(map[*types.Package]Package)
   101  		)
   102  		knownPkgs, err := s.KnownPackages(ctx)
   103  		if err != nil {
   104  			return nil, err
   105  		}
   106  		for _, pkg := range knownPkgs {
   107  			pkgs[pkg.GetTypes()] = pkg
   108  			info := pkg.GetTypesInfo()
   109  			for _, obj := range info.Defs {
   110  				obj, ok := obj.(*types.TypeName)
   111  				// We ignore aliases 'type M = N' to avoid duplicate reporting
   112  				// of the Named type N.
   113  				if !ok || obj.IsAlias() {
   114  					continue
   115  				}
   116  				if named, ok := obj.Type().(*types.Named); ok {
   117  					allNamed = append(allNamed, named)
   118  				}
   119  			}
   120  		}
   121  
   122  		// Find all the named types that match our query.
   123  		for _, named := range allNamed {
   124  			var (
   125  				candObj  types.Object = named.Obj()
   126  				candType              = ensurePointer(named)
   127  			)
   128  
   129  			if !concreteImplementsIntf(candType, queryType) {
   130  				continue
   131  			}
   132  
   133  			ms := types.NewMethodSet(candType)
   134  			if ms.Len() == 0 {
   135  				// Skip empty interfaces.
   136  				continue
   137  			}
   138  
   139  			// If client queried a method, look up corresponding candType method.
   140  			if queryMethod != nil {
   141  				sel := ms.Lookup(queryMethod.Pkg(), queryMethod.Name())
   142  				if sel == nil {
   143  					continue
   144  				}
   145  				candObj = sel.Obj()
   146  			}
   147  
   148  			pos := fset.Position(candObj.Pos())
   149  			if candObj == queryMethod || seen[pos] {
   150  				continue
   151  			}
   152  
   153  			seen[pos] = true
   154  
   155  			impls = append(impls, qualifiedObject{
   156  				obj: candObj,
   157  				pkg: pkgs[candObj.Pkg()],
   158  			})
   159  		}
   160  	}
   161  
   162  	return impls, nil
   163  }
   164  
   165  // concreteImplementsIntf returns true if a is an interface type implemented by
   166  // concrete type b, or vice versa.
   167  func concreteImplementsIntf(a, b types.Type) bool {
   168  	aIsIntf, bIsIntf := isInterface(a), isInterface(b)
   169  
   170  	// Make sure exactly one is an interface type.
   171  	if aIsIntf == bIsIntf {
   172  		return false
   173  	}
   174  
   175  	// Rearrange if needed so "a" is the concrete type.
   176  	if aIsIntf {
   177  		a, b = b, a
   178  	}
   179  
   180  	return types.AssignableTo(a, b)
   181  }
   182  
   183  // ensurePointer wraps T in a *types.Pointer if T is a named, non-interface
   184  // type. This is useful to make sure you consider a named type's full method
   185  // set.
   186  func ensurePointer(T types.Type) types.Type {
   187  	if _, ok := T.(*types.Named); ok && !isInterface(T) {
   188  		return types.NewPointer(T)
   189  	}
   190  
   191  	return T
   192  }
   193  
   194  type qualifiedObject struct {
   195  	obj types.Object
   196  
   197  	// pkg is the Package that contains obj's definition.
   198  	pkg Package
   199  
   200  	// node is the *ast.Ident or *ast.ImportSpec we followed to find obj, if any.
   201  	node ast.Node
   202  
   203  	// sourcePkg is the Package that contains node, if any.
   204  	sourcePkg Package
   205  }
   206  
   207  var (
   208  	errBuiltin       = errors.New("builtin object")
   209  	errNoObjectFound = errors.New("no object found")
   210  )
   211  
   212  // qualifiedObjsAtProtocolPos returns info for all the type.Objects
   213  // referenced at the given position. An object will be returned for
   214  // every package that the file belongs to, in every typechecking mode
   215  // applicable.
   216  func qualifiedObjsAtProtocolPos(ctx context.Context, s Snapshot, fh FileHandle, pp protocol.Position) ([]qualifiedObject, error) {
   217  	pkgs, err := s.PackagesForFile(ctx, fh.URI(), TypecheckAll)
   218  	if err != nil {
   219  		return nil, err
   220  	}
   221  	// Check all the packages that the file belongs to.
   222  	var qualifiedObjs []qualifiedObject
   223  	for _, searchpkg := range pkgs {
   224  		astFile, pos, err := getASTFile(searchpkg, fh, pp)
   225  		if err != nil {
   226  			return nil, err
   227  		}
   228  		path := pathEnclosingObjNode(astFile, pos)
   229  		if path == nil {
   230  			continue
   231  		}
   232  		var objs []types.Object
   233  		switch leaf := path[0].(type) {
   234  		case *ast.Ident:
   235  			// If leaf represents an implicit type switch object or the type
   236  			// switch "assign" variable, expand to all of the type switch's
   237  			// implicit objects.
   238  			if implicits, _ := typeSwitchImplicits(searchpkg, path); len(implicits) > 0 {
   239  				objs = append(objs, implicits...)
   240  			} else {
   241  				obj := searchpkg.GetTypesInfo().ObjectOf(leaf)
   242  				if obj == nil {
   243  					return nil, xerrors.Errorf("%w for %q", errNoObjectFound, leaf.Name)
   244  				}
   245  				objs = append(objs, obj)
   246  			}
   247  		case *ast.ImportSpec:
   248  			// Look up the implicit *types.PkgName.
   249  			obj := searchpkg.GetTypesInfo().Implicits[leaf]
   250  			if obj == nil {
   251  				return nil, xerrors.Errorf("%w for import %q", errNoObjectFound, importPath(leaf))
   252  			}
   253  			objs = append(objs, obj)
   254  		}
   255  		// Get all of the transitive dependencies of the search package.
   256  		pkgs := make(map[*types.Package]Package)
   257  		var addPkg func(pkg Package)
   258  		addPkg = func(pkg Package) {
   259  			pkgs[pkg.GetTypes()] = pkg
   260  			for _, imp := range pkg.Imports() {
   261  				if _, ok := pkgs[imp.GetTypes()]; !ok {
   262  					addPkg(imp)
   263  				}
   264  			}
   265  		}
   266  		addPkg(searchpkg)
   267  		for _, obj := range objs {
   268  			if obj.Parent() == types.Universe {
   269  				return nil, xerrors.Errorf("%q: %w", obj.Name(), errBuiltin)
   270  			}
   271  			pkg, ok := pkgs[obj.Pkg()]
   272  			if !ok {
   273  				event.Error(ctx, fmt.Sprintf("no package for obj %s: %v", obj, obj.Pkg()), err)
   274  				continue
   275  			}
   276  			qualifiedObjs = append(qualifiedObjs, qualifiedObject{
   277  				obj:       obj,
   278  				pkg:       pkg,
   279  				sourcePkg: searchpkg,
   280  				node:      path[0],
   281  			})
   282  		}
   283  	}
   284  	// Return an error if no objects were found since callers will assume that
   285  	// the slice has at least 1 element.
   286  	if len(qualifiedObjs) == 0 {
   287  		return nil, errNoObjectFound
   288  	}
   289  	return qualifiedObjs, nil
   290  }
   291  
   292  func getASTFile(pkg Package, f FileHandle, pos protocol.Position) (*ast.File, token.Pos, error) {
   293  	pgf, err := pkg.File(f.URI())
   294  	if err != nil {
   295  		return nil, 0, err
   296  	}
   297  	spn, err := pgf.Mapper.PointSpan(pos)
   298  	if err != nil {
   299  		return nil, 0, err
   300  	}
   301  	rng, err := spn.Range(pgf.Mapper.Converter)
   302  	if err != nil {
   303  		return nil, 0, err
   304  	}
   305  	return pgf.File, rng.Start, nil
   306  }
   307  
   308  // pathEnclosingObjNode returns the AST path to the object-defining
   309  // node associated with pos. "Object-defining" means either an
   310  // *ast.Ident mapped directly to a types.Object or an ast.Node mapped
   311  // implicitly to a types.Object.
   312  func pathEnclosingObjNode(f *ast.File, pos token.Pos) []ast.Node {
   313  	var (
   314  		path  []ast.Node
   315  		found bool
   316  	)
   317  
   318  	ast.Inspect(f, func(n ast.Node) bool {
   319  		if found {
   320  			return false
   321  		}
   322  
   323  		if n == nil {
   324  			path = path[:len(path)-1]
   325  			return false
   326  		}
   327  
   328  		path = append(path, n)
   329  
   330  		switch n := n.(type) {
   331  		case *ast.Ident:
   332  			// Include the position directly after identifier. This handles
   333  			// the common case where the cursor is right after the
   334  			// identifier the user is currently typing. Previously we
   335  			// handled this by calling astutil.PathEnclosingInterval twice,
   336  			// once for "pos" and once for "pos-1".
   337  			found = n.Pos() <= pos && pos <= n.End()
   338  		case *ast.ImportSpec:
   339  			if n.Path.Pos() <= pos && pos < n.Path.End() {
   340  				found = true
   341  				// If import spec has a name, add name to path even though
   342  				// position isn't in the name.
   343  				if n.Name != nil {
   344  					path = append(path, n.Name)
   345  				}
   346  			}
   347  		case *ast.StarExpr:
   348  			// Follow star expressions to the inner identifier.
   349  			if pos == n.Star {
   350  				pos = n.X.Pos()
   351  			}
   352  		case *ast.SelectorExpr:
   353  			// If pos is on the ".", move it into the selector.
   354  			if pos == n.X.End() {
   355  				pos = n.Sel.Pos()
   356  			}
   357  		}
   358  
   359  		return !found
   360  	})
   361  
   362  	if len(path) == 0 {
   363  		return nil
   364  	}
   365  
   366  	// Reverse path so leaf is first element.
   367  	for i := 0; i < len(path)/2; i++ {
   368  		path[i], path[len(path)-1-i] = path[len(path)-1-i], path[i]
   369  	}
   370  
   371  	return path
   372  }