golang.org/x/tools/gopls@v0.15.3/internal/golang/call_hierarchy.go (about)

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package golang
     6  
     7  import (
     8  	"context"
     9  	"errors"
    10  	"fmt"
    11  	"go/ast"
    12  	"go/token"
    13  	"go/types"
    14  	"path/filepath"
    15  
    16  	"golang.org/x/tools/go/ast/astutil"
    17  	"golang.org/x/tools/gopls/internal/cache"
    18  	"golang.org/x/tools/gopls/internal/file"
    19  	"golang.org/x/tools/gopls/internal/protocol"
    20  	"golang.org/x/tools/gopls/internal/util/bug"
    21  	"golang.org/x/tools/gopls/internal/util/safetoken"
    22  	"golang.org/x/tools/internal/event"
    23  	"golang.org/x/tools/internal/event/tag"
    24  )
    25  
    26  // PrepareCallHierarchy returns an array of CallHierarchyItem for a file and the position within the file.
    27  func PrepareCallHierarchy(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, pp protocol.Position) ([]protocol.CallHierarchyItem, error) {
    28  	ctx, done := event.Start(ctx, "golang.PrepareCallHierarchy")
    29  	defer done()
    30  
    31  	pkg, pgf, err := NarrowestPackageForFile(ctx, snapshot, fh.URI())
    32  	if err != nil {
    33  		return nil, err
    34  	}
    35  	pos, err := pgf.PositionPos(pp)
    36  	if err != nil {
    37  		return nil, err
    38  	}
    39  
    40  	_, obj, _ := referencedObject(pkg, pgf, pos)
    41  	if obj == nil {
    42  		return nil, nil
    43  	}
    44  
    45  	if _, ok := obj.Type().Underlying().(*types.Signature); !ok {
    46  		return nil, nil
    47  	}
    48  
    49  	declLoc, err := mapPosition(ctx, pkg.FileSet(), snapshot, obj.Pos(), adjustedObjEnd(obj))
    50  	if err != nil {
    51  		return nil, err
    52  	}
    53  	rng := declLoc.Range
    54  
    55  	callHierarchyItem := protocol.CallHierarchyItem{
    56  		Name:           obj.Name(),
    57  		Kind:           protocol.Function,
    58  		Tags:           []protocol.SymbolTag{},
    59  		Detail:         fmt.Sprintf("%s • %s", obj.Pkg().Path(), filepath.Base(declLoc.URI.Path())),
    60  		URI:            declLoc.URI,
    61  		Range:          rng,
    62  		SelectionRange: rng,
    63  	}
    64  	return []protocol.CallHierarchyItem{callHierarchyItem}, nil
    65  }
    66  
    67  // IncomingCalls returns an array of CallHierarchyIncomingCall for a file and the position within the file.
    68  func IncomingCalls(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, pos protocol.Position) ([]protocol.CallHierarchyIncomingCall, error) {
    69  	ctx, done := event.Start(ctx, "golang.IncomingCalls")
    70  	defer done()
    71  
    72  	refs, err := references(ctx, snapshot, fh, pos, false)
    73  	if err != nil {
    74  		if errors.Is(err, ErrNoIdentFound) || errors.Is(err, errNoObjectFound) {
    75  			return nil, nil
    76  		}
    77  		return nil, err
    78  	}
    79  
    80  	// Group references by their enclosing function declaration.
    81  	incomingCalls := make(map[protocol.Location]*protocol.CallHierarchyIncomingCall)
    82  	for _, ref := range refs {
    83  		callItem, err := enclosingNodeCallItem(ctx, snapshot, ref.pkgPath, ref.location)
    84  		if err != nil {
    85  			event.Error(ctx, "error getting enclosing node", err, tag.Method.Of(string(ref.pkgPath)))
    86  			continue
    87  		}
    88  		loc := protocol.Location{
    89  			URI:   callItem.URI,
    90  			Range: callItem.Range,
    91  		}
    92  		call, ok := incomingCalls[loc]
    93  		if !ok {
    94  			call = &protocol.CallHierarchyIncomingCall{From: callItem}
    95  			incomingCalls[loc] = call
    96  		}
    97  		call.FromRanges = append(call.FromRanges, ref.location.Range)
    98  	}
    99  
   100  	// Flatten the map of pointers into a slice of values.
   101  	incomingCallItems := make([]protocol.CallHierarchyIncomingCall, 0, len(incomingCalls))
   102  	for _, callItem := range incomingCalls {
   103  		incomingCallItems = append(incomingCallItems, *callItem)
   104  	}
   105  	return incomingCallItems, nil
   106  }
   107  
   108  // enclosingNodeCallItem creates a CallHierarchyItem representing the function call at loc.
   109  func enclosingNodeCallItem(ctx context.Context, snapshot *cache.Snapshot, pkgPath PackagePath, loc protocol.Location) (protocol.CallHierarchyItem, error) {
   110  	// Parse the file containing the reference.
   111  	fh, err := snapshot.ReadFile(ctx, loc.URI)
   112  	if err != nil {
   113  		return protocol.CallHierarchyItem{}, err
   114  	}
   115  	// TODO(adonovan): opt: before parsing, trim the bodies of functions
   116  	// that don't contain the reference, using either a scanner-based
   117  	// implementation such as https://go.dev/play/p/KUrObH1YkX8
   118  	// (~31% speedup), or a byte-oriented implementation (2x speedup).
   119  	pgf, err := snapshot.ParseGo(ctx, fh, ParseFull)
   120  	if err != nil {
   121  		return protocol.CallHierarchyItem{}, err
   122  	}
   123  	start, end, err := pgf.RangePos(loc.Range)
   124  	if err != nil {
   125  		return protocol.CallHierarchyItem{}, err
   126  	}
   127  
   128  	// Find the enclosing function, if any, and the number of func literals in between.
   129  	var funcDecl *ast.FuncDecl
   130  	var funcLit *ast.FuncLit // innermost function literal
   131  	var litCount int
   132  	path, _ := astutil.PathEnclosingInterval(pgf.File, start, end)
   133  outer:
   134  	for _, node := range path {
   135  		switch n := node.(type) {
   136  		case *ast.FuncDecl:
   137  			funcDecl = n
   138  			break outer
   139  		case *ast.FuncLit:
   140  			litCount++
   141  			if litCount > 1 {
   142  				continue
   143  			}
   144  			funcLit = n
   145  		}
   146  	}
   147  
   148  	nameIdent := path[len(path)-1].(*ast.File).Name
   149  	kind := protocol.Package
   150  	if funcDecl != nil {
   151  		nameIdent = funcDecl.Name
   152  		kind = protocol.Function
   153  	}
   154  
   155  	nameStart, nameEnd := nameIdent.Pos(), nameIdent.End()
   156  	if funcLit != nil {
   157  		nameStart, nameEnd = funcLit.Type.Func, funcLit.Type.Params.Pos()
   158  		kind = protocol.Function
   159  	}
   160  	rng, err := pgf.PosRange(nameStart, nameEnd)
   161  	if err != nil {
   162  		return protocol.CallHierarchyItem{}, err
   163  	}
   164  
   165  	name := nameIdent.Name
   166  	for i := 0; i < litCount; i++ {
   167  		name += ".func()"
   168  	}
   169  
   170  	return protocol.CallHierarchyItem{
   171  		Name:           name,
   172  		Kind:           kind,
   173  		Tags:           []protocol.SymbolTag{},
   174  		Detail:         fmt.Sprintf("%s • %s", pkgPath, filepath.Base(fh.URI().Path())),
   175  		URI:            loc.URI,
   176  		Range:          rng,
   177  		SelectionRange: rng,
   178  	}, nil
   179  }
   180  
   181  // OutgoingCalls returns an array of CallHierarchyOutgoingCall for a file and the position within the file.
   182  func OutgoingCalls(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, pp protocol.Position) ([]protocol.CallHierarchyOutgoingCall, error) {
   183  	ctx, done := event.Start(ctx, "golang.OutgoingCalls")
   184  	defer done()
   185  
   186  	pkg, pgf, err := NarrowestPackageForFile(ctx, snapshot, fh.URI())
   187  	if err != nil {
   188  		return nil, err
   189  	}
   190  	pos, err := pgf.PositionPos(pp)
   191  	if err != nil {
   192  		return nil, err
   193  	}
   194  
   195  	_, obj, _ := referencedObject(pkg, pgf, pos)
   196  	if obj == nil {
   197  		return nil, nil
   198  	}
   199  
   200  	if _, ok := obj.Type().Underlying().(*types.Signature); !ok {
   201  		return nil, nil
   202  	}
   203  
   204  	// Skip builtins.
   205  	if obj.Pkg() == nil {
   206  		return nil, nil
   207  	}
   208  
   209  	if !obj.Pos().IsValid() {
   210  		return nil, bug.Errorf("internal error: object %s.%s missing position", obj.Pkg().Path(), obj.Name())
   211  	}
   212  
   213  	declFile := pkg.FileSet().File(obj.Pos())
   214  	if declFile == nil {
   215  		return nil, bug.Errorf("file not found for %d", obj.Pos())
   216  	}
   217  
   218  	uri := protocol.URIFromPath(declFile.Name())
   219  	offset, err := safetoken.Offset(declFile, obj.Pos())
   220  	if err != nil {
   221  		return nil, err
   222  	}
   223  
   224  	// Use TypecheckFull as we want to inspect the body of the function declaration.
   225  	declPkg, declPGF, err := NarrowestPackageForFile(ctx, snapshot, uri)
   226  	if err != nil {
   227  		return nil, err
   228  	}
   229  
   230  	declPos, err := safetoken.Pos(declPGF.Tok, offset)
   231  	if err != nil {
   232  		return nil, err
   233  	}
   234  
   235  	declNode, _, _ := findDeclInfo([]*ast.File{declPGF.File}, declPos)
   236  	if declNode == nil {
   237  		// TODO(rfindley): why don't we return an error here, or even bug.Errorf?
   238  		return nil, nil
   239  		// return nil, bug.Errorf("failed to find declaration for object %s.%s", obj.Pkg().Path(), obj.Name())
   240  	}
   241  
   242  	type callRange struct {
   243  		start, end token.Pos
   244  	}
   245  	callRanges := []callRange{}
   246  	ast.Inspect(declNode, func(n ast.Node) bool {
   247  		if call, ok := n.(*ast.CallExpr); ok {
   248  			var start, end token.Pos
   249  			switch n := call.Fun.(type) {
   250  			case *ast.SelectorExpr:
   251  				start, end = n.Sel.NamePos, call.Lparen
   252  			case *ast.Ident:
   253  				start, end = n.NamePos, call.Lparen
   254  			case *ast.FuncLit:
   255  				// while we don't add the function literal as an 'outgoing' call
   256  				// we still want to traverse into it
   257  				return true
   258  			default:
   259  				// ignore any other kind of call expressions
   260  				// for ex: direct function literal calls since that's not an 'outgoing' call
   261  				return false
   262  			}
   263  			callRanges = append(callRanges, callRange{start: start, end: end})
   264  		}
   265  		return true
   266  	})
   267  
   268  	outgoingCalls := map[token.Pos]*protocol.CallHierarchyOutgoingCall{}
   269  	for _, callRange := range callRanges {
   270  		_, obj, _ := referencedObject(declPkg, declPGF, callRange.start)
   271  		if obj == nil {
   272  			continue
   273  		}
   274  
   275  		// ignore calls to builtin functions
   276  		if obj.Pkg() == nil {
   277  			continue
   278  		}
   279  
   280  		outgoingCall, ok := outgoingCalls[obj.Pos()]
   281  		if !ok {
   282  			loc, err := mapPosition(ctx, declPkg.FileSet(), snapshot, obj.Pos(), obj.Pos()+token.Pos(len(obj.Name())))
   283  			if err != nil {
   284  				return nil, err
   285  			}
   286  			outgoingCall = &protocol.CallHierarchyOutgoingCall{
   287  				To: protocol.CallHierarchyItem{
   288  					Name:           obj.Name(),
   289  					Kind:           protocol.Function,
   290  					Tags:           []protocol.SymbolTag{},
   291  					Detail:         fmt.Sprintf("%s • %s", obj.Pkg().Path(), filepath.Base(loc.URI.Path())),
   292  					URI:            loc.URI,
   293  					Range:          loc.Range,
   294  					SelectionRange: loc.Range,
   295  				},
   296  			}
   297  			outgoingCalls[obj.Pos()] = outgoingCall
   298  		}
   299  
   300  		rng, err := declPGF.PosRange(callRange.start, callRange.end)
   301  		if err != nil {
   302  			return nil, err
   303  		}
   304  		outgoingCall.FromRanges = append(outgoingCall.FromRanges, rng)
   305  	}
   306  
   307  	outgoingCallItems := make([]protocol.CallHierarchyOutgoingCall, 0, len(outgoingCalls))
   308  	for _, callItem := range outgoingCalls {
   309  		outgoingCallItems = append(outgoingCallItems, *callItem)
   310  	}
   311  	return outgoingCallItems, nil
   312  }