github.com/bartle-stripe/trillian@v1.2.1/storage/mysql/admin_storage.go (about) 1 // Copyright 2017 Google Inc. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package mysql 16 17 import ( 18 "context" 19 "database/sql" 20 "fmt" 21 "sync" 22 "time" 23 24 "github.com/golang/glog" 25 "github.com/golang/protobuf/proto" 26 "github.com/golang/protobuf/ptypes" 27 "github.com/golang/protobuf/ptypes/any" 28 "github.com/google/trillian" 29 "github.com/google/trillian/crypto/keyspb" 30 spb "github.com/google/trillian/crypto/sigpb" 31 "github.com/google/trillian/storage" 32 "google.golang.org/grpc/codes" 33 "google.golang.org/grpc/status" 34 ) 35 36 const ( 37 defaultSequenceIntervalSeconds = 60 38 39 nonDeletedWhere = " WHERE (Deleted IS NULL OR Deleted = 'false')" 40 41 selectTreeIDs = "SELECT TreeId FROM Trees" 42 selectNonDeletedTreeIDs = selectTreeIDs + nonDeletedWhere 43 44 selectTrees = ` 45 SELECT 46 TreeId, 47 TreeState, 48 TreeType, 49 HashStrategy, 50 HashAlgorithm, 51 SignatureAlgorithm, 52 DisplayName, 53 Description, 54 CreateTimeMillis, 55 UpdateTimeMillis, 56 PrivateKey, 57 PublicKey, 58 MaxRootDurationMillis, 59 Deleted, 60 DeleteTimeMillis 61 FROM Trees` 62 selectNonDeletedTrees = selectTrees + nonDeletedWhere 63 selectTreeByID = selectTrees + " WHERE TreeId = ?" 64 65 updateTreeSQL = `UPDATE Trees 66 SET TreeState = ?, TreeType = ?, DisplayName = ?, Description = ?, UpdateTimeMillis = ?, MaxRootDurationMillis = ?, PrivateKey = ? 67 WHERE TreeId = ?` 68 ) 69 70 // NewAdminStorage returns a MySQL storage.AdminStorage implementation backed by DB. 71 func NewAdminStorage(db *sql.DB) storage.AdminStorage { 72 return &mysqlAdminStorage{db} 73 } 74 75 // mysqlAdminStorage implements storage.AdminStorage 76 type mysqlAdminStorage struct { 77 db *sql.DB 78 } 79 80 func (s *mysqlAdminStorage) Snapshot(ctx context.Context) (storage.ReadOnlyAdminTX, error) { 81 return s.beginInternal(ctx) 82 } 83 84 func (s *mysqlAdminStorage) beginInternal(ctx context.Context) (storage.AdminTX, error) { 85 tx, err := s.db.BeginTx(ctx, nil /* opts */) 86 if err != nil { 87 return nil, err 88 } 89 return &adminTX{tx: tx}, nil 90 } 91 92 func (s *mysqlAdminStorage) ReadWriteTransaction(ctx context.Context, f storage.AdminTXFunc) error { 93 tx, err := s.beginInternal(ctx) 94 if err != nil { 95 return err 96 } 97 defer tx.Close() 98 if err := f(ctx, tx); err != nil { 99 return err 100 } 101 return tx.Commit() 102 } 103 104 func (s *mysqlAdminStorage) CheckDatabaseAccessible(ctx context.Context) error { 105 return s.db.PingContext(ctx) 106 } 107 108 type adminTX struct { 109 tx *sql.Tx 110 111 // mu guards *direct* reads/writes on closed, which happen only on 112 // Commit/Rollback/IsClosed/Close methods. 113 // We don't check closed on *all* methods (apart from the ones above), 114 // as we trust tx to keep tabs on its state (and consequently fail to do 115 // queries after closed). 116 mu sync.RWMutex 117 closed bool 118 } 119 120 func (t *adminTX) Commit() error { 121 t.mu.Lock() 122 defer t.mu.Unlock() 123 t.closed = true 124 return t.tx.Commit() 125 } 126 127 func (t *adminTX) Rollback() error { 128 t.mu.Lock() 129 defer t.mu.Unlock() 130 t.closed = true 131 return t.tx.Rollback() 132 } 133 134 func (t *adminTX) IsClosed() bool { 135 t.mu.RLock() 136 defer t.mu.RUnlock() 137 return t.closed 138 } 139 140 func (t *adminTX) Close() error { 141 // Acquire and release read lock manually, without defer, as if the txn 142 // is not closed Rollback() will attempt to acquire the rw lock. 143 t.mu.RLock() 144 closed := t.closed 145 t.mu.RUnlock() 146 if !closed { 147 err := t.Rollback() 148 if err != nil { 149 glog.Warningf("Rollback error on Close(): %v", err) 150 } 151 return err 152 } 153 return nil 154 } 155 156 func (t *adminTX) GetTree(ctx context.Context, treeID int64) (*trillian.Tree, error) { 157 stmt, err := t.tx.PrepareContext(ctx, selectTreeByID) 158 if err != nil { 159 return nil, err 160 } 161 defer stmt.Close() 162 163 // GetTree is an entry point for most RPCs, let's provide somewhat nicer error messages. 164 tree, err := readTree(stmt.QueryRowContext(ctx, treeID)) 165 switch { 166 case err == sql.ErrNoRows: 167 // ErrNoRows doesn't provide useful information, so we don't forward it. 168 return nil, status.Errorf(codes.NotFound, "tree %v not found", treeID) 169 case err != nil: 170 return nil, fmt.Errorf("error reading tree %v: %v", treeID, err) 171 } 172 return tree, nil 173 } 174 175 // There's no common interface between sql.Row and sql.Rows(!), so we have to 176 // define one. 177 type row interface { 178 Scan(dest ...interface{}) error 179 } 180 181 func readTree(row row) (*trillian.Tree, error) { 182 tree := &trillian.Tree{} 183 184 // Enums and Datetimes need an extra conversion step 185 var treeState, treeType, hashStrategy, hashAlgorithm, signatureAlgorithm string 186 var createMillis, updateMillis, maxRootDurationMillis int64 187 var displayName, description sql.NullString 188 var privateKey, publicKey []byte 189 var deleted sql.NullBool 190 var deleteMillis sql.NullInt64 191 err := row.Scan( 192 &tree.TreeId, 193 &treeState, 194 &treeType, 195 &hashStrategy, 196 &hashAlgorithm, 197 &signatureAlgorithm, 198 &displayName, 199 &description, 200 &createMillis, 201 &updateMillis, 202 &privateKey, 203 &publicKey, 204 &maxRootDurationMillis, 205 &deleted, 206 &deleteMillis, 207 ) 208 if err != nil { 209 return nil, err 210 } 211 212 setNullStringIfValid(displayName, &tree.DisplayName) 213 setNullStringIfValid(description, &tree.Description) 214 215 // Convert all things! 216 if ts, ok := trillian.TreeState_value[treeState]; ok { 217 tree.TreeState = trillian.TreeState(ts) 218 } else { 219 return nil, fmt.Errorf("unknown TreeState: %v", treeState) 220 } 221 if tt, ok := trillian.TreeType_value[treeType]; ok { 222 tree.TreeType = trillian.TreeType(tt) 223 } else { 224 return nil, fmt.Errorf("unknown TreeType: %v", treeType) 225 } 226 if hs, ok := trillian.HashStrategy_value[hashStrategy]; ok { 227 tree.HashStrategy = trillian.HashStrategy(hs) 228 } else { 229 return nil, fmt.Errorf("unknown HashStrategy: %v", hashStrategy) 230 } 231 if ha, ok := spb.DigitallySigned_HashAlgorithm_value[hashAlgorithm]; ok { 232 tree.HashAlgorithm = spb.DigitallySigned_HashAlgorithm(ha) 233 } else { 234 return nil, fmt.Errorf("unknown HashAlgorithm: %v", hashAlgorithm) 235 } 236 if sa, ok := spb.DigitallySigned_SignatureAlgorithm_value[signatureAlgorithm]; ok { 237 tree.SignatureAlgorithm = spb.DigitallySigned_SignatureAlgorithm(sa) 238 } else { 239 return nil, fmt.Errorf("unknown SignatureAlgorithm: %v", signatureAlgorithm) 240 } 241 242 // Let's make sure we didn't mismatch any of the casts above 243 ok := tree.TreeState.String() == treeState 244 ok = ok && tree.TreeType.String() == treeType 245 ok = ok && tree.HashStrategy.String() == hashStrategy 246 ok = ok && tree.HashAlgorithm.String() == hashAlgorithm 247 ok = ok && tree.SignatureAlgorithm.String() == signatureAlgorithm 248 if !ok { 249 return nil, fmt.Errorf( 250 "mismatched enum: tree = %v, enums = [%v, %v, %v, %v, %v]", 251 tree, 252 treeState, treeType, hashStrategy, hashAlgorithm, signatureAlgorithm) 253 } 254 255 tree.CreateTime, err = ptypes.TimestampProto(fromMillisSinceEpoch(createMillis)) 256 if err != nil { 257 return nil, fmt.Errorf("failed to parse create time: %v", err) 258 } 259 tree.UpdateTime, err = ptypes.TimestampProto(fromMillisSinceEpoch(updateMillis)) 260 if err != nil { 261 return nil, fmt.Errorf("failed to parse update time: %v", err) 262 } 263 tree.MaxRootDuration = ptypes.DurationProto(time.Duration(maxRootDurationMillis * int64(time.Millisecond))) 264 265 tree.PrivateKey = &any.Any{} 266 if err := proto.Unmarshal(privateKey, tree.PrivateKey); err != nil { 267 return nil, fmt.Errorf("could not unmarshal PrivateKey: %v", err) 268 } 269 tree.PublicKey = &keyspb.PublicKey{Der: publicKey} 270 271 tree.Deleted = deleted.Valid && deleted.Bool 272 if tree.Deleted && deleteMillis.Valid { 273 tree.DeleteTime, err = ptypes.TimestampProto(fromMillisSinceEpoch(deleteMillis.Int64)) 274 if err != nil { 275 return nil, fmt.Errorf("failed to parse delete time: %v", err) 276 } 277 } 278 279 return tree, nil 280 } 281 282 // setNullStringIfValid assigns src to dest if src is Valid. 283 func setNullStringIfValid(src sql.NullString, dest *string) { 284 if src.Valid { 285 *dest = src.String 286 } 287 } 288 289 func (t *adminTX) ListTreeIDs(ctx context.Context, includeDeleted bool) ([]int64, error) { 290 var query string 291 if includeDeleted { 292 query = selectTreeIDs 293 } else { 294 query = selectNonDeletedTreeIDs 295 } 296 297 stmt, err := t.tx.PrepareContext(ctx, query) 298 if err != nil { 299 return nil, err 300 } 301 defer stmt.Close() 302 303 rows, err := stmt.QueryContext(ctx) 304 if err != nil { 305 return nil, err 306 } 307 defer rows.Close() 308 309 treeIDs := []int64{} 310 var treeID int64 311 for rows.Next() { 312 if err := rows.Scan(&treeID); err != nil { 313 return nil, err 314 } 315 treeIDs = append(treeIDs, treeID) 316 } 317 return treeIDs, nil 318 } 319 320 func (t *adminTX) ListTrees(ctx context.Context, includeDeleted bool) ([]*trillian.Tree, error) { 321 var query string 322 if includeDeleted { 323 query = selectTrees 324 } else { 325 query = selectNonDeletedTrees 326 } 327 328 stmt, err := t.tx.PrepareContext(ctx, query) 329 if err != nil { 330 return nil, err 331 } 332 defer stmt.Close() 333 rows, err := stmt.QueryContext(ctx) 334 if err != nil { 335 return nil, err 336 } 337 defer rows.Close() 338 trees := []*trillian.Tree{} 339 for rows.Next() { 340 tree, err := readTree(rows) 341 if err != nil { 342 return nil, err 343 } 344 trees = append(trees, tree) 345 } 346 return trees, nil 347 } 348 349 func (t *adminTX) CreateTree(ctx context.Context, tree *trillian.Tree) (*trillian.Tree, error) { 350 if err := storage.ValidateTreeForCreation(ctx, tree); err != nil { 351 return nil, err 352 } 353 if err := validateStorageSettings(tree); err != nil { 354 return nil, err 355 } 356 357 id, err := storage.NewTreeID() 358 if err != nil { 359 return nil, err 360 } 361 362 // Use the time truncated-to-millis throughout, as that's what's stored. 363 nowMillis := toMillisSinceEpoch(time.Now()) 364 now := fromMillisSinceEpoch(nowMillis) 365 366 newTree := *tree 367 newTree.TreeId = id 368 newTree.CreateTime, err = ptypes.TimestampProto(now) 369 if err != nil { 370 return nil, fmt.Errorf("failed to build create time: %v", err) 371 } 372 newTree.UpdateTime, err = ptypes.TimestampProto(now) 373 if err != nil { 374 return nil, fmt.Errorf("failed to build update time: %v", err) 375 } 376 rootDuration, err := ptypes.Duration(newTree.MaxRootDuration) 377 if err != nil { 378 return nil, fmt.Errorf("could not parse MaxRootDuration: %v", err) 379 } 380 381 insertTreeStmt, err := t.tx.PrepareContext( 382 ctx, 383 `INSERT INTO Trees( 384 TreeId, 385 TreeState, 386 TreeType, 387 HashStrategy, 388 HashAlgorithm, 389 SignatureAlgorithm, 390 DisplayName, 391 Description, 392 CreateTimeMillis, 393 UpdateTimeMillis, 394 PrivateKey, 395 PublicKey, 396 MaxRootDurationMillis) 397 VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`) 398 if err != nil { 399 return nil, err 400 } 401 defer insertTreeStmt.Close() 402 403 privateKey, err := proto.Marshal(newTree.PrivateKey) 404 if err != nil { 405 return nil, fmt.Errorf("could not marshal PrivateKey: %v", err) 406 } 407 408 _, err = insertTreeStmt.ExecContext( 409 ctx, 410 newTree.TreeId, 411 newTree.TreeState.String(), 412 newTree.TreeType.String(), 413 newTree.HashStrategy.String(), 414 newTree.HashAlgorithm.String(), 415 newTree.SignatureAlgorithm.String(), 416 newTree.DisplayName, 417 newTree.Description, 418 nowMillis, 419 nowMillis, 420 privateKey, 421 newTree.PublicKey.GetDer(), 422 rootDuration/time.Millisecond, 423 ) 424 if err != nil { 425 return nil, err 426 } 427 428 // MySQL silently truncates data when running in non-strict mode. 429 // We shouldn't be using non-strict modes, but let's guard against it 430 // anyway. 431 if _, err := t.GetTree(ctx, newTree.TreeId); err != nil { 432 // GetTree will fail for truncated enums (they get recorded as 433 // empty strings, which will not match any known value). 434 return nil, fmt.Errorf("enum truncated: %v", err) 435 } 436 437 insertControlStmt, err := t.tx.PrepareContext( 438 ctx, 439 `INSERT INTO TreeControl( 440 TreeId, 441 SigningEnabled, 442 SequencingEnabled, 443 SequenceIntervalSeconds) 444 VALUES(?, ?, ?, ?)`) 445 if err != nil { 446 return nil, err 447 } 448 defer insertControlStmt.Close() 449 _, err = insertControlStmt.ExecContext( 450 ctx, 451 newTree.TreeId, 452 true, /* SigningEnabled */ 453 true, /* SequencingEnabled */ 454 defaultSequenceIntervalSeconds, 455 ) 456 if err != nil { 457 return nil, err 458 } 459 460 return &newTree, nil 461 } 462 463 func (t *adminTX) UpdateTree(ctx context.Context, treeID int64, updateFunc func(*trillian.Tree)) (*trillian.Tree, error) { 464 tree, err := t.GetTree(ctx, treeID) 465 if err != nil { 466 return nil, err 467 } 468 469 beforeUpdate := *tree 470 updateFunc(tree) 471 if err := storage.ValidateTreeForUpdate(ctx, &beforeUpdate, tree); err != nil { 472 return nil, err 473 } 474 if err := validateStorageSettings(tree); err != nil { 475 return nil, err 476 } 477 478 // TODO(pavelkalinnikov): When switching TreeType from PREORDERED_LOG to LOG, 479 // ensure all entries in SequencedLeafData are integrated. 480 481 // Use the time truncated-to-millis throughout, as that's what's stored. 482 nowMillis := toMillisSinceEpoch(time.Now()) 483 now := fromMillisSinceEpoch(nowMillis) 484 tree.UpdateTime, err = ptypes.TimestampProto(now) 485 if err != nil { 486 return nil, fmt.Errorf("failed to build update time: %v", err) 487 } 488 rootDuration, err := ptypes.Duration(tree.MaxRootDuration) 489 if err != nil { 490 return nil, fmt.Errorf("could not parse MaxRootDuration: %v", err) 491 } 492 493 privateKey, err := proto.Marshal(tree.PrivateKey) 494 if err != nil { 495 return nil, fmt.Errorf("could not marshal PrivateKey: %v", err) 496 } 497 498 stmt, err := t.tx.PrepareContext(ctx, updateTreeSQL) 499 if err != nil { 500 return nil, err 501 } 502 defer stmt.Close() 503 504 if _, err = stmt.ExecContext( 505 ctx, 506 tree.TreeState.String(), 507 tree.TreeType.String(), 508 tree.DisplayName, 509 tree.Description, 510 nowMillis, 511 rootDuration/time.Millisecond, 512 privateKey, 513 tree.TreeId); err != nil { 514 return nil, err 515 } 516 517 return tree, nil 518 } 519 520 func (t *adminTX) SoftDeleteTree(ctx context.Context, treeID int64) (*trillian.Tree, error) { 521 return t.updateDeleted(ctx, treeID, true /* deleted */, toMillisSinceEpoch(time.Now()) /* deleteTimeMillis */) 522 } 523 524 func (t *adminTX) UndeleteTree(ctx context.Context, treeID int64) (*trillian.Tree, error) { 525 return t.updateDeleted(ctx, treeID, false /* deleted */, nil /* deleteTimeMillis */) 526 } 527 528 // updateDeleted updates the Deleted and DeleteTimeMillis fields of the specified tree. 529 // deleteTimeMillis must be either an int64 (in millis since epoch) or nil. 530 func (t *adminTX) updateDeleted(ctx context.Context, treeID int64, deleted bool, deleteTimeMillis interface{}) (*trillian.Tree, error) { 531 if err := validateDeleted(ctx, t.tx, treeID, !deleted); err != nil { 532 return nil, err 533 } 534 if _, err := t.tx.ExecContext( 535 ctx, 536 "UPDATE Trees SET Deleted = ?, DeleteTimeMillis = ? WHERE TreeId = ?", 537 deleted, deleteTimeMillis, treeID); err != nil { 538 return nil, err 539 } 540 return t.GetTree(ctx, treeID) 541 } 542 543 func (t *adminTX) HardDeleteTree(ctx context.Context, treeID int64) error { 544 if err := validateDeleted(ctx, t.tx, treeID, true /* wantDeleted */); err != nil { 545 return err 546 } 547 548 // TreeControl didn't have "ON DELETE CASCADE" on previous versions, so let's hit it explicitly 549 if _, err := t.tx.ExecContext(ctx, "DELETE FROM TreeControl WHERE TreeId = ?", treeID); err != nil { 550 return err 551 } 552 _, err := t.tx.ExecContext(ctx, "DELETE FROM Trees WHERE TreeId = ?", treeID) 553 return err 554 } 555 556 func validateDeleted(ctx context.Context, tx *sql.Tx, treeID int64, wantDeleted bool) error { 557 var nullDeleted sql.NullBool 558 switch err := tx.QueryRowContext(ctx, "SELECT Deleted FROM Trees WHERE TreeId = ?", treeID).Scan(&nullDeleted); { 559 case err == sql.ErrNoRows: 560 return status.Errorf(codes.NotFound, "tree %v not found", treeID) 561 case err != nil: 562 return err 563 } 564 565 switch deleted := nullDeleted.Valid && nullDeleted.Bool; { 566 case wantDeleted && !deleted: 567 return status.Errorf(codes.FailedPrecondition, "tree %v is not soft deleted", treeID) 568 case !wantDeleted && deleted: 569 return status.Errorf(codes.FailedPrecondition, "tree %v already soft deleted", treeID) 570 } 571 return nil 572 } 573 574 func toMillisSinceEpoch(t time.Time) int64 { 575 return t.UnixNano() / 1000000 576 } 577 578 func fromMillisSinceEpoch(ts int64) time.Time { 579 secs := int64(ts / 1000) 580 msecs := int64(ts % 1000) 581 return time.Unix(secs, msecs*1000000) 582 } 583 584 func validateStorageSettings(tree *trillian.Tree) error { 585 if tree.StorageSettings != nil { 586 return fmt.Errorf("storage_settings not supported, but got %v", tree.StorageSettings) 587 } 588 return nil 589 }