github.com/april1989/origin-go-tools@v0.0.32/internal/lsp/source/highlight.go (about)

     1  // Copyright 2019 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package source
     6  
     7  import (
     8  	"context"
     9  	"fmt"
    10  	"go/ast"
    11  	"go/token"
    12  	"go/types"
    13  	"strings"
    14  
    15  	"github.com/april1989/origin-go-tools/go/ast/astutil"
    16  	"github.com/april1989/origin-go-tools/internal/event"
    17  	"github.com/april1989/origin-go-tools/internal/lsp/protocol"
    18  	errors "golang.org/x/xerrors"
    19  )
    20  
    21  func Highlight(ctx context.Context, snapshot Snapshot, fh FileHandle, pos protocol.Position) ([]protocol.Range, error) {
    22  	ctx, done := event.Start(ctx, "source.Highlight")
    23  	defer done()
    24  
    25  	pkg, pgf, err := getParsedFile(ctx, snapshot, fh, WidestPackage)
    26  	if err != nil {
    27  		return nil, errors.Errorf("getting file for Highlight: %w", err)
    28  	}
    29  	spn, err := pgf.Mapper.PointSpan(pos)
    30  	if err != nil {
    31  		return nil, err
    32  	}
    33  	rng, err := spn.Range(pgf.Mapper.Converter)
    34  	if err != nil {
    35  		return nil, err
    36  	}
    37  	path, _ := astutil.PathEnclosingInterval(pgf.File, rng.Start, rng.Start)
    38  	if len(path) == 0 {
    39  		return nil, fmt.Errorf("no enclosing position found for %v:%v", int(pos.Line), int(pos.Character))
    40  	}
    41  	// If start == end for astutil.PathEnclosingInterval, the 1-char interval
    42  	// following start is used instead. As a result, we might not get an exact
    43  	// match so we should check the 1-char interval to the left of the passed
    44  	// in position to see if that is an exact match.
    45  	if _, ok := path[0].(*ast.Ident); !ok {
    46  		if p, _ := astutil.PathEnclosingInterval(pgf.File, rng.Start-1, rng.Start-1); p != nil {
    47  			switch p[0].(type) {
    48  			case *ast.Ident, *ast.SelectorExpr:
    49  				path = p // use preceding ident/selector
    50  			}
    51  		}
    52  	}
    53  	result, err := highlightPath(pkg, path)
    54  	if err != nil {
    55  		return nil, err
    56  	}
    57  	var ranges []protocol.Range
    58  	for rng := range result {
    59  		mRng, err := posToMappedRange(snapshot, pkg, rng.start, rng.end)
    60  		if err != nil {
    61  			return nil, err
    62  		}
    63  		pRng, err := mRng.Range()
    64  		if err != nil {
    65  			return nil, err
    66  		}
    67  		ranges = append(ranges, pRng)
    68  	}
    69  	return ranges, nil
    70  }
    71  
    72  func highlightPath(pkg Package, path []ast.Node) (map[posRange]struct{}, error) {
    73  	result := make(map[posRange]struct{})
    74  	switch node := path[0].(type) {
    75  	case *ast.BasicLit:
    76  		if len(path) > 1 {
    77  			if _, ok := path[1].(*ast.ImportSpec); ok {
    78  				err := highlightImportUses(pkg, path, result)
    79  				return result, err
    80  			}
    81  		}
    82  		highlightFuncControlFlow(path, result)
    83  	case *ast.ReturnStmt, *ast.FuncDecl, *ast.FuncType:
    84  		highlightFuncControlFlow(path, result)
    85  	case *ast.Ident:
    86  		highlightIdentifiers(pkg, path, result)
    87  	case *ast.ForStmt, *ast.RangeStmt:
    88  		highlightLoopControlFlow(path, result)
    89  	case *ast.SwitchStmt:
    90  		highlightSwitchFlow(path, result)
    91  	case *ast.BranchStmt:
    92  		// BREAK can exit a loop, switch or select, while CONTINUE exit a loop so
    93  		// these need to be handled separately. They can also be embedded in any
    94  		// other loop/switch/select if they have a label. TODO: add support for
    95  		// GOTO and FALLTHROUGH as well.
    96  		if node.Label != nil {
    97  			highlightLabeledFlow(node, result)
    98  		} else {
    99  			switch node.Tok {
   100  			case token.BREAK:
   101  				highlightUnlabeledBreakFlow(path, result)
   102  			case token.CONTINUE:
   103  				highlightLoopControlFlow(path, result)
   104  			}
   105  		}
   106  	default:
   107  		// If the cursor is in an unidentified area, return empty results.
   108  		return nil, nil
   109  	}
   110  	return result, nil
   111  }
   112  
   113  type posRange struct {
   114  	start, end token.Pos
   115  }
   116  
   117  func highlightFuncControlFlow(path []ast.Node, result map[posRange]struct{}) {
   118  	var enclosingFunc ast.Node
   119  	var returnStmt *ast.ReturnStmt
   120  	var resultsList *ast.FieldList
   121  	inReturnList := false
   122  
   123  Outer:
   124  	// Reverse walk the path till we get to the func block.
   125  	for i, n := range path {
   126  		switch node := n.(type) {
   127  		case *ast.KeyValueExpr:
   128  			// If cursor is in a key: value expr, we don't want control flow highlighting
   129  			return
   130  		case *ast.CallExpr:
   131  			// If cusor is an arg in a callExpr, we don't want control flow highlighting.
   132  			if i > 0 {
   133  				for _, arg := range node.Args {
   134  					if arg == path[i-1] {
   135  						return
   136  					}
   137  				}
   138  			}
   139  		case *ast.Field:
   140  			inReturnList = true
   141  		case *ast.FuncLit:
   142  			enclosingFunc = n
   143  			resultsList = node.Type.Results
   144  			break Outer
   145  		case *ast.FuncDecl:
   146  			enclosingFunc = n
   147  			resultsList = node.Type.Results
   148  			break Outer
   149  		case *ast.ReturnStmt:
   150  			returnStmt = node
   151  			// If the cursor is not directly in a *ast.ReturnStmt, then
   152  			// we need to know if it is within one of the values that is being returned.
   153  			inReturnList = inReturnList || path[0] != returnStmt
   154  		}
   155  	}
   156  	// Cursor is not in a function.
   157  	if enclosingFunc == nil {
   158  		return
   159  	}
   160  	// If the cursor is on a "return" or "func" keyword, we should highlight all of the exit
   161  	// points of the function, including the "return" and "func" keywords.
   162  	highlightAllReturnsAndFunc := path[0] == returnStmt || path[0] == enclosingFunc
   163  	switch path[0].(type) {
   164  	case *ast.Ident, *ast.BasicLit:
   165  		// Cursor is in an identifier and not in a return statement or in the results list.
   166  		if returnStmt == nil && !inReturnList {
   167  			return
   168  		}
   169  	case *ast.FuncType:
   170  		highlightAllReturnsAndFunc = true
   171  	}
   172  	// The user's cursor may be within the return statement of a function,
   173  	// or within the result section of a function's signature.
   174  	// index := -1
   175  	var nodes []ast.Node
   176  	if returnStmt != nil {
   177  		for _, n := range returnStmt.Results {
   178  			nodes = append(nodes, n)
   179  		}
   180  	} else if resultsList != nil {
   181  		for _, n := range resultsList.List {
   182  			nodes = append(nodes, n)
   183  		}
   184  	}
   185  	_, index := nodeAtPos(nodes, path[0].Pos())
   186  
   187  	// Highlight the correct argument in the function declaration return types.
   188  	if resultsList != nil && -1 < index && index < len(resultsList.List) {
   189  		rng := posRange{
   190  			start: resultsList.List[index].Pos(),
   191  			end:   resultsList.List[index].End(),
   192  		}
   193  		result[rng] = struct{}{}
   194  	}
   195  	// Add the "func" part of the func declaration.
   196  	if highlightAllReturnsAndFunc {
   197  		r := posRange{
   198  			start: enclosingFunc.Pos(),
   199  			end:   enclosingFunc.Pos() + token.Pos(len("func")),
   200  		}
   201  		result[r] = struct{}{}
   202  	}
   203  	ast.Inspect(enclosingFunc, func(n ast.Node) bool {
   204  		// Don't traverse any other functions.
   205  		switch n.(type) {
   206  		case *ast.FuncDecl, *ast.FuncLit:
   207  			return enclosingFunc == n
   208  		}
   209  		ret, ok := n.(*ast.ReturnStmt)
   210  		if !ok {
   211  			return true
   212  		}
   213  		var toAdd ast.Node
   214  		// Add the entire return statement, applies when highlight the word "return" or "func".
   215  		if highlightAllReturnsAndFunc {
   216  			toAdd = n
   217  		}
   218  		// Add the relevant field within the entire return statement.
   219  		if -1 < index && index < len(ret.Results) {
   220  			toAdd = ret.Results[index]
   221  		}
   222  		if toAdd != nil {
   223  			result[posRange{start: toAdd.Pos(), end: toAdd.End()}] = struct{}{}
   224  		}
   225  		return false
   226  	})
   227  }
   228  
   229  func highlightUnlabeledBreakFlow(path []ast.Node, result map[posRange]struct{}) {
   230  	// Reverse walk the path until we find closest loop, select, or switch.
   231  	for _, n := range path {
   232  		switch n.(type) {
   233  		case *ast.ForStmt, *ast.RangeStmt:
   234  			highlightLoopControlFlow(path, result)
   235  			return // only highlight the innermost statement
   236  		case *ast.SwitchStmt:
   237  			highlightSwitchFlow(path, result)
   238  			return
   239  		case *ast.SelectStmt:
   240  			// TODO: add highlight when breaking a select.
   241  			return
   242  		}
   243  	}
   244  }
   245  
   246  func highlightLabeledFlow(node *ast.BranchStmt, result map[posRange]struct{}) {
   247  	obj := node.Label.Obj
   248  	if obj == nil || obj.Decl == nil {
   249  		return
   250  	}
   251  	label, ok := obj.Decl.(*ast.LabeledStmt)
   252  	if !ok {
   253  		return
   254  	}
   255  	switch label.Stmt.(type) {
   256  	case *ast.ForStmt, *ast.RangeStmt:
   257  		highlightLoopControlFlow([]ast.Node{label.Stmt, label}, result)
   258  	case *ast.SwitchStmt:
   259  		highlightSwitchFlow([]ast.Node{label.Stmt, label}, result)
   260  	}
   261  }
   262  
   263  func labelFor(path []ast.Node) *ast.Ident {
   264  	if len(path) > 1 {
   265  		if n, ok := path[1].(*ast.LabeledStmt); ok {
   266  			return n.Label
   267  		}
   268  	}
   269  	return nil
   270  }
   271  
   272  func highlightLoopControlFlow(path []ast.Node, result map[posRange]struct{}) {
   273  	var loop ast.Node
   274  	var loopLabel *ast.Ident
   275  	stmtLabel := labelFor(path)
   276  Outer:
   277  	// Reverse walk the path till we get to the for loop.
   278  	for i := range path {
   279  		switch n := path[i].(type) {
   280  		case *ast.ForStmt, *ast.RangeStmt:
   281  			loopLabel = labelFor(path[i:])
   282  
   283  			if stmtLabel == nil || loopLabel == stmtLabel {
   284  				loop = n
   285  				break Outer
   286  			}
   287  		}
   288  	}
   289  	if loop == nil {
   290  		return
   291  	}
   292  
   293  	// Add the for statement.
   294  	rng := posRange{
   295  		start: loop.Pos(),
   296  		end:   loop.Pos() + token.Pos(len("for")),
   297  	}
   298  	result[rng] = struct{}{}
   299  
   300  	// Traverse AST to find branch statements within the same for-loop.
   301  	ast.Inspect(loop, func(n ast.Node) bool {
   302  		switch n.(type) {
   303  		case *ast.ForStmt, *ast.RangeStmt:
   304  			return loop == n
   305  		case *ast.SwitchStmt, *ast.SelectStmt:
   306  			return false
   307  		}
   308  		b, ok := n.(*ast.BranchStmt)
   309  		if !ok {
   310  			return true
   311  		}
   312  		if b.Label == nil || labelDecl(b.Label) == loopLabel {
   313  			result[posRange{start: b.Pos(), end: b.End()}] = struct{}{}
   314  		}
   315  		return true
   316  	})
   317  
   318  	// Find continue statements in the same loop or switches/selects.
   319  	ast.Inspect(loop, func(n ast.Node) bool {
   320  		switch n.(type) {
   321  		case *ast.ForStmt, *ast.RangeStmt:
   322  			return loop == n
   323  		}
   324  
   325  		if n, ok := n.(*ast.BranchStmt); ok && n.Tok == token.CONTINUE {
   326  			result[posRange{start: n.Pos(), end: n.End()}] = struct{}{}
   327  		}
   328  		return true
   329  	})
   330  
   331  	// We don't need to check other for loops if we aren't looking for labeled statements.
   332  	if loopLabel == nil {
   333  		return
   334  	}
   335  
   336  	// Find labeled branch statements in any loop
   337  	ast.Inspect(loop, func(n ast.Node) bool {
   338  		b, ok := n.(*ast.BranchStmt)
   339  		if !ok {
   340  			return true
   341  		}
   342  		// Statment with labels that matches the loop.
   343  		if b.Label != nil && labelDecl(b.Label) == loopLabel {
   344  			result[posRange{start: b.Pos(), end: b.End()}] = struct{}{}
   345  		}
   346  		return true
   347  	})
   348  }
   349  
   350  func highlightSwitchFlow(path []ast.Node, result map[posRange]struct{}) {
   351  	var switchNode ast.Node
   352  	var switchNodeLabel *ast.Ident
   353  	stmtLabel := labelFor(path)
   354  Outer:
   355  	// Reverse walk the path till we get to the switch statement.
   356  	for i := range path {
   357  		switch n := path[i].(type) {
   358  		case *ast.SwitchStmt:
   359  			switchNodeLabel = labelFor(path[i:])
   360  			if stmtLabel == nil || switchNodeLabel == stmtLabel {
   361  				switchNode = n
   362  				break Outer
   363  			}
   364  		}
   365  	}
   366  	// Cursor is not in a switch statement
   367  	if switchNode == nil {
   368  		return
   369  	}
   370  
   371  	// Add the switch statement.
   372  	rng := posRange{
   373  		start: switchNode.Pos(),
   374  		end:   switchNode.Pos() + token.Pos(len("switch")),
   375  	}
   376  	result[rng] = struct{}{}
   377  
   378  	// Traverse AST to find break statements within the same switch.
   379  	ast.Inspect(switchNode, func(n ast.Node) bool {
   380  		switch n.(type) {
   381  		case *ast.SwitchStmt:
   382  			return switchNode == n
   383  		case *ast.ForStmt, *ast.RangeStmt, *ast.SelectStmt:
   384  			return false
   385  		}
   386  
   387  		b, ok := n.(*ast.BranchStmt)
   388  		if !ok || b.Tok != token.BREAK {
   389  			return true
   390  		}
   391  
   392  		if b.Label == nil || labelDecl(b.Label) == switchNodeLabel {
   393  			result[posRange{start: b.Pos(), end: b.End()}] = struct{}{}
   394  		}
   395  		return true
   396  	})
   397  
   398  	// We don't need to check other switches if we aren't looking for labeled statements.
   399  	if switchNodeLabel == nil {
   400  		return
   401  	}
   402  
   403  	// Find labeled break statements in any switch
   404  	ast.Inspect(switchNode, func(n ast.Node) bool {
   405  		b, ok := n.(*ast.BranchStmt)
   406  		if !ok || b.Tok != token.BREAK {
   407  			return true
   408  		}
   409  
   410  		if b.Label != nil && labelDecl(b.Label) == switchNodeLabel {
   411  			result[posRange{start: b.Pos(), end: b.End()}] = struct{}{}
   412  		}
   413  
   414  		return true
   415  	})
   416  }
   417  
   418  func labelDecl(n *ast.Ident) *ast.Ident {
   419  	if n == nil {
   420  		return nil
   421  	}
   422  	if n.Obj == nil {
   423  		return nil
   424  	}
   425  	if n.Obj.Decl == nil {
   426  		return nil
   427  	}
   428  	stmt, ok := n.Obj.Decl.(*ast.LabeledStmt)
   429  	if !ok {
   430  		return nil
   431  	}
   432  	return stmt.Label
   433  }
   434  
   435  func highlightImportUses(pkg Package, path []ast.Node, result map[posRange]struct{}) error {
   436  	basicLit, ok := path[0].(*ast.BasicLit)
   437  	if !ok {
   438  		return errors.Errorf("highlightImportUses called with an ast.Node of type %T", basicLit)
   439  	}
   440  	ast.Inspect(path[len(path)-1], func(node ast.Node) bool {
   441  		if imp, ok := node.(*ast.ImportSpec); ok && imp.Path == basicLit {
   442  			result[posRange{start: node.Pos(), end: node.End()}] = struct{}{}
   443  			return false
   444  		}
   445  		n, ok := node.(*ast.Ident)
   446  		if !ok {
   447  			return true
   448  		}
   449  		obj, ok := pkg.GetTypesInfo().ObjectOf(n).(*types.PkgName)
   450  		if !ok {
   451  			return true
   452  		}
   453  		if !strings.Contains(basicLit.Value, obj.Name()) {
   454  			return true
   455  		}
   456  		result[posRange{start: n.Pos(), end: n.End()}] = struct{}{}
   457  		return false
   458  	})
   459  	return nil
   460  }
   461  
   462  func highlightIdentifiers(pkg Package, path []ast.Node, result map[posRange]struct{}) error {
   463  	id, ok := path[0].(*ast.Ident)
   464  	if !ok {
   465  		return errors.Errorf("highlightIdentifiers called with an ast.Node of type %T", id)
   466  	}
   467  	// Check if ident is inside return or func decl.
   468  	highlightFuncControlFlow(path, result)
   469  
   470  	// TODO: maybe check if ident is a reserved word, if true then don't continue and return results.
   471  
   472  	idObj := pkg.GetTypesInfo().ObjectOf(id)
   473  	pkgObj, isImported := idObj.(*types.PkgName)
   474  	ast.Inspect(path[len(path)-1], func(node ast.Node) bool {
   475  		if imp, ok := node.(*ast.ImportSpec); ok && isImported {
   476  			highlightImport(pkgObj, imp, result)
   477  		}
   478  		n, ok := node.(*ast.Ident)
   479  		if !ok {
   480  			return true
   481  		}
   482  		if n.Name != id.Name {
   483  			return false
   484  		}
   485  		if nObj := pkg.GetTypesInfo().ObjectOf(n); nObj == idObj {
   486  			result[posRange{start: n.Pos(), end: n.End()}] = struct{}{}
   487  		}
   488  		return false
   489  	})
   490  	return nil
   491  }
   492  
   493  func highlightImport(obj *types.PkgName, imp *ast.ImportSpec, result map[posRange]struct{}) {
   494  	if imp.Name != nil || imp.Path == nil {
   495  		return
   496  	}
   497  	if !strings.Contains(imp.Path.Value, obj.Name()) {
   498  		return
   499  	}
   500  	result[posRange{start: imp.Path.Pos(), end: imp.Path.End()}] = struct{}{}
   501  }