github.com/v2fly/tools@v0.100.0/internal/lsp/source/completion/statements.go (about)

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package completion
     6  
     7  import (
     8  	"fmt"
     9  	"go/ast"
    10  	"go/token"
    11  	"go/types"
    12  
    13  	"github.com/v2fly/tools/internal/lsp/protocol"
    14  	"github.com/v2fly/tools/internal/lsp/snippet"
    15  	"github.com/v2fly/tools/internal/lsp/source"
    16  )
    17  
    18  // addStatementCandidates adds full statement completion candidates
    19  // appropriate for the current context.
    20  func (c *completer) addStatementCandidates() {
    21  	c.addErrCheck()
    22  	c.addAssignAppend()
    23  }
    24  
    25  // addAssignAppend offers a completion candidate of the form:
    26  //
    27  //     someSlice = append(someSlice, )
    28  //
    29  // It will offer the "append" completion in two situations:
    30  //
    31  // 1. Position is in RHS of assign, prefix matches "append", and
    32  //    corresponding LHS object is a slice. For example,
    33  //    "foo = ap<>" completes to "foo = append(foo, )".
    34  //
    35  // Or
    36  //
    37  // 2. Prefix is an ident or selector in an *ast.ExprStmt (i.e.
    38  //    beginning of statement), and our best matching candidate is a
    39  //    slice. For example: "foo.ba" completes to "foo.bar = append(foo.bar, )".
    40  func (c *completer) addAssignAppend() {
    41  	if len(c.path) < 3 {
    42  		return
    43  	}
    44  
    45  	ident, _ := c.path[0].(*ast.Ident)
    46  	if ident == nil {
    47  		return
    48  	}
    49  
    50  	var (
    51  		// sliceText is the full name of our slice object, e.g. "s.abc" in
    52  		// "s.abc = app<>".
    53  		sliceText string
    54  		// needsLHS is true if we need to prepend the LHS slice name and
    55  		// "=" to our candidate.
    56  		needsLHS = false
    57  		fset     = c.snapshot.FileSet()
    58  	)
    59  
    60  	switch n := c.path[1].(type) {
    61  	case *ast.AssignStmt:
    62  		// We are already in an assignment. Make sure our prefix matches "append".
    63  		if c.matcher.Score("append") <= 0 {
    64  			return
    65  		}
    66  
    67  		exprIdx := exprAtPos(c.pos, n.Rhs)
    68  		if exprIdx == len(n.Rhs) || exprIdx > len(n.Lhs)-1 {
    69  			return
    70  		}
    71  
    72  		lhsType := c.pkg.GetTypesInfo().TypeOf(n.Lhs[exprIdx])
    73  		if lhsType == nil {
    74  			return
    75  		}
    76  
    77  		// Make sure our corresponding LHS object is a slice.
    78  		if _, isSlice := lhsType.Underlying().(*types.Slice); !isSlice {
    79  			return
    80  		}
    81  
    82  		// The name or our slice is whatever's in the LHS expression.
    83  		sliceText = source.FormatNode(fset, n.Lhs[exprIdx])
    84  	case *ast.SelectorExpr:
    85  		// Make sure we are a selector at the beginning of a statement.
    86  		if _, parentIsExprtStmt := c.path[2].(*ast.ExprStmt); !parentIsExprtStmt {
    87  			return
    88  		}
    89  
    90  		// So far we only know the first part of our slice name. For
    91  		// example in "s.a<>" we only know our slice begins with "s."
    92  		// since the user could still be typing.
    93  		sliceText = source.FormatNode(fset, n.X) + "."
    94  		needsLHS = true
    95  	case *ast.ExprStmt:
    96  		needsLHS = true
    97  	default:
    98  		return
    99  	}
   100  
   101  	var (
   102  		label string
   103  		snip  snippet.Builder
   104  		score = highScore
   105  	)
   106  
   107  	if needsLHS {
   108  		// Offer the long form assign + append candidate if our best
   109  		// candidate is a slice.
   110  		bestItem := c.topCandidate()
   111  		if bestItem == nil || bestItem.obj == nil || bestItem.obj.Type() == nil {
   112  			return
   113  		}
   114  
   115  		if _, isSlice := bestItem.obj.Type().Underlying().(*types.Slice); !isSlice {
   116  			return
   117  		}
   118  
   119  		// Don't rank the full form assign + append candidate above the
   120  		// slice itself.
   121  		score = bestItem.Score - 0.01
   122  
   123  		// Fill in rest of sliceText now that we have the object name.
   124  		sliceText += bestItem.Label
   125  
   126  		// Fill in the candidate's LHS bits.
   127  		label = fmt.Sprintf("%s = ", bestItem.Label)
   128  		snip.WriteText(label)
   129  	}
   130  
   131  	snip.WriteText(fmt.Sprintf("append(%s, ", sliceText))
   132  	snip.WritePlaceholder(nil)
   133  	snip.WriteText(")")
   134  
   135  	c.items = append(c.items, CompletionItem{
   136  		Label:   label + fmt.Sprintf("append(%s, )", sliceText),
   137  		Kind:    protocol.FunctionCompletion,
   138  		Score:   score,
   139  		snippet: &snip,
   140  	})
   141  }
   142  
   143  // topCandidate returns the strictly highest scoring candidate
   144  // collected so far. If the top two candidates have the same score,
   145  // nil is returned.
   146  func (c *completer) topCandidate() *CompletionItem {
   147  	var bestItem, secondBestItem *CompletionItem
   148  	for i := range c.items {
   149  		if bestItem == nil || c.items[i].Score > bestItem.Score {
   150  			bestItem = &c.items[i]
   151  		} else if secondBestItem == nil || c.items[i].Score > secondBestItem.Score {
   152  			secondBestItem = &c.items[i]
   153  		}
   154  	}
   155  
   156  	// If secondBestItem has the same score, bestItem isn't
   157  	// the strict best.
   158  	if secondBestItem != nil && secondBestItem.Score == bestItem.Score {
   159  		return nil
   160  	}
   161  
   162  	return bestItem
   163  }
   164  
   165  // addErrCheck offers a completion candidate of the form:
   166  //
   167  //     if err != nil {
   168  //       return nil, err
   169  //     }
   170  //
   171  // In the case of test functions, it offers a completion candidate of the form:
   172  //
   173  //     if err != nil {
   174  //       t.Fatal(err)
   175  //     }
   176  //
   177  // The position must be in a function that returns an error, and the
   178  // statement preceding the position must be an assignment where the
   179  // final LHS object is an error. addErrCheck will synthesize
   180  // zero values as necessary to make the return statement valid.
   181  func (c *completer) addErrCheck() {
   182  	if len(c.path) < 2 || c.enclosingFunc == nil || !c.opts.placeholders {
   183  		return
   184  	}
   185  
   186  	var (
   187  		errorType        = types.Universe.Lookup("error").Type()
   188  		result           = c.enclosingFunc.sig.Results()
   189  		testVar          = getTestVar(c.enclosingFunc, c.pkg)
   190  		isTest           = testVar != ""
   191  		doesNotReturnErr = result.Len() == 0 || !types.Identical(result.At(result.Len()-1).Type(), errorType)
   192  	)
   193  	// Make sure our enclosing function is a Test func or returns an error.
   194  	if !isTest && doesNotReturnErr {
   195  		return
   196  	}
   197  
   198  	prevLine := prevStmt(c.pos, c.path)
   199  	if prevLine == nil {
   200  		return
   201  	}
   202  
   203  	// Make sure our preceding statement was as assignment.
   204  	assign, _ := prevLine.(*ast.AssignStmt)
   205  	if assign == nil || len(assign.Lhs) == 0 {
   206  		return
   207  	}
   208  
   209  	lastAssignee := assign.Lhs[len(assign.Lhs)-1]
   210  
   211  	// Make sure the final assignee is an error.
   212  	if !types.Identical(c.pkg.GetTypesInfo().TypeOf(lastAssignee), errorType) {
   213  		return
   214  	}
   215  
   216  	var (
   217  		// errVar is e.g. "err" in "foo, err := bar()".
   218  		errVar = source.FormatNode(c.snapshot.FileSet(), lastAssignee)
   219  
   220  		// Whether we need to include the "if" keyword in our candidate.
   221  		needsIf = true
   222  	)
   223  
   224  	// If the returned error from the previous statement is "_", it is not a real object.
   225  	// If we don't have an error, and the function signature takes a testing.TB that is either ignored
   226  	// or an "_", then we also can't call t.Fatal(err).
   227  	if errVar == "_" {
   228  		return
   229  	}
   230  
   231  	// Below we try to detect if the user has already started typing "if
   232  	// err" so we can replace what they've typed with our complete
   233  	// statement.
   234  	switch n := c.path[0].(type) {
   235  	case *ast.Ident:
   236  		switch c.path[1].(type) {
   237  		case *ast.ExprStmt:
   238  			// This handles:
   239  			//
   240  			//     f, err := os.Open("foo")
   241  			//     i<>
   242  
   243  			// Make sure they are typing "if".
   244  			if c.matcher.Score("if") <= 0 {
   245  				return
   246  			}
   247  		case *ast.IfStmt:
   248  			// This handles:
   249  			//
   250  			//     f, err := os.Open("foo")
   251  			//     if er<>
   252  
   253  			// Make sure they are typing the error's name.
   254  			if c.matcher.Score(errVar) <= 0 {
   255  				return
   256  			}
   257  
   258  			needsIf = false
   259  		default:
   260  			return
   261  		}
   262  	case *ast.IfStmt:
   263  		// This handles:
   264  		//
   265  		//     f, err := os.Open("foo")
   266  		//     if <>
   267  
   268  		// Avoid false positives by ensuring the if's cond is a bad
   269  		// expression. For example, don't offer the completion in cases
   270  		// like "if <> somethingElse".
   271  		if _, bad := n.Cond.(*ast.BadExpr); !bad {
   272  			return
   273  		}
   274  
   275  		// If "if" is our direct prefix, we need to include it in our
   276  		// candidate since the existing "if" will be overwritten.
   277  		needsIf = c.pos == n.Pos()+token.Pos(len("if"))
   278  	}
   279  
   280  	// Build up a snippet that looks like:
   281  	//
   282  	//     if err != nil {
   283  	//       return <zero value>, ..., ${1:err}
   284  	//     }
   285  	//
   286  	// We make the error a placeholder so it is easy to alter the error.
   287  	var snip snippet.Builder
   288  	if needsIf {
   289  		snip.WriteText("if ")
   290  	}
   291  	snip.WriteText(fmt.Sprintf("%s != nil {\n\t", errVar))
   292  
   293  	var label string
   294  	if isTest {
   295  		snip.WriteText(fmt.Sprintf("%s.Fatal(%s)", testVar, errVar))
   296  		label = fmt.Sprintf("%[1]s != nil { %[2]s.Fatal(%[1]s) }", errVar, testVar)
   297  	} else {
   298  		snip.WriteText("return ")
   299  		for i := 0; i < result.Len()-1; i++ {
   300  			snip.WriteText(formatZeroValue(result.At(i).Type(), c.qf))
   301  			snip.WriteText(", ")
   302  		}
   303  		snip.WritePlaceholder(func(b *snippet.Builder) {
   304  			b.WriteText(errVar)
   305  		})
   306  		label = fmt.Sprintf("%[1]s != nil { return %[1]s }", errVar)
   307  	}
   308  
   309  	snip.WriteText("\n}")
   310  
   311  	if needsIf {
   312  		label = "if " + label
   313  	}
   314  
   315  	c.items = append(c.items, CompletionItem{
   316  		Label: label,
   317  		// There doesn't seem to be a more appropriate kind.
   318  		Kind:    protocol.KeywordCompletion,
   319  		Score:   highScore,
   320  		snippet: &snip,
   321  	})
   322  }
   323  
   324  // getTestVar checks the function signature's input parameters and returns
   325  // the name of the first parameter that implements "testing.TB". For example,
   326  // func someFunc(t *testing.T) returns the string "t", func someFunc(b *testing.B)
   327  // returns "b" etc. An empty string indicates that the function signature
   328  // does not take a testing.TB parameter or does so but is ignored such
   329  // as func someFunc(*testing.T).
   330  func getTestVar(enclosingFunc *funcInfo, pkg source.Package) string {
   331  	if enclosingFunc == nil || enclosingFunc.sig == nil {
   332  		return ""
   333  	}
   334  
   335  	sig := enclosingFunc.sig
   336  	for i := 0; i < sig.Params().Len(); i++ {
   337  		param := sig.Params().At(i)
   338  		if param.Name() == "_" {
   339  			continue
   340  		}
   341  		testingPkg, err := pkg.GetImport("testing")
   342  		if err != nil {
   343  			continue
   344  		}
   345  		tbObj := testingPkg.GetTypes().Scope().Lookup("TB")
   346  		if tbObj == nil {
   347  			continue
   348  		}
   349  		iface, ok := tbObj.Type().Underlying().(*types.Interface)
   350  		if !ok {
   351  			continue
   352  		}
   353  		if !types.Implements(param.Type(), iface) {
   354  			continue
   355  		}
   356  		return param.Name()
   357  	}
   358  
   359  	return ""
   360  }