github.com/google/osv-scalibr@v0.4.1/guidedremediation/internal/tui/model/state_choose_strategy.go (about)

     1  // Copyright 2025 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package model
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"slices"
    21  	"strconv"
    22  	"strings"
    23  
    24  	"deps.dev/util/resolve"
    25  	"github.com/charmbracelet/bubbles/key"
    26  	"github.com/charmbracelet/bubbles/textinput"
    27  	tea "github.com/charmbracelet/bubbletea"
    28  	"github.com/google/osv-scalibr/guidedremediation/internal/remediation"
    29  	"github.com/google/osv-scalibr/guidedremediation/internal/resolution"
    30  	"github.com/google/osv-scalibr/guidedremediation/internal/strategy/inplace"
    31  	"github.com/google/osv-scalibr/guidedremediation/internal/tui/components"
    32  	"github.com/google/osv-scalibr/guidedremediation/options"
    33  	"github.com/google/osv-scalibr/guidedremediation/result"
    34  )
    35  
    36  type stateChooseStrategy struct {
    37  	cursorPos chooseStratCursorPos
    38  	canRelock bool
    39  
    40  	vulnCount vulnCount
    41  
    42  	vulnList       components.ViewModel
    43  	inPlaceInfo    components.ViewModel
    44  	relockFixVulns components.ViewModel
    45  	errorsView     components.ViewModel
    46  
    47  	depthInput    textinput.Model
    48  	severityInput textinput.Model
    49  
    50  	focusedInfo components.ViewModel // the infoview that is currently focused, nil if not focused
    51  }
    52  
    53  type chooseStratCursorPos int
    54  
    55  const (
    56  	chooseStratInfo chooseStratCursorPos = iota
    57  	chooseStratErrors
    58  	chooseStratInPlace
    59  	chooseStratRelock
    60  	chooseStratDepth
    61  	chooseStratSeverity
    62  	chooseStratDev
    63  	chooseStratApplyCriteria
    64  	chooseStratQuit
    65  	chooseStratEnd
    66  )
    67  
    68  func newStateChooseStrategy(m Model) stateChooseStrategy {
    69  	s := stateChooseStrategy{
    70  		cursorPos: chooseStratInPlace,
    71  	}
    72  
    73  	// pre-generate the info views for each option
    74  	s.vulnList = components.NewVulnList(m.lockfileGraph.Vulns, "", m.detailsRenderer)
    75  
    76  	// make the in-place view
    77  	s.inPlaceInfo = components.NewInPlaceInfo(m.lockfilePatches, m.lockfileGraph.Vulns, m.detailsRenderer)
    78  
    79  	if m.options.Manifest != "" {
    80  		var relockFixes []resolution.Vulnerability
    81  		for _, v := range m.lockfileGraph.Vulns {
    82  			if !slices.ContainsFunc(m.relockBaseManifest.Vulns, func(r resolution.Vulnerability) bool {
    83  				return r.OSV.Id == v.OSV.Id
    84  			}) {
    85  				relockFixes = append(relockFixes, v)
    86  			}
    87  		}
    88  		s.canRelock = true
    89  		s.relockFixVulns = components.NewVulnList(relockFixes, "Relocking fixes the following vulns:", m.detailsRenderer)
    90  	} else {
    91  		s.canRelock = false
    92  		s.relockFixVulns = components.TextView("Re-run with manifest to resolve vulnerabilities by re-locking")
    93  	}
    94  
    95  	s.depthInput = textinput.New()
    96  	s.depthInput.CharLimit = 3
    97  	s.depthInput.SetValue(strconv.Itoa(m.options.MaxDepth))
    98  
    99  	s.severityInput = textinput.New()
   100  	s.severityInput.CharLimit = 4
   101  	s.severityInput.SetValue(strconv.FormatFloat(m.options.MinSeverity, 'g', -1, 64))
   102  
   103  	s.errorsView = makeErrorsView(m.relockBaseErrors)
   104  
   105  	s.vulnCount = countVulns(m.lockfileGraph.Vulns, m.options.RemediationOptions)
   106  
   107  	return s
   108  }
   109  
   110  func (s stateChooseStrategy) Init(m Model) tea.Cmd {
   111  	return nil
   112  }
   113  
   114  func (s stateChooseStrategy) Update(m Model, msg tea.Msg) (tea.Model, tea.Cmd) {
   115  	var cmds []tea.Cmd
   116  	switch msg := msg.(type) {
   117  	case components.ViewModelCloseMsg:
   118  		// info view wants to quit, just unfocus it
   119  		s.focusedInfo = nil
   120  	case tea.KeyMsg:
   121  		switch {
   122  		case key.Matches(msg, components.Keys.SwitchView):
   123  			if s.IsInfoFocused() {
   124  				s.focusedInfo = nil
   125  			} else if view, canFocus := s.currentInfoView(); canFocus {
   126  				s.focusedInfo = view
   127  			}
   128  		case s.IsInfoFocused():
   129  			var cmd tea.Cmd
   130  			s.focusedInfo, cmd = s.focusedInfo.Update(msg)
   131  
   132  			return m, cmd
   133  		case key.Matches(msg, components.Keys.Quit):
   134  			// only quit if the cursor is over the quit line
   135  			if s.cursorPos == chooseStratQuit {
   136  				return m, tea.Quit
   137  			}
   138  			// otherwise move the cursor to the quit line if it's not already there
   139  			s.cursorPos = chooseStratQuit
   140  		case key.Matches(msg, components.Keys.Select):
   141  			// enter key was pressed, parse input
   142  			return s.parseInput(m)
   143  		// move the cursor and show the corresponding info view
   144  		case key.Matches(msg, components.Keys.Up):
   145  			if s.cursorPos > chooseStratInfo {
   146  				s.cursorPos--
   147  				// Resolution errors aren't rendered if there are none
   148  				if s.cursorPos == chooseStratErrors && len(m.relockBaseErrors) == 0 {
   149  					s.cursorPos--
   150  				}
   151  			}
   152  			s = s.UpdateTextFocus()
   153  		case key.Matches(msg, components.Keys.Down):
   154  			if s.cursorPos < chooseStratEnd-1 {
   155  				s.cursorPos++
   156  				if s.cursorPos == chooseStratErrors && len(m.relockBaseErrors) == 0 {
   157  					s.cursorPos++
   158  				}
   159  			}
   160  			s = s.UpdateTextFocus()
   161  		}
   162  	}
   163  
   164  	var cmd tea.Cmd
   165  	s.depthInput, cmd = s.depthInput.Update(msg)
   166  	cmds = append(cmds, cmd)
   167  
   168  	s.severityInput, cmd = s.severityInput.Update(msg)
   169  	cmds = append(cmds, cmd)
   170  
   171  	m.st = s
   172  	return m, tea.Batch(cmds...)
   173  }
   174  
   175  func (s stateChooseStrategy) UpdateTextFocus() stateChooseStrategy {
   176  	s.depthInput.Blur()
   177  	s.severityInput.Blur()
   178  
   179  	switch s.cursorPos {
   180  	case chooseStratDepth:
   181  		s.depthInput.Focus()
   182  	case chooseStratSeverity:
   183  		s.severityInput.Focus()
   184  	case
   185  		chooseStratInfo,
   186  		chooseStratErrors,
   187  		chooseStratInPlace,
   188  		chooseStratRelock,
   189  		chooseStratDev,
   190  		chooseStratApplyCriteria,
   191  		chooseStratQuit,
   192  		chooseStratEnd:
   193  	}
   194  	return s
   195  }
   196  
   197  func (s stateChooseStrategy) IsInfoFocused() bool {
   198  	return s.focusedInfo != nil
   199  }
   200  
   201  func (s stateChooseStrategy) currentInfoView() (view components.ViewModel, canFocus bool) {
   202  	switch s.cursorPos {
   203  	case chooseStratInfo: // info line
   204  		return s.vulnList, true
   205  	case chooseStratErrors:
   206  		return s.errorsView, false
   207  	case chooseStratInPlace: // in-place
   208  		return s.inPlaceInfo, true
   209  	case chooseStratRelock: // relock
   210  		return s.relockFixVulns, s.canRelock
   211  	case chooseStratQuit: // quit
   212  		return components.TextView("Exit Guided Remediation"), false
   213  	case
   214  		chooseStratDepth,
   215  		chooseStratSeverity,
   216  		chooseStratDev,
   217  		chooseStratApplyCriteria,
   218  		chooseStratEnd:
   219  		fallthrough
   220  	default:
   221  		return components.TextView(""), false
   222  	}
   223  }
   224  
   225  func (s stateChooseStrategy) parseInput(m Model) (tea.Model, tea.Cmd) {
   226  	var cmd tea.Cmd
   227  	switch s.cursorPos {
   228  	case chooseStratInfo: // info line, focus on info view
   229  		s.focusedInfo = s.vulnList
   230  		m.st = s
   231  	case chooseStratInPlace: // in-place
   232  		m.st = newStateInPlaceResult(m, s.inPlaceInfo, nil)
   233  		cmd = m.st.Init(m)
   234  	case chooseStratRelock: // relock
   235  		if s.canRelock {
   236  			m.st = newStateRelockResult(m)
   237  			cmd = m.st.Init(m)
   238  		}
   239  	case chooseStratDev:
   240  		m.options.DevDeps = !m.options.DevDeps
   241  	case chooseStratApplyCriteria:
   242  		maxDepth, err := strconv.Atoi(s.depthInput.Value())
   243  		if err == nil {
   244  			m.options.MaxDepth = maxDepth
   245  		}
   246  
   247  		minSeverity, err := strconv.ParseFloat(s.severityInput.Value(), 64)
   248  		if err == nil {
   249  			m.options.MinSeverity = minSeverity
   250  		}
   251  
   252  		// Recompute vulns/patches with the new filters.
   253  		fn := func(v resolution.Vulnerability) bool { return !remediation.MatchVuln(m.options.RemediationOptions, v) }
   254  		m.lockfileGraph.Vulns = slices.Clone(m.lockfileGraph.UnfilteredVulns)
   255  		m.lockfileGraph.Vulns = slices.DeleteFunc(m.lockfileGraph.Vulns, fn)
   256  		m.lockfilePatches, err = inplace.ComputePatches(context.Background(), m.options.ResolveClient, m.lockfileGraph, &m.options.RemediationOptions)
   257  		if err != nil {
   258  			return errorAndExit(m, err)
   259  		}
   260  		if m.relockBaseManifest != nil {
   261  			m.relockBaseManifest.Vulns = slices.Clone(m.relockBaseManifest.UnfilteredVulns)
   262  			m.relockBaseManifest.Vulns = slices.DeleteFunc(m.relockBaseManifest.Vulns, fn)
   263  		}
   264  
   265  		m.st = newStateChooseStrategy(m)
   266  		cmd = m.st.Init(m)
   267  	case chooseStratQuit: // quit line
   268  		cmd = tea.Quit
   269  	case
   270  		chooseStratErrors,
   271  		chooseStratDepth,
   272  		chooseStratSeverity,
   273  		chooseStratEnd:
   274  	}
   275  
   276  	return m, cmd
   277  }
   278  
   279  func (s stateChooseStrategy) View(m Model) string {
   280  	vulnCount := s.vulnCount
   281  	fixCount := vulnCount.total
   282  	pkgChange := 0
   283  	for _, p := range m.lockfilePatches {
   284  		fixCount -= len(p.Fixed)
   285  		pkgChange += len(p.PackageUpdates)
   286  	}
   287  
   288  	sb := strings.Builder{}
   289  	sb.WriteString(components.RenderSelectorOption(
   290  		s.cursorPos == chooseStratInfo,
   291  		"",
   292  		fmt.Sprintf("Found %%s in lockfile (%d direct, %d transitive, %d dev only) matching the criteria.\n",
   293  			vulnCount.direct, vulnCount.transitive, vulnCount.devOnly),
   294  		fmt.Sprintf("%d vulnerabilities", vulnCount.total),
   295  	))
   296  	if len(m.relockBaseErrors) > 0 {
   297  		sb.WriteString(components.RenderSelectorOption(
   298  			s.cursorPos == chooseStratErrors,
   299  			"",
   300  			"WARNING: Encountered %s during graph resolution.\n",
   301  			fmt.Sprintf("%d errors", len(m.relockBaseErrors)),
   302  		))
   303  	}
   304  	sb.WriteString("\n")
   305  	sb.WriteString("Actions:\n")
   306  	sb.WriteString(components.RenderSelectorOption(
   307  		s.cursorPos == chooseStratInPlace,
   308  		" > ",
   309  		fmt.Sprintf("%%s (fixes %d/%d vulns, changes %d packages)\n", fixCount, vulnCount.total, pkgChange),
   310  		"Modify lockfile in-place",
   311  	))
   312  
   313  	if s.canRelock {
   314  		relockFix := vulnCount.total - len(m.relockBaseManifest.Vulns)
   315  		sb.WriteString(components.RenderSelectorOption(
   316  			s.cursorPos == chooseStratRelock,
   317  			" > ",
   318  			fmt.Sprintf("%%s (fixes %d/%d vulns) and try direct dependency upgrades\n", relockFix, vulnCount.total),
   319  			"Re-lock project",
   320  		))
   321  	} else {
   322  		sb.WriteString(components.RenderSelectorOption(
   323  			s.cursorPos == chooseStratRelock,
   324  			" > ",
   325  			components.DisabledTextStyle.Render("Cannot re-lock - missing manifest file\n"),
   326  		))
   327  	}
   328  	sb.WriteString("\n")
   329  	sb.WriteString("Criteria:\n")
   330  	sb.WriteString(components.RenderSelectorOption(
   331  		s.cursorPos == chooseStratDepth,
   332  		" > ",
   333  		fmt.Sprintf("%%s: %s\n", s.depthInput.View()),
   334  		"Max dependency depth",
   335  	))
   336  	sb.WriteString(components.RenderSelectorOption(
   337  		s.cursorPos == chooseStratSeverity,
   338  		" > ",
   339  		fmt.Sprintf("%%s: %s\n", s.severityInput.View()),
   340  		"Min CVSS score",
   341  	))
   342  
   343  	devString := "YES"
   344  	if m.options.DevDeps {
   345  		devString = "NO"
   346  	}
   347  	sb.WriteString(components.RenderSelectorOption(
   348  		s.cursorPos == chooseStratDev,
   349  		" > ",
   350  		fmt.Sprintf("%%s: %s\n", devString),
   351  		"Exclude dev only",
   352  	))
   353  	sb.WriteString(components.RenderSelectorOption(
   354  		s.cursorPos == chooseStratApplyCriteria,
   355  		" > ",
   356  		"%s\n",
   357  		"Apply criteria",
   358  	))
   359  
   360  	sb.WriteString("\n")
   361  	sb.WriteString(components.RenderSelectorOption(
   362  		s.cursorPos == chooseStratQuit,
   363  		"> ",
   364  		"%s\n",
   365  		"quit",
   366  	))
   367  
   368  	return sb.String()
   369  }
   370  
   371  func (s stateChooseStrategy) InfoView() string {
   372  	v, _ := s.currentInfoView()
   373  	return v.View()
   374  }
   375  
   376  func (s stateChooseStrategy) Resize(_, _ int) modelState { return s }
   377  
   378  func (s stateChooseStrategy) ResizeInfo(w, h int) modelState {
   379  	s.vulnList = s.vulnList.Resize(w, h)
   380  	s.inPlaceInfo = s.inPlaceInfo.Resize(w, h)
   381  	s.relockFixVulns = s.relockFixVulns.Resize(w, h)
   382  
   383  	return s
   384  }
   385  
   386  type vulnCount struct {
   387  	total      int
   388  	direct     int
   389  	transitive int
   390  	devOnly    int
   391  }
   392  
   393  func countVulns(vulns []resolution.Vulnerability, opts options.RemediationOptions) vulnCount {
   394  	var vc vulnCount
   395  	for _, v := range vulns {
   396  		// count vulns per in-place, i.e. unique per ID & package version.
   397  		seen := make(map[resolve.VersionKey]struct{})
   398  		nonDev := make(map[resolve.VersionKey]struct{})
   399  		seenAsDirect := make(map[resolve.VersionKey]struct{})
   400  		for _, sg := range v.Subgraphs {
   401  			devOnly := sg.IsDevOnly(nil)
   402  			// check if the vulnerability should be filtered out.
   403  			if !remediation.MatchVuln(opts, resolution.Vulnerability{
   404  				OSV:       v.OSV,
   405  				Subgraphs: []*resolution.DependencySubgraph{sg},
   406  				DevOnly:   devOnly,
   407  			}) {
   408  				continue
   409  			}
   410  			node := sg.Nodes[sg.Dependency]
   411  			vk := node.Version
   412  			seen[vk] = struct{}{}
   413  			if slices.ContainsFunc(node.Parents, func(e resolve.Edge) bool { return e.From == 0 }) {
   414  				seenAsDirect[vk] = struct{}{}
   415  			}
   416  			if !devOnly {
   417  				nonDev[vk] = struct{}{}
   418  			}
   419  		}
   420  		for vk := range seen {
   421  			vc.total++
   422  			if _, ok := nonDev[vk]; !ok {
   423  				vc.devOnly++
   424  			}
   425  			if _, ok := seenAsDirect[vk]; ok {
   426  				vc.direct++
   427  			} else {
   428  				vc.transitive++
   429  			}
   430  		}
   431  	}
   432  
   433  	return vc
   434  }
   435  
   436  func makeErrorsView(errs []result.ResolveError) components.ViewModel {
   437  	if len(errs) == 0 {
   438  		return components.TextView("")
   439  	}
   440  
   441  	s := strings.Builder{}
   442  	s.WriteString("The following errors were encountered during resolution which may impact results:\n")
   443  	for _, e := range errs {
   444  		fmt.Fprintf(&s, "Error when resolving %s@%s:\n", e.Package.Name, e.Package.Version)
   445  		if strings.Contains(e.Requirement.Version, ":") {
   446  			// this will be the case with unsupported npm requirements e.g. `file:...`, `git+https://...`
   447  			// No easy access to the `knownAs` field to find which package this corresponds to...
   448  			fmt.Fprintf(&s, "\tSkipped resolving unsupported version specification: %s\n", e.Requirement.Version)
   449  		} else {
   450  			fmt.Fprintf(&s, "\t%v: %s@%s\n", e.Error, e.Requirement.Name, e.Requirement.Version)
   451  		}
   452  	}
   453  	return components.TextView(s.String())
   454  }