go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/sdk/db/main_test.go (about)

     1  /*
     2  
     3  Copyright (c) 2023 - Present. Will Charczuk. All rights reserved.
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository.
     5  
     6  */
     7  
     8  package db
     9  
    10  import (
    11  	"database/sql"
    12  	"fmt"
    13  	"os"
    14  	"testing"
    15  	"time"
    16  
    17  	"go.charczuk.com/sdk/uuid"
    18  )
    19  
    20  // TestMain is the testing entrypoint.
    21  func TestMain(m *testing.M) {
    22  	conn, err := OpenTestConnection()
    23  	if err != nil {
    24  		fmt.Fprintf(os.Stderr, "%+v\n", err)
    25  		os.Exit(1)
    26  	}
    27  	setDefaultDB(conn)
    28  	os.Exit(m.Run())
    29  }
    30  
    31  var (
    32  	defaultConnection *Connection
    33  )
    34  
    35  func setDefaultDB(conn *Connection) {
    36  	defaultConnection = conn
    37  }
    38  
    39  func defaultDB() *Connection {
    40  	if defaultConnection == nil || defaultConnection.conn == nil {
    41  		panic("default connection is unset or unopened")
    42  	}
    43  	return defaultConnection
    44  }
    45  
    46  // OpenTestConnection opens a test connection from the environment, disabling ssl.
    47  //
    48  // You should not use this function in production like settings, this is why it is kept in the _test.go file.
    49  func OpenTestConnection(opts ...Option) (*Connection, error) {
    50  	defaultOptions := []Option{
    51  		OptConfigFromEnv(),
    52  		OptEngine(DefaultEngine),
    53  		OptHost("localhost"),
    54  		OptDatabase("postgres"),
    55  		OptSSLMode(SSLModeDisable),
    56  	}
    57  	conn, err := New(append(defaultOptions, opts...)...)
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  	err = conn.Open()
    62  	if err != nil {
    63  		return nil, err
    64  	}
    65  	_, err = conn.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;")
    66  	if err != nil {
    67  		return nil, err
    68  	}
    69  	return conn, nil
    70  }
    71  
    72  // helper types
    73  type benchObj struct {
    74  	ID        int       `db:"id,pk,auto"`
    75  	UUID      string    `db:"uuid"`
    76  	Name      string    `db:"name,uk"`
    77  	Timestamp time.Time `db:"timestamp_utc"`
    78  	Amount    float32   `db:"amount"`
    79  	Pending   bool      `db:"pending"`
    80  	Category  string    `db:"category"`
    81  }
    82  
    83  func (b *benchObj) Populate(rows Rows) error {
    84  	return rows.Scan(&b.ID, &b.UUID, &b.Name, &b.Timestamp, &b.Amount, &b.Pending, &b.Category)
    85  }
    86  
    87  func (b benchObj) TableName() string {
    88  	return "bench_object"
    89  }
    90  
    91  func createTable(tx *sql.Tx) error {
    92  	createSQL := `CREATE TABLE IF NOT EXISTS bench_object (
    93  		id serial not null primary key
    94  		, uuid uuid not null
    95  		, name varchar(255)
    96  		, timestamp_utc timestamp
    97  		, amount real
    98  		, pending boolean
    99  		, category varchar(255)
   100  	);`
   101  	_, err := defaultDB().Invoke(OptTx(tx)).Exec(createSQL)
   102  	return err
   103  }
   104  
   105  func createIndex(tx *sql.Tx) error {
   106  	createSQL := `CREATE UNIQUE INDEX ON bench_object (name)`
   107  	_, err := defaultDB().Invoke(OptTx(tx)).Exec(createSQL)
   108  	return err
   109  }
   110  
   111  func dropTableIfExists(tx *sql.Tx) error {
   112  	dropSQL := `DROP TABLE IF EXISTS bench_object;`
   113  	_, err := defaultDB().Invoke(OptTx(tx)).Exec(dropSQL)
   114  	return err
   115  }
   116  
   117  func ensureUUIDExtension() error {
   118  	uuidCreate := `CREATE EXTENSION IF NOT EXISTS "uuid-ossp";`
   119  	_, err := defaultDB().Exec(uuidCreate)
   120  	return err
   121  }
   122  
   123  func createObject(index int, tx *sql.Tx) error {
   124  	obj := benchObj{
   125  		Name:      fmt.Sprintf("test_object_%d", index),
   126  		UUID:      uuid.V4().String(),
   127  		Timestamp: time.Now().UTC(),
   128  		Amount:    1000.0 + (5.0 * float32(index)),
   129  		Pending:   index%2 == 0,
   130  		Category:  fmt.Sprintf("category_%d", index),
   131  	}
   132  	return defaultDB().Invoke(OptTx(tx)).Create(&obj)
   133  }
   134  
   135  func seedObjects(count int, tx *sql.Tx) error {
   136  	if err := ensureUUIDExtension(); err != nil {
   137  		return err
   138  	}
   139  	if err := dropTableIfExists(tx); err != nil {
   140  		return err
   141  	}
   142  
   143  	if err := createTable(tx); err != nil {
   144  		return err
   145  	}
   146  
   147  	for i := 0; i < count; i++ {
   148  		if err := createObject(i, tx); err != nil {
   149  			return err
   150  		}
   151  	}
   152  	return nil
   153  }
   154  
   155  type upsertObj struct {
   156  	UUID      uuid.UUID `db:"uuid,pk,auto"`
   157  	Timestamp time.Time `db:"timestamp_utc"`
   158  	Category  string    `db:"category"`
   159  }
   160  
   161  func (uo upsertObj) TableName() string {
   162  	return "upsert_object"
   163  }
   164  
   165  func createUpsertObjectTable(tx *sql.Tx) error {
   166  	createSQL := `CREATE TABLE IF NOT EXISTS upsert_object (uuid uuid primary key default gen_random_uuid(), timestamp_utc timestamp, category varchar(255));`
   167  	_, err := defaultDB().Invoke(OptTx(tx)).Exec(createSQL)
   168  	return err
   169  }