github.com/decred/politeia@v1.4.0/politeiad/backendv2/tstorebe/store/mysql/mysql.go (about) 1 // Copyright (c) 2020-2022 The Decred developers 2 // Use of this source code is governed by an ISC 3 // license that can be found in the LICENSE file. 4 5 package mysql 6 7 import ( 8 "context" 9 "database/sql" 10 "fmt" 11 "strings" 12 "sync/atomic" 13 "time" 14 15 "github.com/DATA-DOG/go-sqlmock" 16 "github.com/decred/politeia/politeiad/backendv2/tstorebe/store" 17 "github.com/decred/politeia/util" 18 "github.com/pkg/errors" 19 20 _ "github.com/go-sql-driver/mysql" 21 ) 22 23 const ( 24 // Database options 25 connTimeout = 1 * time.Minute 26 connMaxLifetime = 1 * time.Minute 27 maxOpenConns = 0 // 0 is unlimited 28 maxIdleConns = 100 29 30 // Database table names 31 tableNameKeyValue = "kv" 32 tableNameNonce = "nonce" 33 34 // maxPlaceholders is the maximum number of placeholders, "(?, ?, ?)", that 35 // can be used in a prepared statement. MySQL uses an uint16 for this, so 36 // the limit is the the maximum value of an uint16. 37 maxPlaceholders = 65535 38 ) 39 40 // tableKeyValue defines the key-value table. 41 const tableKeyValue = ` 42 k VARCHAR(255) NOT NULL PRIMARY KEY, 43 v LONGBLOB NOT NULL 44 ` 45 46 // tableNonce defines the table used to track the encryption nonce. 47 const tableNonce = ` 48 n BIGINT PRIMARY KEY AUTO_INCREMENT 49 ` 50 51 var ( 52 _ store.BlobKV = (*mysqlCtx)(nil) 53 ) 54 55 // mysqlCtx implements the store BlobKV interface using a mysql driver. 56 type mysqlCtx struct { 57 shutdown uint64 58 db *sql.DB 59 key [32]byte 60 61 // The following fields are only used during unit tests. 62 testing bool 63 mock sqlmock.Sqlmock 64 } 65 66 func ctxWithTimeout() (context.Context, func()) { 67 return context.WithTimeout(context.Background(), connTimeout) 68 } 69 70 func (s *mysqlCtx) isShutdown() bool { 71 return atomic.LoadUint64(&s.shutdown) != 0 72 } 73 74 // put saves the provided key-value pairs to the database using a transaction. 75 // New entries are inserted. Existing entries are updated. 76 func (s *mysqlCtx) put(blobs map[string][]byte, encrypt bool, ctx context.Context, tx *sql.Tx) error { 77 // Encrypt blobs 78 if encrypt { 79 encrypted := make(map[string][]byte, len(blobs)) 80 for k, v := range blobs { 81 e, err := s.encrypt(ctx, tx, v) 82 if err != nil { 83 return err 84 } 85 encrypted[k] = e 86 } 87 88 // Sanity check 89 if len(encrypted) != len(blobs) { 90 return errors.Errorf("unexpected number of encrypted blobs") 91 } 92 93 blobs = encrypted 94 } 95 96 // Save blobs 97 for k, v := range blobs { 98 _, err := tx.ExecContext(ctx, 99 "REPLACE INTO kv (k, v) VALUES (?, ?);", k, v) 100 if err != nil { 101 return errors.WithStack(err) 102 } 103 } 104 105 return nil 106 } 107 108 // Put saves the provided key-value entries to the database. New entries are 109 // inserted. Existing entries are updated. 110 // 111 // This operation is atomic. 112 // 113 // This function satisfies the store BlobKV interface. 114 func (s *mysqlCtx) Put(blobs map[string][]byte, encrypt bool) error { 115 log.Tracef("Put: %v blobs", len(blobs)) 116 117 if s.isShutdown() { 118 return store.ErrShutdown 119 } 120 121 ctx, cancel := ctxWithTimeout() 122 defer cancel() 123 124 // Start transaction 125 opts := &sql.TxOptions{ 126 Isolation: sql.LevelDefault, 127 } 128 tx, err := s.db.BeginTx(ctx, opts) 129 if err != nil { 130 return err 131 } 132 133 // Save blobs 134 err = s.put(blobs, encrypt, ctx, tx) 135 if err != nil { 136 // Attempt to roll back the transaction 137 if err2 := tx.Rollback(); err2 != nil { 138 // We're in trouble! 139 e := fmt.Sprintf("put: %v, unable to rollback: %v", err, err2) 140 panic(e) 141 } 142 return err 143 } 144 145 // Commit transaction 146 err = tx.Commit() 147 if err != nil { 148 return err 149 } 150 151 log.Debugf("Saved blobs (%v) to store", len(blobs)) 152 153 return nil 154 } 155 156 // Del deletes the key-value entries from the database for the provided keys. 157 // 158 // This operation is atomic. 159 // 160 // This function satisfies the store BlobKV interface. 161 func (s *mysqlCtx) Del(keys []string) error { 162 log.Tracef("Del: %v", keys) 163 164 if s.isShutdown() { 165 return store.ErrShutdown 166 } 167 168 ctx, cancel := ctxWithTimeout() 169 defer cancel() 170 171 // Start transaction 172 opts := &sql.TxOptions{ 173 Isolation: sql.LevelDefault, 174 } 175 tx, err := s.db.BeginTx(ctx, opts) 176 if err != nil { 177 return err 178 } 179 180 // Delete blobs 181 for _, v := range keys { 182 _, err = tx.ExecContext(ctx, "DELETE FROM kv WHERE k IN (?);", v) 183 if err != nil { 184 // Attempt to roll back the transaction 185 if err2 := tx.Rollback(); err2 != nil { 186 // We're in trouble! 187 e := fmt.Sprintf("del: %v, unable to rollback: %v", err, err2) 188 panic(e) 189 } 190 return err 191 } 192 } 193 194 // Commit transaction 195 err = tx.Commit() 196 if err != nil { 197 return err 198 } 199 200 log.Debugf("Deleted blobs (%v) from store", len(keys)) 201 202 return nil 203 } 204 205 // Get retrieves the key-value entries from the database for the provided 206 // keys. 207 // 208 // An entry will not exist in the returned map for any blobs that are not 209 // found. It is the responsibility of the caller to ensure a blob was returned 210 // for all provided keys. 211 // 212 // This function satisfies the store BlobKV interface. 213 func (s *mysqlCtx) Get(keys []string) (map[string][]byte, error) { 214 log.Tracef("Get: %v", keys) 215 216 if s.isShutdown() { 217 return nil, store.ErrShutdown 218 } 219 220 // Build the select statements 221 statements := buildSelectStatements(keys, maxPlaceholders) 222 223 log.Debugf("Get %v blobs using %v prepared statements", 224 len(keys), len(statements)) 225 226 // Execute the statements 227 reply := make(map[string][]byte, len(keys)) 228 for i, e := range statements { 229 log.Debugf("Executing select statement %v/%v", i+1, len(statements)) 230 231 ctx, cancel := ctxWithTimeout() 232 defer cancel() 233 234 rows, err := s.db.QueryContext(ctx, e.Query, e.Args...) 235 if err != nil { 236 return nil, errors.WithStack(err) 237 } 238 defer rows.Close() 239 240 // Unpack the reply 241 for rows.Next() { 242 var k string 243 var v []byte 244 err = rows.Scan(&k, &v) 245 if err != nil { 246 return nil, errors.WithStack(err) 247 } 248 249 // Decrypt the blob if required 250 if isEncrypted(v) { 251 log.Tracef("Encrypted blob: %v", k) 252 v, _, err = s.decrypt(v) 253 if err != nil { 254 return nil, err 255 } 256 } 257 258 // Save the blob 259 reply[k] = v 260 } 261 err = rows.Err() 262 if err != nil { 263 return nil, errors.WithStack(err) 264 } 265 } 266 267 return reply, nil 268 } 269 270 // Close closes the database connection. 271 func (s *mysqlCtx) Close() { 272 log.Tracef("Close") 273 274 atomic.AddUint64(&s.shutdown, 1) 275 276 // Zero the encryption key 277 util.Zero(s.key[:]) 278 279 // Close mysql connection 280 s.db.Close() 281 } 282 283 // selectStatement contains the query string and arguments for a SELECT 284 // statement. 285 type selectStatement struct { 286 Query string 287 Args []interface{} 288 } 289 290 // buildSelectStatements builds the SELECT statements that can be executed 291 // against the MySQL key-value store. The maximum number of records that will 292 // be retrieved in any individual SELECT statement is determined by the size 293 // argument. The keys are split up into multiple statements if they exceed this 294 // limit. 295 func buildSelectStatements(keys []string, size int) []selectStatement { 296 statements := make([]selectStatement, 0, (len(keys)/size)+1) 297 var startIdx int 298 for startIdx < len(keys) { 299 // Find the end index 300 endIdx := startIdx + size 301 if endIdx > len(keys) { 302 // We've reached the end of the slice 303 endIdx = len(keys) 304 } 305 306 // startIdx is included. endIdx is excluded. 307 statementKeys := keys[startIdx:endIdx] 308 309 // Build the query 310 q := buildSelectQuery(len(statementKeys)) 311 log.Tracef("%v", q) 312 313 // Convert the keys to interfaces. The sql query 314 // methods require arguments be interfaces. 315 args := make([]interface{}, len(statementKeys)) 316 for i, v := range statementKeys { 317 args[i] = v 318 } 319 320 // Save the statement 321 statements = append(statements, selectStatement{ 322 Query: q, 323 Args: args, 324 }) 325 326 // Update the start index 327 startIdx = endIdx 328 } 329 330 return statements 331 } 332 333 // buildSelectQuery returns a query string for the MySQL key-value store. 334 // 335 // Example: "SELECT k, v FROM kv WHERE k IN (?,?);" 336 func buildSelectQuery(placeholders int) string { 337 return fmt.Sprintf("SELECT k, v FROM kv WHERE k IN %v;", 338 buildPlaceholders(placeholders)) 339 } 340 341 // buildPlaceholders builds and returns a parameter placeholder string with the 342 // specified number of placeholders. 343 // 344 // Input: 1 Output: "(?)" 345 // Input: 3 Output: "(?,?,?)" 346 func buildPlaceholders(placeholders int) string { 347 var b strings.Builder 348 349 b.WriteString("(") 350 for i := 0; i < placeholders; i++ { 351 b.WriteString("?") 352 // Don't add a comma on the last one 353 if i < placeholders-1 { 354 b.WriteString(",") 355 } 356 } 357 b.WriteString(")") 358 359 return b.String() 360 } 361 362 // New connects to a mysql instance using the given connection params, 363 // and returns pointer to the created mysql struct. 364 func New(host, user, password, dbname string) (*mysqlCtx, error) { 365 // The password is required to derive the encryption key 366 if password == "" { 367 return nil, errors.Errorf("password not provided") 368 } 369 370 // Connect to database 371 log.Infof("MySQL host: %v:[password]@tcp(%v)/%v", user, host, dbname) 372 373 h := fmt.Sprintf("%v:%v@tcp(%v)/%v", user, password, host, dbname) 374 db, err := sql.Open("mysql", h) 375 if err != nil { 376 return nil, err 377 } 378 379 // Setup database options 380 db.SetConnMaxLifetime(connMaxLifetime) 381 db.SetMaxOpenConns(maxOpenConns) 382 db.SetMaxIdleConns(maxIdleConns) 383 384 // Verify database connection 385 err = db.Ping() 386 if err != nil { 387 return nil, err 388 } 389 390 // Setup key-value table 391 q := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %v (%v)`, 392 tableNameKeyValue, tableKeyValue) 393 _, err = db.Exec(q) 394 if err != nil { 395 return nil, errors.WithStack(err) 396 } 397 398 // Setup nonce table 399 q = fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %v (%v)`, 400 tableNameNonce, tableNonce) 401 _, err = db.Exec(q) 402 if err != nil { 403 return nil, errors.WithStack(err) 404 } 405 406 // Setup mysql context 407 s := &mysqlCtx{ 408 db: db, 409 } 410 411 // Derive encryption key from password. Key is set in argon2idKey 412 err = s.deriveEncryptionKey(password) 413 if err != nil { 414 return nil, err 415 } 416 417 return s, nil 418 }