github.com/jd-ly/tools@v0.5.7/internal/lsp/source/call_hierarchy.go (about)

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package source
     6  
     7  import (
     8  	"context"
     9  	"fmt"
    10  	"go/ast"
    11  	"go/token"
    12  	"go/types"
    13  	"path/filepath"
    14  
    15  	"github.com/jd-ly/tools/go/ast/astutil"
    16  	"github.com/jd-ly/tools/internal/event"
    17  	"github.com/jd-ly/tools/internal/lsp/debug/tag"
    18  	"github.com/jd-ly/tools/internal/lsp/protocol"
    19  	"github.com/jd-ly/tools/internal/span"
    20  	errors "golang.org/x/xerrors"
    21  )
    22  
    23  // PrepareCallHierarchy returns an array of CallHierarchyItem for a file and the position within the file.
    24  func PrepareCallHierarchy(ctx context.Context, snapshot Snapshot, fh FileHandle, pos protocol.Position) ([]protocol.CallHierarchyItem, error) {
    25  	ctx, done := event.Start(ctx, "source.PrepareCallHierarchy")
    26  	defer done()
    27  
    28  	identifier, err := Identifier(ctx, snapshot, fh, pos)
    29  	if err != nil {
    30  		if errors.Is(err, ErrNoIdentFound) || errors.Is(err, errNoObjectFound) {
    31  			return nil, nil
    32  		}
    33  		return nil, err
    34  	}
    35  	// The identifier can be nil if it is an import spec.
    36  	if identifier == nil {
    37  		return nil, nil
    38  	}
    39  
    40  	if _, ok := identifier.Declaration.obj.Type().Underlying().(*types.Signature); !ok {
    41  		return nil, nil
    42  	}
    43  
    44  	if len(identifier.Declaration.MappedRange) == 0 {
    45  		return nil, nil
    46  	}
    47  	declMappedRange := identifier.Declaration.MappedRange[0]
    48  	rng, err := declMappedRange.Range()
    49  	if err != nil {
    50  		return nil, err
    51  	}
    52  
    53  	callHierarchyItem := protocol.CallHierarchyItem{
    54  		Name:           identifier.Name,
    55  		Kind:           protocol.Function,
    56  		Tags:           []protocol.SymbolTag{},
    57  		Detail:         fmt.Sprintf("%s • %s", identifier.Declaration.obj.Pkg().Path(), filepath.Base(declMappedRange.URI().Filename())),
    58  		URI:            protocol.DocumentURI(declMappedRange.URI()),
    59  		Range:          rng,
    60  		SelectionRange: rng,
    61  	}
    62  	return []protocol.CallHierarchyItem{callHierarchyItem}, nil
    63  }
    64  
    65  // IncomingCalls returns an array of CallHierarchyIncomingCall for a file and the position within the file.
    66  func IncomingCalls(ctx context.Context, snapshot Snapshot, fh FileHandle, pos protocol.Position) ([]protocol.CallHierarchyIncomingCall, error) {
    67  	ctx, done := event.Start(ctx, "source.IncomingCalls")
    68  	defer done()
    69  
    70  	refs, err := References(ctx, snapshot, fh, pos, false)
    71  	if err != nil {
    72  		if errors.Is(err, ErrNoIdentFound) || errors.Is(err, errNoObjectFound) {
    73  			return nil, nil
    74  		}
    75  		return nil, err
    76  	}
    77  
    78  	return toProtocolIncomingCalls(ctx, snapshot, refs)
    79  }
    80  
    81  // toProtocolIncomingCalls returns an array of protocol.CallHierarchyIncomingCall for ReferenceInfo's.
    82  // References inside same enclosure are assigned to the same enclosing function.
    83  func toProtocolIncomingCalls(ctx context.Context, snapshot Snapshot, refs []*ReferenceInfo) ([]protocol.CallHierarchyIncomingCall, error) {
    84  	// an enclosing node could have multiple calls to a reference, we only show the enclosure
    85  	// once in the result but highlight all calls using FromRanges (ranges at which the calls occur)
    86  	var incomingCalls = map[protocol.Location]*protocol.CallHierarchyIncomingCall{}
    87  	for _, ref := range refs {
    88  		refRange, err := ref.Range()
    89  		if err != nil {
    90  			return nil, err
    91  		}
    92  
    93  		callItem, err := enclosingNodeCallItem(snapshot, ref.pkg, ref.URI(), ref.ident.NamePos)
    94  		if err != nil {
    95  			event.Error(ctx, "error getting enclosing node", err, tag.Method.Of(ref.Name))
    96  			continue
    97  		}
    98  		loc := protocol.Location{
    99  			URI:   callItem.URI,
   100  			Range: callItem.Range,
   101  		}
   102  
   103  		if incomingCall, ok := incomingCalls[loc]; ok {
   104  			incomingCall.FromRanges = append(incomingCall.FromRanges, refRange)
   105  			continue
   106  		}
   107  		incomingCalls[loc] = &protocol.CallHierarchyIncomingCall{
   108  			From:       callItem,
   109  			FromRanges: []protocol.Range{refRange},
   110  		}
   111  	}
   112  
   113  	incomingCallItems := make([]protocol.CallHierarchyIncomingCall, 0, len(incomingCalls))
   114  	for _, callItem := range incomingCalls {
   115  		incomingCallItems = append(incomingCallItems, *callItem)
   116  	}
   117  	return incomingCallItems, nil
   118  }
   119  
   120  // enclosingNodeCallItem creates a CallHierarchyItem representing the function call at pos
   121  func enclosingNodeCallItem(snapshot Snapshot, pkg Package, uri span.URI, pos token.Pos) (protocol.CallHierarchyItem, error) {
   122  	pgf, err := pkg.File(uri)
   123  	if err != nil {
   124  		return protocol.CallHierarchyItem{}, err
   125  	}
   126  
   127  	var funcDecl *ast.FuncDecl
   128  	var funcLit *ast.FuncLit // innermost function literal
   129  	var litCount int
   130  	// Find the enclosing function, if any, and the number of func literals in between.
   131  	path, _ := astutil.PathEnclosingInterval(pgf.File, pos, pos)
   132  outer:
   133  	for _, node := range path {
   134  		switch n := node.(type) {
   135  		case *ast.FuncDecl:
   136  			funcDecl = n
   137  			break outer
   138  		case *ast.FuncLit:
   139  			litCount++
   140  			if litCount > 1 {
   141  				continue
   142  			}
   143  			funcLit = n
   144  		}
   145  	}
   146  
   147  	nameIdent := path[len(path)-1].(*ast.File).Name
   148  	kind := protocol.Package
   149  	if funcDecl != nil {
   150  		nameIdent = funcDecl.Name
   151  		kind = protocol.Function
   152  	}
   153  
   154  	nameStart, nameEnd := nameIdent.NamePos, nameIdent.NamePos+token.Pos(len(nameIdent.Name))
   155  	if funcLit != nil {
   156  		nameStart, nameEnd = funcLit.Type.Func, funcLit.Type.Params.Pos()
   157  		kind = protocol.Function
   158  	}
   159  	rng, err := NewMappedRange(snapshot.FileSet(), pgf.Mapper, nameStart, nameEnd).Range()
   160  	if err != nil {
   161  		return protocol.CallHierarchyItem{}, err
   162  	}
   163  
   164  	name := nameIdent.Name
   165  	for i := 0; i < litCount; i++ {
   166  		name += ".func()"
   167  	}
   168  
   169  	return protocol.CallHierarchyItem{
   170  		Name:           name,
   171  		Kind:           kind,
   172  		Tags:           []protocol.SymbolTag{},
   173  		Detail:         fmt.Sprintf("%s • %s", pkg.PkgPath(), filepath.Base(uri.Filename())),
   174  		URI:            protocol.DocumentURI(uri),
   175  		Range:          rng,
   176  		SelectionRange: rng,
   177  	}, nil
   178  }
   179  
   180  // OutgoingCalls returns an array of CallHierarchyOutgoingCall for a file and the position within the file.
   181  func OutgoingCalls(ctx context.Context, snapshot Snapshot, fh FileHandle, pos protocol.Position) ([]protocol.CallHierarchyOutgoingCall, error) {
   182  	ctx, done := event.Start(ctx, "source.OutgoingCalls")
   183  	defer done()
   184  
   185  	identifier, err := Identifier(ctx, snapshot, fh, pos)
   186  	if err != nil {
   187  		if errors.Is(err, ErrNoIdentFound) || errors.Is(err, errNoObjectFound) {
   188  			return nil, nil
   189  		}
   190  		return nil, err
   191  	}
   192  
   193  	if _, ok := identifier.Declaration.obj.Type().Underlying().(*types.Signature); !ok {
   194  		return nil, nil
   195  	}
   196  
   197  	if len(identifier.Declaration.MappedRange) == 0 {
   198  		return nil, nil
   199  	}
   200  	declMappedRange := identifier.Declaration.MappedRange[0]
   201  	callExprs, err := collectCallExpressions(snapshot.FileSet(), declMappedRange.m, identifier.Declaration.node)
   202  	if err != nil {
   203  		return nil, err
   204  	}
   205  
   206  	return toProtocolOutgoingCalls(ctx, snapshot, fh, callExprs)
   207  }
   208  
   209  // collectCallExpressions collects call expression ranges inside a function.
   210  func collectCallExpressions(fset *token.FileSet, mapper *protocol.ColumnMapper, node ast.Node) ([]protocol.Range, error) {
   211  	type callPos struct {
   212  		start, end token.Pos
   213  	}
   214  	callPositions := []callPos{}
   215  
   216  	ast.Inspect(node, func(n ast.Node) bool {
   217  		if call, ok := n.(*ast.CallExpr); ok {
   218  			var start, end token.Pos
   219  			switch n := call.Fun.(type) {
   220  			case *ast.SelectorExpr:
   221  				start, end = n.Sel.NamePos, call.Lparen
   222  			case *ast.Ident:
   223  				start, end = n.NamePos, call.Lparen
   224  			default:
   225  				// ignore any other kind of call expressions
   226  				// for ex: direct function literal calls since that's not an 'outgoing' call
   227  				return false
   228  			}
   229  			callPositions = append(callPositions, callPos{start: start, end: end})
   230  		}
   231  		return true
   232  	})
   233  
   234  	callRanges := []protocol.Range{}
   235  	for _, call := range callPositions {
   236  		callRange, err := NewMappedRange(fset, mapper, call.start, call.end).Range()
   237  		if err != nil {
   238  			return nil, err
   239  		}
   240  		callRanges = append(callRanges, callRange)
   241  	}
   242  	return callRanges, nil
   243  }
   244  
   245  // toProtocolOutgoingCalls returns an array of protocol.CallHierarchyOutgoingCall for ast call expressions.
   246  // Calls to the same function are assigned to the same declaration.
   247  func toProtocolOutgoingCalls(ctx context.Context, snapshot Snapshot, fh FileHandle, callRanges []protocol.Range) ([]protocol.CallHierarchyOutgoingCall, error) {
   248  	// multiple calls could be made to the same function
   249  	var outgoingCalls = map[ast.Node]*protocol.CallHierarchyOutgoingCall{}
   250  	for _, callRange := range callRanges {
   251  		identifier, err := Identifier(ctx, snapshot, fh, callRange.Start)
   252  		if err != nil {
   253  			if errors.Is(err, ErrNoIdentFound) || errors.Is(err, errNoObjectFound) {
   254  				continue
   255  			}
   256  			return nil, err
   257  		}
   258  
   259  		// ignore calls to builtin functions
   260  		if identifier.Declaration.obj.Pkg() == nil {
   261  			continue
   262  		}
   263  
   264  		if outgoingCall, ok := outgoingCalls[identifier.Declaration.node]; ok {
   265  			outgoingCall.FromRanges = append(outgoingCall.FromRanges, callRange)
   266  			continue
   267  		}
   268  
   269  		if len(identifier.Declaration.MappedRange) == 0 {
   270  			continue
   271  		}
   272  		declMappedRange := identifier.Declaration.MappedRange[0]
   273  		rng, err := declMappedRange.Range()
   274  		if err != nil {
   275  			return nil, err
   276  		}
   277  
   278  		outgoingCalls[identifier.Declaration.node] = &protocol.CallHierarchyOutgoingCall{
   279  			To: protocol.CallHierarchyItem{
   280  				Name:           identifier.Name,
   281  				Kind:           protocol.Function,
   282  				Tags:           []protocol.SymbolTag{},
   283  				Detail:         fmt.Sprintf("%s • %s", identifier.Declaration.obj.Pkg().Path(), filepath.Base(declMappedRange.URI().Filename())),
   284  				URI:            protocol.DocumentURI(declMappedRange.URI()),
   285  				Range:          rng,
   286  				SelectionRange: rng,
   287  			},
   288  			FromRanges: []protocol.Range{callRange},
   289  		}
   290  	}
   291  
   292  	outgoingCallItems := make([]protocol.CallHierarchyOutgoingCall, 0, len(outgoingCalls))
   293  	for _, callItem := range outgoingCalls {
   294  		outgoingCallItems = append(outgoingCallItems, *callItem)
   295  	}
   296  	return outgoingCallItems, nil
   297  }