github.com/google/osv-scalibr@v0.4.1/guidedremediation/internal/strategy/override/override.go (about)

     1  // Copyright 2025 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package override implements the override remediation strategy.
    16  package override
    17  
    18  import (
    19  	"context"
    20  	"fmt"
    21  	"slices"
    22  
    23  	"deps.dev/util/resolve"
    24  	"deps.dev/util/resolve/dep"
    25  	"deps.dev/util/semver"
    26  	"github.com/google/osv-scalibr/enricher"
    27  	"github.com/google/osv-scalibr/guidedremediation/internal/remediation"
    28  	"github.com/google/osv-scalibr/guidedremediation/internal/resolution"
    29  	"github.com/google/osv-scalibr/guidedremediation/internal/strategy/common"
    30  	"github.com/google/osv-scalibr/guidedremediation/internal/vulns"
    31  	"github.com/google/osv-scalibr/guidedremediation/options"
    32  	"github.com/google/osv-scalibr/guidedremediation/upgrade"
    33  	"github.com/google/osv-scalibr/internal/mavenutil"
    34  	"github.com/google/osv-scalibr/log"
    35  )
    36  
    37  // ComputePatches attempts to resolve each vulnerability found in result independently, returning the list of unique possible patches.
    38  // Vulnerabilities are resolved by directly overriding versions of vulnerable packages to non-vulnerable versions.
    39  // If a patch introduces new vulnerabilities, additional overrides are attempted for the new vulnerabilities.
    40  func ComputePatches(ctx context.Context, cl resolve.Client, ve enricher.Enricher, resolved *remediation.ResolvedManifest, opts *options.RemediationOptions) (common.PatchResult, error) {
    41  	patchFn := func(vulnIDs []string) common.StrategyResult {
    42  		patched, err := patchVulns(ctx, cl, ve, resolved, vulnIDs, opts)
    43  		return common.StrategyResult{
    44  			VulnIDs:  vulnIDs,
    45  			Resolved: patched,
    46  			Err:      err}
    47  	}
    48  
    49  	return common.ComputePatches(patchFn, resolved, true)
    50  }
    51  
    52  // patchVulns tries to fix as many vulns in vulnIDs as possible by overriding dependency versions.
    53  // returns ErrPatchImpossible if 0 vulns are patchable, otherwise returns the most possible patches.
    54  func patchVulns(ctx context.Context, cl resolve.Client, ve enricher.Enricher, resolved *remediation.ResolvedManifest, vulnIDs []string, opts *options.RemediationOptions) (*remediation.ResolvedManifest, error) {
    55  	resolved = &remediation.ResolvedManifest{
    56  		Manifest:      resolved.Manifest.Clone(),
    57  		ResolvedGraph: resolved.ResolvedGraph,
    58  	}
    59  
    60  	for {
    61  		// Find the relevant vulns affecting each version key.
    62  		vkVulns := make(map[resolve.VersionKey][]*resolution.Vulnerability)
    63  		for i, v := range resolved.Vulns {
    64  			if !slices.Contains(vulnIDs, v.OSV.Id) {
    65  				continue
    66  			}
    67  			// Keep track of VersionKeys we've seen for this vuln to avoid duplicates.
    68  			// Usually, there will only be one VersionKey per vuln, but some vulns affect multiple packages.
    69  			seenVKs := make(map[resolve.VersionKey]struct{})
    70  			// Use the Subgraphs to find all the affected nodes.
    71  			for _, sg := range v.Subgraphs {
    72  				for _, e := range sg.Nodes[sg.Dependency].Parents {
    73  					// It's hard to know if a specific classifier or type exists for a given version.
    74  					// Blindly updating versions can lead to compilation failures if the artifact+version+classifier+type doesn't exist.
    75  					// We can't reliably attempt remediation in these cases, so don't try.
    76  					if e.Type.HasAttr(dep.MavenClassifier) || e.Type.HasAttr(dep.MavenArtifactType) {
    77  						return nil, fmt.Errorf("%w: cannot fix vulns in artifacts with classifier or type", common.ErrPatchImpossible)
    78  					}
    79  					vk := sg.Nodes[sg.Dependency].Version
    80  					if _, seen := seenVKs[vk]; !seen {
    81  						vkVulns[vk] = append(vkVulns[vk], &resolved.Vulns[i])
    82  						seenVKs[vk] = struct{}{}
    83  					}
    84  				}
    85  			}
    86  		}
    87  
    88  		if len(vkVulns) == 0 {
    89  			// All vulns have been fixed.
    90  			break
    91  		}
    92  
    93  		didPatch := false
    94  
    95  		// For each VersionKey, try fix as many of the vulns affecting it as possible.
    96  		for vk, vulnerabilities := range vkVulns {
    97  			// Consider vulns affecting packages we don't want to change unfixable
    98  			if opts.UpgradeConfig.Get(vk.Name) == upgrade.None {
    99  				continue
   100  			}
   101  
   102  			bestVK := vk
   103  			bestCount := len(vulnerabilities) // remaining vulns
   104  			versions, err := getVersionsGreater(ctx, cl, vk)
   105  			if err != nil {
   106  				return nil, err
   107  			}
   108  
   109  			// Find the minimal greater version that fixes as many vulnerabilities as possible.
   110  			for _, ver := range versions {
   111  				// Break if we've encountered a disallowed version update.
   112  				if _, diff, _ := vk.System.Semver().Difference(vk.Version, ver.Version); !opts.UpgradeConfig.Get(vk.Name).Allows(diff) {
   113  					break
   114  				}
   115  
   116  				// Count the remaining known vulns that affect this version.
   117  				count := 0 // remaining vulns
   118  				for _, rv := range vulnerabilities {
   119  					if vulns.IsAffected(rv.OSV, vulns.VKToPackage(ver.VersionKey)) {
   120  						count++
   121  					}
   122  				}
   123  				if count < bestCount {
   124  					// Found a new candidate.
   125  					bestCount = count
   126  					bestVK = ver.VersionKey
   127  					if bestCount == 0 { // stop if there are 0 vulns remaining
   128  						break
   129  					}
   130  				}
   131  			}
   132  
   133  			if bestCount < len(vulnerabilities) {
   134  				// Found a version that fixes some vulns.
   135  				if err := resolved.Manifest.PatchRequirement(resolve.RequirementVersion{VersionKey: bestVK}); err != nil {
   136  					return nil, err
   137  				}
   138  				didPatch = true
   139  			}
   140  		}
   141  
   142  		if !didPatch {
   143  			break
   144  		}
   145  
   146  		// Re-resolve the manifest
   147  		var err error
   148  		resolved.Graph, err = resolution.Resolve(ctx, cl, resolved.Manifest, opts.ResolutionOptions)
   149  		if err != nil {
   150  			return nil, err
   151  		}
   152  		resolved.UnfilteredVulns, err = resolution.FindVulnerabilities(ctx, ve, resolved.Manifest.Groups(), resolved.Graph)
   153  		if err != nil {
   154  			return nil, err
   155  		}
   156  		resolved.Vulns = slices.Clone(resolved.UnfilteredVulns)
   157  		resolved.Vulns = slices.DeleteFunc(resolved.Vulns, func(v resolution.Vulnerability) bool { return !remediation.MatchVuln(*opts, v) })
   158  	}
   159  
   160  	return resolved, nil
   161  }
   162  
   163  // getVersionsGreater gets the known versions of a package that are greater than the given version, sorted in ascending order.
   164  func getVersionsGreater(ctx context.Context, cl resolve.Client, vk resolve.VersionKey) ([]resolve.Version, error) {
   165  	// Get & sort all the valid versions of this package
   166  	versions, err := cl.Versions(ctx, vk.PackageKey)
   167  	if err != nil {
   168  		return nil, err
   169  	}
   170  	semvers := make(map[resolve.VersionKey]*semver.Version)
   171  	sv := vk.Semver()
   172  	for _, ver := range versions {
   173  		parsed, err := sv.Parse(ver.Version)
   174  		if err != nil {
   175  			log.Warnf("error parsing version %s: %v", parsed, err)
   176  			continue
   177  		}
   178  		semvers[ver.VersionKey] = parsed
   179  	}
   180  
   181  	cmpFunc := func(a, b resolve.Version) int {
   182  		if vk.System == resolve.Maven {
   183  			return mavenutil.CompareVersions(vk, semvers[a.VersionKey], semvers[b.VersionKey])
   184  		}
   185  
   186  		return sv.Compare(a.Version, b.Version)
   187  	}
   188  	if !slices.IsSortedFunc(versions, cmpFunc) {
   189  		versions = slices.Clone(versions)
   190  		slices.SortFunc(versions, cmpFunc)
   191  	}
   192  	// Find the index of the next higher version
   193  	offset, vkFound := slices.BinarySearchFunc(versions, resolve.Version{VersionKey: vk}, cmpFunc)
   194  	if vkFound { // if the given version somehow doesn't exist, offset will already be at the next higher version
   195  		offset++
   196  	}
   197  
   198  	return versions[offset:], nil
   199  }