github.com/jd-ly/tools@v0.5.7/internal/lsp/tests/util.go (about)

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tests
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"fmt"
    11  	"go/token"
    12  	"path/filepath"
    13  	"sort"
    14  	"strconv"
    15  	"strings"
    16  	"testing"
    17  
    18  	"github.com/jd-ly/tools/internal/lsp/diff"
    19  	"github.com/jd-ly/tools/internal/lsp/diff/myers"
    20  	"github.com/jd-ly/tools/internal/lsp/protocol"
    21  	"github.com/jd-ly/tools/internal/lsp/source"
    22  	"github.com/jd-ly/tools/internal/lsp/source/completion"
    23  	"github.com/jd-ly/tools/internal/span"
    24  )
    25  
    26  // DiffLinks takes the links we got and checks if they are located within the source or a Note.
    27  // If the link is within a Note, the link is removed.
    28  // Returns an diff comment if there are differences and empty string if no diffs.
    29  func DiffLinks(mapper *protocol.ColumnMapper, wantLinks []Link, gotLinks []protocol.DocumentLink) string {
    30  	var notePositions []token.Position
    31  	links := make(map[span.Span]string, len(wantLinks))
    32  	for _, link := range wantLinks {
    33  		links[link.Src] = link.Target
    34  		notePositions = append(notePositions, link.NotePosition)
    35  	}
    36  	for _, link := range gotLinks {
    37  		spn, err := mapper.RangeSpan(link.Range)
    38  		if err != nil {
    39  			return fmt.Sprintf("%v", err)
    40  		}
    41  		linkInNote := false
    42  		for _, notePosition := range notePositions {
    43  			// Drop the links found inside expectation notes arguments as this links are not collected by expect package.
    44  			if notePosition.Line == spn.Start().Line() &&
    45  				notePosition.Column <= spn.Start().Column() {
    46  				delete(links, spn)
    47  				linkInNote = true
    48  			}
    49  		}
    50  		if linkInNote {
    51  			continue
    52  		}
    53  		if target, ok := links[spn]; ok {
    54  			delete(links, spn)
    55  			if target != link.Target {
    56  				return fmt.Sprintf("for %v want %v, got %v\n", spn, target, link.Target)
    57  			}
    58  		} else {
    59  			return fmt.Sprintf("unexpected link %v:%v\n", spn, link.Target)
    60  		}
    61  	}
    62  	for spn, target := range links {
    63  		return fmt.Sprintf("missing link %v:%v\n", spn, target)
    64  	}
    65  	return ""
    66  }
    67  
    68  // DiffSymbols prints the diff between expected and actual symbols test results.
    69  func DiffSymbols(t *testing.T, uri span.URI, want, got []protocol.DocumentSymbol) string {
    70  	sort.Slice(want, func(i, j int) bool { return want[i].Name < want[j].Name })
    71  	sort.Slice(got, func(i, j int) bool { return got[i].Name < got[j].Name })
    72  	if len(got) != len(want) {
    73  		return summarizeSymbols(-1, want, got, "different lengths got %v want %v", len(got), len(want))
    74  	}
    75  	for i, w := range want {
    76  		g := got[i]
    77  		if w.Name != g.Name {
    78  			return summarizeSymbols(i, want, got, "incorrect name got %v want %v", g.Name, w.Name)
    79  		}
    80  		if w.Kind != g.Kind {
    81  			return summarizeSymbols(i, want, got, "incorrect kind got %v want %v", g.Kind, w.Kind)
    82  		}
    83  		if protocol.CompareRange(w.SelectionRange, g.SelectionRange) != 0 {
    84  			return summarizeSymbols(i, want, got, "incorrect span got %v want %v", g.SelectionRange, w.SelectionRange)
    85  		}
    86  		if msg := DiffSymbols(t, uri, w.Children, g.Children); msg != "" {
    87  			return fmt.Sprintf("children of %s: %s", w.Name, msg)
    88  		}
    89  	}
    90  	return ""
    91  }
    92  
    93  func summarizeSymbols(i int, want, got []protocol.DocumentSymbol, reason string, args ...interface{}) string {
    94  	msg := &bytes.Buffer{}
    95  	fmt.Fprint(msg, "document symbols failed")
    96  	if i >= 0 {
    97  		fmt.Fprintf(msg, " at %d", i)
    98  	}
    99  	fmt.Fprint(msg, " because of ")
   100  	fmt.Fprintf(msg, reason, args...)
   101  	fmt.Fprint(msg, ":\nexpected:\n")
   102  	for _, s := range want {
   103  		fmt.Fprintf(msg, "  %v %v %v\n", s.Name, s.Kind, s.SelectionRange)
   104  	}
   105  	fmt.Fprintf(msg, "got:\n")
   106  	for _, s := range got {
   107  		fmt.Fprintf(msg, "  %v %v %v\n", s.Name, s.Kind, s.SelectionRange)
   108  	}
   109  	return msg.String()
   110  }
   111  
   112  // DiffDiagnostics prints the diff between expected and actual diagnostics test
   113  // results.
   114  func DiffDiagnostics(uri span.URI, want, got []*source.Diagnostic) string {
   115  	source.SortDiagnostics(want)
   116  	source.SortDiagnostics(got)
   117  
   118  	if len(got) != len(want) {
   119  		return summarizeDiagnostics(-1, uri, want, got, "different lengths got %v want %v", len(got), len(want))
   120  	}
   121  	for i, w := range want {
   122  		g := got[i]
   123  		if w.Message != g.Message {
   124  			return summarizeDiagnostics(i, uri, want, got, "incorrect Message got %v want %v", g.Message, w.Message)
   125  		}
   126  		if w.Severity != g.Severity {
   127  			return summarizeDiagnostics(i, uri, want, got, "incorrect Severity got %v want %v", g.Severity, w.Severity)
   128  		}
   129  		if w.Source != g.Source {
   130  			return summarizeDiagnostics(i, uri, want, got, "incorrect Source got %v want %v", g.Source, w.Source)
   131  		}
   132  		if protocol.ComparePosition(w.Range.Start, g.Range.Start) != 0 {
   133  			return summarizeDiagnostics(i, uri, want, got, "incorrect Start got %v want %v", g.Range.Start, w.Range.Start)
   134  		}
   135  		if !protocol.IsPoint(g.Range) { // Accept any 'want' range if the diagnostic returns a zero-length range.
   136  			if protocol.ComparePosition(w.Range.End, g.Range.End) != 0 {
   137  				return summarizeDiagnostics(i, uri, want, got, "incorrect End got %v want %v", g.Range.End, w.Range.End)
   138  			}
   139  		}
   140  	}
   141  	return ""
   142  }
   143  
   144  func summarizeDiagnostics(i int, uri span.URI, want, got []*source.Diagnostic, reason string, args ...interface{}) string {
   145  	msg := &bytes.Buffer{}
   146  	fmt.Fprint(msg, "diagnostics failed")
   147  	if i >= 0 {
   148  		fmt.Fprintf(msg, " at %d", i)
   149  	}
   150  	fmt.Fprint(msg, " because of ")
   151  	fmt.Fprintf(msg, reason, args...)
   152  	fmt.Fprint(msg, ":\nexpected:\n")
   153  	for _, d := range want {
   154  		fmt.Fprintf(msg, "  %s:%v: %s\n", uri, d.Range, d.Message)
   155  	}
   156  	fmt.Fprintf(msg, "got:\n")
   157  	for _, d := range got {
   158  		fmt.Fprintf(msg, "  %s:%v: %s\n", uri, d.Range, d.Message)
   159  	}
   160  	return msg.String()
   161  }
   162  
   163  func DiffCodeLens(uri span.URI, want, got []protocol.CodeLens) string {
   164  	sortCodeLens(want)
   165  	sortCodeLens(got)
   166  
   167  	if len(got) != len(want) {
   168  		return summarizeCodeLens(-1, uri, want, got, "different lengths got %v want %v", len(got), len(want))
   169  	}
   170  	for i, w := range want {
   171  		g := got[i]
   172  		if w.Command.Command != g.Command.Command {
   173  			return summarizeCodeLens(i, uri, want, got, "incorrect Command Name got %v want %v", g.Command.Command, w.Command.Command)
   174  		}
   175  		if w.Command.Title != g.Command.Title {
   176  			return summarizeCodeLens(i, uri, want, got, "incorrect Command Title got %v want %v", g.Command.Title, w.Command.Title)
   177  		}
   178  		if protocol.ComparePosition(w.Range.Start, g.Range.Start) != 0 {
   179  			return summarizeCodeLens(i, uri, want, got, "incorrect Start got %v want %v", g.Range.Start, w.Range.Start)
   180  		}
   181  		if !protocol.IsPoint(g.Range) { // Accept any 'want' range if the codelens returns a zero-length range.
   182  			if protocol.ComparePosition(w.Range.End, g.Range.End) != 0 {
   183  				return summarizeCodeLens(i, uri, want, got, "incorrect End got %v want %v", g.Range.End, w.Range.End)
   184  			}
   185  		}
   186  	}
   187  	return ""
   188  }
   189  
   190  func sortCodeLens(c []protocol.CodeLens) {
   191  	sort.Slice(c, func(i int, j int) bool {
   192  		if r := protocol.CompareRange(c[i].Range, c[j].Range); r != 0 {
   193  			return r < 0
   194  		}
   195  		if c[i].Command.Command < c[j].Command.Command {
   196  			return true
   197  		} else if c[i].Command.Command == c[j].Command.Command {
   198  			return c[i].Command.Title < c[j].Command.Title
   199  		} else {
   200  			return false
   201  		}
   202  	})
   203  }
   204  
   205  func summarizeCodeLens(i int, uri span.URI, want, got []protocol.CodeLens, reason string, args ...interface{}) string {
   206  	msg := &bytes.Buffer{}
   207  	fmt.Fprint(msg, "codelens failed")
   208  	if i >= 0 {
   209  		fmt.Fprintf(msg, " at %d", i)
   210  	}
   211  	fmt.Fprint(msg, " because of ")
   212  	fmt.Fprintf(msg, reason, args...)
   213  	fmt.Fprint(msg, ":\nexpected:\n")
   214  	for _, d := range want {
   215  		fmt.Fprintf(msg, "  %s:%v: %s | %s\n", uri, d.Range, d.Command.Command, d.Command.Title)
   216  	}
   217  	fmt.Fprintf(msg, "got:\n")
   218  	for _, d := range got {
   219  		fmt.Fprintf(msg, "  %s:%v: %s | %s\n", uri, d.Range, d.Command.Command, d.Command.Title)
   220  	}
   221  	return msg.String()
   222  }
   223  
   224  func DiffSignatures(spn span.Span, want, got *protocol.SignatureHelp) string {
   225  	decorate := func(f string, args ...interface{}) string {
   226  		return fmt.Sprintf("invalid signature at %s: %s", spn, fmt.Sprintf(f, args...))
   227  	}
   228  	if len(got.Signatures) != 1 {
   229  		return decorate("wanted 1 signature, got %d", len(got.Signatures))
   230  	}
   231  	if got.ActiveSignature != 0 {
   232  		return decorate("wanted active signature of 0, got %d", int(got.ActiveSignature))
   233  	}
   234  	if want.ActiveParameter != got.ActiveParameter {
   235  		return decorate("wanted active parameter of %d, got %d", want.ActiveParameter, int(got.ActiveParameter))
   236  	}
   237  	g := got.Signatures[0]
   238  	w := want.Signatures[0]
   239  	if w.Label != g.Label {
   240  		wLabel := w.Label + "\n"
   241  		d := myers.ComputeEdits("", wLabel, g.Label+"\n")
   242  		return decorate("mismatched labels:\n%q", diff.ToUnified("want", "got", wLabel, d))
   243  	}
   244  	var paramParts []string
   245  	for _, p := range g.Parameters {
   246  		paramParts = append(paramParts, p.Label)
   247  	}
   248  	paramsStr := strings.Join(paramParts, ", ")
   249  	if !strings.Contains(g.Label, paramsStr) {
   250  		return decorate("expected signature %q to contain params %q", g.Label, paramsStr)
   251  	}
   252  	return ""
   253  }
   254  
   255  // DiffCallHierarchyItems returns the diff between expected and actual call locations for incoming/outgoing call hierarchies
   256  func DiffCallHierarchyItems(gotCalls []protocol.CallHierarchyItem, expectedCalls []protocol.CallHierarchyItem) string {
   257  	expected := make(map[protocol.Location]bool)
   258  	for _, call := range expectedCalls {
   259  		expected[protocol.Location{URI: call.URI, Range: call.Range}] = true
   260  	}
   261  
   262  	got := make(map[protocol.Location]bool)
   263  	for _, call := range gotCalls {
   264  		got[protocol.Location{URI: call.URI, Range: call.Range}] = true
   265  	}
   266  	if len(got) != len(expected) {
   267  		return fmt.Sprintf("expected %d calls but got %d", len(expected), len(got))
   268  	}
   269  	for spn := range got {
   270  		if !expected[spn] {
   271  			return fmt.Sprintf("incorrect calls, expected locations %v but got locations %v", expected, got)
   272  		}
   273  	}
   274  	return ""
   275  }
   276  
   277  func ToProtocolCompletionItems(items []completion.CompletionItem) []protocol.CompletionItem {
   278  	var result []protocol.CompletionItem
   279  	for _, item := range items {
   280  		result = append(result, ToProtocolCompletionItem(item))
   281  	}
   282  	return result
   283  }
   284  
   285  func ToProtocolCompletionItem(item completion.CompletionItem) protocol.CompletionItem {
   286  	pItem := protocol.CompletionItem{
   287  		Label:         item.Label,
   288  		Kind:          item.Kind,
   289  		Detail:        item.Detail,
   290  		Documentation: item.Documentation,
   291  		InsertText:    item.InsertText,
   292  		TextEdit: &protocol.TextEdit{
   293  			NewText: item.Snippet(),
   294  		},
   295  		// Negate score so best score has lowest sort text like real API.
   296  		SortText: fmt.Sprint(-item.Score),
   297  	}
   298  	if pItem.InsertText == "" {
   299  		pItem.InsertText = pItem.Label
   300  	}
   301  	return pItem
   302  }
   303  
   304  func FilterBuiltins(src span.Span, items []protocol.CompletionItem) []protocol.CompletionItem {
   305  	var (
   306  		got          []protocol.CompletionItem
   307  		wantBuiltins = strings.Contains(string(src.URI()), "builtins")
   308  		wantKeywords = strings.Contains(string(src.URI()), "keywords")
   309  	)
   310  	for _, item := range items {
   311  		if !wantBuiltins && isBuiltin(item.Label, item.Detail, item.Kind) {
   312  			continue
   313  		}
   314  
   315  		if !wantKeywords && token.Lookup(item.Label).IsKeyword() {
   316  			continue
   317  		}
   318  
   319  		got = append(got, item)
   320  	}
   321  	return got
   322  }
   323  
   324  func isBuiltin(label, detail string, kind protocol.CompletionItemKind) bool {
   325  	if detail == "" && kind == protocol.ClassCompletion {
   326  		return true
   327  	}
   328  	// Remaining builtin constants, variables, interfaces, and functions.
   329  	trimmed := label
   330  	if i := strings.Index(trimmed, "("); i >= 0 {
   331  		trimmed = trimmed[:i]
   332  	}
   333  	switch trimmed {
   334  	case "append", "cap", "close", "complex", "copy", "delete",
   335  		"error", "false", "imag", "iota", "len", "make", "new",
   336  		"nil", "panic", "print", "println", "real", "recover", "true":
   337  		return true
   338  	}
   339  	return false
   340  }
   341  
   342  func CheckCompletionOrder(want, got []protocol.CompletionItem, strictScores bool) string {
   343  	var (
   344  		matchedIdxs []int
   345  		lastGotIdx  int
   346  		lastGotSort float64
   347  		inOrder     = true
   348  		errorMsg    = "completions out of order"
   349  	)
   350  	for _, w := range want {
   351  		var found bool
   352  		for i, g := range got {
   353  			if w.Label == g.Label && w.Detail == g.Detail && w.Kind == g.Kind {
   354  				matchedIdxs = append(matchedIdxs, i)
   355  				found = true
   356  
   357  				if i < lastGotIdx {
   358  					inOrder = false
   359  				}
   360  				lastGotIdx = i
   361  
   362  				sort, _ := strconv.ParseFloat(g.SortText, 64)
   363  				if strictScores && len(matchedIdxs) > 1 && sort <= lastGotSort {
   364  					inOrder = false
   365  					errorMsg = "candidate scores not strictly decreasing"
   366  				}
   367  				lastGotSort = sort
   368  
   369  				break
   370  			}
   371  		}
   372  		if !found {
   373  			return summarizeCompletionItems(-1, []protocol.CompletionItem{w}, got, "didn't find expected completion")
   374  		}
   375  	}
   376  
   377  	sort.Ints(matchedIdxs)
   378  	matched := make([]protocol.CompletionItem, 0, len(matchedIdxs))
   379  	for _, idx := range matchedIdxs {
   380  		matched = append(matched, got[idx])
   381  	}
   382  
   383  	if !inOrder {
   384  		return summarizeCompletionItems(-1, want, matched, errorMsg)
   385  	}
   386  
   387  	return ""
   388  }
   389  
   390  func DiffSnippets(want string, got *protocol.CompletionItem) string {
   391  	if want == "" {
   392  		if got != nil {
   393  			x := got.TextEdit
   394  			return fmt.Sprintf("expected no snippet but got %s", x.NewText)
   395  		}
   396  	} else {
   397  		if got == nil {
   398  			return fmt.Sprintf("couldn't find completion matching %q", want)
   399  		}
   400  		x := got.TextEdit
   401  		if want != x.NewText {
   402  			return fmt.Sprintf("expected snippet %q, got %q", want, x.NewText)
   403  		}
   404  	}
   405  	return ""
   406  }
   407  
   408  func FindItem(list []protocol.CompletionItem, want completion.CompletionItem) *protocol.CompletionItem {
   409  	for _, item := range list {
   410  		if item.Label == want.Label {
   411  			return &item
   412  		}
   413  	}
   414  	return nil
   415  }
   416  
   417  // DiffCompletionItems prints the diff between expected and actual completion
   418  // test results.
   419  func DiffCompletionItems(want, got []protocol.CompletionItem) string {
   420  	if len(got) != len(want) {
   421  		return summarizeCompletionItems(-1, want, got, "different lengths got %v want %v", len(got), len(want))
   422  	}
   423  	for i, w := range want {
   424  		g := got[i]
   425  		if w.Label != g.Label {
   426  			return summarizeCompletionItems(i, want, got, "incorrect Label got %v want %v", g.Label, w.Label)
   427  		}
   428  		if w.Detail != g.Detail {
   429  			return summarizeCompletionItems(i, want, got, "incorrect Detail got %v want %v", g.Detail, w.Detail)
   430  		}
   431  		if w.Documentation != "" && !strings.HasPrefix(w.Documentation, "@") {
   432  			if w.Documentation != g.Documentation {
   433  				return summarizeCompletionItems(i, want, got, "incorrect Documentation got %v want %v", g.Documentation, w.Documentation)
   434  			}
   435  		}
   436  		if w.Kind != g.Kind {
   437  			return summarizeCompletionItems(i, want, got, "incorrect Kind got %v want %v", g.Kind, w.Kind)
   438  		}
   439  	}
   440  	return ""
   441  }
   442  
   443  func summarizeCompletionItems(i int, want, got []protocol.CompletionItem, reason string, args ...interface{}) string {
   444  	msg := &bytes.Buffer{}
   445  	fmt.Fprint(msg, "completion failed")
   446  	if i >= 0 {
   447  		fmt.Fprintf(msg, " at %d", i)
   448  	}
   449  	fmt.Fprint(msg, " because of ")
   450  	fmt.Fprintf(msg, reason, args...)
   451  	fmt.Fprint(msg, ":\nexpected:\n")
   452  	for _, d := range want {
   453  		fmt.Fprintf(msg, "  %v\n", d)
   454  	}
   455  	fmt.Fprintf(msg, "got:\n")
   456  	for _, d := range got {
   457  		fmt.Fprintf(msg, "  %v\n", d)
   458  	}
   459  	return msg.String()
   460  }
   461  
   462  func EnableAllAnalyzers(view source.View, opts *source.Options) {
   463  	if opts.Analyses == nil {
   464  		opts.Analyses = make(map[string]bool)
   465  	}
   466  	for _, a := range opts.DefaultAnalyzers {
   467  		if !a.IsEnabled(view) {
   468  			opts.Analyses[a.Analyzer.Name] = true
   469  		}
   470  	}
   471  	for _, a := range opts.TypeErrorAnalyzers {
   472  		if !a.IsEnabled(view) {
   473  			opts.Analyses[a.Analyzer.Name] = true
   474  		}
   475  	}
   476  	for _, a := range opts.ConvenienceAnalyzers {
   477  		if !a.IsEnabled(view) {
   478  			opts.Analyses[a.Analyzer.Name] = true
   479  		}
   480  	}
   481  	for _, a := range opts.StaticcheckAnalyzers {
   482  		if !a.IsEnabled(view) {
   483  			opts.Analyses[a.Analyzer.Name] = true
   484  		}
   485  	}
   486  }
   487  
   488  func WorkspaceSymbolsString(ctx context.Context, data *Data, queryURI span.URI, symbols []protocol.SymbolInformation) (string, error) {
   489  	queryDir := filepath.Dir(queryURI.Filename())
   490  	var filtered []string
   491  	for _, s := range symbols {
   492  		uri := s.Location.URI.SpanURI()
   493  		dir := filepath.Dir(uri.Filename())
   494  		if !source.InDir(queryDir, dir) { // assume queries always issue from higher directories
   495  			continue
   496  		}
   497  		m, err := data.Mapper(uri)
   498  		if err != nil {
   499  			return "", err
   500  		}
   501  		spn, err := m.Span(s.Location)
   502  		if err != nil {
   503  			return "", err
   504  		}
   505  		filtered = append(filtered, fmt.Sprintf("%s %s %s", spn, s.Name, s.Kind))
   506  	}
   507  	sort.Strings(filtered)
   508  	return strings.Join(filtered, "\n") + "\n", nil
   509  }
   510  
   511  func WorkspaceSymbolsTestTypeToMatcher(typ WorkspaceSymbolsTestType) source.SymbolMatcher {
   512  	switch typ {
   513  	case WorkspaceSymbolsFuzzy:
   514  		return source.SymbolFuzzy
   515  	case WorkspaceSymbolsCaseSensitive:
   516  		return source.SymbolCaseSensitive
   517  	default:
   518  		return source.SymbolCaseInsensitive
   519  	}
   520  }
   521  
   522  func Diff(want, got string) string {
   523  	if want == got {
   524  		return ""
   525  	}
   526  	// Add newlines to avoid newline messages in diff.
   527  	want += "\n"
   528  	got += "\n"
   529  	d := myers.ComputeEdits("", want, got)
   530  	return fmt.Sprintf("%q", diff.ToUnified("want", "got", want, d))
   531  }