github.com/powerman/golang-tools@v0.1.11-0.20220410185822-5ad214d8d803/internal/lsp/source/implementation.go (about)

     1  // Copyright 2019 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package source
     6  
     7  import (
     8  	"context"
     9  	"errors"
    10  	"fmt"
    11  	"go/ast"
    12  	"go/token"
    13  	"go/types"
    14  	"sort"
    15  
    16  	"github.com/powerman/golang-tools/internal/event"
    17  	"github.com/powerman/golang-tools/internal/lsp/protocol"
    18  	"github.com/powerman/golang-tools/internal/span"
    19  	"golang.org/x/xerrors"
    20  )
    21  
    22  func Implementation(ctx context.Context, snapshot Snapshot, f FileHandle, pp protocol.Position) ([]protocol.Location, error) {
    23  	ctx, done := event.Start(ctx, "source.Implementation")
    24  	defer done()
    25  
    26  	impls, err := implementations(ctx, snapshot, f, pp)
    27  	if err != nil {
    28  		return nil, err
    29  	}
    30  	var locations []protocol.Location
    31  	for _, impl := range impls {
    32  		if impl.pkg == nil || len(impl.pkg.CompiledGoFiles()) == 0 {
    33  			continue
    34  		}
    35  		rng, err := objToMappedRange(snapshot, impl.pkg, impl.obj)
    36  		if err != nil {
    37  			return nil, err
    38  		}
    39  		pr, err := rng.Range()
    40  		if err != nil {
    41  			return nil, err
    42  		}
    43  		locations = append(locations, protocol.Location{
    44  			URI:   protocol.URIFromSpanURI(rng.URI()),
    45  			Range: pr,
    46  		})
    47  	}
    48  	sort.Slice(locations, func(i, j int) bool {
    49  		li, lj := locations[i], locations[j]
    50  		if li.URI == lj.URI {
    51  			return protocol.CompareRange(li.Range, lj.Range) < 0
    52  		}
    53  		return li.URI < lj.URI
    54  	})
    55  	return locations, nil
    56  }
    57  
    58  var ErrNotAType = errors.New("not a type name or method")
    59  
    60  // implementations returns the concrete implementations of the specified
    61  // interface, or the interfaces implemented by the specified concrete type.
    62  func implementations(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position) ([]qualifiedObject, error) {
    63  	var (
    64  		impls []qualifiedObject
    65  		seen  = make(map[token.Position]bool)
    66  		fset  = s.FileSet()
    67  	)
    68  
    69  	qos, err := qualifiedObjsAtProtocolPos(ctx, s, f.URI(), pp)
    70  	if err != nil {
    71  		return nil, err
    72  	}
    73  	for _, qo := range qos {
    74  		var (
    75  			queryType   types.Type
    76  			queryMethod *types.Func
    77  		)
    78  
    79  		switch obj := qo.obj.(type) {
    80  		case *types.Func:
    81  			queryMethod = obj
    82  			if recv := obj.Type().(*types.Signature).Recv(); recv != nil {
    83  				queryType = ensurePointer(recv.Type())
    84  			}
    85  		case *types.TypeName:
    86  			queryType = ensurePointer(obj.Type())
    87  		}
    88  
    89  		if queryType == nil {
    90  			return nil, ErrNotAType
    91  		}
    92  
    93  		if types.NewMethodSet(queryType).Len() == 0 {
    94  			return nil, nil
    95  		}
    96  
    97  		// Find all named types, even local types (which can have methods
    98  		// due to promotion).
    99  		var (
   100  			allNamed []*types.Named
   101  			pkgs     = make(map[*types.Package]Package)
   102  		)
   103  		knownPkgs, err := s.KnownPackages(ctx)
   104  		if err != nil {
   105  			return nil, err
   106  		}
   107  		for _, pkg := range knownPkgs {
   108  			pkgs[pkg.GetTypes()] = pkg
   109  			info := pkg.GetTypesInfo()
   110  			for _, obj := range info.Defs {
   111  				obj, ok := obj.(*types.TypeName)
   112  				// We ignore aliases 'type M = N' to avoid duplicate reporting
   113  				// of the Named type N.
   114  				if !ok || obj.IsAlias() {
   115  					continue
   116  				}
   117  				if named, ok := obj.Type().(*types.Named); ok {
   118  					allNamed = append(allNamed, named)
   119  				}
   120  			}
   121  		}
   122  
   123  		// Find all the named types that match our query.
   124  		for _, named := range allNamed {
   125  			var (
   126  				candObj  types.Object = named.Obj()
   127  				candType              = ensurePointer(named)
   128  			)
   129  
   130  			if !concreteImplementsIntf(candType, queryType) {
   131  				continue
   132  			}
   133  
   134  			ms := types.NewMethodSet(candType)
   135  			if ms.Len() == 0 {
   136  				// Skip empty interfaces.
   137  				continue
   138  			}
   139  
   140  			// If client queried a method, look up corresponding candType method.
   141  			if queryMethod != nil {
   142  				sel := ms.Lookup(queryMethod.Pkg(), queryMethod.Name())
   143  				if sel == nil {
   144  					continue
   145  				}
   146  				candObj = sel.Obj()
   147  			}
   148  
   149  			pos := fset.Position(candObj.Pos())
   150  			if candObj == queryMethod || seen[pos] {
   151  				continue
   152  			}
   153  
   154  			seen[pos] = true
   155  
   156  			impls = append(impls, qualifiedObject{
   157  				obj: candObj,
   158  				pkg: pkgs[candObj.Pkg()],
   159  			})
   160  		}
   161  	}
   162  
   163  	return impls, nil
   164  }
   165  
   166  // concreteImplementsIntf returns true if a is an interface type implemented by
   167  // concrete type b, or vice versa.
   168  func concreteImplementsIntf(a, b types.Type) bool {
   169  	aIsIntf, bIsIntf := IsInterface(a), IsInterface(b)
   170  
   171  	// Make sure exactly one is an interface type.
   172  	if aIsIntf == bIsIntf {
   173  		return false
   174  	}
   175  
   176  	// Rearrange if needed so "a" is the concrete type.
   177  	if aIsIntf {
   178  		a, b = b, a
   179  	}
   180  
   181  	return types.AssignableTo(a, b)
   182  }
   183  
   184  // ensurePointer wraps T in a *types.Pointer if T is a named, non-interface
   185  // type. This is useful to make sure you consider a named type's full method
   186  // set.
   187  func ensurePointer(T types.Type) types.Type {
   188  	if _, ok := T.(*types.Named); ok && !IsInterface(T) {
   189  		return types.NewPointer(T)
   190  	}
   191  
   192  	return T
   193  }
   194  
   195  type qualifiedObject struct {
   196  	obj types.Object
   197  
   198  	// pkg is the Package that contains obj's definition.
   199  	pkg Package
   200  
   201  	// node is the *ast.Ident or *ast.ImportSpec we followed to find obj, if any.
   202  	node ast.Node
   203  
   204  	// sourcePkg is the Package that contains node, if any.
   205  	sourcePkg Package
   206  }
   207  
   208  var (
   209  	errBuiltin       = errors.New("builtin object")
   210  	errNoObjectFound = errors.New("no object found")
   211  )
   212  
   213  // qualifiedObjsAtProtocolPos returns info for all the type.Objects
   214  // referenced at the given position. An object will be returned for
   215  // every package that the file belongs to, in every typechecking mode
   216  // applicable.
   217  func qualifiedObjsAtProtocolPos(ctx context.Context, s Snapshot, uri span.URI, pp protocol.Position) ([]qualifiedObject, error) {
   218  	pkgs, err := s.PackagesForFile(ctx, uri, TypecheckAll, false)
   219  	if err != nil {
   220  		return nil, err
   221  	}
   222  	if len(pkgs) == 0 {
   223  		return nil, errNoObjectFound
   224  	}
   225  	pkg := pkgs[0]
   226  	pgf, err := pkg.File(uri)
   227  	if err != nil {
   228  		return nil, err
   229  	}
   230  	spn, err := pgf.Mapper.PointSpan(pp)
   231  	if err != nil {
   232  		return nil, err
   233  	}
   234  	rng, err := spn.Range(pgf.Mapper.Converter)
   235  	if err != nil {
   236  		return nil, err
   237  	}
   238  	offset, err := Offset(pgf.Tok, rng.Start)
   239  	if err != nil {
   240  		return nil, err
   241  	}
   242  	return qualifiedObjsAtLocation(ctx, s, objSearchKey{uri, offset}, map[objSearchKey]bool{})
   243  }
   244  
   245  type objSearchKey struct {
   246  	uri    span.URI
   247  	offset int
   248  }
   249  
   250  // qualifiedObjsAtLocation finds all objects referenced at offset in uri, across
   251  // all packages in the snapshot.
   252  func qualifiedObjsAtLocation(ctx context.Context, s Snapshot, key objSearchKey, seen map[objSearchKey]bool) ([]qualifiedObject, error) {
   253  	if seen[key] {
   254  		return nil, nil
   255  	}
   256  	seen[key] = true
   257  
   258  	// We search for referenced objects starting with all packages containing the
   259  	// current location, and then repeating the search for every distinct object
   260  	// location discovered.
   261  	//
   262  	// In the common case, there should be at most one additional location to
   263  	// consider: the definition of the object referenced by the location. But we
   264  	// try to be comprehensive in case we ever support variations on build
   265  	// constraints.
   266  
   267  	pkgs, err := s.PackagesForFile(ctx, key.uri, TypecheckAll, false)
   268  	if err != nil {
   269  		return nil, err
   270  	}
   271  
   272  	// report objects in the order we encounter them. This ensures that the first
   273  	// result is at the cursor...
   274  	var qualifiedObjs []qualifiedObject
   275  	// ...but avoid duplicates.
   276  	seenObjs := map[types.Object]bool{}
   277  
   278  	for _, searchpkg := range pkgs {
   279  		pgf, err := searchpkg.File(key.uri)
   280  		if err != nil {
   281  			return nil, err
   282  		}
   283  		pos := pgf.Tok.Pos(key.offset)
   284  		path := pathEnclosingObjNode(pgf.File, pos)
   285  		if path == nil {
   286  			continue
   287  		}
   288  		var objs []types.Object
   289  		switch leaf := path[0].(type) {
   290  		case *ast.Ident:
   291  			// If leaf represents an implicit type switch object or the type
   292  			// switch "assign" variable, expand to all of the type switch's
   293  			// implicit objects.
   294  			if implicits, _ := typeSwitchImplicits(searchpkg, path); len(implicits) > 0 {
   295  				objs = append(objs, implicits...)
   296  			} else {
   297  				obj := searchpkg.GetTypesInfo().ObjectOf(leaf)
   298  				if obj == nil {
   299  					return nil, xerrors.Errorf("%w for %q", errNoObjectFound, leaf.Name)
   300  				}
   301  				objs = append(objs, obj)
   302  			}
   303  		case *ast.ImportSpec:
   304  			// Look up the implicit *types.PkgName.
   305  			obj := searchpkg.GetTypesInfo().Implicits[leaf]
   306  			if obj == nil {
   307  				return nil, xerrors.Errorf("%w for import %q", errNoObjectFound, ImportPath(leaf))
   308  			}
   309  			objs = append(objs, obj)
   310  		}
   311  		// Get all of the transitive dependencies of the search package.
   312  		pkgs := make(map[*types.Package]Package)
   313  		var addPkg func(pkg Package)
   314  		addPkg = func(pkg Package) {
   315  			pkgs[pkg.GetTypes()] = pkg
   316  			for _, imp := range pkg.Imports() {
   317  				if _, ok := pkgs[imp.GetTypes()]; !ok {
   318  					addPkg(imp)
   319  				}
   320  			}
   321  		}
   322  		addPkg(searchpkg)
   323  		for _, obj := range objs {
   324  			if obj.Parent() == types.Universe {
   325  				return nil, xerrors.Errorf("%q: %w", obj.Name(), errBuiltin)
   326  			}
   327  			pkg, ok := pkgs[obj.Pkg()]
   328  			if !ok {
   329  				event.Error(ctx, fmt.Sprintf("no package for obj %s: %v", obj, obj.Pkg()), err)
   330  				continue
   331  			}
   332  			qualifiedObjs = append(qualifiedObjs, qualifiedObject{
   333  				obj:       obj,
   334  				pkg:       pkg,
   335  				sourcePkg: searchpkg,
   336  				node:      path[0],
   337  			})
   338  			seenObjs[obj] = true
   339  
   340  			// If the qualified object is in another file (or more likely, another
   341  			// package), it's possible that there is another copy of it in a package
   342  			// that we haven't searched, e.g. a test variant. See golang/go#47564.
   343  			//
   344  			// In order to be sure we've considered all packages, call
   345  			// qualifiedObjsAtLocation recursively for all locations we encounter. We
   346  			// could probably be more precise here, only continuing the search if obj
   347  			// is in another package, but this should be good enough to find all
   348  			// uses.
   349  
   350  			pos := obj.Pos()
   351  			var uri span.URI
   352  			offset := -1
   353  			for _, pgf := range pkg.CompiledGoFiles() {
   354  				if pgf.Tok.Base() <= int(pos) && int(pos) <= pgf.Tok.Base()+pgf.Tok.Size() {
   355  					var err error
   356  					offset, err = Offset(pgf.Tok, pos)
   357  					if err != nil {
   358  						return nil, err
   359  					}
   360  					uri = pgf.URI
   361  				}
   362  			}
   363  			if offset >= 0 {
   364  				otherObjs, err := qualifiedObjsAtLocation(ctx, s, objSearchKey{uri, offset}, seen)
   365  				if err != nil {
   366  					return nil, err
   367  				}
   368  				for _, other := range otherObjs {
   369  					if !seenObjs[other.obj] {
   370  						qualifiedObjs = append(qualifiedObjs, other)
   371  						seenObjs[other.obj] = true
   372  					}
   373  				}
   374  			} else {
   375  				return nil, fmt.Errorf("missing file for position of %q in %q", obj.Name(), obj.Pkg().Name())
   376  			}
   377  		}
   378  	}
   379  	// Return an error if no objects were found since callers will assume that
   380  	// the slice has at least 1 element.
   381  	if len(qualifiedObjs) == 0 {
   382  		return nil, errNoObjectFound
   383  	}
   384  	return qualifiedObjs, nil
   385  }
   386  
   387  // pathEnclosingObjNode returns the AST path to the object-defining
   388  // node associated with pos. "Object-defining" means either an
   389  // *ast.Ident mapped directly to a types.Object or an ast.Node mapped
   390  // implicitly to a types.Object.
   391  func pathEnclosingObjNode(f *ast.File, pos token.Pos) []ast.Node {
   392  	var (
   393  		path  []ast.Node
   394  		found bool
   395  	)
   396  
   397  	ast.Inspect(f, func(n ast.Node) bool {
   398  		if found {
   399  			return false
   400  		}
   401  
   402  		if n == nil {
   403  			path = path[:len(path)-1]
   404  			return false
   405  		}
   406  
   407  		path = append(path, n)
   408  
   409  		switch n := n.(type) {
   410  		case *ast.Ident:
   411  			// Include the position directly after identifier. This handles
   412  			// the common case where the cursor is right after the
   413  			// identifier the user is currently typing. Previously we
   414  			// handled this by calling astutil.PathEnclosingInterval twice,
   415  			// once for "pos" and once for "pos-1".
   416  			found = n.Pos() <= pos && pos <= n.End()
   417  		case *ast.ImportSpec:
   418  			if n.Path.Pos() <= pos && pos < n.Path.End() {
   419  				found = true
   420  				// If import spec has a name, add name to path even though
   421  				// position isn't in the name.
   422  				if n.Name != nil {
   423  					path = append(path, n.Name)
   424  				}
   425  			}
   426  		case *ast.StarExpr:
   427  			// Follow star expressions to the inner identifier.
   428  			if pos == n.Star {
   429  				pos = n.X.Pos()
   430  			}
   431  		}
   432  
   433  		return !found
   434  	})
   435  
   436  	if len(path) == 0 {
   437  		return nil
   438  	}
   439  
   440  	// Reverse path so leaf is first element.
   441  	for i := 0; i < len(path)/2; i++ {
   442  		path[i], path[len(path)-1-i] = path[len(path)-1-i], path[i]
   443  	}
   444  
   445  	return path
   446  }