github.com/khulnasoft-lab/tunnel-db@v0.0.0-20231117205118-74e1113bd007/pkg/db/db.go (about)

     1  package db
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"os"
     7  	"path/filepath"
     8  	"runtime/debug"
     9  	"strings"
    10  
    11  	bolt "go.etcd.io/bbolt"
    12  	"golang.org/x/xerrors"
    13  
    14  	"github.com/khulnasoft-lab/tunnel-db/pkg/log"
    15  	"github.com/khulnasoft-lab/tunnel-db/pkg/types"
    16  )
    17  
    18  type CustomPut func(dbc Operation, tx *bolt.Tx, adv interface{}) error
    19  
    20  const SchemaVersion = 2
    21  
    22  var (
    23  	db    *bolt.DB
    24  	dbDir string
    25  )
    26  
    27  type Operation interface {
    28  	BatchUpdate(fn func(*bolt.Tx) error) (err error)
    29  
    30  	GetVulnerabilityDetail(cveID string) (detail map[types.SourceID]types.VulnerabilityDetail, err error)
    31  	PutVulnerabilityDetail(tx *bolt.Tx, vulnerabilityID string, source types.SourceID,
    32  		vulnerability types.VulnerabilityDetail) (err error)
    33  	DeleteVulnerabilityDetailBucket() (err error)
    34  
    35  	ForEachAdvisory(sources []string, pkgName string) (value map[string]Value, err error)
    36  	GetAdvisories(source string, pkgName string) (advisories []types.Advisory, err error)
    37  
    38  	PutVulnerabilityID(tx *bolt.Tx, vulnerabilityID string) (err error)
    39  	ForEachVulnerabilityID(fn func(tx *bolt.Tx, cveID string) error) (err error)
    40  
    41  	PutVulnerability(tx *bolt.Tx, vulnerabilityID string, vulnerability types.Vulnerability) (err error)
    42  	GetVulnerability(vulnerabilityID string) (vulnerability types.Vulnerability, err error)
    43  
    44  	SaveAdvisoryDetails(tx *bolt.Tx, cveID string) (err error)
    45  	PutAdvisoryDetail(tx *bolt.Tx, vulnerabilityID, pkgName string, nestedBktNames []string, advisory interface{}) (err error)
    46  	DeleteAdvisoryDetailBucket() error
    47  
    48  	PutDataSource(tx *bolt.Tx, bktName string, source types.DataSource) (err error)
    49  
    50  	// For Red Hat
    51  	PutRedHatRepositories(tx *bolt.Tx, repository string, cpeIndices []int) (err error)
    52  	PutRedHatNVRs(tx *bolt.Tx, nvr string, cpeIndices []int) (err error)
    53  	PutRedHatCPEs(tx *bolt.Tx, cpeIndex int, cpe string) (err error)
    54  	RedHatRepoToCPEs(repository string) (cpeIndices []int, err error)
    55  	RedHatNVRToCPEs(nvr string) (cpeIndices []int, err error)
    56  }
    57  
    58  type Config struct {
    59  }
    60  
    61  func Init(cacheDir string) (err error) {
    62  	dbPath := Path(cacheDir)
    63  	dbDir = filepath.Dir(dbPath)
    64  	if err = os.MkdirAll(dbDir, 0700); err != nil {
    65  		return xerrors.Errorf("failed to mkdir: %w", err)
    66  	}
    67  
    68  	// bbolt sometimes occurs the fatal error of "unexpected fault address".
    69  	// In that case, the local DB should be broken and needs to be removed.
    70  	debug.SetPanicOnFault(true)
    71  	defer func() {
    72  		if r := recover(); r != nil {
    73  			if err = os.Remove(dbPath); err != nil {
    74  				return
    75  			}
    76  			db, err = bolt.Open(dbPath, 0600, nil)
    77  		}
    78  		debug.SetPanicOnFault(false)
    79  	}()
    80  
    81  	db, err = bolt.Open(dbPath, 0600, nil)
    82  	if err != nil {
    83  		return xerrors.Errorf("failed to open db: %w", err)
    84  	}
    85  	return nil
    86  }
    87  
    88  func Dir(cacheDir string) string {
    89  	return filepath.Join(cacheDir, "db")
    90  }
    91  
    92  func Path(cacheDir string) string {
    93  	dbPath := filepath.Join(Dir(cacheDir), "tunnel.db")
    94  	return dbPath
    95  }
    96  
    97  func Close() error {
    98  	// Skip closing the database if the connection is not established.
    99  	if db == nil {
   100  		return nil
   101  	}
   102  	if err := db.Close(); err != nil {
   103  		return xerrors.Errorf("failed to close DB: %w", err)
   104  	}
   105  	return nil
   106  }
   107  
   108  func (dbc Config) Connection() *bolt.DB {
   109  	return db
   110  }
   111  
   112  func (dbc Config) BatchUpdate(fn func(tx *bolt.Tx) error) error {
   113  	err := db.Batch(fn)
   114  	if err != nil {
   115  		return xerrors.Errorf("error in batch update: %w", err)
   116  	}
   117  	return nil
   118  }
   119  
   120  func (dbc Config) put(tx *bolt.Tx, bktNames []string, key string, value interface{}) error {
   121  	if len(bktNames) == 0 {
   122  		return xerrors.Errorf("empty bucket name")
   123  	}
   124  
   125  	bkt, err := tx.CreateBucketIfNotExists([]byte(bktNames[0]))
   126  	if err != nil {
   127  		return xerrors.Errorf("failed to create '%s' bucket: %w", bktNames[0], err)
   128  	}
   129  
   130  	for _, bktName := range bktNames[1:] {
   131  		bkt, err = bkt.CreateBucketIfNotExists([]byte(bktName))
   132  		if err != nil {
   133  			return xerrors.Errorf("failed to create a bucket: %w", err)
   134  		}
   135  	}
   136  	v, err := json.Marshal(value)
   137  	if err != nil {
   138  		return xerrors.Errorf("failed to unmarshal JSON: %w", err)
   139  	}
   140  
   141  	return bkt.Put([]byte(key), v)
   142  }
   143  
   144  func (dbc Config) get(bktNames []string, key string) (value []byte, err error) {
   145  	err = db.View(func(tx *bolt.Tx) error {
   146  		if len(bktNames) == 0 {
   147  			return xerrors.Errorf("empty bucket name")
   148  		}
   149  
   150  		bkt := tx.Bucket([]byte(bktNames[0]))
   151  		if bkt == nil {
   152  			return nil
   153  		}
   154  		for _, bktName := range bktNames[1:] {
   155  			bkt = bkt.Bucket([]byte(bktName))
   156  			if bkt == nil {
   157  				return nil
   158  			}
   159  		}
   160  		dbValue := bkt.Get([]byte(key))
   161  
   162  		// Copy the byte slice so it can be used outside of the current transaction
   163  		value = make([]byte, len(dbValue))
   164  		copy(value, dbValue)
   165  
   166  		return nil
   167  	})
   168  	if err != nil {
   169  		return nil, xerrors.Errorf("failed to get data from db: %w", err)
   170  	}
   171  	return value, nil
   172  }
   173  
   174  type Value struct {
   175  	Source  types.DataSource
   176  	Content []byte
   177  }
   178  
   179  func (dbc Config) forEach(bktNames []string) (map[string]Value, error) {
   180  	if len(bktNames) < 2 {
   181  		return nil, xerrors.Errorf("bucket must be nested: %v", bktNames)
   182  	}
   183  	rootBucket, nestedBuckets := bktNames[0], bktNames[1:]
   184  
   185  	values := map[string]Value{}
   186  	err := db.View(func(tx *bolt.Tx) error {
   187  		var rootBuckets []string
   188  
   189  		if strings.Contains(rootBucket, "::") {
   190  			// e.g. "pip::", "rubygems::"
   191  			prefix := []byte(rootBucket)
   192  			c := tx.Cursor()
   193  			for k, _ := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, _ = c.Next() {
   194  				rootBuckets = append(rootBuckets, string(k))
   195  			}
   196  		} else {
   197  			// e.g. "GitHub Security Advisory Composer"
   198  			rootBuckets = append(rootBuckets, rootBucket)
   199  		}
   200  
   201  		for _, r := range rootBuckets {
   202  			root := tx.Bucket([]byte(r))
   203  			if root == nil {
   204  				continue
   205  			}
   206  
   207  			source, err := dbc.getDataSource(tx, r)
   208  			if err != nil {
   209  				log.Logger.Debugf("Data source error: %s", err)
   210  			}
   211  
   212  			bkt := root
   213  			for _, nestedBkt := range nestedBuckets {
   214  				bkt = bkt.Bucket([]byte(nestedBkt))
   215  				if bkt == nil {
   216  					break
   217  				}
   218  			}
   219  			if bkt == nil {
   220  				continue
   221  			}
   222  
   223  			err = bkt.ForEach(func(k, v []byte) error {
   224  				if len(v) == 0 {
   225  					return nil
   226  				}
   227  				// Copy the byte slice so it can be used outside of the current transaction
   228  				copiedContent := make([]byte, len(v))
   229  				copy(copiedContent, v)
   230  
   231  				values[string(k)] = Value{
   232  					Source:  source,
   233  					Content: copiedContent,
   234  				}
   235  				return nil
   236  			})
   237  			if err != nil {
   238  				return xerrors.Errorf("db foreach error: %w", err)
   239  			}
   240  		}
   241  		return nil
   242  	})
   243  	if err != nil {
   244  		return nil, xerrors.Errorf("failed to get all key/value in the specified bucket: %w", err)
   245  	}
   246  	return values, nil
   247  }
   248  
   249  func (dbc Config) deleteBucket(bucketName string) error {
   250  	return db.Update(func(tx *bolt.Tx) error {
   251  		if err := tx.DeleteBucket([]byte(bucketName)); err != nil {
   252  			return xerrors.Errorf("failed to delete bucket: %w", err)
   253  		}
   254  		return nil
   255  	})
   256  }