golang.org/x/tools/gopls@v0.15.3/internal/cache/mod_vuln.go (about)

     1  // Copyright 2022 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package cache
     6  
     7  import (
     8  	"context"
     9  	"fmt"
    10  	"io"
    11  	"os"
    12  	"sort"
    13  	"strings"
    14  	"sync"
    15  
    16  	"golang.org/x/mod/semver"
    17  	"golang.org/x/sync/errgroup"
    18  	"golang.org/x/tools/go/packages"
    19  	"golang.org/x/tools/gopls/internal/cache/metadata"
    20  	"golang.org/x/tools/gopls/internal/protocol"
    21  	"golang.org/x/tools/gopls/internal/vulncheck"
    22  	"golang.org/x/tools/gopls/internal/vulncheck/govulncheck"
    23  	"golang.org/x/tools/gopls/internal/vulncheck/osv"
    24  	isem "golang.org/x/tools/gopls/internal/vulncheck/semver"
    25  	"golang.org/x/tools/internal/memoize"
    26  	"golang.org/x/vuln/scan"
    27  )
    28  
    29  // ModVuln returns import vulnerability analysis for the given go.mod URI.
    30  // Concurrent requests are combined into a single command.
    31  func (s *Snapshot) ModVuln(ctx context.Context, modURI protocol.DocumentURI) (*vulncheck.Result, error) {
    32  	s.mu.Lock()
    33  	entry, hit := s.modVulnHandles.Get(modURI)
    34  	s.mu.Unlock()
    35  
    36  	type modVuln struct {
    37  		result *vulncheck.Result
    38  		err    error
    39  	}
    40  
    41  	// Cache miss?
    42  	if !hit {
    43  		handle := memoize.NewPromise("modVuln", func(ctx context.Context, arg interface{}) interface{} {
    44  			result, err := modVulnImpl(ctx, arg.(*Snapshot))
    45  			return modVuln{result, err}
    46  		})
    47  
    48  		entry = handle
    49  		s.mu.Lock()
    50  		s.modVulnHandles.Set(modURI, entry, nil)
    51  		s.mu.Unlock()
    52  	}
    53  
    54  	// Await result.
    55  	v, err := s.awaitPromise(ctx, entry)
    56  	if err != nil {
    57  		return nil, err
    58  	}
    59  	res := v.(modVuln)
    60  	return res.result, res.err
    61  }
    62  
    63  // GoVersionForVulnTest is an internal environment variable used in gopls
    64  // testing to examine govulncheck behavior with a go version different
    65  // than what `go version` returns in the system.
    66  const GoVersionForVulnTest = "_GOPLS_TEST_VULNCHECK_GOVERSION"
    67  
    68  // modVulnImpl queries the vulndb and reports which vulnerabilities
    69  // apply to this snapshot. The result contains a set of packages,
    70  // grouped by vuln ID and by module. This implements the "import-based"
    71  // vulnerability report on go.mod files.
    72  func modVulnImpl(ctx context.Context, snapshot *Snapshot) (*vulncheck.Result, error) {
    73  	// TODO(hyangah): can we let 'govulncheck' take a package list
    74  	// used in the workspace and implement this function?
    75  
    76  	// We want to report the intersection of vulnerable packages in the vulndb
    77  	// and packages transitively imported by this module ('go list -deps all').
    78  	// We use snapshot.AllMetadata to retrieve the list of packages
    79  	// as an approximation.
    80  	//
    81  	// TODO(hyangah): snapshot.AllMetadata is a superset of
    82  	// `go list all` - e.g. when the workspace has multiple main modules
    83  	// (multiple go.mod files), that can include packages that are not
    84  	// used by this module. Vulncheck behavior with go.work is not well
    85  	// defined. Figure out the meaning, and if we decide to present
    86  	// the result as if each module is analyzed independently, make
    87  	// gopls track a separate build list for each module and use that
    88  	// information instead of snapshot.AllMetadata.
    89  	allMeta, err := snapshot.AllMetadata(ctx)
    90  	if err != nil {
    91  		return nil, err
    92  	}
    93  
    94  	// TODO(hyangah): handle vulnerabilities in the standard library.
    95  
    96  	// Group packages by modules since vuln db is keyed by module.
    97  	packagesByModule := map[metadata.PackagePath][]*metadata.Package{}
    98  	for _, mp := range allMeta {
    99  		modulePath := metadata.PackagePath(osv.GoStdModulePath)
   100  		if mi := mp.Module; mi != nil {
   101  			modulePath = metadata.PackagePath(mi.Path)
   102  		}
   103  		packagesByModule[modulePath] = append(packagesByModule[modulePath], mp)
   104  	}
   105  
   106  	var (
   107  		mu sync.Mutex
   108  		// Keys are osv.Entry.ID
   109  		osvs     = map[string]*osv.Entry{}
   110  		findings []*govulncheck.Finding
   111  	)
   112  
   113  	goVersion := snapshot.Options().Env[GoVersionForVulnTest]
   114  	if goVersion == "" {
   115  		goVersion = snapshot.GoVersionString()
   116  	}
   117  
   118  	stdlibModule := &packages.Module{
   119  		Path:    osv.GoStdModulePath,
   120  		Version: goVersion,
   121  	}
   122  
   123  	// GOVULNDB may point the test db URI.
   124  	db := GetEnv(snapshot, "GOVULNDB")
   125  
   126  	var group errgroup.Group
   127  	group.SetLimit(10) // limit govulncheck api runs
   128  	for _, mps := range packagesByModule {
   129  		mps := mps
   130  		group.Go(func() error {
   131  			effectiveModule := stdlibModule
   132  			if m := mps[0].Module; m != nil {
   133  				effectiveModule = m
   134  			}
   135  			for effectiveModule.Replace != nil {
   136  				effectiveModule = effectiveModule.Replace
   137  			}
   138  			ver := effectiveModule.Version
   139  			if ver == "" || !isem.Valid(ver) {
   140  				// skip invalid version strings. the underlying scan api is strict.
   141  				return nil
   142  			}
   143  
   144  			// TODO(hyangah): batch these requests and add in-memory cache for efficiency.
   145  			vulns, err := osvsByModule(ctx, db, effectiveModule.Path+"@"+ver)
   146  			if err != nil {
   147  				return err
   148  			}
   149  			if len(vulns) == 0 { // No known vulnerability.
   150  				return nil
   151  			}
   152  
   153  			// set of packages in this module known to gopls.
   154  			// This will be lazily initialized when we need it.
   155  			var knownPkgs map[metadata.PackagePath]bool
   156  
   157  			// Report vulnerabilities that affect packages of this module.
   158  			for _, entry := range vulns {
   159  				var vulnerablePkgs []*govulncheck.Finding
   160  				fixed := fixedVersion(effectiveModule.Path, entry.Affected)
   161  
   162  				for _, a := range entry.Affected {
   163  					if a.Module.Ecosystem != osv.GoEcosystem || a.Module.Path != effectiveModule.Path {
   164  						continue
   165  					}
   166  					for _, imp := range a.EcosystemSpecific.Packages {
   167  						if knownPkgs == nil {
   168  							knownPkgs = toPackagePathSet(mps)
   169  						}
   170  						if knownPkgs[metadata.PackagePath(imp.Path)] {
   171  							vulnerablePkgs = append(vulnerablePkgs, &govulncheck.Finding{
   172  								OSV:          entry.ID,
   173  								FixedVersion: fixed,
   174  								Trace: []*govulncheck.Frame{
   175  									{
   176  										Module:  effectiveModule.Path,
   177  										Version: effectiveModule.Version,
   178  										Package: imp.Path,
   179  									},
   180  								},
   181  							})
   182  						}
   183  					}
   184  				}
   185  				if len(vulnerablePkgs) == 0 {
   186  					continue
   187  				}
   188  				mu.Lock()
   189  				osvs[entry.ID] = entry
   190  				findings = append(findings, vulnerablePkgs...)
   191  				mu.Unlock()
   192  			}
   193  			return nil
   194  		})
   195  	}
   196  	if err := group.Wait(); err != nil {
   197  		return nil, err
   198  	}
   199  
   200  	// Sort so the results are deterministic.
   201  	sort.Slice(findings, func(i, j int) bool {
   202  		x, y := findings[i], findings[j]
   203  		if x.OSV != y.OSV {
   204  			return x.OSV < y.OSV
   205  		}
   206  		return x.Trace[0].Package < y.Trace[0].Package
   207  	})
   208  	ret := &vulncheck.Result{
   209  		Entries:  osvs,
   210  		Findings: findings,
   211  		Mode:     vulncheck.ModeImports,
   212  	}
   213  	return ret, nil
   214  }
   215  
   216  // TODO(rfindley): this function was exposed during refactoring. Reconsider it.
   217  func GetEnv(snapshot *Snapshot, key string) string {
   218  	val, ok := snapshot.Options().Env[key]
   219  	if ok {
   220  		return val
   221  	}
   222  	return os.Getenv(key)
   223  }
   224  
   225  // toPackagePathSet transforms the metadata to a set of package paths.
   226  func toPackagePathSet(mds []*metadata.Package) map[metadata.PackagePath]bool {
   227  	pkgPaths := make(map[metadata.PackagePath]bool, len(mds))
   228  	for _, md := range mds {
   229  		pkgPaths[md.PkgPath] = true
   230  	}
   231  	return pkgPaths
   232  }
   233  
   234  func fixedVersion(modulePath string, affected []osv.Affected) string {
   235  	fixed := latestFixed(modulePath, affected)
   236  	if fixed != "" {
   237  		fixed = versionString(modulePath, fixed)
   238  	}
   239  	return fixed
   240  }
   241  
   242  // latestFixed returns the latest fixed version in the list of affected ranges,
   243  // or the empty string if there are no fixed versions.
   244  func latestFixed(modulePath string, as []osv.Affected) string {
   245  	v := ""
   246  	for _, a := range as {
   247  		if a.Module.Path != modulePath {
   248  			continue
   249  		}
   250  		for _, r := range a.Ranges {
   251  			if r.Type == osv.RangeTypeSemver {
   252  				for _, e := range r.Events {
   253  					if e.Fixed != "" && (v == "" ||
   254  						semver.Compare(isem.CanonicalizeSemverPrefix(e.Fixed), isem.CanonicalizeSemverPrefix(v)) > 0) {
   255  						v = e.Fixed
   256  					}
   257  				}
   258  			}
   259  		}
   260  	}
   261  	return v
   262  }
   263  
   264  // versionString prepends a version string prefix (`v` or `go`
   265  // depending on the modulePath) to the given semver-style version string.
   266  func versionString(modulePath, version string) string {
   267  	if version == "" {
   268  		return ""
   269  	}
   270  	v := "v" + version
   271  	// These are internal Go module paths used by the vuln DB
   272  	// when listing vulns in standard library and the go command.
   273  	if modulePath == "stdlib" || modulePath == "toolchain" {
   274  		return semverToGoTag(v)
   275  	}
   276  	return v
   277  }
   278  
   279  // semverToGoTag returns the Go standard library repository tag corresponding
   280  // to semver, a version string without the initial "v".
   281  // Go tags differ from standard semantic versions in a few ways,
   282  // such as beginning with "go" instead of "v".
   283  func semverToGoTag(v string) string {
   284  	if strings.HasPrefix(v, "v0.0.0") {
   285  		return "master"
   286  	}
   287  	// Special case: v1.0.0 => go1.
   288  	if v == "v1.0.0" {
   289  		return "go1"
   290  	}
   291  	if !semver.IsValid(v) {
   292  		return fmt.Sprintf("<!%s:invalid semver>", v)
   293  	}
   294  	goVersion := semver.Canonical(v)
   295  	prerelease := semver.Prerelease(goVersion)
   296  	versionWithoutPrerelease := strings.TrimSuffix(goVersion, prerelease)
   297  	patch := strings.TrimPrefix(versionWithoutPrerelease, semver.MajorMinor(goVersion)+".")
   298  	if patch == "0" {
   299  		versionWithoutPrerelease = strings.TrimSuffix(versionWithoutPrerelease, ".0")
   300  	}
   301  	goVersion = fmt.Sprintf("go%s", strings.TrimPrefix(versionWithoutPrerelease, "v"))
   302  	if prerelease != "" {
   303  		// Go prereleases look like  "beta1" instead of "beta.1".
   304  		// "beta1" is bad for sorting (since beta10 comes before beta9), so
   305  		// require the dot form.
   306  		i := finalDigitsIndex(prerelease)
   307  		if i >= 1 {
   308  			if prerelease[i-1] != '.' {
   309  				return fmt.Sprintf("<!%s:final digits in a prerelease must follow a period>", v)
   310  			}
   311  			// Remove the dot.
   312  			prerelease = prerelease[:i-1] + prerelease[i:]
   313  		}
   314  		goVersion += strings.TrimPrefix(prerelease, "-")
   315  	}
   316  	return goVersion
   317  }
   318  
   319  // finalDigitsIndex returns the index of the first digit in the sequence of digits ending s.
   320  // If s doesn't end in digits, it returns -1.
   321  func finalDigitsIndex(s string) int {
   322  	// Assume ASCII (since the semver package does anyway).
   323  	var i int
   324  	for i = len(s) - 1; i >= 0; i-- {
   325  		if s[i] < '0' || s[i] > '9' {
   326  			break
   327  		}
   328  	}
   329  	if i == len(s)-1 {
   330  		return -1
   331  	}
   332  	return i + 1
   333  }
   334  
   335  // osvsByModule runs a govulncheck database query.
   336  func osvsByModule(ctx context.Context, db, moduleVersion string) ([]*osv.Entry, error) {
   337  	var args []string
   338  	args = append(args, "-mode=query", "-json")
   339  	if db != "" {
   340  		args = append(args, "-db="+db)
   341  	}
   342  	args = append(args, moduleVersion)
   343  
   344  	ir, iw := io.Pipe()
   345  	handler := &osvReader{}
   346  
   347  	var g errgroup.Group
   348  	g.Go(func() error {
   349  		defer iw.Close() // scan API doesn't close cmd.Stderr/cmd.Stdout.
   350  		cmd := scan.Command(ctx, args...)
   351  		cmd.Stdout = iw
   352  		// TODO(hakim): Do we need to set cmd.Env = getEnvSlices(),
   353  		// or is the process environment good enough?
   354  		if err := cmd.Start(); err != nil {
   355  			return err
   356  		}
   357  		return cmd.Wait()
   358  	})
   359  	g.Go(func() error {
   360  		return govulncheck.HandleJSON(ir, handler)
   361  	})
   362  
   363  	if err := g.Wait(); err != nil {
   364  		return nil, err
   365  	}
   366  	return handler.entry, nil
   367  }
   368  
   369  // osvReader implements govulncheck.Handler.
   370  type osvReader struct {
   371  	entry []*osv.Entry
   372  }
   373  
   374  func (h *osvReader) OSV(entry *osv.Entry) error {
   375  	h.entry = append(h.entry, entry)
   376  	return nil
   377  }
   378  
   379  func (h *osvReader) Config(config *govulncheck.Config) error {
   380  	return nil
   381  }
   382  
   383  func (h *osvReader) Finding(finding *govulncheck.Finding) error {
   384  	return nil
   385  }
   386  
   387  func (h *osvReader) Progress(progress *govulncheck.Progress) error {
   388  	return nil
   389  }