github.com/khulnasoft-lab/tunnel-db@v0.0.0-20231117205118-74e1113bd007/pkg/db/db.go (about) 1 package db 2 3 import ( 4 "bytes" 5 "encoding/json" 6 "os" 7 "path/filepath" 8 "runtime/debug" 9 "strings" 10 11 bolt "go.etcd.io/bbolt" 12 "golang.org/x/xerrors" 13 14 "github.com/khulnasoft-lab/tunnel-db/pkg/log" 15 "github.com/khulnasoft-lab/tunnel-db/pkg/types" 16 ) 17 18 type CustomPut func(dbc Operation, tx *bolt.Tx, adv interface{}) error 19 20 const SchemaVersion = 2 21 22 var ( 23 db *bolt.DB 24 dbDir string 25 ) 26 27 type Operation interface { 28 BatchUpdate(fn func(*bolt.Tx) error) (err error) 29 30 GetVulnerabilityDetail(cveID string) (detail map[types.SourceID]types.VulnerabilityDetail, err error) 31 PutVulnerabilityDetail(tx *bolt.Tx, vulnerabilityID string, source types.SourceID, 32 vulnerability types.VulnerabilityDetail) (err error) 33 DeleteVulnerabilityDetailBucket() (err error) 34 35 ForEachAdvisory(sources []string, pkgName string) (value map[string]Value, err error) 36 GetAdvisories(source string, pkgName string) (advisories []types.Advisory, err error) 37 38 PutVulnerabilityID(tx *bolt.Tx, vulnerabilityID string) (err error) 39 ForEachVulnerabilityID(fn func(tx *bolt.Tx, cveID string) error) (err error) 40 41 PutVulnerability(tx *bolt.Tx, vulnerabilityID string, vulnerability types.Vulnerability) (err error) 42 GetVulnerability(vulnerabilityID string) (vulnerability types.Vulnerability, err error) 43 44 SaveAdvisoryDetails(tx *bolt.Tx, cveID string) (err error) 45 PutAdvisoryDetail(tx *bolt.Tx, vulnerabilityID, pkgName string, nestedBktNames []string, advisory interface{}) (err error) 46 DeleteAdvisoryDetailBucket() error 47 48 PutDataSource(tx *bolt.Tx, bktName string, source types.DataSource) (err error) 49 50 // For Red Hat 51 PutRedHatRepositories(tx *bolt.Tx, repository string, cpeIndices []int) (err error) 52 PutRedHatNVRs(tx *bolt.Tx, nvr string, cpeIndices []int) (err error) 53 PutRedHatCPEs(tx *bolt.Tx, cpeIndex int, cpe string) (err error) 54 RedHatRepoToCPEs(repository string) (cpeIndices []int, err error) 55 RedHatNVRToCPEs(nvr string) (cpeIndices []int, err error) 56 } 57 58 type Config struct { 59 } 60 61 func Init(cacheDir string) (err error) { 62 dbPath := Path(cacheDir) 63 dbDir = filepath.Dir(dbPath) 64 if err = os.MkdirAll(dbDir, 0700); err != nil { 65 return xerrors.Errorf("failed to mkdir: %w", err) 66 } 67 68 // bbolt sometimes occurs the fatal error of "unexpected fault address". 69 // In that case, the local DB should be broken and needs to be removed. 70 debug.SetPanicOnFault(true) 71 defer func() { 72 if r := recover(); r != nil { 73 if err = os.Remove(dbPath); err != nil { 74 return 75 } 76 db, err = bolt.Open(dbPath, 0600, nil) 77 } 78 debug.SetPanicOnFault(false) 79 }() 80 81 db, err = bolt.Open(dbPath, 0600, nil) 82 if err != nil { 83 return xerrors.Errorf("failed to open db: %w", err) 84 } 85 return nil 86 } 87 88 func Dir(cacheDir string) string { 89 return filepath.Join(cacheDir, "db") 90 } 91 92 func Path(cacheDir string) string { 93 dbPath := filepath.Join(Dir(cacheDir), "tunnel.db") 94 return dbPath 95 } 96 97 func Close() error { 98 // Skip closing the database if the connection is not established. 99 if db == nil { 100 return nil 101 } 102 if err := db.Close(); err != nil { 103 return xerrors.Errorf("failed to close DB: %w", err) 104 } 105 return nil 106 } 107 108 func (dbc Config) Connection() *bolt.DB { 109 return db 110 } 111 112 func (dbc Config) BatchUpdate(fn func(tx *bolt.Tx) error) error { 113 err := db.Batch(fn) 114 if err != nil { 115 return xerrors.Errorf("error in batch update: %w", err) 116 } 117 return nil 118 } 119 120 func (dbc Config) put(tx *bolt.Tx, bktNames []string, key string, value interface{}) error { 121 if len(bktNames) == 0 { 122 return xerrors.Errorf("empty bucket name") 123 } 124 125 bkt, err := tx.CreateBucketIfNotExists([]byte(bktNames[0])) 126 if err != nil { 127 return xerrors.Errorf("failed to create '%s' bucket: %w", bktNames[0], err) 128 } 129 130 for _, bktName := range bktNames[1:] { 131 bkt, err = bkt.CreateBucketIfNotExists([]byte(bktName)) 132 if err != nil { 133 return xerrors.Errorf("failed to create a bucket: %w", err) 134 } 135 } 136 v, err := json.Marshal(value) 137 if err != nil { 138 return xerrors.Errorf("failed to unmarshal JSON: %w", err) 139 } 140 141 return bkt.Put([]byte(key), v) 142 } 143 144 func (dbc Config) get(bktNames []string, key string) (value []byte, err error) { 145 err = db.View(func(tx *bolt.Tx) error { 146 if len(bktNames) == 0 { 147 return xerrors.Errorf("empty bucket name") 148 } 149 150 bkt := tx.Bucket([]byte(bktNames[0])) 151 if bkt == nil { 152 return nil 153 } 154 for _, bktName := range bktNames[1:] { 155 bkt = bkt.Bucket([]byte(bktName)) 156 if bkt == nil { 157 return nil 158 } 159 } 160 dbValue := bkt.Get([]byte(key)) 161 162 // Copy the byte slice so it can be used outside of the current transaction 163 value = make([]byte, len(dbValue)) 164 copy(value, dbValue) 165 166 return nil 167 }) 168 if err != nil { 169 return nil, xerrors.Errorf("failed to get data from db: %w", err) 170 } 171 return value, nil 172 } 173 174 type Value struct { 175 Source types.DataSource 176 Content []byte 177 } 178 179 func (dbc Config) forEach(bktNames []string) (map[string]Value, error) { 180 if len(bktNames) < 2 { 181 return nil, xerrors.Errorf("bucket must be nested: %v", bktNames) 182 } 183 rootBucket, nestedBuckets := bktNames[0], bktNames[1:] 184 185 values := map[string]Value{} 186 err := db.View(func(tx *bolt.Tx) error { 187 var rootBuckets []string 188 189 if strings.Contains(rootBucket, "::") { 190 // e.g. "pip::", "rubygems::" 191 prefix := []byte(rootBucket) 192 c := tx.Cursor() 193 for k, _ := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, _ = c.Next() { 194 rootBuckets = append(rootBuckets, string(k)) 195 } 196 } else { 197 // e.g. "GitHub Security Advisory Composer" 198 rootBuckets = append(rootBuckets, rootBucket) 199 } 200 201 for _, r := range rootBuckets { 202 root := tx.Bucket([]byte(r)) 203 if root == nil { 204 continue 205 } 206 207 source, err := dbc.getDataSource(tx, r) 208 if err != nil { 209 log.Logger.Debugf("Data source error: %s", err) 210 } 211 212 bkt := root 213 for _, nestedBkt := range nestedBuckets { 214 bkt = bkt.Bucket([]byte(nestedBkt)) 215 if bkt == nil { 216 break 217 } 218 } 219 if bkt == nil { 220 continue 221 } 222 223 err = bkt.ForEach(func(k, v []byte) error { 224 if len(v) == 0 { 225 return nil 226 } 227 // Copy the byte slice so it can be used outside of the current transaction 228 copiedContent := make([]byte, len(v)) 229 copy(copiedContent, v) 230 231 values[string(k)] = Value{ 232 Source: source, 233 Content: copiedContent, 234 } 235 return nil 236 }) 237 if err != nil { 238 return xerrors.Errorf("db foreach error: %w", err) 239 } 240 } 241 return nil 242 }) 243 if err != nil { 244 return nil, xerrors.Errorf("failed to get all key/value in the specified bucket: %w", err) 245 } 246 return values, nil 247 } 248 249 func (dbc Config) deleteBucket(bucketName string) error { 250 return db.Update(func(tx *bolt.Tx) error { 251 if err := tx.DeleteBucket([]byte(bucketName)); err != nil { 252 return xerrors.Errorf("failed to delete bucket: %w", err) 253 } 254 return nil 255 }) 256 }