github.com/rkt/rkt@v1.30.1-0.20200224141603-171c416fac02/store/db/db_test.go (about)

     1  // Copyright 2015 The rkt Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package db
    16  
    17  import (
    18  	"database/sql"
    19  	"fmt"
    20  	"io/ioutil"
    21  	"os"
    22  	"runtime"
    23  	"testing"
    24  
    25  	"github.com/rkt/rkt/tests/testutils"
    26  )
    27  
    28  func queryValue(query string, tx *sql.Tx) (int, error) {
    29  	var value int
    30  	rows, err := tx.Query(query)
    31  	if err != nil {
    32  		return -1, err
    33  	}
    34  	defer rows.Close()
    35  
    36  	if !rows.Next() {
    37  		return -1, fmt.Errorf("no result of %q", query)
    38  	}
    39  	if err := rows.Scan(&value); err != nil {
    40  		return -1, err
    41  	}
    42  	return value, nil
    43  }
    44  
    45  func insertValue(db *DB) error {
    46  	return db.Do(func(tx *sql.Tx) error {
    47  		// Get the current count.
    48  		count, err := queryValue("SELECT count(*) FROM rkt_db_test", tx)
    49  		if err != nil {
    50  			return err
    51  		}
    52  		// Increase the count.
    53  		_, err = tx.Exec(fmt.Sprintf("INSERT INTO rkt_db_test VALUES (%d)", count+1))
    54  		return err
    55  	})
    56  }
    57  
    58  func getMaxCount(db *DB, t *testing.T) int {
    59  	var maxCount int
    60  	var err error
    61  	if err := db.Do(func(tx *sql.Tx) error {
    62  		// Get the maximum count.
    63  		maxCount, err = queryValue("SELECT max(counts) FROM rkt_db_test", tx)
    64  		return err
    65  	}); err != nil {
    66  		t.Fatalf("Failed to get the maximum count: %v", err)
    67  	}
    68  	return maxCount
    69  }
    70  
    71  func createTable(db *DB, t *testing.T) {
    72  	if err := db.Do(func(tx *sql.Tx) error {
    73  		_, err := tx.Exec("CREATE TABLE rkt_db_test (counts int)")
    74  		return err
    75  	}); err != nil {
    76  		t.Fatalf("Unexpected error: %v", err)
    77  	}
    78  }
    79  
    80  func TestDBRace(t *testing.T) {
    81  	oldGoMaxProcs := runtime.GOMAXPROCS(runtime.NumCPU())
    82  	defer runtime.GOMAXPROCS(oldGoMaxProcs)
    83  
    84  	dir, err := ioutil.TempDir("", "rkt_db_test")
    85  	if err != nil {
    86  		t.Fatalf("Unexpected error: %v", err)
    87  	}
    88  	defer os.RemoveAll(dir)
    89  
    90  	db, err := NewDB(dir)
    91  	if err != nil {
    92  		t.Fatalf("Unexpected error: %v", err)
    93  	}
    94  
    95  	// Create the table.
    96  	createTable(db, t)
    97  
    98  	// Insert values concurrently.
    99  	ga := testutils.NewGoroutineAssistant(t)
   100  	runs := 100
   101  	ga.Add(runs)
   102  	for i := 0; i < runs; i++ {
   103  		go func() {
   104  			if err := insertValue(db); err != nil {
   105  				ga.Fatalf("Failed to insert value: %v", err)
   106  			}
   107  			ga.Done()
   108  		}()
   109  	}
   110  	ga.Wait()
   111  
   112  	// Check the final values.
   113  	maxCount := getMaxCount(db, t)
   114  	if maxCount != runs {
   115  		t.Errorf("Expected: %v, saw: %v", runs, maxCount)
   116  	}
   117  }