github.com/rkt/rkt@v1.30.1-0.20200224141603-171c416fac02/store/db/db_test.go (about) 1 // Copyright 2015 The rkt Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package db 16 17 import ( 18 "database/sql" 19 "fmt" 20 "io/ioutil" 21 "os" 22 "runtime" 23 "testing" 24 25 "github.com/rkt/rkt/tests/testutils" 26 ) 27 28 func queryValue(query string, tx *sql.Tx) (int, error) { 29 var value int 30 rows, err := tx.Query(query) 31 if err != nil { 32 return -1, err 33 } 34 defer rows.Close() 35 36 if !rows.Next() { 37 return -1, fmt.Errorf("no result of %q", query) 38 } 39 if err := rows.Scan(&value); err != nil { 40 return -1, err 41 } 42 return value, nil 43 } 44 45 func insertValue(db *DB) error { 46 return db.Do(func(tx *sql.Tx) error { 47 // Get the current count. 48 count, err := queryValue("SELECT count(*) FROM rkt_db_test", tx) 49 if err != nil { 50 return err 51 } 52 // Increase the count. 53 _, err = tx.Exec(fmt.Sprintf("INSERT INTO rkt_db_test VALUES (%d)", count+1)) 54 return err 55 }) 56 } 57 58 func getMaxCount(db *DB, t *testing.T) int { 59 var maxCount int 60 var err error 61 if err := db.Do(func(tx *sql.Tx) error { 62 // Get the maximum count. 63 maxCount, err = queryValue("SELECT max(counts) FROM rkt_db_test", tx) 64 return err 65 }); err != nil { 66 t.Fatalf("Failed to get the maximum count: %v", err) 67 } 68 return maxCount 69 } 70 71 func createTable(db *DB, t *testing.T) { 72 if err := db.Do(func(tx *sql.Tx) error { 73 _, err := tx.Exec("CREATE TABLE rkt_db_test (counts int)") 74 return err 75 }); err != nil { 76 t.Fatalf("Unexpected error: %v", err) 77 } 78 } 79 80 func TestDBRace(t *testing.T) { 81 oldGoMaxProcs := runtime.GOMAXPROCS(runtime.NumCPU()) 82 defer runtime.GOMAXPROCS(oldGoMaxProcs) 83 84 dir, err := ioutil.TempDir("", "rkt_db_test") 85 if err != nil { 86 t.Fatalf("Unexpected error: %v", err) 87 } 88 defer os.RemoveAll(dir) 89 90 db, err := NewDB(dir) 91 if err != nil { 92 t.Fatalf("Unexpected error: %v", err) 93 } 94 95 // Create the table. 96 createTable(db, t) 97 98 // Insert values concurrently. 99 ga := testutils.NewGoroutineAssistant(t) 100 runs := 100 101 ga.Add(runs) 102 for i := 0; i < runs; i++ { 103 go func() { 104 if err := insertValue(db); err != nil { 105 ga.Fatalf("Failed to insert value: %v", err) 106 } 107 ga.Done() 108 }() 109 } 110 ga.Wait() 111 112 // Check the final values. 113 maxCount := getMaxCount(db, t) 114 if maxCount != runs { 115 t.Errorf("Expected: %v, saw: %v", runs, maxCount) 116 } 117 }