github.com/jhump/golang-x-tools@v0.0.0-20220218190644-4958d6d39439/internal/lsp/source/call_hierarchy.go (about)

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package source
     6  
     7  import (
     8  	"context"
     9  	"fmt"
    10  	"go/ast"
    11  	"go/token"
    12  	"go/types"
    13  	"path/filepath"
    14  
    15  	"github.com/jhump/golang-x-tools/go/ast/astutil"
    16  	"github.com/jhump/golang-x-tools/internal/event"
    17  	"github.com/jhump/golang-x-tools/internal/lsp/debug/tag"
    18  	"github.com/jhump/golang-x-tools/internal/lsp/protocol"
    19  	"github.com/jhump/golang-x-tools/internal/span"
    20  	errors "golang.org/x/xerrors"
    21  )
    22  
    23  // PrepareCallHierarchy returns an array of CallHierarchyItem for a file and the position within the file.
    24  func PrepareCallHierarchy(ctx context.Context, snapshot Snapshot, fh FileHandle, pos protocol.Position) ([]protocol.CallHierarchyItem, error) {
    25  	ctx, done := event.Start(ctx, "source.PrepareCallHierarchy")
    26  	defer done()
    27  
    28  	identifier, err := Identifier(ctx, snapshot, fh, pos)
    29  	if err != nil {
    30  		if errors.Is(err, ErrNoIdentFound) || errors.Is(err, errNoObjectFound) {
    31  			return nil, nil
    32  		}
    33  		return nil, err
    34  	}
    35  
    36  	// The identifier can be nil if it is an import spec.
    37  	if identifier == nil || identifier.Declaration.obj == nil {
    38  		return nil, nil
    39  	}
    40  
    41  	if _, ok := identifier.Declaration.obj.Type().Underlying().(*types.Signature); !ok {
    42  		return nil, nil
    43  	}
    44  
    45  	if len(identifier.Declaration.MappedRange) == 0 {
    46  		return nil, nil
    47  	}
    48  	declMappedRange := identifier.Declaration.MappedRange[0]
    49  	rng, err := declMappedRange.Range()
    50  	if err != nil {
    51  		return nil, err
    52  	}
    53  
    54  	callHierarchyItem := protocol.CallHierarchyItem{
    55  		Name:           identifier.Name,
    56  		Kind:           protocol.Function,
    57  		Tags:           []protocol.SymbolTag{},
    58  		Detail:         fmt.Sprintf("%s • %s", identifier.Declaration.obj.Pkg().Path(), filepath.Base(declMappedRange.URI().Filename())),
    59  		URI:            protocol.DocumentURI(declMappedRange.URI()),
    60  		Range:          rng,
    61  		SelectionRange: rng,
    62  	}
    63  	return []protocol.CallHierarchyItem{callHierarchyItem}, nil
    64  }
    65  
    66  // IncomingCalls returns an array of CallHierarchyIncomingCall for a file and the position within the file.
    67  func IncomingCalls(ctx context.Context, snapshot Snapshot, fh FileHandle, pos protocol.Position) ([]protocol.CallHierarchyIncomingCall, error) {
    68  	ctx, done := event.Start(ctx, "source.IncomingCalls")
    69  	defer done()
    70  
    71  	refs, err := References(ctx, snapshot, fh, pos, false)
    72  	if err != nil {
    73  		if errors.Is(err, ErrNoIdentFound) || errors.Is(err, errNoObjectFound) {
    74  			return nil, nil
    75  		}
    76  		return nil, err
    77  	}
    78  
    79  	return toProtocolIncomingCalls(ctx, snapshot, refs)
    80  }
    81  
    82  // toProtocolIncomingCalls returns an array of protocol.CallHierarchyIncomingCall for ReferenceInfo's.
    83  // References inside same enclosure are assigned to the same enclosing function.
    84  func toProtocolIncomingCalls(ctx context.Context, snapshot Snapshot, refs []*ReferenceInfo) ([]protocol.CallHierarchyIncomingCall, error) {
    85  	// an enclosing node could have multiple calls to a reference, we only show the enclosure
    86  	// once in the result but highlight all calls using FromRanges (ranges at which the calls occur)
    87  	var incomingCalls = map[protocol.Location]*protocol.CallHierarchyIncomingCall{}
    88  	for _, ref := range refs {
    89  		refRange, err := ref.Range()
    90  		if err != nil {
    91  			return nil, err
    92  		}
    93  
    94  		callItem, err := enclosingNodeCallItem(snapshot, ref.pkg, ref.URI(), ref.ident.NamePos)
    95  		if err != nil {
    96  			event.Error(ctx, "error getting enclosing node", err, tag.Method.Of(ref.Name))
    97  			continue
    98  		}
    99  		loc := protocol.Location{
   100  			URI:   callItem.URI,
   101  			Range: callItem.Range,
   102  		}
   103  
   104  		if incomingCall, ok := incomingCalls[loc]; ok {
   105  			incomingCall.FromRanges = append(incomingCall.FromRanges, refRange)
   106  			continue
   107  		}
   108  		incomingCalls[loc] = &protocol.CallHierarchyIncomingCall{
   109  			From:       callItem,
   110  			FromRanges: []protocol.Range{refRange},
   111  		}
   112  	}
   113  
   114  	incomingCallItems := make([]protocol.CallHierarchyIncomingCall, 0, len(incomingCalls))
   115  	for _, callItem := range incomingCalls {
   116  		incomingCallItems = append(incomingCallItems, *callItem)
   117  	}
   118  	return incomingCallItems, nil
   119  }
   120  
   121  // enclosingNodeCallItem creates a CallHierarchyItem representing the function call at pos
   122  func enclosingNodeCallItem(snapshot Snapshot, pkg Package, uri span.URI, pos token.Pos) (protocol.CallHierarchyItem, error) {
   123  	pgf, err := pkg.File(uri)
   124  	if err != nil {
   125  		return protocol.CallHierarchyItem{}, err
   126  	}
   127  
   128  	var funcDecl *ast.FuncDecl
   129  	var funcLit *ast.FuncLit // innermost function literal
   130  	var litCount int
   131  	// Find the enclosing function, if any, and the number of func literals in between.
   132  	path, _ := astutil.PathEnclosingInterval(pgf.File, pos, pos)
   133  outer:
   134  	for _, node := range path {
   135  		switch n := node.(type) {
   136  		case *ast.FuncDecl:
   137  			funcDecl = n
   138  			break outer
   139  		case *ast.FuncLit:
   140  			litCount++
   141  			if litCount > 1 {
   142  				continue
   143  			}
   144  			funcLit = n
   145  		}
   146  	}
   147  
   148  	nameIdent := path[len(path)-1].(*ast.File).Name
   149  	kind := protocol.Package
   150  	if funcDecl != nil {
   151  		nameIdent = funcDecl.Name
   152  		kind = protocol.Function
   153  	}
   154  
   155  	nameStart, nameEnd := nameIdent.NamePos, nameIdent.NamePos+token.Pos(len(nameIdent.Name))
   156  	if funcLit != nil {
   157  		nameStart, nameEnd = funcLit.Type.Func, funcLit.Type.Params.Pos()
   158  		kind = protocol.Function
   159  	}
   160  	rng, err := NewMappedRange(snapshot.FileSet(), pgf.Mapper, nameStart, nameEnd).Range()
   161  	if err != nil {
   162  		return protocol.CallHierarchyItem{}, err
   163  	}
   164  
   165  	name := nameIdent.Name
   166  	for i := 0; i < litCount; i++ {
   167  		name += ".func()"
   168  	}
   169  
   170  	return protocol.CallHierarchyItem{
   171  		Name:           name,
   172  		Kind:           kind,
   173  		Tags:           []protocol.SymbolTag{},
   174  		Detail:         fmt.Sprintf("%s • %s", pkg.PkgPath(), filepath.Base(uri.Filename())),
   175  		URI:            protocol.DocumentURI(uri),
   176  		Range:          rng,
   177  		SelectionRange: rng,
   178  	}, nil
   179  }
   180  
   181  // OutgoingCalls returns an array of CallHierarchyOutgoingCall for a file and the position within the file.
   182  func OutgoingCalls(ctx context.Context, snapshot Snapshot, fh FileHandle, pos protocol.Position) ([]protocol.CallHierarchyOutgoingCall, error) {
   183  	ctx, done := event.Start(ctx, "source.OutgoingCalls")
   184  	defer done()
   185  
   186  	identifier, err := Identifier(ctx, snapshot, fh, pos)
   187  	if err != nil {
   188  		if errors.Is(err, ErrNoIdentFound) || errors.Is(err, errNoObjectFound) {
   189  			return nil, nil
   190  		}
   191  		return nil, err
   192  	}
   193  
   194  	if _, ok := identifier.Declaration.obj.Type().Underlying().(*types.Signature); !ok {
   195  		return nil, nil
   196  	}
   197  	if identifier.Declaration.node == nil {
   198  		return nil, nil
   199  	}
   200  	if len(identifier.Declaration.MappedRange) == 0 {
   201  		return nil, nil
   202  	}
   203  	declMappedRange := identifier.Declaration.MappedRange[0]
   204  	callExprs, err := collectCallExpressions(snapshot.FileSet(), declMappedRange.m, identifier.Declaration.node)
   205  	if err != nil {
   206  		return nil, err
   207  	}
   208  
   209  	return toProtocolOutgoingCalls(ctx, snapshot, fh, callExprs)
   210  }
   211  
   212  // collectCallExpressions collects call expression ranges inside a function.
   213  func collectCallExpressions(fset *token.FileSet, mapper *protocol.ColumnMapper, node ast.Node) ([]protocol.Range, error) {
   214  	type callPos struct {
   215  		start, end token.Pos
   216  	}
   217  	callPositions := []callPos{}
   218  
   219  	ast.Inspect(node, func(n ast.Node) bool {
   220  		if call, ok := n.(*ast.CallExpr); ok {
   221  			var start, end token.Pos
   222  			switch n := call.Fun.(type) {
   223  			case *ast.SelectorExpr:
   224  				start, end = n.Sel.NamePos, call.Lparen
   225  			case *ast.Ident:
   226  				start, end = n.NamePos, call.Lparen
   227  			case *ast.FuncLit:
   228  				// while we don't add the function literal as an 'outgoing' call
   229  				// we still want to traverse into it
   230  				return true
   231  			default:
   232  				// ignore any other kind of call expressions
   233  				// for ex: direct function literal calls since that's not an 'outgoing' call
   234  				return false
   235  			}
   236  			callPositions = append(callPositions, callPos{start: start, end: end})
   237  		}
   238  		return true
   239  	})
   240  
   241  	callRanges := []protocol.Range{}
   242  	for _, call := range callPositions {
   243  		callRange, err := NewMappedRange(fset, mapper, call.start, call.end).Range()
   244  		if err != nil {
   245  			return nil, err
   246  		}
   247  		callRanges = append(callRanges, callRange)
   248  	}
   249  	return callRanges, nil
   250  }
   251  
   252  // toProtocolOutgoingCalls returns an array of protocol.CallHierarchyOutgoingCall for ast call expressions.
   253  // Calls to the same function are assigned to the same declaration.
   254  func toProtocolOutgoingCalls(ctx context.Context, snapshot Snapshot, fh FileHandle, callRanges []protocol.Range) ([]protocol.CallHierarchyOutgoingCall, error) {
   255  	// Multiple calls could be made to the same function, defined by "same declaration
   256  	// AST node & same idenfitier name" to provide a unique identifier key even when
   257  	// the func is declared in a struct or interface.
   258  	type key struct {
   259  		decl ast.Node
   260  		name string
   261  	}
   262  	outgoingCalls := map[key]*protocol.CallHierarchyOutgoingCall{}
   263  	for _, callRange := range callRanges {
   264  		identifier, err := Identifier(ctx, snapshot, fh, callRange.Start)
   265  		if err != nil {
   266  			if errors.Is(err, ErrNoIdentFound) || errors.Is(err, errNoObjectFound) {
   267  				continue
   268  			}
   269  			return nil, err
   270  		}
   271  
   272  		// ignore calls to builtin functions
   273  		if identifier.Declaration.obj.Pkg() == nil {
   274  			continue
   275  		}
   276  
   277  		if outgoingCall, ok := outgoingCalls[key{identifier.Declaration.node, identifier.Name}]; ok {
   278  			outgoingCall.FromRanges = append(outgoingCall.FromRanges, callRange)
   279  			continue
   280  		}
   281  
   282  		if len(identifier.Declaration.MappedRange) == 0 {
   283  			continue
   284  		}
   285  		declMappedRange := identifier.Declaration.MappedRange[0]
   286  		rng, err := declMappedRange.Range()
   287  		if err != nil {
   288  			return nil, err
   289  		}
   290  
   291  		outgoingCalls[key{identifier.Declaration.node, identifier.Name}] = &protocol.CallHierarchyOutgoingCall{
   292  			To: protocol.CallHierarchyItem{
   293  				Name:           identifier.Name,
   294  				Kind:           protocol.Function,
   295  				Tags:           []protocol.SymbolTag{},
   296  				Detail:         fmt.Sprintf("%s • %s", identifier.Declaration.obj.Pkg().Path(), filepath.Base(declMappedRange.URI().Filename())),
   297  				URI:            protocol.DocumentURI(declMappedRange.URI()),
   298  				Range:          rng,
   299  				SelectionRange: rng,
   300  			},
   301  			FromRanges: []protocol.Range{callRange},
   302  		}
   303  	}
   304  
   305  	outgoingCallItems := make([]protocol.CallHierarchyOutgoingCall, 0, len(outgoingCalls))
   306  	for _, callItem := range outgoingCalls {
   307  		outgoingCallItems = append(outgoingCallItems, *callItem)
   308  	}
   309  	return outgoingCallItems, nil
   310  }