github.com/google/osv-scalibr@v0.4.1/guidedremediation/internal/strategy/common/common.go (about)

     1  // Copyright 2025 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package common implements functions common to multiple remediation strategies.
    16  package common
    17  
    18  import (
    19  	"errors"
    20  	"slices"
    21  
    22  	"github.com/google/osv-scalibr/guidedremediation/internal/remediation"
    23  	"github.com/google/osv-scalibr/guidedremediation/result"
    24  	"github.com/google/osv-scalibr/log"
    25  )
    26  
    27  // PatchFunc is a bound function that attempts to patch the given vulns.
    28  type PatchFunc func([]string) StrategyResult
    29  
    30  // StrategyResult is the result of a remediation strategy.
    31  type StrategyResult struct {
    32  	VulnIDs  []string
    33  	Resolved *remediation.ResolvedManifest
    34  	Err      error
    35  }
    36  
    37  // ErrPatchImpossible is returned when no patch is possible for the vulns.
    38  var ErrPatchImpossible = errors.New("cannot find a patch for the vulns")
    39  
    40  // PatchResult is the result of computing patches.
    41  type PatchResult struct {
    42  	Patches  []result.Patch                  // the list of unique patches
    43  	Resolved []*remediation.ResolvedManifest // the resolved manifest after each patch is applied
    44  }
    45  
    46  // ComputePatches attempts to resolve each vulnerability found in the resolved manifest independently,
    47  // returning the list of unique possible patches.
    48  // Vulnerabilities are resolved by calling patchFunc for each vulnerability.
    49  // If a patch introduces new vulnerabilities, additional patches are attempted for the new vulnerabilities.
    50  // If groupIntroduced is true, introduced vulns are all attempted to be patched together.
    51  // Otherwise, they are patched one-by-one independently.
    52  func ComputePatches(patchFunc PatchFunc, resolved *remediation.ResolvedManifest, groupIntroduced bool) (PatchResult, error) {
    53  	ch := make(chan StrategyResult)
    54  	doPatch := func(vulnIDs []string) {
    55  		ch <- patchFunc(vulnIDs)
    56  	}
    57  
    58  	toProcess := 0
    59  	for _, v := range resolved.Vulns {
    60  		go doPatch([]string{v.OSV.Id})
    61  		toProcess++
    62  	}
    63  
    64  	type patchRes struct {
    65  		patch    result.Patch
    66  		resolved *remediation.ResolvedManifest
    67  	}
    68  
    69  	var allResults []patchRes
    70  	for toProcess > 0 {
    71  		r := <-ch
    72  		toProcess--
    73  		if r.Err != nil {
    74  			if !errors.Is(r.Err, ErrPatchImpossible) {
    75  				log.Warnf("error attempting to patch for vulns %v: %v", r.VulnIDs, r.Err)
    76  			}
    77  			continue
    78  		}
    79  
    80  		patch := remediation.ConstructPatches(resolved, r.Resolved)
    81  		if len(patch.PackageUpdates) == 0 {
    82  			continue
    83  		}
    84  		allResults = append(allResults, patchRes{patch: patch, resolved: r.Resolved})
    85  
    86  		// If there are any new vulns, try patching them as well
    87  		var newlyAdded []string
    88  		for _, v := range patch.Introduced {
    89  			if !slices.Contains(r.VulnIDs, v.ID) {
    90  				newlyAdded = append(newlyAdded, v.ID)
    91  			}
    92  		}
    93  		if len(newlyAdded) > 0 {
    94  			if groupIntroduced {
    95  				// If we group introduced vulns, try patch them all together.
    96  				go doPatch(append(r.VulnIDs, newlyAdded...)) // No need to clone r.VulnIDs here
    97  				toProcess++
    98  			} else {
    99  				// If we don't group introduced vulns, try patch individually.
   100  				// This can cause every permutation of introduced vulns to be computed.
   101  				for _, v := range newlyAdded {
   102  					go doPatch(append(slices.Clone(r.VulnIDs), v))
   103  					toProcess++
   104  				}
   105  			}
   106  		}
   107  	}
   108  
   109  	// Sort and remove duplicate patches
   110  	cmpFn := func(a, b patchRes) int { return a.patch.Compare(b.patch, resolved.Manifest.System().Semver()) }
   111  	slices.SortFunc(allResults, cmpFn)
   112  	allResults = slices.CompactFunc(allResults, func(a, b patchRes) bool { return cmpFn(a, b) == 0 })
   113  
   114  	var output PatchResult
   115  	output.Patches = make([]result.Patch, 0, len(allResults))
   116  	output.Resolved = make([]*remediation.ResolvedManifest, 0, len(allResults))
   117  	for _, r := range allResults {
   118  		output.Patches = append(output.Patches, r.patch)
   119  		output.Resolved = append(output.Resolved, r.resolved)
   120  	}
   121  
   122  	return output, nil
   123  }