github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak/src/server/mysql/mysql_test.go (about)

     1  package mysql
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"os"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/google/fleetspeak/fleetspeak/src/server/db"
    12  	"github.com/google/fleetspeak/fleetspeak/src/server/dbtesting"
    13  
    14  	// We access the driver through sql.Open, but need to bring in the
    15  	// dependency.
    16  	_ "github.com/go-sql-driver/mysql"
    17  )
    18  
    19  type mysqlTestEnv struct {
    20  	user string
    21  	pass string
    22  	addr string
    23  
    24  	dbName string
    25  
    26  	aconn *sql.DB
    27  	conn  *sql.DB
    28  }
    29  
    30  func (e *mysqlTestEnv) Create() error {
    31  	ctx, fin := context.WithTimeout(context.Background(), 30*time.Second)
    32  	defer fin()
    33  
    34  	cs := fmt.Sprintf("%s:%s@tcp(%s)/", e.user, e.pass, e.addr)
    35  	var err error
    36  	e.aconn, err = sql.Open("mysql", cs)
    37  	if err != nil {
    38  		return fmt.Errorf("Unable to open connection [%s] to create database: %v", cs, err)
    39  	}
    40  	if _, err = e.aconn.ExecContext(ctx, "DROP DATABASE IF EXISTS "+e.dbName); err != nil {
    41  		return fmt.Errorf("Unable to drop database [%s]: %v", e.dbName, err)
    42  	}
    43  	if _, err = e.aconn.ExecContext(ctx, "CREATE DATABASE "+e.dbName); err != nil {
    44  		return fmt.Errorf("Unable to create database [%s]: %v", e.dbName, err)
    45  	}
    46  	if _, err = e.aconn.ExecContext(ctx, "USE "+e.dbName); err != nil {
    47  		return fmt.Errorf("Unable to use database [%s]: %v", e.dbName, err)
    48  	}
    49  
    50  	return nil
    51  }
    52  
    53  func (e *mysqlTestEnv) Clean() (db.Store, error) {
    54  	if e.conn != nil {
    55  		e.conn.Close()
    56  	}
    57  
    58  	ctx := context.Background()
    59  
    60  	rows, err := e.aconn.QueryContext(ctx, "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE table_schema=?", e.dbName)
    61  	if err != nil {
    62  		return nil, fmt.Errorf("Can't fetch list of tables: %v", err)
    63  	}
    64  
    65  	tables := make([]string, 0)
    66  
    67  	defer rows.Close()
    68  	for rows.Next() {
    69  		var tname string
    70  		if err := rows.Scan(&tname); err != nil {
    71  			return nil, err
    72  		}
    73  
    74  		tables = append(tables, tname)
    75  	}
    76  
    77  	if _, err = e.aconn.ExecContext(ctx, "SET FOREIGN_KEY_CHECKS=0"); err != nil {
    78  		return nil, fmt.Errorf("Unable to disable foreign key checks: %v", err)
    79  	}
    80  	defer e.aconn.ExecContext(ctx, "SET FOREIGN_KEY_CHECKS=1")
    81  
    82  	for _, tname := range tables {
    83  		if _, err = e.aconn.ExecContext(ctx, "TRUNCATE TABLE "+tname); err != nil {
    84  			return nil, fmt.Errorf("Unable to truncate table %v: %v", tname, err)
    85  		}
    86  	}
    87  
    88  	cs := fmt.Sprintf("%s:%s@tcp(%s)/%s", e.user, e.pass, e.addr, e.dbName)
    89  	e.conn, err = sql.Open("mysql", cs)
    90  	if err != nil {
    91  		return nil, fmt.Errorf("Unable to open connection [%s] to database: %v", cs, err)
    92  	}
    93  	s, err := MakeDatastore(e.conn)
    94  	if err != nil {
    95  		return nil, err
    96  	}
    97  
    98  	return s, nil
    99  }
   100  
   101  func (e *mysqlTestEnv) Destroy() error {
   102  	if e.conn != nil {
   103  		e.conn.Close()
   104  	}
   105  
   106  	ctx, fin := context.WithTimeout(context.Background(), 30*time.Second)
   107  	defer fin()
   108  	if _, err := e.aconn.ExecContext(ctx, "DROP DATABASE "+e.dbName); err != nil {
   109  		return fmt.Errorf("Unable to drop database [%s]: %v", e.dbName, err)
   110  	}
   111  	return e.aconn.Close()
   112  }
   113  
   114  func newMysqlTestEnv(user string, pass string, addr string) *mysqlTestEnv {
   115  	return &mysqlTestEnv{
   116  		user:   user,
   117  		pass:   pass,
   118  		addr:   addr,
   119  		dbName: "fleetspeaktestdb",
   120  	}
   121  }
   122  
   123  func TestMysqlStore(t *testing.T) {
   124  	var user = os.Getenv("MYSQL_TEST_USER")
   125  	var pass = os.Getenv("MYSQL_TEST_PASS")
   126  	var addr = os.Getenv("MYSQL_TEST_ADDR")
   127  
   128  	if user == "" {
   129  		t.Skip("MYSQL_TEST_USER not set")
   130  	}
   131  	if addr == "" {
   132  		t.Skip("MYSQL_TEST_ADDR not set")
   133  	}
   134  
   135  	dbtesting.DataStoreTestSuite(t, newMysqlTestEnv(user, pass, addr))
   136  }