github.com/devseccon/trivy@v0.47.1-0.20231123133102-bd902a0bd996/pkg/report/table/vulnerability.go (about)

     1  package table
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"path/filepath"
     7  	"sort"
     8  	"strings"
     9  	"sync"
    10  
    11  	"github.com/samber/lo"
    12  	"github.com/xlab/treeprint"
    13  	"golang.org/x/exp/maps"
    14  	"golang.org/x/exp/slices"
    15  
    16  	"github.com/aquasecurity/table"
    17  	"github.com/aquasecurity/tml"
    18  	dbTypes "github.com/aquasecurity/trivy-db/pkg/types"
    19  	ftypes "github.com/devseccon/trivy/pkg/fanal/types"
    20  	"github.com/devseccon/trivy/pkg/log"
    21  	"github.com/devseccon/trivy/pkg/types"
    22  )
    23  
    24  type vulnerabilityRenderer struct {
    25  	w           *bytes.Buffer
    26  	tableWriter *table.Table
    27  	result      types.Result
    28  	isTerminal  bool
    29  	tree        bool
    30  	severities  []dbTypes.Severity
    31  	once        *sync.Once
    32  }
    33  
    34  func NewVulnerabilityRenderer(result types.Result, isTerminal, tree bool, severities []dbTypes.Severity) *vulnerabilityRenderer {
    35  	buf := bytes.NewBuffer([]byte{})
    36  	if !isTerminal {
    37  		tml.DisableFormatting()
    38  	}
    39  	return &vulnerabilityRenderer{
    40  		w:           buf,
    41  		tableWriter: newTableWriter(buf, isTerminal),
    42  		result:      result,
    43  		isTerminal:  isTerminal,
    44  		tree:        tree,
    45  		severities:  severities,
    46  		once:        new(sync.Once),
    47  	}
    48  }
    49  
    50  func (r *vulnerabilityRenderer) Render() string {
    51  	r.setHeaders()
    52  	r.setVulnerabilityRows(r.result.Vulnerabilities)
    53  
    54  	severityCount := r.countSeverities(r.result.Vulnerabilities)
    55  	total, summaries := summarize(r.severities, severityCount)
    56  
    57  	target := r.result.Target
    58  	if r.result.Class == types.ClassLangPkg {
    59  		target += fmt.Sprintf(" (%s)", r.result.Type)
    60  	}
    61  	RenderTarget(r.w, target, r.isTerminal)
    62  	r.printf("Total: %d (%s)\n\n", total, strings.Join(summaries, ", "))
    63  
    64  	r.tableWriter.Render()
    65  	if r.tree {
    66  		r.renderDependencyTree()
    67  	}
    68  
    69  	return r.w.String()
    70  }
    71  
    72  func (r *vulnerabilityRenderer) setHeaders() {
    73  	if len(r.result.Vulnerabilities) == 0 {
    74  		return
    75  	}
    76  	header := []string{
    77  		"Library",
    78  		"Vulnerability",
    79  		"Severity",
    80  		"Status",
    81  		"Installed Version",
    82  		"Fixed Version",
    83  		"Title",
    84  	}
    85  	r.tableWriter.SetHeaders(header...)
    86  }
    87  
    88  func (r *vulnerabilityRenderer) setVulnerabilityRows(vulns []types.DetectedVulnerability) {
    89  	for _, v := range vulns {
    90  		lib := v.PkgName
    91  		if v.PkgPath != "" {
    92  			// get path to root jar
    93  			// for other languages return unchanged path
    94  			pkgPath := rootJarFromPath(v.PkgPath)
    95  			fileName := filepath.Base(pkgPath)
    96  			lib = fmt.Sprintf("%s (%s)", v.PkgName, fileName)
    97  			r.once.Do(func() {
    98  				log.Logger.Infof("Table result includes only package filenames. Use '--format json' option to get the full path to the package file.")
    99  			})
   100  		}
   101  
   102  		title := v.Title
   103  		if title == "" {
   104  			title = v.Description
   105  		}
   106  		splitTitle := strings.Split(title, " ")
   107  		if len(splitTitle) >= 12 {
   108  			title = strings.Join(splitTitle[:12], " ") + "..."
   109  		}
   110  
   111  		if len(v.PrimaryURL) > 0 {
   112  			if r.isTerminal {
   113  				title = tml.Sprintf("%s\n<blue>%s</blue>", title, v.PrimaryURL)
   114  			} else {
   115  				title = fmt.Sprintf("%s\n%s", title, v.PrimaryURL)
   116  			}
   117  		}
   118  
   119  		var row []string
   120  		if r.isTerminal {
   121  			row = []string{
   122  				lib,
   123  				v.VulnerabilityID,
   124  				ColorizeSeverity(v.Severity, v.Severity),
   125  				v.Status.String(),
   126  				v.InstalledVersion,
   127  				v.FixedVersion,
   128  				strings.TrimSpace(title),
   129  			}
   130  		} else {
   131  			row = []string{
   132  				lib,
   133  				v.VulnerabilityID,
   134  				v.Severity,
   135  				v.Status.String(),
   136  				v.InstalledVersion,
   137  				v.FixedVersion,
   138  				strings.TrimSpace(title),
   139  			}
   140  		}
   141  
   142  		r.tableWriter.AddRow(row...)
   143  	}
   144  }
   145  
   146  func (r *vulnerabilityRenderer) countSeverities(vulns []types.DetectedVulnerability) map[string]int {
   147  	severityCount := make(map[string]int)
   148  	for _, v := range vulns {
   149  		severityCount[v.Severity]++
   150  	}
   151  	return severityCount
   152  }
   153  
   154  func (r *vulnerabilityRenderer) renderDependencyTree() {
   155  	// Get parents of each dependency
   156  	parents := ftypes.Packages(r.result.Packages).ParentDeps()
   157  	if len(parents) == 0 {
   158  		return
   159  	}
   160  	ancestors := traverseAncestors(r.result.Packages, parents)
   161  
   162  	root := treeprint.NewWithRoot(fmt.Sprintf(`
   163  Dependency Origin Tree (Reversed)
   164  =================================
   165  %s`, r.result.Target))
   166  
   167  	// This count is next to the package ID.
   168  	// e.g. node-fetch@1.7.3 (MEDIUM: 2, HIGH: 1, CRITICAL: 3)
   169  	pkgSeverityCount := make(map[string]map[string]int)
   170  	for _, vuln := range r.result.Vulnerabilities {
   171  		cnts, ok := pkgSeverityCount[vuln.PkgID]
   172  		if !ok {
   173  			cnts = make(map[string]int)
   174  		}
   175  
   176  		cnts[vuln.Severity]++
   177  		pkgSeverityCount[vuln.PkgID] = cnts
   178  	}
   179  
   180  	// Extract vulnerable packages
   181  	vulnPkgs := lo.Filter(r.result.Packages, func(pkg ftypes.Package, _ int) bool {
   182  		return lo.ContainsBy(r.result.Vulnerabilities, func(vuln types.DetectedVulnerability) bool {
   183  			return pkg.ID == vuln.PkgID
   184  		})
   185  	})
   186  
   187  	// Render tree
   188  	for _, vulnPkg := range vulnPkgs {
   189  		_, summaries := summarize(r.severities, pkgSeverityCount[vulnPkg.ID])
   190  		topLvlID := tml.Sprintf("<red>%s, (%s)</red>", vulnPkg.ID, strings.Join(summaries, ", "))
   191  
   192  		branch := root.AddBranch(topLvlID)
   193  		addParents(branch, vulnPkg, parents, ancestors, map[string]struct{}{vulnPkg.ID: {}}, 1)
   194  
   195  	}
   196  	r.printf(root.String())
   197  }
   198  
   199  func (r *vulnerabilityRenderer) printf(format string, args ...interface{}) {
   200  	// nolint
   201  	_ = tml.Fprintf(r.w, format, args...)
   202  }
   203  
   204  func addParents(topItem treeprint.Tree, pkg ftypes.Package, parentMap map[string]ftypes.Packages, ancestors map[string][]string,
   205  	seen map[string]struct{}, depth int) {
   206  	if !pkg.Indirect {
   207  		return
   208  	}
   209  
   210  	roots := make(map[string]struct{})
   211  	for _, parent := range parentMap[pkg.ID] {
   212  		if _, ok := seen[parent.ID]; ok {
   213  			continue
   214  		}
   215  		seen[parent.ID] = struct{}{} // to avoid infinite loops
   216  
   217  		if depth == 1 && !parent.Indirect {
   218  			topItem.AddBranch(parent.ID)
   219  		} else {
   220  			// We omit intermediate dependencies and show only direct dependencies
   221  			// as this could make the dependency tree huge.
   222  			for _, ancestor := range ancestors[parent.ID] {
   223  				roots[ancestor] = struct{}{}
   224  			}
   225  		}
   226  	}
   227  
   228  	// Omitted
   229  	rootIDs := lo.Filter(maps.Keys(roots), func(pkgID string, _ int) bool {
   230  		_, ok := seen[pkgID]
   231  		return !ok
   232  	})
   233  	sort.Strings(rootIDs)
   234  	if len(rootIDs) > 0 {
   235  		branch := topItem.AddBranch("...(omitted)...")
   236  		for _, rootID := range rootIDs {
   237  			branch.AddBranch(rootID)
   238  		}
   239  	}
   240  }
   241  
   242  func traverseAncestors(pkgs []ftypes.Package, parentMap map[string]ftypes.Packages) map[string][]string {
   243  	ancestors := make(map[string][]string)
   244  	for _, pkg := range pkgs {
   245  		ancestors[pkg.ID] = findAncestor(pkg.ID, parentMap, make(map[string]struct{}))
   246  	}
   247  	return ancestors
   248  }
   249  
   250  func findAncestor(pkgID string, parentMap map[string]ftypes.Packages, seen map[string]struct{}) []string {
   251  	ancestors := make(map[string]struct{})
   252  	seen[pkgID] = struct{}{}
   253  	for _, parent := range parentMap[pkgID] {
   254  		if _, ok := seen[parent.ID]; ok {
   255  			continue
   256  		}
   257  		switch {
   258  		case !parent.Indirect:
   259  			ancestors[parent.ID] = struct{}{}
   260  		case len(parentMap[parent.ID]) == 0:
   261  			// Direct dependencies cannot be identified in some package managers like "package-lock.json" v1,
   262  			// then the "Indirect" field can be always true. We try to guess direct dependencies in this case.
   263  			// A dependency with no parents must be a direct dependency.
   264  			//
   265  			// e.g.
   266  			//   -> styled-components
   267  			//     -> fbjs
   268  			//       -> isomorphic-fetch
   269  			//         -> node-fetch
   270  			//
   271  			// Even if `styled-components` is not marked as a direct dependency, it must be a direct dependency
   272  			// as it has no parents. Note that it doesn't mean `fbjs` is an indirect dependency.
   273  			ancestors[parent.ID] = struct{}{}
   274  		default:
   275  			for _, a := range findAncestor(parent.ID, parentMap, seen) {
   276  				ancestors[a] = struct{}{}
   277  			}
   278  		}
   279  	}
   280  	return maps.Keys(ancestors)
   281  }
   282  
   283  var jarExtensions = []string{
   284  	".jar",
   285  	".war",
   286  	".par",
   287  	".ear",
   288  }
   289  
   290  func rootJarFromPath(path string) string {
   291  	// File paths are always forward-slashed in Trivy
   292  	paths := strings.Split(path, "/")
   293  	for i, p := range paths {
   294  		if slices.Contains(jarExtensions, filepath.Ext(p)) {
   295  			return strings.Join(paths[:i+1], "/")
   296  		}
   297  	}
   298  	return path
   299  }