github.com/v2fly/tools@v0.100.0/internal/lsp/source/highlight.go (about)

     1  // Copyright 2019 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package source
     6  
     7  import (
     8  	"context"
     9  	"fmt"
    10  	"go/ast"
    11  	"go/token"
    12  	"go/types"
    13  	"strings"
    14  
    15  	"github.com/v2fly/tools/go/ast/astutil"
    16  	"github.com/v2fly/tools/internal/event"
    17  	"github.com/v2fly/tools/internal/lsp/protocol"
    18  	errors "golang.org/x/xerrors"
    19  )
    20  
    21  func Highlight(ctx context.Context, snapshot Snapshot, fh FileHandle, pos protocol.Position) ([]protocol.Range, error) {
    22  	ctx, done := event.Start(ctx, "source.Highlight")
    23  	defer done()
    24  
    25  	// Don't use GetParsedFile because it uses TypecheckWorkspace, and we
    26  	// always want fully parsed files for highlight, regardless of whether
    27  	// the file belongs to a workspace package.
    28  	pkg, err := snapshot.PackageForFile(ctx, fh.URI(), TypecheckFull, WidestPackage)
    29  	if err != nil {
    30  		return nil, errors.Errorf("getting package for Highlight: %w", err)
    31  	}
    32  	pgf, err := pkg.File(fh.URI())
    33  	if err != nil {
    34  		return nil, errors.Errorf("getting file for Highlight: %w", err)
    35  	}
    36  
    37  	spn, err := pgf.Mapper.PointSpan(pos)
    38  	if err != nil {
    39  		return nil, err
    40  	}
    41  	rng, err := spn.Range(pgf.Mapper.Converter)
    42  	if err != nil {
    43  		return nil, err
    44  	}
    45  	path, _ := astutil.PathEnclosingInterval(pgf.File, rng.Start, rng.Start)
    46  	if len(path) == 0 {
    47  		return nil, fmt.Errorf("no enclosing position found for %v:%v", int(pos.Line), int(pos.Character))
    48  	}
    49  	// If start == end for astutil.PathEnclosingInterval, the 1-char interval
    50  	// following start is used instead. As a result, we might not get an exact
    51  	// match so we should check the 1-char interval to the left of the passed
    52  	// in position to see if that is an exact match.
    53  	if _, ok := path[0].(*ast.Ident); !ok {
    54  		if p, _ := astutil.PathEnclosingInterval(pgf.File, rng.Start-1, rng.Start-1); p != nil {
    55  			switch p[0].(type) {
    56  			case *ast.Ident, *ast.SelectorExpr:
    57  				path = p // use preceding ident/selector
    58  			}
    59  		}
    60  	}
    61  	result, err := highlightPath(pkg, path)
    62  	if err != nil {
    63  		return nil, err
    64  	}
    65  	var ranges []protocol.Range
    66  	for rng := range result {
    67  		mRng, err := posToMappedRange(snapshot, pkg, rng.start, rng.end)
    68  		if err != nil {
    69  			return nil, err
    70  		}
    71  		pRng, err := mRng.Range()
    72  		if err != nil {
    73  			return nil, err
    74  		}
    75  		ranges = append(ranges, pRng)
    76  	}
    77  	return ranges, nil
    78  }
    79  
    80  func highlightPath(pkg Package, path []ast.Node) (map[posRange]struct{}, error) {
    81  	result := make(map[posRange]struct{})
    82  	switch node := path[0].(type) {
    83  	case *ast.BasicLit:
    84  		if len(path) > 1 {
    85  			if _, ok := path[1].(*ast.ImportSpec); ok {
    86  				err := highlightImportUses(pkg, path, result)
    87  				return result, err
    88  			}
    89  		}
    90  		highlightFuncControlFlow(path, result)
    91  	case *ast.ReturnStmt, *ast.FuncDecl, *ast.FuncType:
    92  		highlightFuncControlFlow(path, result)
    93  	case *ast.Ident:
    94  		highlightIdentifiers(pkg, path, result)
    95  	case *ast.ForStmt, *ast.RangeStmt:
    96  		highlightLoopControlFlow(path, result)
    97  	case *ast.SwitchStmt:
    98  		highlightSwitchFlow(path, result)
    99  	case *ast.BranchStmt:
   100  		// BREAK can exit a loop, switch or select, while CONTINUE exit a loop so
   101  		// these need to be handled separately. They can also be embedded in any
   102  		// other loop/switch/select if they have a label. TODO: add support for
   103  		// GOTO and FALLTHROUGH as well.
   104  		if node.Label != nil {
   105  			highlightLabeledFlow(node, result)
   106  		} else {
   107  			switch node.Tok {
   108  			case token.BREAK:
   109  				highlightUnlabeledBreakFlow(path, result)
   110  			case token.CONTINUE:
   111  				highlightLoopControlFlow(path, result)
   112  			}
   113  		}
   114  	default:
   115  		// If the cursor is in an unidentified area, return empty results.
   116  		return nil, nil
   117  	}
   118  	return result, nil
   119  }
   120  
   121  type posRange struct {
   122  	start, end token.Pos
   123  }
   124  
   125  func highlightFuncControlFlow(path []ast.Node, result map[posRange]struct{}) {
   126  	var enclosingFunc ast.Node
   127  	var returnStmt *ast.ReturnStmt
   128  	var resultsList *ast.FieldList
   129  	inReturnList := false
   130  
   131  Outer:
   132  	// Reverse walk the path till we get to the func block.
   133  	for i, n := range path {
   134  		switch node := n.(type) {
   135  		case *ast.KeyValueExpr:
   136  			// If cursor is in a key: value expr, we don't want control flow highlighting
   137  			return
   138  		case *ast.CallExpr:
   139  			// If cusor is an arg in a callExpr, we don't want control flow highlighting.
   140  			if i > 0 {
   141  				for _, arg := range node.Args {
   142  					if arg == path[i-1] {
   143  						return
   144  					}
   145  				}
   146  			}
   147  		case *ast.Field:
   148  			inReturnList = true
   149  		case *ast.FuncLit:
   150  			enclosingFunc = n
   151  			resultsList = node.Type.Results
   152  			break Outer
   153  		case *ast.FuncDecl:
   154  			enclosingFunc = n
   155  			resultsList = node.Type.Results
   156  			break Outer
   157  		case *ast.ReturnStmt:
   158  			returnStmt = node
   159  			// If the cursor is not directly in a *ast.ReturnStmt, then
   160  			// we need to know if it is within one of the values that is being returned.
   161  			inReturnList = inReturnList || path[0] != returnStmt
   162  		}
   163  	}
   164  	// Cursor is not in a function.
   165  	if enclosingFunc == nil {
   166  		return
   167  	}
   168  	// If the cursor is on a "return" or "func" keyword, we should highlight all of the exit
   169  	// points of the function, including the "return" and "func" keywords.
   170  	highlightAllReturnsAndFunc := path[0] == returnStmt || path[0] == enclosingFunc
   171  	switch path[0].(type) {
   172  	case *ast.Ident, *ast.BasicLit:
   173  		// Cursor is in an identifier and not in a return statement or in the results list.
   174  		if returnStmt == nil && !inReturnList {
   175  			return
   176  		}
   177  	case *ast.FuncType:
   178  		highlightAllReturnsAndFunc = true
   179  	}
   180  	// The user's cursor may be within the return statement of a function,
   181  	// or within the result section of a function's signature.
   182  	// index := -1
   183  	var nodes []ast.Node
   184  	if returnStmt != nil {
   185  		for _, n := range returnStmt.Results {
   186  			nodes = append(nodes, n)
   187  		}
   188  	} else if resultsList != nil {
   189  		for _, n := range resultsList.List {
   190  			nodes = append(nodes, n)
   191  		}
   192  	}
   193  	_, index := nodeAtPos(nodes, path[0].Pos())
   194  
   195  	// Highlight the correct argument in the function declaration return types.
   196  	if resultsList != nil && -1 < index && index < len(resultsList.List) {
   197  		rng := posRange{
   198  			start: resultsList.List[index].Pos(),
   199  			end:   resultsList.List[index].End(),
   200  		}
   201  		result[rng] = struct{}{}
   202  	}
   203  	// Add the "func" part of the func declaration.
   204  	if highlightAllReturnsAndFunc {
   205  		r := posRange{
   206  			start: enclosingFunc.Pos(),
   207  			end:   enclosingFunc.Pos() + token.Pos(len("func")),
   208  		}
   209  		result[r] = struct{}{}
   210  	}
   211  	ast.Inspect(enclosingFunc, func(n ast.Node) bool {
   212  		// Don't traverse any other functions.
   213  		switch n.(type) {
   214  		case *ast.FuncDecl, *ast.FuncLit:
   215  			return enclosingFunc == n
   216  		}
   217  		ret, ok := n.(*ast.ReturnStmt)
   218  		if !ok {
   219  			return true
   220  		}
   221  		var toAdd ast.Node
   222  		// Add the entire return statement, applies when highlight the word "return" or "func".
   223  		if highlightAllReturnsAndFunc {
   224  			toAdd = n
   225  		}
   226  		// Add the relevant field within the entire return statement.
   227  		if -1 < index && index < len(ret.Results) {
   228  			toAdd = ret.Results[index]
   229  		}
   230  		if toAdd != nil {
   231  			result[posRange{start: toAdd.Pos(), end: toAdd.End()}] = struct{}{}
   232  		}
   233  		return false
   234  	})
   235  }
   236  
   237  func highlightUnlabeledBreakFlow(path []ast.Node, result map[posRange]struct{}) {
   238  	// Reverse walk the path until we find closest loop, select, or switch.
   239  	for _, n := range path {
   240  		switch n.(type) {
   241  		case *ast.ForStmt, *ast.RangeStmt:
   242  			highlightLoopControlFlow(path, result)
   243  			return // only highlight the innermost statement
   244  		case *ast.SwitchStmt:
   245  			highlightSwitchFlow(path, result)
   246  			return
   247  		case *ast.SelectStmt:
   248  			// TODO: add highlight when breaking a select.
   249  			return
   250  		}
   251  	}
   252  }
   253  
   254  func highlightLabeledFlow(node *ast.BranchStmt, result map[posRange]struct{}) {
   255  	obj := node.Label.Obj
   256  	if obj == nil || obj.Decl == nil {
   257  		return
   258  	}
   259  	label, ok := obj.Decl.(*ast.LabeledStmt)
   260  	if !ok {
   261  		return
   262  	}
   263  	switch label.Stmt.(type) {
   264  	case *ast.ForStmt, *ast.RangeStmt:
   265  		highlightLoopControlFlow([]ast.Node{label.Stmt, label}, result)
   266  	case *ast.SwitchStmt:
   267  		highlightSwitchFlow([]ast.Node{label.Stmt, label}, result)
   268  	}
   269  }
   270  
   271  func labelFor(path []ast.Node) *ast.Ident {
   272  	if len(path) > 1 {
   273  		if n, ok := path[1].(*ast.LabeledStmt); ok {
   274  			return n.Label
   275  		}
   276  	}
   277  	return nil
   278  }
   279  
   280  func highlightLoopControlFlow(path []ast.Node, result map[posRange]struct{}) {
   281  	var loop ast.Node
   282  	var loopLabel *ast.Ident
   283  	stmtLabel := labelFor(path)
   284  Outer:
   285  	// Reverse walk the path till we get to the for loop.
   286  	for i := range path {
   287  		switch n := path[i].(type) {
   288  		case *ast.ForStmt, *ast.RangeStmt:
   289  			loopLabel = labelFor(path[i:])
   290  
   291  			if stmtLabel == nil || loopLabel == stmtLabel {
   292  				loop = n
   293  				break Outer
   294  			}
   295  		}
   296  	}
   297  	if loop == nil {
   298  		return
   299  	}
   300  
   301  	// Add the for statement.
   302  	rng := posRange{
   303  		start: loop.Pos(),
   304  		end:   loop.Pos() + token.Pos(len("for")),
   305  	}
   306  	result[rng] = struct{}{}
   307  
   308  	// Traverse AST to find branch statements within the same for-loop.
   309  	ast.Inspect(loop, func(n ast.Node) bool {
   310  		switch n.(type) {
   311  		case *ast.ForStmt, *ast.RangeStmt:
   312  			return loop == n
   313  		case *ast.SwitchStmt, *ast.SelectStmt:
   314  			return false
   315  		}
   316  		b, ok := n.(*ast.BranchStmt)
   317  		if !ok {
   318  			return true
   319  		}
   320  		if b.Label == nil || labelDecl(b.Label) == loopLabel {
   321  			result[posRange{start: b.Pos(), end: b.End()}] = struct{}{}
   322  		}
   323  		return true
   324  	})
   325  
   326  	// Find continue statements in the same loop or switches/selects.
   327  	ast.Inspect(loop, func(n ast.Node) bool {
   328  		switch n.(type) {
   329  		case *ast.ForStmt, *ast.RangeStmt:
   330  			return loop == n
   331  		}
   332  
   333  		if n, ok := n.(*ast.BranchStmt); ok && n.Tok == token.CONTINUE {
   334  			result[posRange{start: n.Pos(), end: n.End()}] = struct{}{}
   335  		}
   336  		return true
   337  	})
   338  
   339  	// We don't need to check other for loops if we aren't looking for labeled statements.
   340  	if loopLabel == nil {
   341  		return
   342  	}
   343  
   344  	// Find labeled branch statements in any loop.
   345  	ast.Inspect(loop, func(n ast.Node) bool {
   346  		b, ok := n.(*ast.BranchStmt)
   347  		if !ok {
   348  			return true
   349  		}
   350  		// statement with labels that matches the loop
   351  		if b.Label != nil && labelDecl(b.Label) == loopLabel {
   352  			result[posRange{start: b.Pos(), end: b.End()}] = struct{}{}
   353  		}
   354  		return true
   355  	})
   356  }
   357  
   358  func highlightSwitchFlow(path []ast.Node, result map[posRange]struct{}) {
   359  	var switchNode ast.Node
   360  	var switchNodeLabel *ast.Ident
   361  	stmtLabel := labelFor(path)
   362  Outer:
   363  	// Reverse walk the path till we get to the switch statement.
   364  	for i := range path {
   365  		switch n := path[i].(type) {
   366  		case *ast.SwitchStmt:
   367  			switchNodeLabel = labelFor(path[i:])
   368  			if stmtLabel == nil || switchNodeLabel == stmtLabel {
   369  				switchNode = n
   370  				break Outer
   371  			}
   372  		}
   373  	}
   374  	// Cursor is not in a switch statement
   375  	if switchNode == nil {
   376  		return
   377  	}
   378  
   379  	// Add the switch statement.
   380  	rng := posRange{
   381  		start: switchNode.Pos(),
   382  		end:   switchNode.Pos() + token.Pos(len("switch")),
   383  	}
   384  	result[rng] = struct{}{}
   385  
   386  	// Traverse AST to find break statements within the same switch.
   387  	ast.Inspect(switchNode, func(n ast.Node) bool {
   388  		switch n.(type) {
   389  		case *ast.SwitchStmt:
   390  			return switchNode == n
   391  		case *ast.ForStmt, *ast.RangeStmt, *ast.SelectStmt:
   392  			return false
   393  		}
   394  
   395  		b, ok := n.(*ast.BranchStmt)
   396  		if !ok || b.Tok != token.BREAK {
   397  			return true
   398  		}
   399  
   400  		if b.Label == nil || labelDecl(b.Label) == switchNodeLabel {
   401  			result[posRange{start: b.Pos(), end: b.End()}] = struct{}{}
   402  		}
   403  		return true
   404  	})
   405  
   406  	// We don't need to check other switches if we aren't looking for labeled statements.
   407  	if switchNodeLabel == nil {
   408  		return
   409  	}
   410  
   411  	// Find labeled break statements in any switch
   412  	ast.Inspect(switchNode, func(n ast.Node) bool {
   413  		b, ok := n.(*ast.BranchStmt)
   414  		if !ok || b.Tok != token.BREAK {
   415  			return true
   416  		}
   417  
   418  		if b.Label != nil && labelDecl(b.Label) == switchNodeLabel {
   419  			result[posRange{start: b.Pos(), end: b.End()}] = struct{}{}
   420  		}
   421  
   422  		return true
   423  	})
   424  }
   425  
   426  func labelDecl(n *ast.Ident) *ast.Ident {
   427  	if n == nil {
   428  		return nil
   429  	}
   430  	if n.Obj == nil {
   431  		return nil
   432  	}
   433  	if n.Obj.Decl == nil {
   434  		return nil
   435  	}
   436  	stmt, ok := n.Obj.Decl.(*ast.LabeledStmt)
   437  	if !ok {
   438  		return nil
   439  	}
   440  	return stmt.Label
   441  }
   442  
   443  func highlightImportUses(pkg Package, path []ast.Node, result map[posRange]struct{}) error {
   444  	basicLit, ok := path[0].(*ast.BasicLit)
   445  	if !ok {
   446  		return errors.Errorf("highlightImportUses called with an ast.Node of type %T", basicLit)
   447  	}
   448  	ast.Inspect(path[len(path)-1], func(node ast.Node) bool {
   449  		if imp, ok := node.(*ast.ImportSpec); ok && imp.Path == basicLit {
   450  			result[posRange{start: node.Pos(), end: node.End()}] = struct{}{}
   451  			return false
   452  		}
   453  		n, ok := node.(*ast.Ident)
   454  		if !ok {
   455  			return true
   456  		}
   457  		obj, ok := pkg.GetTypesInfo().ObjectOf(n).(*types.PkgName)
   458  		if !ok {
   459  			return true
   460  		}
   461  		if !strings.Contains(basicLit.Value, obj.Name()) {
   462  			return true
   463  		}
   464  		result[posRange{start: n.Pos(), end: n.End()}] = struct{}{}
   465  		return false
   466  	})
   467  	return nil
   468  }
   469  
   470  func highlightIdentifiers(pkg Package, path []ast.Node, result map[posRange]struct{}) error {
   471  	id, ok := path[0].(*ast.Ident)
   472  	if !ok {
   473  		return errors.Errorf("highlightIdentifiers called with an ast.Node of type %T", id)
   474  	}
   475  	// Check if ident is inside return or func decl.
   476  	highlightFuncControlFlow(path, result)
   477  
   478  	// TODO: maybe check if ident is a reserved word, if true then don't continue and return results.
   479  
   480  	idObj := pkg.GetTypesInfo().ObjectOf(id)
   481  	pkgObj, isImported := idObj.(*types.PkgName)
   482  	ast.Inspect(path[len(path)-1], func(node ast.Node) bool {
   483  		if imp, ok := node.(*ast.ImportSpec); ok && isImported {
   484  			highlightImport(pkgObj, imp, result)
   485  		}
   486  		n, ok := node.(*ast.Ident)
   487  		if !ok {
   488  			return true
   489  		}
   490  		if n.Name != id.Name {
   491  			return false
   492  		}
   493  		if nObj := pkg.GetTypesInfo().ObjectOf(n); nObj == idObj {
   494  			result[posRange{start: n.Pos(), end: n.End()}] = struct{}{}
   495  		}
   496  		return false
   497  	})
   498  	return nil
   499  }
   500  
   501  func highlightImport(obj *types.PkgName, imp *ast.ImportSpec, result map[posRange]struct{}) {
   502  	if imp.Name != nil || imp.Path == nil {
   503  		return
   504  	}
   505  	if !strings.Contains(imp.Path.Value, obj.Name()) {
   506  		return
   507  	}
   508  	result[posRange{start: imp.Path.Pos(), end: imp.Path.End()}] = struct{}{}
   509  }