github.com/google/osv-scalibr@v0.4.1/guidedremediation/internal/remediation/remediation.go (about)

     1  // Copyright 2025 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package remediation has the vulnerability remediation implementations.
    16  package remediation
    17  
    18  import (
    19  	"cmp"
    20  	"context"
    21  	"slices"
    22  
    23  	"deps.dev/util/resolve"
    24  	"deps.dev/util/resolve/dep"
    25  	"github.com/google/osv-scalibr/enricher"
    26  	"github.com/google/osv-scalibr/guidedremediation/internal/manifest"
    27  	"github.com/google/osv-scalibr/guidedremediation/internal/resolution"
    28  	"github.com/google/osv-scalibr/guidedremediation/options"
    29  	"github.com/google/osv-scalibr/guidedremediation/result"
    30  	"github.com/google/osv-scalibr/internal/mavenutil"
    31  )
    32  
    33  // ResolvedGraph is a dependency graph and the vulnerabilities found in it.
    34  type ResolvedGraph struct {
    35  	Graph           *resolve.Graph
    36  	Vulns           []resolution.Vulnerability
    37  	UnfilteredVulns []resolution.Vulnerability
    38  }
    39  
    40  // ResolvedManifest is a manifest, its resolved dependency graph, and the vulnerabilities found in it.
    41  type ResolvedManifest struct {
    42  	ResolvedGraph
    43  
    44  	Manifest manifest.Manifest
    45  }
    46  
    47  // ResolveManifest resolves and find vulnerabilities in a manifest.
    48  func ResolveManifest(ctx context.Context, cl resolve.Client, ve enricher.Enricher, m manifest.Manifest, opts *options.RemediationOptions) (*ResolvedManifest, error) {
    49  	g, err := resolution.Resolve(ctx, cl, m, opts.ResolutionOptions)
    50  	if err != nil {
    51  		return nil, err
    52  	}
    53  
    54  	resGraph, err := ResolveGraphVulns(ctx, cl, ve, g, m.Groups(), opts)
    55  	if err != nil {
    56  		return nil, err
    57  	}
    58  
    59  	return &ResolvedManifest{
    60  		Manifest:      m,
    61  		ResolvedGraph: resGraph,
    62  	}, nil
    63  }
    64  
    65  // ResolveGraphVulns finds the vulnerabilities in a graph.
    66  func ResolveGraphVulns(ctx context.Context, cl resolve.Client, ve enricher.Enricher, g *resolve.Graph, depGroups map[manifest.RequirementKey][]string, opts *options.RemediationOptions) (ResolvedGraph, error) {
    67  	allVulns, err := resolution.FindVulnerabilities(ctx, ve, depGroups, g)
    68  	if err != nil {
    69  		return ResolvedGraph{}, err
    70  	}
    71  
    72  	// If explicit vulns are set, add the others to ignored vulns.
    73  	if len(opts.ExplicitVulns) > 0 {
    74  		for _, v := range allVulns {
    75  			if !slices.Contains(opts.ExplicitVulns, v.OSV.Id) {
    76  				opts.IgnoreVulns = append(opts.IgnoreVulns, v.OSV.Id)
    77  			}
    78  		}
    79  	}
    80  
    81  	filteredVulns := slices.Clone(allVulns)
    82  	filteredVulns = slices.DeleteFunc(filteredVulns, func(v resolution.Vulnerability) bool { return !MatchVuln(*opts, v) })
    83  	return ResolvedGraph{
    84  		Graph:           g,
    85  		Vulns:           filteredVulns,
    86  		UnfilteredVulns: allVulns,
    87  	}, nil
    88  }
    89  
    90  // ConstructPatches computes the effective Patches that were applied to oldRes to get newRes.
    91  func ConstructPatches(oldRes, newRes *ResolvedManifest) result.Patch {
    92  	fixedVulns := make(map[string]*resolution.Vulnerability)
    93  	for _, v := range oldRes.Vulns {
    94  		fixedVulns[v.OSV.Id] = &v
    95  	}
    96  	introducedVulns := make(map[string]*resolution.Vulnerability)
    97  	for _, v := range newRes.Vulns {
    98  		if _, ok := fixedVulns[v.OSV.Id]; !ok {
    99  			introducedVulns[v.OSV.Id] = &v
   100  		} else {
   101  			delete(fixedVulns, v.OSV.Id)
   102  		}
   103  	}
   104  
   105  	var output result.Patch
   106  	output.Fixed = make([]result.Vuln, 0, len(fixedVulns))
   107  	for _, v := range fixedVulns {
   108  		vuln := result.Vuln{ID: v.OSV.Id}
   109  		for _, sg := range v.Subgraphs {
   110  			n := oldRes.Graph.Nodes[sg.Dependency]
   111  			vuln.Packages = append(vuln.Packages, result.Package{Name: n.Version.Name, Version: n.Version.Version})
   112  		}
   113  		output.Fixed = append(output.Fixed, vuln)
   114  	}
   115  	slices.SortFunc(output.Fixed, func(a, b result.Vuln) int { return cmp.Compare(a.ID, b.ID) })
   116  
   117  	if len(introducedVulns) > 0 {
   118  		output.Introduced = make([]result.Vuln, 0, len(introducedVulns))
   119  	}
   120  	for _, v := range introducedVulns {
   121  		vuln := result.Vuln{ID: v.OSV.Id}
   122  		for _, sg := range v.Subgraphs {
   123  			n := newRes.Graph.Nodes[sg.Dependency]
   124  			vuln.Packages = append(vuln.Packages, result.Package{Name: n.Version.Name, Version: n.Version.Version})
   125  		}
   126  		output.Introduced = append(output.Introduced, vuln)
   127  	}
   128  	slices.SortFunc(output.Introduced, func(a, b result.Vuln) int { return cmp.Compare(a.ID, b.ID) })
   129  
   130  	oldReqs := make(map[manifest.RequirementKey]resolve.RequirementVersion)
   131  	for _, req := range oldRes.Manifest.Requirements() {
   132  		oldReqs[resolution.MakeRequirementKey(req)] = req
   133  	}
   134  	for _, req := range newRes.Manifest.Requirements() {
   135  		oldReq, ok := oldReqs[resolution.MakeRequirementKey(req)]
   136  		if !ok {
   137  			typ := dep.NewType()
   138  			typ.AddAttr(dep.MavenDependencyOrigin, mavenutil.OriginManagement)
   139  			output.PackageUpdates = append(output.PackageUpdates, result.PackageUpdate{
   140  				Name:        req.Name,
   141  				VersionFrom: "",
   142  				VersionTo:   req.Version,
   143  				Type:        typ,
   144  				Transitive:  true,
   145  			})
   146  			continue
   147  		}
   148  		if req.Version == oldReq.Version {
   149  			continue
   150  		}
   151  
   152  		// In Maven, a dependency can be in both <dependencies> and <dependencyManagement>.
   153  		// To work out if this is direct or transitive, we need to check if this is appears the regular dependencies.
   154  		direct := slices.ContainsFunc(oldRes.Manifest.Requirements(), func(r resolve.RequirementVersion) bool {
   155  			if r.Name != req.Name {
   156  				return false
   157  			}
   158  			origin, _ := r.Type.GetAttr(dep.MavenDependencyOrigin)
   159  			return origin != mavenutil.OriginManagement
   160  		})
   161  
   162  		output.PackageUpdates = append(output.PackageUpdates, result.PackageUpdate{
   163  			Name:        req.Name,
   164  			VersionFrom: oldReq.Version,
   165  			VersionTo:   req.Version,
   166  			Type:        oldReq.Type.Clone(),
   167  			Transitive:  !direct,
   168  		})
   169  	}
   170  	cmpFn := func(a, b result.PackageUpdate) int {
   171  		if c := cmp.Compare(a.Name, b.Name); c != 0 {
   172  			return c
   173  		}
   174  		if c := cmp.Compare(a.VersionFrom, b.VersionFrom); c != 0 {
   175  			return c
   176  		}
   177  		if c := cmp.Compare(a.VersionTo, b.VersionTo); c != 0 {
   178  			return c
   179  		}
   180  		return a.Type.Compare(b.Type)
   181  	}
   182  	slices.SortFunc(output.PackageUpdates, cmpFn)
   183  	// It's possible something is in the requirements twice (e.g. with Maven dependencyManagement)
   184  	// Deduplicate the patches in this case.
   185  	output.PackageUpdates = slices.CompactFunc(output.PackageUpdates, func(a, b result.PackageUpdate) bool {
   186  		return cmpFn(a, b) == 0
   187  	})
   188  
   189  	return output
   190  }