github.com/bartle-stripe/trillian@v1.2.1/storage/testdb/testdb.go (about) 1 // Copyright 2017 Google Inc. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package testdb creates new databases for tests. 16 package testdb 17 18 import ( 19 "bytes" 20 "context" 21 "database/sql" 22 "fmt" 23 "io/ioutil" 24 "log" 25 "strings" 26 "testing" 27 "time" 28 29 "github.com/google/trillian/testonly" 30 31 _ "github.com/go-sql-driver/mysql" // mysql driver 32 ) 33 34 var ( 35 trillianSQL = testonly.RelativeToPackage("../mysql/storage.sql") 36 dataSource = "root@/" 37 ) 38 39 // MySQLAvailable indicates whether a default MySQL database is available. 40 func MySQLAvailable() bool { 41 db, err := sql.Open("mysql", dataSource) 42 if err != nil { 43 log.Printf("sql.Open(): %v", err) 44 return false 45 } 46 defer db.Close() 47 if err := db.Ping(); err != nil { 48 log.Printf("db.Ping(): %v", err) 49 return false 50 } 51 return true 52 } 53 54 // newEmptyDB creates a new, empty database. 55 func newEmptyDB(ctx context.Context) (*sql.DB, error) { 56 db, err := sql.Open("mysql", dataSource) 57 if err != nil { 58 return nil, err 59 } 60 61 // Create a randomly-named database and then connect using the new name. 62 name := fmt.Sprintf("trl_%v", time.Now().UnixNano()) 63 64 stmt := fmt.Sprintf("CREATE DATABASE %v", name) 65 if _, err := db.ExecContext(ctx, stmt); err != nil { 66 return nil, fmt.Errorf("error running statement %q: %v", stmt, err) 67 } 68 69 db.Close() 70 db, err = sql.Open("mysql", dataSource+name) 71 if err != nil { 72 return nil, err 73 } 74 75 return db, db.Ping() 76 } 77 78 // NewTrillianDB creates an empty database with the Trillian schema. The database name is randomly 79 // generated. 80 // NewTrillianDB is equivalent to Default().NewTrillianDB(ctx). 81 func NewTrillianDB(ctx context.Context) (*sql.DB, error) { 82 db, err := newEmptyDB(ctx) 83 if err != nil { 84 return nil, err 85 } 86 87 sqlBytes, err := ioutil.ReadFile(trillianSQL) 88 if err != nil { 89 return nil, err 90 } 91 92 for _, stmt := range strings.Split(sanitize(string(sqlBytes)), ";") { 93 stmt = strings.TrimSpace(stmt) 94 if stmt == "" { 95 continue 96 } 97 if _, err := db.ExecContext(ctx, stmt); err != nil { 98 return nil, fmt.Errorf("error running statement %q: %v", stmt, err) 99 } 100 } 101 return db, nil 102 } 103 104 func sanitize(script string) string { 105 buf := &bytes.Buffer{} 106 for _, line := range strings.Split(string(script), "\n") { 107 line = strings.TrimSpace(line) 108 if line == "" || line[0] == '#' || strings.Index(line, "--") == 0 { 109 continue // skip empty lines and comments 110 } 111 buf.WriteString(line) 112 buf.WriteString("\n") 113 } 114 return buf.String() 115 } 116 117 // SkipIfNoMySQL is a test helper that skips tests that require a local MySQL. 118 func SkipIfNoMySQL(t *testing.T) { 119 t.Helper() 120 if !MySQLAvailable() { 121 t.Skip("Skipping test as MySQL not available") 122 } 123 }