github.com/jhump/golang-x-tools@v0.0.0-20220218190644-4958d6d39439/internal/lsp/source/references.go (about)

     1  // Copyright 2019 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package source
     6  
     7  import (
     8  	"context"
     9  	"fmt"
    10  	"go/ast"
    11  	"go/token"
    12  	"go/types"
    13  	"sort"
    14  
    15  	"github.com/jhump/golang-x-tools/internal/event"
    16  	"github.com/jhump/golang-x-tools/internal/lsp/protocol"
    17  	"github.com/jhump/golang-x-tools/internal/span"
    18  	errors "golang.org/x/xerrors"
    19  )
    20  
    21  // ReferenceInfo holds information about reference to an identifier in Go source.
    22  type ReferenceInfo struct {
    23  	Name string
    24  	MappedRange
    25  	ident         *ast.Ident
    26  	obj           types.Object
    27  	pkg           Package
    28  	isDeclaration bool
    29  }
    30  
    31  // References returns a list of references for a given identifier within the packages
    32  // containing i.File. Declarations appear first in the result.
    33  func References(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position, includeDeclaration bool) ([]*ReferenceInfo, error) {
    34  	ctx, done := event.Start(ctx, "source.References")
    35  	defer done()
    36  
    37  	qualifiedObjs, err := qualifiedObjsAtProtocolPos(ctx, s, f.URI(), pp)
    38  	// Don't return references for builtin types.
    39  	if errors.Is(err, errBuiltin) {
    40  		return nil, nil
    41  	}
    42  	if err != nil {
    43  		return nil, err
    44  	}
    45  
    46  	refs, err := references(ctx, s, qualifiedObjs, includeDeclaration, true, false)
    47  	if err != nil {
    48  		return nil, err
    49  	}
    50  
    51  	toSort := refs
    52  	if includeDeclaration {
    53  		toSort = refs[1:]
    54  	}
    55  	sort.Slice(toSort, func(i, j int) bool {
    56  		x := CompareURI(toSort[i].URI(), toSort[j].URI())
    57  		if x == 0 {
    58  			return toSort[i].ident.Pos() < toSort[j].ident.Pos()
    59  		}
    60  		return x < 0
    61  	})
    62  	return refs, nil
    63  }
    64  
    65  // references is a helper function to avoid recomputing qualifiedObjsAtProtocolPos.
    66  func references(ctx context.Context, snapshot Snapshot, qos []qualifiedObject, includeDeclaration, includeInterfaceRefs, includeEmbeddedRefs bool) ([]*ReferenceInfo, error) {
    67  	var (
    68  		references []*ReferenceInfo
    69  		seen       = make(map[token.Pos]bool)
    70  	)
    71  
    72  	pos := qos[0].obj.Pos()
    73  	if pos == token.NoPos {
    74  		return nil, fmt.Errorf("no position for %s", qos[0].obj)
    75  	}
    76  	filename := snapshot.FileSet().Position(pos).Filename
    77  	pgf, err := qos[0].pkg.File(span.URIFromPath(filename))
    78  	if err != nil {
    79  		return nil, err
    80  	}
    81  	declIdent, err := findIdentifier(ctx, snapshot, qos[0].pkg, pgf, qos[0].obj.Pos())
    82  	if err != nil {
    83  		return nil, err
    84  	}
    85  	// Make sure declaration is the first item in the response.
    86  	if includeDeclaration {
    87  		references = append(references, &ReferenceInfo{
    88  			MappedRange:   declIdent.MappedRange,
    89  			Name:          qos[0].obj.Name(),
    90  			ident:         declIdent.ident,
    91  			obj:           qos[0].obj,
    92  			pkg:           declIdent.pkg,
    93  			isDeclaration: true,
    94  		})
    95  	}
    96  
    97  	for _, qo := range qos {
    98  		var searchPkgs []Package
    99  
   100  		// Only search dependents if the object is exported.
   101  		if qo.obj.Exported() {
   102  			reverseDeps, err := snapshot.GetReverseDependencies(ctx, qo.pkg.ID())
   103  			if err != nil {
   104  				return nil, err
   105  			}
   106  			searchPkgs = append(searchPkgs, reverseDeps...)
   107  		}
   108  		// Add the package in which the identifier is declared.
   109  		searchPkgs = append(searchPkgs, qo.pkg)
   110  		for _, pkg := range searchPkgs {
   111  			for ident, obj := range pkg.GetTypesInfo().Uses {
   112  				if obj != qo.obj {
   113  					// If ident is not a use of qo.obj, skip it, with one exception: uses
   114  					// of an embedded field can be considered references of the embedded
   115  					// type name.
   116  					if !includeEmbeddedRefs {
   117  						continue
   118  					}
   119  					v, ok := obj.(*types.Var)
   120  					if !ok || !v.Embedded() {
   121  						continue
   122  					}
   123  					named, ok := v.Type().(*types.Named)
   124  					if !ok || named.Obj() != qo.obj {
   125  						continue
   126  					}
   127  				}
   128  				if seen[ident.Pos()] {
   129  					continue
   130  				}
   131  				seen[ident.Pos()] = true
   132  				rng, err := posToMappedRange(snapshot, pkg, ident.Pos(), ident.End())
   133  				if err != nil {
   134  					return nil, err
   135  				}
   136  				references = append(references, &ReferenceInfo{
   137  					Name:        ident.Name,
   138  					ident:       ident,
   139  					pkg:         pkg,
   140  					obj:         obj,
   141  					MappedRange: rng,
   142  				})
   143  			}
   144  		}
   145  	}
   146  
   147  	// When searching on type name, don't include interface references -- they
   148  	// would be things like all references to Stringer for any type that
   149  	// happened to have a String method.
   150  	_, isType := declIdent.Declaration.obj.(*types.TypeName)
   151  	if includeInterfaceRefs && !isType {
   152  		declRange, err := declIdent.Range()
   153  		if err != nil {
   154  			return nil, err
   155  		}
   156  		fh, err := snapshot.GetFile(ctx, declIdent.URI())
   157  		if err != nil {
   158  			return nil, err
   159  		}
   160  		interfaceRefs, err := interfaceReferences(ctx, snapshot, fh, declRange.Start)
   161  		if err != nil {
   162  			return nil, err
   163  		}
   164  		references = append(references, interfaceRefs...)
   165  	}
   166  
   167  	return references, nil
   168  }
   169  
   170  // interfaceReferences returns the references to the interfaces implemented by
   171  // the type or method at the given position.
   172  func interfaceReferences(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position) ([]*ReferenceInfo, error) {
   173  	implementations, err := implementations(ctx, s, f, pp)
   174  	if err != nil {
   175  		if errors.Is(err, ErrNotAType) {
   176  			return nil, nil
   177  		}
   178  		return nil, err
   179  	}
   180  
   181  	var refs []*ReferenceInfo
   182  	for _, impl := range implementations {
   183  		implRefs, err := references(ctx, s, []qualifiedObject{impl}, false, false, false)
   184  		if err != nil {
   185  			return nil, err
   186  		}
   187  		refs = append(refs, implRefs...)
   188  	}
   189  	return refs, nil
   190  }