github.com/blend/go-sdk@v1.20220411.3/db/main_test.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package db
     9  
    10  import (
    11  	"context"
    12  	"database/sql"
    13  	"fmt"
    14  	"os"
    15  	"testing"
    16  	"time"
    17  
    18  	"github.com/blend/go-sdk/uuid"
    19  )
    20  
    21  //------------------------------------------------------------------------------------------------
    22  // Testing Entrypoint
    23  //------------------------------------------------------------------------------------------------
    24  
    25  // TestMain is the testing entrypoint.
    26  func TestMain(m *testing.M) {
    27  	conn, err := OpenTestConnection()
    28  	if err != nil {
    29  		fmt.Fprintf(os.Stderr, "%+v\n", err)
    30  		os.Exit(1)
    31  	}
    32  
    33  	setDefaultDB(conn)
    34  	os.Exit(m.Run())
    35  }
    36  
    37  // BenchmarkMain is the benchmarking entrypoint.
    38  func BenchmarkMain(b *testing.B) {
    39  	tx, err := defaultDB().Begin()
    40  	if err != nil {
    41  		b.Error("Unable to create transaction")
    42  		b.FailNow()
    43  	}
    44  	if tx == nil {
    45  		b.Error("`tx` is nil")
    46  		b.FailNow()
    47  	}
    48  
    49  	defer func() {
    50  		if tx != nil {
    51  			if err := tx.Rollback(); err != nil {
    52  				b.Errorf("Error rolling back transaction: %v", err)
    53  				b.FailNow()
    54  			}
    55  		}
    56  	}()
    57  
    58  	err = seedObjects(10000, tx)
    59  	if err != nil {
    60  		b.Errorf("Error seeding objects: %v", err)
    61  		b.FailNow()
    62  	}
    63  
    64  	var manual time.Duration
    65  	for x := 0; x < b.N*10; x++ {
    66  		manualStart := time.Now()
    67  		_, err = readManual(tx)
    68  		if err != nil {
    69  			b.Errorf("Error using manual query: %v", err)
    70  			b.FailNow()
    71  		}
    72  		manual += time.Since(manualStart)
    73  	}
    74  
    75  	var orm time.Duration
    76  	for x := 0; x < b.N*10; x++ {
    77  		ormStart := time.Now()
    78  		_, err = readOrm(tx)
    79  		if err != nil {
    80  			b.Errorf("Error using orm: %v", err)
    81  			b.FailNow()
    82  		}
    83  		orm += time.Since(ormStart)
    84  	}
    85  
    86  	var ormCached time.Duration
    87  	for x := 0; x < b.N*10; x++ {
    88  		ormCachedStart := time.Now()
    89  		_, err = readCachedOrm(tx)
    90  		if err != nil {
    91  			b.Errorf("Error using orm: %v", err)
    92  			b.FailNow()
    93  		}
    94  		ormCached += time.Since(ormCachedStart)
    95  	}
    96  
    97  	b.Logf("Benchmark Test Results:\nManual: %v \nOrm: %v\nOrm (Cached Plan): %v\n", manual, orm, ormCached)
    98  }
    99  
   100  // OpenTestConnection opens a test connection from the environment, disabling ssl.
   101  //
   102  // You should not use this function in production like settings, this is why it is kept in the _test.go file.
   103  func OpenTestConnection(opts ...Option) (*Connection, error) {
   104  	defaultOptions := []Option{OptConfigFromEnv(), OptSSLMode(SSLModeDisable)}
   105  	conn, err := Open(New(append(defaultOptions, opts...)...))
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  
   110  	_, err = conn.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;")
   111  	if err != nil {
   112  		return nil, err
   113  	}
   114  
   115  	return conn, nil
   116  }
   117  
   118  //------------------------------------------------------------------------------------------------
   119  // Util Types
   120  //------------------------------------------------------------------------------------------------
   121  
   122  type upsertObj struct {
   123  	UUID      uuid.UUID `db:"uuid,pk,auto"`
   124  	Timestamp time.Time `db:"timestamp_utc"`
   125  	Category  string    `db:"category"`
   126  }
   127  
   128  func (uo upsertObj) TableName() string {
   129  	return "upsert_object"
   130  }
   131  
   132  func createUpsertObjectTable(tx *sql.Tx) error {
   133  	createSQL := `CREATE TABLE IF NOT EXISTS upsert_object (uuid uuid primary key default gen_random_uuid(), timestamp_utc timestamp, category varchar(255));`
   134  	return IgnoreExecResult(defaultDB().Invoke(OptTx(tx)).Exec(createSQL))
   135  }
   136  
   137  type upsertNoAutosObj struct {
   138  	UUID      uuid.UUID `db:"uuid,pk"`
   139  	Timestamp time.Time `db:"timestamp_utc"`
   140  	Category  string    `db:"category"`
   141  }
   142  
   143  func (uo upsertNoAutosObj) TableName() string {
   144  	return "upsert_no_autos_object"
   145  }
   146  
   147  func createUpsertNoAutosObjectTable(tx *sql.Tx) error {
   148  	createSQL := `CREATE TABLE IF NOT EXISTS upsert_no_autos_object (uuid varchar(255) primary key, timestamp_utc timestamp, category varchar(255));`
   149  	return IgnoreExecResult(defaultDB().Invoke(OptTx(tx)).Exec(createSQL))
   150  }
   151  
   152  //------------------------------------------------------------------------------------------------
   153  // Benchmarking
   154  //------------------------------------------------------------------------------------------------
   155  
   156  type benchObj struct {
   157  	ID        int       `db:"id,pk,auto"`
   158  	UUID      string    `db:"uuid"`
   159  	Name      string    `db:"name,uk"`
   160  	Timestamp time.Time `db:"timestamp_utc"`
   161  	Amount    float32   `db:"amount"`
   162  	Pending   bool      `db:"pending"`
   163  	Category  string    `db:"category"`
   164  }
   165  
   166  func (b *benchObj) Populate(rows Scanner) error {
   167  	return rows.Scan(&b.ID, &b.UUID, &b.Name, &b.Timestamp, &b.Amount, &b.Pending, &b.Category)
   168  }
   169  
   170  func (b benchObj) TableName() string {
   171  	return "bench_object"
   172  }
   173  
   174  func createTable(tx *sql.Tx) error {
   175  	createSQL := `CREATE TABLE IF NOT EXISTS bench_object (
   176  		id serial not null primary key
   177  		, uuid uuid not null
   178  		, name varchar(255)
   179  		, timestamp_utc timestamp
   180  		, amount real
   181  		, pending boolean
   182  		, category varchar(255)
   183  	);`
   184  	return IgnoreExecResult(defaultDB().Invoke(OptTx(tx)).Exec(createSQL))
   185  }
   186  
   187  func createIndex(tx *sql.Tx) error {
   188  	createSQL := `CREATE UNIQUE INDEX ON bench_object (name)`
   189  	return IgnoreExecResult(defaultDB().Invoke(OptTx(tx)).Exec(createSQL))
   190  }
   191  
   192  func dropTableIfExists(tx *sql.Tx) error {
   193  	dropSQL := `DROP TABLE IF EXISTS bench_object;`
   194  	return IgnoreExecResult(defaultDB().Invoke(OptTx(tx)).Exec(dropSQL))
   195  }
   196  
   197  func ensureUUIDExtension() error {
   198  	uuidCreate := `CREATE EXTENSION IF NOT EXISTS "uuid-ossp";`
   199  	return IgnoreExecResult(defaultDB().Exec(uuidCreate))
   200  }
   201  
   202  func createObject(index int, tx *sql.Tx) error {
   203  	obj := benchObj{
   204  		Name:      fmt.Sprintf("test_object_%d", index),
   205  		UUID:      uuid.V4().String(),
   206  		Timestamp: time.Now().UTC(),
   207  		Amount:    1000.0 + (5.0 * float32(index)),
   208  		Pending:   index%2 == 0,
   209  		Category:  fmt.Sprintf("category_%d", index),
   210  	}
   211  	return defaultDB().Invoke(OptTx(tx)).Create(&obj)
   212  }
   213  
   214  func seedObjects(count int, tx *sql.Tx) error {
   215  	if err := ensureUUIDExtension(); err != nil {
   216  		return err
   217  	}
   218  	if err := dropTableIfExists(tx); err != nil {
   219  		return err
   220  	}
   221  
   222  	if err := createTable(tx); err != nil {
   223  		return err
   224  	}
   225  
   226  	for i := 0; i < count; i++ {
   227  		if err := createObject(i, tx); err != nil {
   228  			return err
   229  		}
   230  	}
   231  	return nil
   232  }
   233  
   234  func readManual(tx *sql.Tx) ([]benchObj, error) {
   235  	var objs []benchObj
   236  	readSQL := `select id,uuid,name,timestamp_utc,amount,pending,category from bench_object`
   237  	readStmt, err := defaultDB().PrepareContext(context.Background(), readSQL, tx)
   238  	if err != nil {
   239  		return nil, err
   240  	}
   241  	defer readStmt.Close()
   242  
   243  	rows, err := readStmt.Query()
   244  	defer func() { _ = rows.Close() }()
   245  	if err != nil {
   246  		return nil, err
   247  	}
   248  
   249  	for rows.Next() {
   250  		obj := &benchObj{}
   251  		err = obj.Populate(rows)
   252  		if err != nil {
   253  			return nil, err
   254  		}
   255  		objs = append(objs, *obj)
   256  	}
   257  
   258  	return objs, nil
   259  }
   260  
   261  func readOrm(tx *sql.Tx) ([]benchObj, error) {
   262  	var objs []benchObj
   263  	allErr := defaultDB().Invoke(OptTx(tx)).Query(fmt.Sprintf("select %s from bench_object", ColumnNamesCSV(benchObj{}))).OutMany(&objs)
   264  	return objs, allErr
   265  }
   266  
   267  func readCachedOrm(tx *sql.Tx) ([]benchObj, error) {
   268  	var objs []benchObj
   269  	allErr := defaultDB().Invoke(OptTx(tx), OptLabel("get_all_bench_object")).Query(fmt.Sprintf("select %s from bench_object", ColumnNamesCSV(benchObj{}))).OutMany(&objs)
   270  	return objs, allErr
   271  }
   272  
   273  var (
   274  	defaultConnection *Connection
   275  )
   276  
   277  func setDefaultDB(conn *Connection) {
   278  	defaultConnection = conn
   279  }
   280  
   281  func defaultDB() *Connection {
   282  	return defaultConnection
   283  }
   284  
   285  type mockTracer struct {
   286  	PrepareHandler       func(context.Context, Config, string)
   287  	QueryHandler         func(context.Context, Config, string, string) TraceFinisher
   288  	FinishPrepareHandler func(context.Context, error)
   289  	FinishQueryHandler   func(context.Context, sql.Result, error)
   290  }
   291  
   292  func (mt mockTracer) Prepare(ctx context.Context, cfg Config, statement string) TraceFinisher {
   293  	if mt.PrepareHandler != nil {
   294  		mt.PrepareHandler(ctx, cfg, statement)
   295  	}
   296  	return mockTraceFinisher{
   297  		FinishPrepareHandler: mt.FinishPrepareHandler,
   298  		FinishQueryHandler:   mt.FinishQueryHandler,
   299  	}
   300  }
   301  
   302  func (mt mockTracer) Query(ctx context.Context, cfg Config, label, statement string) TraceFinisher {
   303  	if mt.PrepareHandler != nil {
   304  		mt.PrepareHandler(ctx, cfg, statement)
   305  	}
   306  	return mockTraceFinisher{
   307  		FinishPrepareHandler: mt.FinishPrepareHandler,
   308  		FinishQueryHandler:   mt.FinishQueryHandler,
   309  	}
   310  }
   311  
   312  type mockTraceFinisher struct {
   313  	FinishPrepareHandler func(context.Context, error)
   314  	FinishQueryHandler   func(context.Context, sql.Result, error)
   315  }
   316  
   317  func (mtf mockTraceFinisher) FinishPrepare(ctx context.Context, err error) {
   318  	if mtf.FinishPrepareHandler != nil {
   319  		mtf.FinishPrepareHandler(ctx, err)
   320  	}
   321  }
   322  
   323  func (mtf mockTraceFinisher) FinishQuery(ctx context.Context, res sql.Result, err error) {
   324  	if mtf.FinishQueryHandler != nil {
   325  		mtf.FinishQueryHandler(ctx, res, err)
   326  	}
   327  }
   328  
   329  var (
   330  	_ Tracer = (*captureStatementTracer)(nil)
   331  )
   332  
   333  type captureStatementTracer struct {
   334  	Tracer
   335  
   336  	Label     string
   337  	Statement string
   338  	Err       error
   339  }
   340  
   341  func (cst *captureStatementTracer) Query(_ context.Context, cfg Config, label string, statement string) TraceFinisher {
   342  	cst.Label = label
   343  	cst.Statement = statement
   344  	return &captureStatementTracerFinisher{cst}
   345  }
   346  
   347  type captureStatementTracerFinisher struct {
   348  	*captureStatementTracer
   349  }
   350  
   351  func (cstf *captureStatementTracerFinisher) FinishPrepare(context.Context, error) {}
   352  func (cstf *captureStatementTracerFinisher) FinishQuery(_ context.Context, _ sql.Result, err error) {
   353  	cstf.captureStatementTracer.Err = err
   354  }
   355  
   356  var failInterceptorError = "this is just an interceptor error"
   357  
   358  func failInterceptor(_ context.Context, _, statement string) (string, error) {
   359  	return "", fmt.Errorf(failInterceptorError)
   360  }
   361  
   362  type uniqueObj struct {
   363  	ID   int    `db:"id,pk"`
   364  	Name string `db:"name"`
   365  }
   366  
   367  // TableName returns the mapped table name.
   368  func (uo uniqueObj) TableName() string {
   369  	return "unique_obj"
   370  }
   371  
   372  type uuidTest struct {
   373  	ID   uuid.UUID `db:"id"`
   374  	Name string    `db:"name"`
   375  }
   376  
   377  func (ut uuidTest) TableName() string {
   378  	return "uuid_test"
   379  }
   380  
   381  type EmbeddedTestMeta struct {
   382  	ID           uuid.UUID `db:"id,pk"`
   383  	TimestampUTC time.Time `db:"timestamp_utc"`
   384  }
   385  
   386  type embeddedTest struct {
   387  	EmbeddedTestMeta `db:",inline"`
   388  	Name             string `db:"name"`
   389  }
   390  
   391  func (et embeddedTest) TableName() string {
   392  	return "embedded_test"
   393  }
   394  
   395  type jsonTestChild struct {
   396  	Label string `json:"label"`
   397  }
   398  
   399  type jsonTest struct {
   400  	ID   int    `db:"id,pk,auto"`
   401  	Name string `db:"name"`
   402  
   403  	NotNull  jsonTestChild `db:"not_null,json"`
   404  	Nullable []string      `db:"nullable,json"`
   405  }
   406  
   407  func (jt jsonTest) TableName() string {
   408  	return "json_test"
   409  }
   410  
   411  func secondArgErr(_ interface{}, err error) error {
   412  	return err
   413  }
   414  
   415  func createJSONTestTable(tx *sql.Tx) error {
   416  	return IgnoreExecResult(defaultDB().Invoke(OptTx(tx)).Exec("create table json_test (id serial primary key, name varchar(255), not_null json, nullable json)"))
   417  }
   418  
   419  func dropJSONTextTable(tx *sql.Tx) error {
   420  	return IgnoreExecResult(defaultDB().Invoke(OptTx(tx)).Exec("drop table if exists json_test"))
   421  }
   422  
   423  func createUpsertAutosRegressionTable(tx *sql.Tx) error {
   424  	schemaDefinition := `CREATE TABLE upsert_auto_regression (
   425  		id uuid not null,
   426  		status smallint not null,
   427  		required boolean not null default false,
   428  		created_at timestamp default current_timestamp,
   429  		updated_at timestamp,
   430  		migrated_at timestamp
   431  	);`
   432  	schemaPrimaryKey := "ALTER TABLE upsert_auto_regression ADD CONSTRAINT pk_upsert_auto_regression_id PRIMARY KEY (id);"
   433  	if _, err := defaultDB().Invoke(OptTx(tx)).Exec(schemaDefinition); err != nil {
   434  		return err
   435  	}
   436  	if _, err := defaultDB().Invoke(OptTx(tx)).Exec(schemaPrimaryKey); err != nil {
   437  		return err
   438  	}
   439  	return nil
   440  }
   441  
   442  func dropUpsertRegressionTable(tx *sql.Tx) error {
   443  	_, err := defaultDB().Invoke(OptTx(tx)).Exec("DROP TABLE upsert_auto_regression")
   444  	return err
   445  }
   446  
   447  func createUpsertSerialPKTable(tx *sql.Tx) error {
   448  	schemaDefinition := `CREATE TABLE upsert_serial_pk (
   449  		id serial not null primary key,
   450  		status smallint not null,
   451  		required boolean not null default false,
   452  		created_at timestamp default current_timestamp,
   453  		updated_at timestamp,
   454  		migrated_at timestamp
   455  	);`
   456  	if _, err := defaultDB().Invoke(OptTx(tx)).Exec(schemaDefinition); err != nil {
   457  		return err
   458  	}
   459  	return nil
   460  }
   461  
   462  func dropUpsertSerialPKTable(tx *sql.Tx) error {
   463  	_, err := defaultDB().Invoke(OptTx(tx)).Exec("DROP TABLE upsert_serial_pk")
   464  	return err
   465  }
   466  
   467  // upsertAutoRegression contains all data associated with an envelope of documents.
   468  type upsertAutoRegression struct {
   469  	ID         uuid.UUID  `db:"id,pk"`
   470  	Status     uint8      `db:"status"`
   471  	Required   bool       `db:"required"`
   472  	CreatedAt  *time.Time `db:"created_at,auto"`
   473  	UpdatedAt  *time.Time `db:"updated_at,auto"`
   474  	MigratedAt *time.Time `db:"migrated_at"`
   475  	ReadOnly   string     `db:"read_only,readonly"`
   476  }
   477  
   478  // TableName returns the table name.
   479  func (uar upsertAutoRegression) TableName() string {
   480  	return "upsert_auto_regression"
   481  }
   482  
   483  type upsertSerialPK struct {
   484  	ID         int        `db:"id,pk,serial"`
   485  	Status     uint8      `db:"status"`
   486  	Required   bool       `db:"required"`
   487  	CreatedAt  *time.Time `db:"created_at,auto"`
   488  	UpdatedAt  *time.Time `db:"updated_at,auto"`
   489  	MigratedAt *time.Time `db:"migrated_at"`
   490  	ReadOnly   string     `db:"read_only,readonly"`
   491  }
   492  
   493  // TableName returns the table name.
   494  func (uar upsertSerialPK) TableName() string {
   495  	return "upsert_serial_pk"
   496  }