github.com/jhump/golang-x-tools@v0.0.0-20220218190644-4958d6d39439/internal/lsp/cache/mod_tidy.go (about)

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package cache
     6  
     7  import (
     8  	"context"
     9  	"fmt"
    10  	"go/ast"
    11  	"io/ioutil"
    12  	"os"
    13  	"path/filepath"
    14  	"sort"
    15  	"strconv"
    16  	"strings"
    17  
    18  	"golang.org/x/mod/modfile"
    19  	"github.com/jhump/golang-x-tools/internal/event"
    20  	"github.com/jhump/golang-x-tools/internal/gocommand"
    21  	"github.com/jhump/golang-x-tools/internal/lsp/command"
    22  	"github.com/jhump/golang-x-tools/internal/lsp/debug/tag"
    23  	"github.com/jhump/golang-x-tools/internal/lsp/diff"
    24  	"github.com/jhump/golang-x-tools/internal/lsp/protocol"
    25  	"github.com/jhump/golang-x-tools/internal/lsp/source"
    26  	"github.com/jhump/golang-x-tools/internal/memoize"
    27  	"github.com/jhump/golang-x-tools/internal/span"
    28  )
    29  
    30  type modTidyKey struct {
    31  	sessionID       string
    32  	env             string
    33  	gomod           source.FileIdentity
    34  	imports         string
    35  	unsavedOverlays string
    36  	view            string
    37  }
    38  
    39  type modTidyHandle struct {
    40  	handle *memoize.Handle
    41  }
    42  
    43  type modTidyData struct {
    44  	tidied *source.TidiedModule
    45  	err    error
    46  }
    47  
    48  func (mth *modTidyHandle) tidy(ctx context.Context, snapshot *snapshot) (*source.TidiedModule, error) {
    49  	v, err := mth.handle.Get(ctx, snapshot.generation, snapshot)
    50  	if err != nil {
    51  		return nil, err
    52  	}
    53  	data := v.(*modTidyData)
    54  	return data.tidied, data.err
    55  }
    56  
    57  func (s *snapshot) ModTidy(ctx context.Context, pm *source.ParsedModule) (*source.TidiedModule, error) {
    58  	if pm.File == nil {
    59  		return nil, fmt.Errorf("cannot tidy unparseable go.mod file: %v", pm.URI)
    60  	}
    61  	if handle := s.getModTidyHandle(pm.URI); handle != nil {
    62  		return handle.tidy(ctx, s)
    63  	}
    64  	fh, err := s.GetFile(ctx, pm.URI)
    65  	if err != nil {
    66  		return nil, err
    67  	}
    68  	// If the file handle is an overlay, it may not be written to disk.
    69  	// The go.mod file has to be on disk for `go mod tidy` to work.
    70  	if _, ok := fh.(*overlay); ok {
    71  		if info, _ := os.Stat(fh.URI().Filename()); info == nil {
    72  			return nil, source.ErrNoModOnDisk
    73  		}
    74  	}
    75  	if criticalErr := s.GetCriticalError(ctx); criticalErr != nil {
    76  		return &source.TidiedModule{
    77  			Diagnostics: criticalErr.DiagList,
    78  		}, nil
    79  	}
    80  	workspacePkgs, err := s.workspacePackageHandles(ctx)
    81  	if err != nil {
    82  		return nil, err
    83  	}
    84  	importHash, err := s.hashImports(ctx, workspacePkgs)
    85  	if err != nil {
    86  		return nil, err
    87  	}
    88  
    89  	s.mu.Lock()
    90  	overlayHash := hashUnsavedOverlays(s.files)
    91  	s.mu.Unlock()
    92  
    93  	key := modTidyKey{
    94  		sessionID:       s.view.session.id,
    95  		view:            s.view.folder.Filename(),
    96  		imports:         importHash,
    97  		unsavedOverlays: overlayHash,
    98  		gomod:           fh.FileIdentity(),
    99  		env:             hashEnv(s),
   100  	}
   101  	h := s.generation.Bind(key, func(ctx context.Context, arg memoize.Arg) interface{} {
   102  		ctx, done := event.Start(ctx, "cache.ModTidyHandle", tag.URI.Of(fh.URI()))
   103  		defer done()
   104  
   105  		snapshot := arg.(*snapshot)
   106  		inv := &gocommand.Invocation{
   107  			Verb:       "mod",
   108  			Args:       []string{"tidy"},
   109  			WorkingDir: filepath.Dir(fh.URI().Filename()),
   110  		}
   111  		tmpURI, inv, cleanup, err := snapshot.goCommandInvocation(ctx, source.WriteTemporaryModFile, inv)
   112  		if err != nil {
   113  			return &modTidyData{err: err}
   114  		}
   115  		// Keep the temporary go.mod file around long enough to parse it.
   116  		defer cleanup()
   117  
   118  		if _, err := s.view.session.gocmdRunner.Run(ctx, *inv); err != nil {
   119  			return &modTidyData{err: err}
   120  		}
   121  		// Go directly to disk to get the temporary mod file, since it is
   122  		// always on disk.
   123  		tempContents, err := ioutil.ReadFile(tmpURI.Filename())
   124  		if err != nil {
   125  			return &modTidyData{err: err}
   126  		}
   127  		ideal, err := modfile.Parse(tmpURI.Filename(), tempContents, nil)
   128  		if err != nil {
   129  			// We do not need to worry about the temporary file's parse errors
   130  			// since it has been "tidied".
   131  			return &modTidyData{err: err}
   132  		}
   133  		// Compare the original and tidied go.mod files to compute errors and
   134  		// suggested fixes.
   135  		diagnostics, err := modTidyDiagnostics(ctx, snapshot, pm, ideal, workspacePkgs)
   136  		if err != nil {
   137  			return &modTidyData{err: err}
   138  		}
   139  		return &modTidyData{
   140  			tidied: &source.TidiedModule{
   141  				Diagnostics:   diagnostics,
   142  				TidiedContent: tempContents,
   143  			},
   144  		}
   145  	}, nil)
   146  
   147  	mth := &modTidyHandle{handle: h}
   148  	s.mu.Lock()
   149  	s.modTidyHandles[fh.URI()] = mth
   150  	s.mu.Unlock()
   151  
   152  	return mth.tidy(ctx, s)
   153  }
   154  
   155  func (s *snapshot) uriToModDecl(ctx context.Context, uri span.URI) (protocol.Range, error) {
   156  	fh, err := s.GetFile(ctx, uri)
   157  	if err != nil {
   158  		return protocol.Range{}, nil
   159  	}
   160  	pmf, err := s.ParseMod(ctx, fh)
   161  	if err != nil {
   162  		return protocol.Range{}, nil
   163  	}
   164  	if pmf.File.Module == nil || pmf.File.Module.Syntax == nil {
   165  		return protocol.Range{}, nil
   166  	}
   167  	return rangeFromPositions(pmf.Mapper, pmf.File.Module.Syntax.Start, pmf.File.Module.Syntax.End)
   168  }
   169  
   170  func (s *snapshot) hashImports(ctx context.Context, wsPackages []*packageHandle) (string, error) {
   171  	seen := map[string]struct{}{}
   172  	var imports []string
   173  	for _, ph := range wsPackages {
   174  		for _, imp := range ph.imports(ctx, s) {
   175  			if _, ok := seen[imp]; !ok {
   176  				imports = append(imports, imp)
   177  				seen[imp] = struct{}{}
   178  			}
   179  		}
   180  	}
   181  	sort.Strings(imports)
   182  	hashed := strings.Join(imports, ",")
   183  	return hashContents([]byte(hashed)), nil
   184  }
   185  
   186  // modTidyDiagnostics computes the differences between the original and tidied
   187  // go.mod files to produce diagnostic and suggested fixes. Some diagnostics
   188  // may appear on the Go files that import packages from missing modules.
   189  func modTidyDiagnostics(ctx context.Context, snapshot source.Snapshot, pm *source.ParsedModule, ideal *modfile.File, workspacePkgs []*packageHandle) (diagnostics []*source.Diagnostic, err error) {
   190  	// First, determine which modules are unused and which are missing from the
   191  	// original go.mod file.
   192  	var (
   193  		unused          = make(map[string]*modfile.Require, len(pm.File.Require))
   194  		missing         = make(map[string]*modfile.Require, len(ideal.Require))
   195  		wrongDirectness = make(map[string]*modfile.Require, len(pm.File.Require))
   196  	)
   197  	for _, req := range pm.File.Require {
   198  		unused[req.Mod.Path] = req
   199  	}
   200  	for _, req := range ideal.Require {
   201  		origReq := unused[req.Mod.Path]
   202  		if origReq == nil {
   203  			missing[req.Mod.Path] = req
   204  			continue
   205  		} else if origReq.Indirect != req.Indirect {
   206  			wrongDirectness[req.Mod.Path] = origReq
   207  		}
   208  		delete(unused, req.Mod.Path)
   209  	}
   210  	for _, req := range wrongDirectness {
   211  		// Handle dependencies that are incorrectly labeled indirect and
   212  		// vice versa.
   213  		srcDiag, err := directnessDiagnostic(pm.Mapper, req, snapshot.View().Options().ComputeEdits)
   214  		if err != nil {
   215  			// We're probably in a bad state if we can't compute a
   216  			// directnessDiagnostic, but try to keep going so as to not suppress
   217  			// other, valid diagnostics.
   218  			event.Error(ctx, "computing directness diagnostic", err)
   219  			continue
   220  		}
   221  		diagnostics = append(diagnostics, srcDiag)
   222  	}
   223  	// Next, compute any diagnostics for modules that are missing from the
   224  	// go.mod file. The fixes will be for the go.mod file, but the
   225  	// diagnostics should also appear in both the go.mod file and the import
   226  	// statements in the Go files in which the dependencies are used.
   227  	missingModuleFixes := map[*modfile.Require][]source.SuggestedFix{}
   228  	for _, req := range missing {
   229  		srcDiag, err := missingModuleDiagnostic(pm, req)
   230  		if err != nil {
   231  			return nil, err
   232  		}
   233  		missingModuleFixes[req] = srcDiag.SuggestedFixes
   234  		diagnostics = append(diagnostics, srcDiag)
   235  	}
   236  	// Add diagnostics for missing modules anywhere they are imported in the
   237  	// workspace.
   238  	for _, ph := range workspacePkgs {
   239  		missingImports := map[string]*modfile.Require{}
   240  
   241  		// If -mod=readonly is not set we may have successfully imported
   242  		// packages from missing modules. Otherwise they'll be in
   243  		// MissingDependencies. Combine both.
   244  		importedPkgs := ph.imports(ctx, snapshot)
   245  
   246  		for _, imp := range importedPkgs {
   247  			if req, ok := missing[imp]; ok {
   248  				missingImports[imp] = req
   249  				break
   250  			}
   251  			// If the import is a package of the dependency, then add the
   252  			// package to the map, this will eliminate the need to do this
   253  			// prefix package search on each import for each file.
   254  			// Example:
   255  			//
   256  			// import (
   257  			//   "github.com/jhump/golang-x-tools/go/expect"
   258  			//   "github.com/jhump/golang-x-tools/go/packages"
   259  			// )
   260  			// They both are related to the same module: "golang.org/x/tools".
   261  			var match string
   262  			for _, req := range ideal.Require {
   263  				if strings.HasPrefix(imp, req.Mod.Path) && len(req.Mod.Path) > len(match) {
   264  					match = req.Mod.Path
   265  				}
   266  			}
   267  			if req, ok := missing[match]; ok {
   268  				missingImports[imp] = req
   269  			}
   270  		}
   271  		// None of this package's imports are from missing modules.
   272  		if len(missingImports) == 0 {
   273  			continue
   274  		}
   275  		for _, pgh := range ph.compiledGoFiles {
   276  			pgf, err := snapshot.ParseGo(ctx, pgh.file, source.ParseHeader)
   277  			if err != nil {
   278  				continue
   279  			}
   280  			file, m := pgf.File, pgf.Mapper
   281  			if file == nil || m == nil {
   282  				continue
   283  			}
   284  			imports := make(map[string]*ast.ImportSpec)
   285  			for _, imp := range file.Imports {
   286  				if imp.Path == nil {
   287  					continue
   288  				}
   289  				if target, err := strconv.Unquote(imp.Path.Value); err == nil {
   290  					imports[target] = imp
   291  				}
   292  			}
   293  			if len(imports) == 0 {
   294  				continue
   295  			}
   296  			for importPath, req := range missingImports {
   297  				imp, ok := imports[importPath]
   298  				if !ok {
   299  					continue
   300  				}
   301  				fixes, ok := missingModuleFixes[req]
   302  				if !ok {
   303  					return nil, fmt.Errorf("no missing module fix for %q (%q)", importPath, req.Mod.Path)
   304  				}
   305  				srcErr, err := missingModuleForImport(snapshot, m, imp, req, fixes)
   306  				if err != nil {
   307  					return nil, err
   308  				}
   309  				diagnostics = append(diagnostics, srcErr)
   310  			}
   311  		}
   312  	}
   313  	// Finally, add errors for any unused dependencies.
   314  	onlyDiagnostic := len(diagnostics) == 0 && len(unused) == 1
   315  	for _, req := range unused {
   316  		srcErr, err := unusedDiagnostic(pm.Mapper, req, onlyDiagnostic)
   317  		if err != nil {
   318  			return nil, err
   319  		}
   320  		diagnostics = append(diagnostics, srcErr)
   321  	}
   322  	return diagnostics, nil
   323  }
   324  
   325  // unusedDiagnostic returns a source.Diagnostic for an unused require.
   326  func unusedDiagnostic(m *protocol.ColumnMapper, req *modfile.Require, onlyDiagnostic bool) (*source.Diagnostic, error) {
   327  	rng, err := rangeFromPositions(m, req.Syntax.Start, req.Syntax.End)
   328  	if err != nil {
   329  		return nil, err
   330  	}
   331  	title := fmt.Sprintf("Remove dependency: %s", req.Mod.Path)
   332  	cmd, err := command.NewRemoveDependencyCommand(title, command.RemoveDependencyArgs{
   333  		URI:            protocol.URIFromSpanURI(m.URI),
   334  		OnlyDiagnostic: onlyDiagnostic,
   335  		ModulePath:     req.Mod.Path,
   336  	})
   337  	if err != nil {
   338  		return nil, err
   339  	}
   340  	return &source.Diagnostic{
   341  		URI:            m.URI,
   342  		Range:          rng,
   343  		Severity:       protocol.SeverityWarning,
   344  		Source:         source.ModTidyError,
   345  		Message:        fmt.Sprintf("%s is not used in this module", req.Mod.Path),
   346  		SuggestedFixes: []source.SuggestedFix{source.SuggestedFixFromCommand(cmd, protocol.QuickFix)},
   347  	}, nil
   348  }
   349  
   350  // directnessDiagnostic extracts errors when a dependency is labeled indirect when
   351  // it should be direct and vice versa.
   352  func directnessDiagnostic(m *protocol.ColumnMapper, req *modfile.Require, computeEdits diff.ComputeEdits) (*source.Diagnostic, error) {
   353  	rng, err := rangeFromPositions(m, req.Syntax.Start, req.Syntax.End)
   354  	if err != nil {
   355  		return nil, err
   356  	}
   357  	direction := "indirect"
   358  	if req.Indirect {
   359  		direction = "direct"
   360  
   361  		// If the dependency should be direct, just highlight the // indirect.
   362  		if comments := req.Syntax.Comment(); comments != nil && len(comments.Suffix) > 0 {
   363  			end := comments.Suffix[0].Start
   364  			end.LineRune += len(comments.Suffix[0].Token)
   365  			end.Byte += len([]byte(comments.Suffix[0].Token))
   366  			rng, err = rangeFromPositions(m, comments.Suffix[0].Start, end)
   367  			if err != nil {
   368  				return nil, err
   369  			}
   370  		}
   371  	}
   372  	// If the dependency should be indirect, add the // indirect.
   373  	edits, err := switchDirectness(req, m, computeEdits)
   374  	if err != nil {
   375  		return nil, err
   376  	}
   377  	return &source.Diagnostic{
   378  		URI:      m.URI,
   379  		Range:    rng,
   380  		Severity: protocol.SeverityWarning,
   381  		Source:   source.ModTidyError,
   382  		Message:  fmt.Sprintf("%s should be %s", req.Mod.Path, direction),
   383  		SuggestedFixes: []source.SuggestedFix{{
   384  			Title: fmt.Sprintf("Change %s to %s", req.Mod.Path, direction),
   385  			Edits: map[span.URI][]protocol.TextEdit{
   386  				m.URI: edits,
   387  			},
   388  			ActionKind: protocol.QuickFix,
   389  		}},
   390  	}, nil
   391  }
   392  
   393  func missingModuleDiagnostic(pm *source.ParsedModule, req *modfile.Require) (*source.Diagnostic, error) {
   394  	var rng protocol.Range
   395  	// Default to the start of the file if there is no module declaration.
   396  	if pm.File != nil && pm.File.Module != nil && pm.File.Module.Syntax != nil {
   397  		start, end := pm.File.Module.Syntax.Span()
   398  		var err error
   399  		rng, err = rangeFromPositions(pm.Mapper, start, end)
   400  		if err != nil {
   401  			return nil, err
   402  		}
   403  	}
   404  	title := fmt.Sprintf("Add %s to your go.mod file", req.Mod.Path)
   405  	cmd, err := command.NewAddDependencyCommand(title, command.DependencyArgs{
   406  		URI:        protocol.URIFromSpanURI(pm.Mapper.URI),
   407  		AddRequire: !req.Indirect,
   408  		GoCmdArgs:  []string{req.Mod.Path + "@" + req.Mod.Version},
   409  	})
   410  	if err != nil {
   411  		return nil, err
   412  	}
   413  	return &source.Diagnostic{
   414  		URI:            pm.Mapper.URI,
   415  		Range:          rng,
   416  		Severity:       protocol.SeverityError,
   417  		Source:         source.ModTidyError,
   418  		Message:        fmt.Sprintf("%s is not in your go.mod file", req.Mod.Path),
   419  		SuggestedFixes: []source.SuggestedFix{source.SuggestedFixFromCommand(cmd, protocol.QuickFix)},
   420  	}, nil
   421  }
   422  
   423  // switchDirectness gets the edits needed to change an indirect dependency to
   424  // direct and vice versa.
   425  func switchDirectness(req *modfile.Require, m *protocol.ColumnMapper, computeEdits diff.ComputeEdits) ([]protocol.TextEdit, error) {
   426  	// We need a private copy of the parsed go.mod file, since we're going to
   427  	// modify it.
   428  	copied, err := modfile.Parse("", m.Content, nil)
   429  	if err != nil {
   430  		return nil, err
   431  	}
   432  	// Change the directness in the matching require statement. To avoid
   433  	// reordering the require statements, rewrite all of them.
   434  	var requires []*modfile.Require
   435  	seenVersions := make(map[string]string)
   436  	for _, r := range copied.Require {
   437  		if seen := seenVersions[r.Mod.Path]; seen != "" && seen != r.Mod.Version {
   438  			// Avoid a panic in SetRequire below, which panics on conflicting
   439  			// versions.
   440  			return nil, fmt.Errorf("%q has conflicting versions: %q and %q", r.Mod.Path, seen, r.Mod.Version)
   441  		}
   442  		seenVersions[r.Mod.Path] = r.Mod.Version
   443  		if r.Mod.Path == req.Mod.Path {
   444  			requires = append(requires, &modfile.Require{
   445  				Mod:      r.Mod,
   446  				Syntax:   r.Syntax,
   447  				Indirect: !r.Indirect,
   448  			})
   449  			continue
   450  		}
   451  		requires = append(requires, r)
   452  	}
   453  	copied.SetRequire(requires)
   454  	newContent, err := copied.Format()
   455  	if err != nil {
   456  		return nil, err
   457  	}
   458  	// Calculate the edits to be made due to the change.
   459  	diff, err := computeEdits(m.URI, string(m.Content), string(newContent))
   460  	if err != nil {
   461  		return nil, err
   462  	}
   463  	return source.ToProtocolEdits(m, diff)
   464  }
   465  
   466  // missingModuleForImport creates an error for a given import path that comes
   467  // from a missing module.
   468  func missingModuleForImport(snapshot source.Snapshot, m *protocol.ColumnMapper, imp *ast.ImportSpec, req *modfile.Require, fixes []source.SuggestedFix) (*source.Diagnostic, error) {
   469  	if req.Syntax == nil {
   470  		return nil, fmt.Errorf("no syntax for %v", req)
   471  	}
   472  	spn, err := span.NewRange(snapshot.FileSet(), imp.Path.Pos(), imp.Path.End()).Span()
   473  	if err != nil {
   474  		return nil, err
   475  	}
   476  	rng, err := m.Range(spn)
   477  	if err != nil {
   478  		return nil, err
   479  	}
   480  	return &source.Diagnostic{
   481  		URI:            m.URI,
   482  		Range:          rng,
   483  		Severity:       protocol.SeverityError,
   484  		Source:         source.ModTidyError,
   485  		Message:        fmt.Sprintf("%s is not in your go.mod file", req.Mod.Path),
   486  		SuggestedFixes: fixes,
   487  	}, nil
   488  }
   489  
   490  func rangeFromPositions(m *protocol.ColumnMapper, s, e modfile.Position) (protocol.Range, error) {
   491  	spn, err := spanFromPositions(m, s, e)
   492  	if err != nil {
   493  		return protocol.Range{}, err
   494  	}
   495  	return m.Range(spn)
   496  }
   497  
   498  func spanFromPositions(m *protocol.ColumnMapper, s, e modfile.Position) (span.Span, error) {
   499  	toPoint := func(offset int) (span.Point, error) {
   500  		l, c, err := m.Converter.ToPosition(offset)
   501  		if err != nil {
   502  			return span.Point{}, err
   503  		}
   504  		return span.NewPoint(l, c, offset), nil
   505  	}
   506  	start, err := toPoint(s.Byte)
   507  	if err != nil {
   508  		return span.Span{}, err
   509  	}
   510  	end, err := toPoint(e.Byte)
   511  	if err != nil {
   512  		return span.Span{}, err
   513  	}
   514  	return span.New(m.URI, start, end), nil
   515  }