github.com/khulnasoft-lab/tunnel-db@v0.0.0-20231117205118-74e1113bd007/pkg/vulnsrctest/vulnsrctest.go (about)

     1  package vulnsrctest
     2  
     3  import (
     4  	"sort"
     5  	"testing"
     6  
     7  	"github.com/stretchr/testify/assert"
     8  	"github.com/stretchr/testify/require"
     9  
    10  	"github.com/khulnasoft-lab/tunnel-db/pkg/db"
    11  	"github.com/khulnasoft-lab/tunnel-db/pkg/dbtest"
    12  	"github.com/khulnasoft-lab/tunnel-db/pkg/types"
    13  )
    14  
    15  type Updater interface {
    16  	Update(dir string) (err error)
    17  }
    18  
    19  type WantValues struct {
    20  	Key   []string
    21  	Value interface{}
    22  }
    23  
    24  type TestUpdateArgs struct {
    25  	Dir        string
    26  	WantValues []WantValues
    27  	WantErr    string
    28  	NoBuckets  [][]string
    29  }
    30  
    31  func TestUpdate(t *testing.T, vulnsrc Updater, args TestUpdateArgs) {
    32  	t.Helper()
    33  
    34  	tempDir := t.TempDir()
    35  	dbPath := db.Path(tempDir)
    36  
    37  	err := db.Init(tempDir)
    38  	require.NoError(t, err)
    39  	defer db.Close()
    40  
    41  	err = vulnsrc.Update(args.Dir)
    42  	if args.WantErr != "" {
    43  		require.NotNil(t, err)
    44  		assert.Contains(t, err.Error(), args.WantErr)
    45  		return
    46  	}
    47  
    48  	require.NoError(t, err)
    49  	require.NoError(t, db.Close()) // Need to close before dbtest.JSONEq is called
    50  	for _, want := range args.WantValues {
    51  		dbtest.JSONEq(t, dbPath, want.Key, want.Value, want.Key)
    52  	}
    53  
    54  	for _, noBucket := range args.NoBuckets {
    55  		dbtest.NoBucket(t, dbPath, noBucket, noBucket)
    56  	}
    57  }
    58  
    59  type Getter interface {
    60  	Get(string, string) ([]types.Advisory, error)
    61  }
    62  
    63  type TestGetArgs struct {
    64  	Fixtures   []string
    65  	WantValues []types.Advisory
    66  	Release    string
    67  	PkgName    string
    68  	WantErr    string
    69  }
    70  
    71  func TestGet(t *testing.T, vulnsrc Getter, args TestGetArgs) {
    72  	t.Helper()
    73  
    74  	_ = dbtest.InitDB(t, args.Fixtures)
    75  	defer db.Close()
    76  
    77  	got, err := vulnsrc.Get(args.Release, args.PkgName)
    78  
    79  	if args.WantErr != "" {
    80  		require.Error(t, err)
    81  		assert.Contains(t, err.Error(), args.WantErr)
    82  		return
    83  	}
    84  
    85  	sort.Slice(got, func(i, j int) bool {
    86  		return got[i].VulnerabilityID < got[j].VulnerabilityID
    87  	})
    88  
    89  	assert.NoError(t, err)
    90  	assert.Equal(t, args.WantValues, got)
    91  }