github.com/v2fly/tools@v0.100.0/internal/lsp/source/call_hierarchy.go (about)

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package source
     6  
     7  import (
     8  	"context"
     9  	"fmt"
    10  	"go/ast"
    11  	"go/token"
    12  	"go/types"
    13  	"path/filepath"
    14  
    15  	"github.com/v2fly/tools/go/ast/astutil"
    16  	"github.com/v2fly/tools/internal/event"
    17  	"github.com/v2fly/tools/internal/lsp/debug/tag"
    18  	"github.com/v2fly/tools/internal/lsp/protocol"
    19  	"github.com/v2fly/tools/internal/span"
    20  	errors "golang.org/x/xerrors"
    21  )
    22  
    23  // PrepareCallHierarchy returns an array of CallHierarchyItem for a file and the position within the file.
    24  func PrepareCallHierarchy(ctx context.Context, snapshot Snapshot, fh FileHandle, pos protocol.Position) ([]protocol.CallHierarchyItem, error) {
    25  	ctx, done := event.Start(ctx, "source.PrepareCallHierarchy")
    26  	defer done()
    27  
    28  	identifier, err := Identifier(ctx, snapshot, fh, pos)
    29  	if err != nil {
    30  		if errors.Is(err, ErrNoIdentFound) || errors.Is(err, errNoObjectFound) {
    31  			return nil, nil
    32  		}
    33  		return nil, err
    34  	}
    35  	// The identifier can be nil if it is an import spec.
    36  	if identifier == nil {
    37  		return nil, nil
    38  	}
    39  
    40  	if _, ok := identifier.Declaration.obj.Type().Underlying().(*types.Signature); !ok {
    41  		return nil, nil
    42  	}
    43  
    44  	if len(identifier.Declaration.MappedRange) == 0 {
    45  		return nil, nil
    46  	}
    47  	declMappedRange := identifier.Declaration.MappedRange[0]
    48  	rng, err := declMappedRange.Range()
    49  	if err != nil {
    50  		return nil, err
    51  	}
    52  
    53  	callHierarchyItem := protocol.CallHierarchyItem{
    54  		Name:           identifier.Name,
    55  		Kind:           protocol.Function,
    56  		Tags:           []protocol.SymbolTag{},
    57  		Detail:         fmt.Sprintf("%s • %s", identifier.Declaration.obj.Pkg().Path(), filepath.Base(declMappedRange.URI().Filename())),
    58  		URI:            protocol.DocumentURI(declMappedRange.URI()),
    59  		Range:          rng,
    60  		SelectionRange: rng,
    61  	}
    62  	return []protocol.CallHierarchyItem{callHierarchyItem}, nil
    63  }
    64  
    65  // IncomingCalls returns an array of CallHierarchyIncomingCall for a file and the position within the file.
    66  func IncomingCalls(ctx context.Context, snapshot Snapshot, fh FileHandle, pos protocol.Position) ([]protocol.CallHierarchyIncomingCall, error) {
    67  	ctx, done := event.Start(ctx, "source.IncomingCalls")
    68  	defer done()
    69  
    70  	refs, err := References(ctx, snapshot, fh, pos, false)
    71  	if err != nil {
    72  		if errors.Is(err, ErrNoIdentFound) || errors.Is(err, errNoObjectFound) {
    73  			return nil, nil
    74  		}
    75  		return nil, err
    76  	}
    77  
    78  	return toProtocolIncomingCalls(ctx, snapshot, refs)
    79  }
    80  
    81  // toProtocolIncomingCalls returns an array of protocol.CallHierarchyIncomingCall for ReferenceInfo's.
    82  // References inside same enclosure are assigned to the same enclosing function.
    83  func toProtocolIncomingCalls(ctx context.Context, snapshot Snapshot, refs []*ReferenceInfo) ([]protocol.CallHierarchyIncomingCall, error) {
    84  	// an enclosing node could have multiple calls to a reference, we only show the enclosure
    85  	// once in the result but highlight all calls using FromRanges (ranges at which the calls occur)
    86  	var incomingCalls = map[protocol.Location]*protocol.CallHierarchyIncomingCall{}
    87  	for _, ref := range refs {
    88  		refRange, err := ref.Range()
    89  		if err != nil {
    90  			return nil, err
    91  		}
    92  
    93  		callItem, err := enclosingNodeCallItem(snapshot, ref.pkg, ref.URI(), ref.ident.NamePos)
    94  		if err != nil {
    95  			event.Error(ctx, "error getting enclosing node", err, tag.Method.Of(ref.Name))
    96  			continue
    97  		}
    98  		loc := protocol.Location{
    99  			URI:   callItem.URI,
   100  			Range: callItem.Range,
   101  		}
   102  
   103  		if incomingCall, ok := incomingCalls[loc]; ok {
   104  			incomingCall.FromRanges = append(incomingCall.FromRanges, refRange)
   105  			continue
   106  		}
   107  		incomingCalls[loc] = &protocol.CallHierarchyIncomingCall{
   108  			From:       callItem,
   109  			FromRanges: []protocol.Range{refRange},
   110  		}
   111  	}
   112  
   113  	incomingCallItems := make([]protocol.CallHierarchyIncomingCall, 0, len(incomingCalls))
   114  	for _, callItem := range incomingCalls {
   115  		incomingCallItems = append(incomingCallItems, *callItem)
   116  	}
   117  	return incomingCallItems, nil
   118  }
   119  
   120  // enclosingNodeCallItem creates a CallHierarchyItem representing the function call at pos
   121  func enclosingNodeCallItem(snapshot Snapshot, pkg Package, uri span.URI, pos token.Pos) (protocol.CallHierarchyItem, error) {
   122  	pgf, err := pkg.File(uri)
   123  	if err != nil {
   124  		return protocol.CallHierarchyItem{}, err
   125  	}
   126  
   127  	var funcDecl *ast.FuncDecl
   128  	var funcLit *ast.FuncLit // innermost function literal
   129  	var litCount int
   130  	// Find the enclosing function, if any, and the number of func literals in between.
   131  	path, _ := astutil.PathEnclosingInterval(pgf.File, pos, pos)
   132  outer:
   133  	for _, node := range path {
   134  		switch n := node.(type) {
   135  		case *ast.FuncDecl:
   136  			funcDecl = n
   137  			break outer
   138  		case *ast.FuncLit:
   139  			litCount++
   140  			if litCount > 1 {
   141  				continue
   142  			}
   143  			funcLit = n
   144  		}
   145  	}
   146  
   147  	nameIdent := path[len(path)-1].(*ast.File).Name
   148  	kind := protocol.Package
   149  	if funcDecl != nil {
   150  		nameIdent = funcDecl.Name
   151  		kind = protocol.Function
   152  	}
   153  
   154  	nameStart, nameEnd := nameIdent.NamePos, nameIdent.NamePos+token.Pos(len(nameIdent.Name))
   155  	if funcLit != nil {
   156  		nameStart, nameEnd = funcLit.Type.Func, funcLit.Type.Params.Pos()
   157  		kind = protocol.Function
   158  	}
   159  	rng, err := NewMappedRange(snapshot.FileSet(), pgf.Mapper, nameStart, nameEnd).Range()
   160  	if err != nil {
   161  		return protocol.CallHierarchyItem{}, err
   162  	}
   163  
   164  	name := nameIdent.Name
   165  	for i := 0; i < litCount; i++ {
   166  		name += ".func()"
   167  	}
   168  
   169  	return protocol.CallHierarchyItem{
   170  		Name:           name,
   171  		Kind:           kind,
   172  		Tags:           []protocol.SymbolTag{},
   173  		Detail:         fmt.Sprintf("%s • %s", pkg.PkgPath(), filepath.Base(uri.Filename())),
   174  		URI:            protocol.DocumentURI(uri),
   175  		Range:          rng,
   176  		SelectionRange: rng,
   177  	}, nil
   178  }
   179  
   180  // OutgoingCalls returns an array of CallHierarchyOutgoingCall for a file and the position within the file.
   181  func OutgoingCalls(ctx context.Context, snapshot Snapshot, fh FileHandle, pos protocol.Position) ([]protocol.CallHierarchyOutgoingCall, error) {
   182  	ctx, done := event.Start(ctx, "source.OutgoingCalls")
   183  	defer done()
   184  
   185  	identifier, err := Identifier(ctx, snapshot, fh, pos)
   186  	if err != nil {
   187  		if errors.Is(err, ErrNoIdentFound) || errors.Is(err, errNoObjectFound) {
   188  			return nil, nil
   189  		}
   190  		return nil, err
   191  	}
   192  
   193  	if _, ok := identifier.Declaration.obj.Type().Underlying().(*types.Signature); !ok {
   194  		return nil, nil
   195  	}
   196  	if identifier.Declaration.node == nil {
   197  		return nil, nil
   198  	}
   199  	if len(identifier.Declaration.MappedRange) == 0 {
   200  		return nil, nil
   201  	}
   202  	declMappedRange := identifier.Declaration.MappedRange[0]
   203  	callExprs, err := collectCallExpressions(snapshot.FileSet(), declMappedRange.m, identifier.Declaration.node)
   204  	if err != nil {
   205  		return nil, err
   206  	}
   207  
   208  	return toProtocolOutgoingCalls(ctx, snapshot, fh, callExprs)
   209  }
   210  
   211  // collectCallExpressions collects call expression ranges inside a function.
   212  func collectCallExpressions(fset *token.FileSet, mapper *protocol.ColumnMapper, node ast.Node) ([]protocol.Range, error) {
   213  	type callPos struct {
   214  		start, end token.Pos
   215  	}
   216  	callPositions := []callPos{}
   217  
   218  	ast.Inspect(node, func(n ast.Node) bool {
   219  		if call, ok := n.(*ast.CallExpr); ok {
   220  			var start, end token.Pos
   221  			switch n := call.Fun.(type) {
   222  			case *ast.SelectorExpr:
   223  				start, end = n.Sel.NamePos, call.Lparen
   224  			case *ast.Ident:
   225  				start, end = n.NamePos, call.Lparen
   226  			default:
   227  				// ignore any other kind of call expressions
   228  				// for ex: direct function literal calls since that's not an 'outgoing' call
   229  				return false
   230  			}
   231  			callPositions = append(callPositions, callPos{start: start, end: end})
   232  		}
   233  		return true
   234  	})
   235  
   236  	callRanges := []protocol.Range{}
   237  	for _, call := range callPositions {
   238  		callRange, err := NewMappedRange(fset, mapper, call.start, call.end).Range()
   239  		if err != nil {
   240  			return nil, err
   241  		}
   242  		callRanges = append(callRanges, callRange)
   243  	}
   244  	return callRanges, nil
   245  }
   246  
   247  // toProtocolOutgoingCalls returns an array of protocol.CallHierarchyOutgoingCall for ast call expressions.
   248  // Calls to the same function are assigned to the same declaration.
   249  func toProtocolOutgoingCalls(ctx context.Context, snapshot Snapshot, fh FileHandle, callRanges []protocol.Range) ([]protocol.CallHierarchyOutgoingCall, error) {
   250  	// Multiple calls could be made to the same function, defined by "same declaration
   251  	// AST node & same idenfitier name" to provide a unique identifier key even when
   252  	// the func is declared in a struct or interface.
   253  	type key struct {
   254  		decl ast.Node
   255  		name string
   256  	}
   257  	outgoingCalls := map[key]*protocol.CallHierarchyOutgoingCall{}
   258  	for _, callRange := range callRanges {
   259  		identifier, err := Identifier(ctx, snapshot, fh, callRange.Start)
   260  		if err != nil {
   261  			if errors.Is(err, ErrNoIdentFound) || errors.Is(err, errNoObjectFound) {
   262  				continue
   263  			}
   264  			return nil, err
   265  		}
   266  
   267  		// ignore calls to builtin functions
   268  		if identifier.Declaration.obj.Pkg() == nil {
   269  			continue
   270  		}
   271  
   272  		if outgoingCall, ok := outgoingCalls[key{identifier.Declaration.node, identifier.Name}]; ok {
   273  			outgoingCall.FromRanges = append(outgoingCall.FromRanges, callRange)
   274  			continue
   275  		}
   276  
   277  		if len(identifier.Declaration.MappedRange) == 0 {
   278  			continue
   279  		}
   280  		declMappedRange := identifier.Declaration.MappedRange[0]
   281  		rng, err := declMappedRange.Range()
   282  		if err != nil {
   283  			return nil, err
   284  		}
   285  
   286  		outgoingCalls[key{identifier.Declaration.node, identifier.Name}] = &protocol.CallHierarchyOutgoingCall{
   287  			To: protocol.CallHierarchyItem{
   288  				Name:           identifier.Name,
   289  				Kind:           protocol.Function,
   290  				Tags:           []protocol.SymbolTag{},
   291  				Detail:         fmt.Sprintf("%s • %s", identifier.Declaration.obj.Pkg().Path(), filepath.Base(declMappedRange.URI().Filename())),
   292  				URI:            protocol.DocumentURI(declMappedRange.URI()),
   293  				Range:          rng,
   294  				SelectionRange: rng,
   295  			},
   296  			FromRanges: []protocol.Range{callRange},
   297  		}
   298  	}
   299  
   300  	outgoingCallItems := make([]protocol.CallHierarchyOutgoingCall, 0, len(outgoingCalls))
   301  	for _, callItem := range outgoingCalls {
   302  		outgoingCallItems = append(outgoingCallItems, *callItem)
   303  	}
   304  	return outgoingCallItems, nil
   305  }