github.com/khulnasoft-lab/tunnel-db@v0.0.0-20231117205118-74e1113bd007/pkg/dbtest/assert.go (about)

     1  package dbtest
     2  
     3  import (
     4  	"encoding/json"
     5  	"reflect"
     6  	"testing"
     7  
     8  	"github.com/stretchr/testify/assert"
     9  	"github.com/stretchr/testify/require"
    10  	bolt "go.etcd.io/bbolt"
    11  	"golang.org/x/xerrors"
    12  )
    13  
    14  var (
    15  	ErrNoBucket = xerrors.New("no such bucket")
    16  )
    17  
    18  func NoKey(t *testing.T, dbPath string, keys []string, msgAndArgs ...interface{}) {
    19  	t.Helper()
    20  
    21  	value := get(t, dbPath, keys)
    22  	assert.Nil(t, value, msgAndArgs...)
    23  }
    24  
    25  func NoBucket(t *testing.T, dbPath string, buckets []string, msgAndArgs ...interface{}) {
    26  	t.Helper()
    27  
    28  	db := open(t, dbPath)
    29  	defer db.Close()
    30  
    31  	err := db.View(func(tx *bolt.Tx) error {
    32  		bkt, err := nestedBuckets(tx, buckets)
    33  		if err != nil {
    34  			return err
    35  		}
    36  
    37  		// The specified bucket must not exist.
    38  		assert.Nil(t, bkt, msgAndArgs...)
    39  
    40  		return nil
    41  	})
    42  
    43  	require.NoError(t, err, msgAndArgs...)
    44  }
    45  
    46  func JSONEq(t *testing.T, dbPath string, key []string, want interface{}, msgAndArgs ...interface{}) {
    47  	t.Helper()
    48  
    49  	wantByte, err := json.Marshal(want)
    50  	require.NoError(t, err, msgAndArgs...)
    51  
    52  	got := get(t, dbPath, key, msgAndArgs...)
    53  	assert.JSONEq(t, string(wantByte), string(got), msgAndArgs...)
    54  }
    55  
    56  type bucketer interface {
    57  	Bucket(name []byte) *bolt.Bucket
    58  }
    59  
    60  func get(t *testing.T, dbPath string, keys []string, msgAndArgs ...interface{}) []byte {
    61  	if len(keys) < 2 {
    62  		require.Failf(t, "malformed keys: %v", "", keys)
    63  	}
    64  	db := open(t, dbPath)
    65  	defer db.Close()
    66  
    67  	var b []byte
    68  	err := db.View(func(tx *bolt.Tx) error {
    69  		bkts, key := keys[:len(keys)-1], keys[len(keys)-1]
    70  
    71  		var bucket bucketer = tx
    72  		for _, k := range bkts {
    73  			if reflect.ValueOf(bucket).IsNil() {
    74  				return xerrors.Errorf("bucket error %s: %w", k, ErrNoBucket)
    75  			}
    76  			bucket = bucket.Bucket([]byte(k))
    77  		}
    78  		bkt, ok := bucket.(*bolt.Bucket)
    79  		if !ok {
    80  			return xerrors.Errorf("bucket error %v: %w", keys, ErrNoBucket)
    81  		} else if bkt == nil {
    82  			return xerrors.Errorf("empty bucket %v: %w", keys, ErrNoBucket)
    83  		}
    84  		res := bkt.Get([]byte(key))
    85  		if res == nil {
    86  			return nil
    87  		}
    88  
    89  		// Copy the returned value
    90  		b = make([]byte, len(res))
    91  		copy(b, res)
    92  		return nil
    93  	})
    94  	require.NoError(t, err, msgAndArgs...)
    95  
    96  	return b
    97  }
    98  
    99  func open(t *testing.T, dbPath string) *bolt.DB {
   100  	db, err := bolt.Open(dbPath, 0600, &bolt.Options{ReadOnly: true})
   101  	require.NoError(t, err)
   102  
   103  	return db
   104  }
   105  
   106  func nestedBuckets(start bucketer, buckets []string) (*bolt.Bucket, error) {
   107  	bucket := start
   108  	for _, k := range buckets {
   109  		if reflect.ValueOf(bucket).IsNil() {
   110  			return nil, xerrors.Errorf("bucket error %v: %w", buckets, ErrNoBucket)
   111  		}
   112  		bucket = bucket.Bucket([]byte(k))
   113  	}
   114  	bkt, ok := bucket.(*bolt.Bucket)
   115  	if !ok {
   116  		return nil, xerrors.Errorf("bucket error %v: %w", buckets, ErrNoBucket)
   117  	}
   118  	return bkt, nil
   119  }