github.com/powerman/golang-tools@v0.1.11-0.20220410185822-5ad214d8d803/internal/lsp/code_action.go (about)

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package lsp
     6  
     7  import (
     8  	"context"
     9  	"fmt"
    10  	"sort"
    11  	"strings"
    12  
    13  	"github.com/powerman/golang-tools/internal/event"
    14  	"github.com/powerman/golang-tools/internal/imports"
    15  	"github.com/powerman/golang-tools/internal/lsp/command"
    16  	"github.com/powerman/golang-tools/internal/lsp/debug/tag"
    17  	"github.com/powerman/golang-tools/internal/lsp/mod"
    18  	"github.com/powerman/golang-tools/internal/lsp/protocol"
    19  	"github.com/powerman/golang-tools/internal/lsp/source"
    20  	"github.com/powerman/golang-tools/internal/span"
    21  	errors "golang.org/x/xerrors"
    22  )
    23  
    24  func (s *Server) codeAction(ctx context.Context, params *protocol.CodeActionParams) ([]protocol.CodeAction, error) {
    25  	snapshot, fh, ok, release, err := s.beginFileRequest(ctx, params.TextDocument.URI, source.UnknownKind)
    26  	defer release()
    27  	if !ok {
    28  		return nil, err
    29  	}
    30  	uri := fh.URI()
    31  
    32  	// Determine the supported actions for this file kind.
    33  	kind := snapshot.View().FileKind(fh)
    34  	supportedCodeActions, ok := snapshot.View().Options().SupportedCodeActions[kind]
    35  	if !ok {
    36  		return nil, fmt.Errorf("no supported code actions for %v file kind", kind)
    37  	}
    38  
    39  	// The Only field of the context specifies which code actions the client wants.
    40  	// If Only is empty, assume that the client wants all of the non-explicit code actions.
    41  	var wanted map[protocol.CodeActionKind]bool
    42  
    43  	// Explicit Code Actions are opt-in and shouldn't be returned to the client unless
    44  	// requested using Only.
    45  	// TODO: Add other CodeLenses such as GoGenerate, RegenerateCgo, etc..
    46  	explicit := map[protocol.CodeActionKind]bool{
    47  		protocol.GoTest: true,
    48  	}
    49  
    50  	if len(params.Context.Only) == 0 {
    51  		wanted = supportedCodeActions
    52  	} else {
    53  		wanted = make(map[protocol.CodeActionKind]bool)
    54  		for _, only := range params.Context.Only {
    55  			for k, v := range supportedCodeActions {
    56  				if only == k || strings.HasPrefix(string(k), string(only)+".") {
    57  					wanted[k] = wanted[k] || v
    58  				}
    59  			}
    60  			wanted[only] = wanted[only] || explicit[only]
    61  		}
    62  	}
    63  	if len(supportedCodeActions) == 0 {
    64  		return nil, nil // not an error if there are none supported
    65  	}
    66  	if len(wanted) == 0 {
    67  		return nil, fmt.Errorf("no supported code action to execute for %s, wanted %v", uri, params.Context.Only)
    68  	}
    69  
    70  	var codeActions []protocol.CodeAction
    71  	switch kind {
    72  	case source.Mod:
    73  		if diagnostics := params.Context.Diagnostics; len(diagnostics) > 0 {
    74  			diags, err := mod.DiagnosticsForMod(ctx, snapshot, fh)
    75  			if source.IsNonFatalGoModError(err) {
    76  				return nil, nil
    77  			}
    78  			if err != nil {
    79  				return nil, err
    80  			}
    81  			quickFixes, err := codeActionsMatchingDiagnostics(ctx, snapshot, diagnostics, diags)
    82  			if err != nil {
    83  				return nil, err
    84  			}
    85  			codeActions = append(codeActions, quickFixes...)
    86  		}
    87  	case source.Go:
    88  		// Don't suggest fixes for generated files, since they are generally
    89  		// not useful and some editors may apply them automatically on save.
    90  		if source.IsGenerated(ctx, snapshot, uri) {
    91  			return nil, nil
    92  		}
    93  		diagnostics := params.Context.Diagnostics
    94  
    95  		// First, process any missing imports and pair them with the
    96  		// diagnostics they fix.
    97  		if wantQuickFixes := wanted[protocol.QuickFix] && len(diagnostics) > 0; wantQuickFixes || wanted[protocol.SourceOrganizeImports] {
    98  			importEdits, importEditsPerFix, err := source.AllImportsFixes(ctx, snapshot, fh)
    99  			if err != nil {
   100  				event.Error(ctx, "imports fixes", err, tag.File.Of(fh.URI().Filename()))
   101  			}
   102  			// Separate this into a set of codeActions per diagnostic, where
   103  			// each action is the addition, removal, or renaming of one import.
   104  			if wantQuickFixes {
   105  				for _, importFix := range importEditsPerFix {
   106  					fixes := importDiagnostics(importFix.Fix, diagnostics)
   107  					if len(fixes) == 0 {
   108  						continue
   109  					}
   110  					codeActions = append(codeActions, protocol.CodeAction{
   111  						Title: importFixTitle(importFix.Fix),
   112  						Kind:  protocol.QuickFix,
   113  						Edit: protocol.WorkspaceEdit{
   114  							DocumentChanges: documentChanges(fh, importFix.Edits),
   115  						},
   116  						Diagnostics: fixes,
   117  					})
   118  				}
   119  			}
   120  
   121  			// Send all of the import edits as one code action if the file is
   122  			// being organized.
   123  			if wanted[protocol.SourceOrganizeImports] && len(importEdits) > 0 {
   124  				codeActions = append(codeActions, protocol.CodeAction{
   125  					Title: "Organize Imports",
   126  					Kind:  protocol.SourceOrganizeImports,
   127  					Edit: protocol.WorkspaceEdit{
   128  						DocumentChanges: documentChanges(fh, importEdits),
   129  					},
   130  				})
   131  			}
   132  		}
   133  		if ctx.Err() != nil {
   134  			return nil, ctx.Err()
   135  		}
   136  		pkg, err := snapshot.PackageForFile(ctx, fh.URI(), source.TypecheckFull, source.WidestPackage)
   137  		if err != nil {
   138  			return nil, err
   139  		}
   140  
   141  		pkgDiagnostics, err := snapshot.DiagnosePackage(ctx, pkg)
   142  		if err != nil {
   143  			return nil, err
   144  		}
   145  		analysisDiags, err := source.Analyze(ctx, snapshot, pkg, true)
   146  		if err != nil {
   147  			return nil, err
   148  		}
   149  		fileDiags := append(pkgDiagnostics[uri], analysisDiags[uri]...)
   150  
   151  		// Split diagnostics into fixes, which must match incoming diagnostics,
   152  		// and non-fixes, which must match the requested range. Build actions
   153  		// for all of them.
   154  		var fixDiags, nonFixDiags []*source.Diagnostic
   155  		for _, d := range fileDiags {
   156  			if len(d.SuggestedFixes) == 0 {
   157  				continue
   158  			}
   159  			var isFix bool
   160  			for _, fix := range d.SuggestedFixes {
   161  				if fix.ActionKind == protocol.QuickFix || fix.ActionKind == protocol.SourceFixAll {
   162  					isFix = true
   163  					break
   164  				}
   165  			}
   166  			if isFix {
   167  				fixDiags = append(fixDiags, d)
   168  			} else {
   169  				nonFixDiags = append(nonFixDiags, d)
   170  			}
   171  		}
   172  
   173  		fixActions, err := codeActionsMatchingDiagnostics(ctx, snapshot, diagnostics, fixDiags)
   174  		if err != nil {
   175  			return nil, err
   176  		}
   177  		codeActions = append(codeActions, fixActions...)
   178  
   179  		for _, nonfix := range nonFixDiags {
   180  			// For now, only show diagnostics for matching lines. Maybe we should
   181  			// alter this behavior in the future, depending on the user experience.
   182  			if !protocol.Intersect(nonfix.Range, params.Range) {
   183  				continue
   184  			}
   185  			actions, err := codeActionsForDiagnostic(ctx, snapshot, nonfix, nil)
   186  			if err != nil {
   187  				return nil, err
   188  			}
   189  			codeActions = append(codeActions, actions...)
   190  		}
   191  
   192  		if wanted[protocol.RefactorExtract] {
   193  			fixes, err := extractionFixes(ctx, snapshot, pkg, uri, params.Range)
   194  			if err != nil {
   195  				return nil, err
   196  			}
   197  			codeActions = append(codeActions, fixes...)
   198  		}
   199  
   200  		if wanted[protocol.GoTest] {
   201  			fixes, err := goTest(ctx, snapshot, uri, params.Range)
   202  			if err != nil {
   203  				return nil, err
   204  			}
   205  			codeActions = append(codeActions, fixes...)
   206  		}
   207  
   208  	default:
   209  		// Unsupported file kind for a code action.
   210  		return nil, nil
   211  	}
   212  
   213  	var filtered []protocol.CodeAction
   214  	for _, action := range codeActions {
   215  		if wanted[action.Kind] {
   216  			filtered = append(filtered, action)
   217  		}
   218  	}
   219  	return filtered, nil
   220  }
   221  
   222  func (s *Server) getSupportedCodeActions() []protocol.CodeActionKind {
   223  	allCodeActionKinds := make(map[protocol.CodeActionKind]struct{})
   224  	for _, kinds := range s.session.Options().SupportedCodeActions {
   225  		for kind := range kinds {
   226  			allCodeActionKinds[kind] = struct{}{}
   227  		}
   228  	}
   229  	var result []protocol.CodeActionKind
   230  	for kind := range allCodeActionKinds {
   231  		result = append(result, kind)
   232  	}
   233  	sort.Slice(result, func(i, j int) bool {
   234  		return result[i] < result[j]
   235  	})
   236  	return result
   237  }
   238  
   239  func importFixTitle(fix *imports.ImportFix) string {
   240  	var str string
   241  	switch fix.FixType {
   242  	case imports.AddImport:
   243  		str = fmt.Sprintf("Add import: %s %q", fix.StmtInfo.Name, fix.StmtInfo.ImportPath)
   244  	case imports.DeleteImport:
   245  		str = fmt.Sprintf("Delete import: %s %q", fix.StmtInfo.Name, fix.StmtInfo.ImportPath)
   246  	case imports.SetImportName:
   247  		str = fmt.Sprintf("Rename import: %s %q", fix.StmtInfo.Name, fix.StmtInfo.ImportPath)
   248  	}
   249  	return str
   250  }
   251  
   252  func importDiagnostics(fix *imports.ImportFix, diagnostics []protocol.Diagnostic) (results []protocol.Diagnostic) {
   253  	for _, diagnostic := range diagnostics {
   254  		switch {
   255  		// "undeclared name: X" may be an unresolved import.
   256  		case strings.HasPrefix(diagnostic.Message, "undeclared name: "):
   257  			ident := strings.TrimPrefix(diagnostic.Message, "undeclared name: ")
   258  			if ident == fix.IdentName {
   259  				results = append(results, diagnostic)
   260  			}
   261  		// "could not import: X" may be an invalid import.
   262  		case strings.HasPrefix(diagnostic.Message, "could not import: "):
   263  			ident := strings.TrimPrefix(diagnostic.Message, "could not import: ")
   264  			if ident == fix.IdentName {
   265  				results = append(results, diagnostic)
   266  			}
   267  		// "X imported but not used" is an unused import.
   268  		// "X imported but not used as Y" is an unused import.
   269  		case strings.Contains(diagnostic.Message, " imported but not used"):
   270  			idx := strings.Index(diagnostic.Message, " imported but not used")
   271  			importPath := diagnostic.Message[:idx]
   272  			if importPath == fmt.Sprintf("%q", fix.StmtInfo.ImportPath) {
   273  				results = append(results, diagnostic)
   274  			}
   275  		}
   276  	}
   277  	return results
   278  }
   279  
   280  func extractionFixes(ctx context.Context, snapshot source.Snapshot, pkg source.Package, uri span.URI, rng protocol.Range) ([]protocol.CodeAction, error) {
   281  	if rng.Start == rng.End {
   282  		return nil, nil
   283  	}
   284  	fh, err := snapshot.GetFile(ctx, uri)
   285  	if err != nil {
   286  		return nil, err
   287  	}
   288  	_, pgf, err := source.GetParsedFile(ctx, snapshot, fh, source.NarrowestPackage)
   289  	if err != nil {
   290  		return nil, errors.Errorf("getting file for Identifier: %w", err)
   291  	}
   292  	srng, err := pgf.Mapper.RangeToSpanRange(rng)
   293  	if err != nil {
   294  		return nil, err
   295  	}
   296  	puri := protocol.URIFromSpanURI(uri)
   297  	var commands []protocol.Command
   298  	if _, ok, methodOk, _ := source.CanExtractFunction(snapshot.FileSet(), srng, pgf.Src, pgf.File); ok {
   299  		cmd, err := command.NewApplyFixCommand("Extract function", command.ApplyFixArgs{
   300  			URI:   puri,
   301  			Fix:   source.ExtractFunction,
   302  			Range: rng,
   303  		})
   304  		if err != nil {
   305  			return nil, err
   306  		}
   307  		commands = append(commands, cmd)
   308  		if methodOk {
   309  			cmd, err := command.NewApplyFixCommand("Extract method", command.ApplyFixArgs{
   310  				URI:   puri,
   311  				Fix:   source.ExtractMethod,
   312  				Range: rng,
   313  			})
   314  			if err != nil {
   315  				return nil, err
   316  			}
   317  			commands = append(commands, cmd)
   318  		}
   319  	}
   320  	if _, _, ok, _ := source.CanExtractVariable(srng, pgf.File); ok {
   321  		cmd, err := command.NewApplyFixCommand("Extract variable", command.ApplyFixArgs{
   322  			URI:   puri,
   323  			Fix:   source.ExtractVariable,
   324  			Range: rng,
   325  		})
   326  		if err != nil {
   327  			return nil, err
   328  		}
   329  		commands = append(commands, cmd)
   330  	}
   331  	var actions []protocol.CodeAction
   332  	for i := range commands {
   333  		actions = append(actions, protocol.CodeAction{
   334  			Title:   commands[i].Title,
   335  			Kind:    protocol.RefactorExtract,
   336  			Command: &commands[i],
   337  		})
   338  	}
   339  	return actions, nil
   340  }
   341  
   342  func documentChanges(fh source.VersionedFileHandle, edits []protocol.TextEdit) []protocol.TextDocumentEdit {
   343  	return []protocol.TextDocumentEdit{
   344  		{
   345  			TextDocument: protocol.OptionalVersionedTextDocumentIdentifier{
   346  				Version: fh.Version(),
   347  				TextDocumentIdentifier: protocol.TextDocumentIdentifier{
   348  					URI: protocol.URIFromSpanURI(fh.URI()),
   349  				},
   350  			},
   351  			Edits: edits,
   352  		},
   353  	}
   354  }
   355  
   356  func codeActionsMatchingDiagnostics(ctx context.Context, snapshot source.Snapshot, pdiags []protocol.Diagnostic, sdiags []*source.Diagnostic) ([]protocol.CodeAction, error) {
   357  	var actions []protocol.CodeAction
   358  	for _, sd := range sdiags {
   359  		var diag *protocol.Diagnostic
   360  		for _, pd := range pdiags {
   361  			if sameDiagnostic(pd, sd) {
   362  				diag = &pd
   363  				break
   364  			}
   365  		}
   366  		if diag == nil {
   367  			continue
   368  		}
   369  		diagActions, err := codeActionsForDiagnostic(ctx, snapshot, sd, diag)
   370  		if err != nil {
   371  			return nil, err
   372  		}
   373  		actions = append(actions, diagActions...)
   374  
   375  	}
   376  	return actions, nil
   377  }
   378  
   379  func codeActionsForDiagnostic(ctx context.Context, snapshot source.Snapshot, sd *source.Diagnostic, pd *protocol.Diagnostic) ([]protocol.CodeAction, error) {
   380  	var actions []protocol.CodeAction
   381  	for _, fix := range sd.SuggestedFixes {
   382  		var changes []protocol.TextDocumentEdit
   383  		for uri, edits := range fix.Edits {
   384  			fh, err := snapshot.GetVersionedFile(ctx, uri)
   385  			if err != nil {
   386  				return nil, err
   387  			}
   388  			changes = append(changes, protocol.TextDocumentEdit{
   389  				TextDocument: protocol.OptionalVersionedTextDocumentIdentifier{
   390  					Version: fh.Version(),
   391  					TextDocumentIdentifier: protocol.TextDocumentIdentifier{
   392  						URI: protocol.URIFromSpanURI(uri),
   393  					},
   394  				},
   395  				Edits: edits,
   396  			})
   397  		}
   398  		action := protocol.CodeAction{
   399  			Title: fix.Title,
   400  			Kind:  fix.ActionKind,
   401  			Edit: protocol.WorkspaceEdit{
   402  				DocumentChanges: changes,
   403  			},
   404  			Command: fix.Command,
   405  		}
   406  		if pd != nil {
   407  			action.Diagnostics = []protocol.Diagnostic{*pd}
   408  		}
   409  		actions = append(actions, action)
   410  	}
   411  	return actions, nil
   412  }
   413  
   414  func sameDiagnostic(pd protocol.Diagnostic, sd *source.Diagnostic) bool {
   415  	return pd.Message == sd.Message && protocol.CompareRange(pd.Range, sd.Range) == 0 && pd.Source == string(sd.Source)
   416  }
   417  
   418  func goTest(ctx context.Context, snapshot source.Snapshot, uri span.URI, rng protocol.Range) ([]protocol.CodeAction, error) {
   419  	fh, err := snapshot.GetFile(ctx, uri)
   420  	if err != nil {
   421  		return nil, err
   422  	}
   423  	fns, err := source.TestsAndBenchmarks(ctx, snapshot, fh)
   424  	if err != nil {
   425  		return nil, err
   426  	}
   427  
   428  	var tests, benchmarks []string
   429  	for _, fn := range fns.Tests {
   430  		if !protocol.Intersect(fn.Rng, rng) {
   431  			continue
   432  		}
   433  		tests = append(tests, fn.Name)
   434  	}
   435  	for _, fn := range fns.Benchmarks {
   436  		if !protocol.Intersect(fn.Rng, rng) {
   437  			continue
   438  		}
   439  		benchmarks = append(benchmarks, fn.Name)
   440  	}
   441  
   442  	if len(tests) == 0 && len(benchmarks) == 0 {
   443  		return nil, nil
   444  	}
   445  
   446  	cmd, err := command.NewTestCommand("Run tests and benchmarks", protocol.URIFromSpanURI(uri), tests, benchmarks)
   447  	if err != nil {
   448  		return nil, err
   449  	}
   450  	return []protocol.CodeAction{{
   451  		Title:   cmd.Title,
   452  		Kind:    protocol.GoTest,
   453  		Command: &cmd,
   454  	}}, nil
   455  }