zotregistry.dev/zot@v1.4.4-0.20240314164342-eec277e14d20/pkg/extensions/search/cve/pagination.go (about)

     1  package cveinfo
     2  
     3  import (
     4  	"fmt"
     5  	"sort"
     6  
     7  	zerr "zotregistry.dev/zot/errors"
     8  	"zotregistry.dev/zot/pkg/common"
     9  	cvemodel "zotregistry.dev/zot/pkg/extensions/search/cve/model"
    10  )
    11  
    12  const (
    13  	AlphabeticAsc = cvemodel.SortCriteria("ALPHABETIC_ASC")
    14  	AlphabeticDsc = cvemodel.SortCriteria("ALPHABETIC_DSC")
    15  	SeverityDsc   = cvemodel.SortCriteria("SEVERITY")
    16  )
    17  
    18  func SortFunctions() map[cvemodel.SortCriteria]func(pageBuffer []cvemodel.CVE) func(i, j int) bool {
    19  	return map[cvemodel.SortCriteria]func(pageBuffer []cvemodel.CVE) func(i, j int) bool{
    20  		AlphabeticAsc: SortByAlphabeticAsc,
    21  		AlphabeticDsc: SortByAlphabeticDsc,
    22  		SeverityDsc:   SortBySeverity,
    23  	}
    24  }
    25  
    26  func SortByAlphabeticAsc(pageBuffer []cvemodel.CVE) func(i, j int) bool {
    27  	return func(i, j int) bool {
    28  		return pageBuffer[i].ID < pageBuffer[j].ID
    29  	}
    30  }
    31  
    32  func SortByAlphabeticDsc(pageBuffer []cvemodel.CVE) func(i, j int) bool {
    33  	return func(i, j int) bool {
    34  		return pageBuffer[i].ID > pageBuffer[j].ID
    35  	}
    36  }
    37  
    38  func SortBySeverity(pageBuffer []cvemodel.CVE) func(i, j int) bool {
    39  	return func(i, j int) bool {
    40  		if cvemodel.CompareSeverities(pageBuffer[i].Severity, pageBuffer[j].Severity) == 0 {
    41  			return pageBuffer[i].ID < pageBuffer[j].ID
    42  		}
    43  
    44  		return cvemodel.CompareSeverities(pageBuffer[i].Severity, pageBuffer[j].Severity) < 0
    45  	}
    46  }
    47  
    48  // PageFinder permits keeping a pool of objects using Add
    49  // and returning a specific page.
    50  type PageFinder interface {
    51  	Add(cve cvemodel.CVE)
    52  	Page() ([]cvemodel.CVE, common.PageInfo)
    53  	Reset()
    54  }
    55  
    56  // CvePageFinder implements PageFinder. It manages Cve objects and calculates the page
    57  // using the given limit, offset and sortBy option.
    58  type CvePageFinder struct {
    59  	limit      int
    60  	offset     int
    61  	sortBy     cvemodel.SortCriteria
    62  	pageBuffer []cvemodel.CVE
    63  }
    64  
    65  func NewCvePageFinder(limit, offset int, sortBy cvemodel.SortCriteria) (*CvePageFinder, error) {
    66  	if sortBy == "" {
    67  		sortBy = SeverityDsc
    68  	}
    69  
    70  	if limit < 0 {
    71  		return nil, zerr.ErrLimitIsNegative
    72  	}
    73  
    74  	if offset < 0 {
    75  		return nil, zerr.ErrOffsetIsNegative
    76  	}
    77  
    78  	if _, found := SortFunctions()[sortBy]; !found {
    79  		return nil, fmt.Errorf("sorting CVEs by '%s' is not supported %w", sortBy, zerr.ErrSortCriteriaNotSupported)
    80  	}
    81  
    82  	return &CvePageFinder{
    83  		limit:      limit,
    84  		offset:     offset,
    85  		sortBy:     sortBy,
    86  		pageBuffer: make([]cvemodel.CVE, 0, limit),
    87  	}, nil
    88  }
    89  
    90  func (bpt *CvePageFinder) Reset() {
    91  	bpt.pageBuffer = []cvemodel.CVE{}
    92  }
    93  
    94  func (bpt *CvePageFinder) Add(cve cvemodel.CVE) {
    95  	bpt.pageBuffer = append(bpt.pageBuffer, cve)
    96  }
    97  
    98  func (bpt *CvePageFinder) Page() ([]cvemodel.CVE, common.PageInfo) {
    99  	if len(bpt.pageBuffer) == 0 {
   100  		return []cvemodel.CVE{}, common.PageInfo{}
   101  	}
   102  
   103  	pageInfo := &common.PageInfo{}
   104  
   105  	sort.Slice(bpt.pageBuffer, SortFunctions()[bpt.sortBy](bpt.pageBuffer))
   106  
   107  	// the offset and limit are calculated in terms of CVEs counted
   108  	start := bpt.offset
   109  	end := bpt.offset + bpt.limit
   110  
   111  	// we'll return an empty array when the offset is greater than the number of elements
   112  	if start >= len(bpt.pageBuffer) {
   113  		start = len(bpt.pageBuffer)
   114  		end = start
   115  	}
   116  
   117  	if end >= len(bpt.pageBuffer) {
   118  		end = len(bpt.pageBuffer)
   119  	}
   120  
   121  	cves := bpt.pageBuffer[start:end]
   122  
   123  	pageInfo.ItemCount = len(cves)
   124  
   125  	if start == 0 && end == 0 {
   126  		cves = bpt.pageBuffer
   127  		pageInfo.ItemCount = len(cves)
   128  	}
   129  
   130  	pageInfo.TotalCount = len(bpt.pageBuffer)
   131  
   132  	return cves, *pageInfo
   133  }