github.com/blend/go-sdk@v1.20220411.3/db/main_test.go (about) 1 /* 2 3 Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file. 5 6 */ 7 8 package db 9 10 import ( 11 "context" 12 "database/sql" 13 "fmt" 14 "os" 15 "testing" 16 "time" 17 18 "github.com/blend/go-sdk/uuid" 19 ) 20 21 //------------------------------------------------------------------------------------------------ 22 // Testing Entrypoint 23 //------------------------------------------------------------------------------------------------ 24 25 // TestMain is the testing entrypoint. 26 func TestMain(m *testing.M) { 27 conn, err := OpenTestConnection() 28 if err != nil { 29 fmt.Fprintf(os.Stderr, "%+v\n", err) 30 os.Exit(1) 31 } 32 33 setDefaultDB(conn) 34 os.Exit(m.Run()) 35 } 36 37 // BenchmarkMain is the benchmarking entrypoint. 38 func BenchmarkMain(b *testing.B) { 39 tx, err := defaultDB().Begin() 40 if err != nil { 41 b.Error("Unable to create transaction") 42 b.FailNow() 43 } 44 if tx == nil { 45 b.Error("`tx` is nil") 46 b.FailNow() 47 } 48 49 defer func() { 50 if tx != nil { 51 if err := tx.Rollback(); err != nil { 52 b.Errorf("Error rolling back transaction: %v", err) 53 b.FailNow() 54 } 55 } 56 }() 57 58 err = seedObjects(10000, tx) 59 if err != nil { 60 b.Errorf("Error seeding objects: %v", err) 61 b.FailNow() 62 } 63 64 var manual time.Duration 65 for x := 0; x < b.N*10; x++ { 66 manualStart := time.Now() 67 _, err = readManual(tx) 68 if err != nil { 69 b.Errorf("Error using manual query: %v", err) 70 b.FailNow() 71 } 72 manual += time.Since(manualStart) 73 } 74 75 var orm time.Duration 76 for x := 0; x < b.N*10; x++ { 77 ormStart := time.Now() 78 _, err = readOrm(tx) 79 if err != nil { 80 b.Errorf("Error using orm: %v", err) 81 b.FailNow() 82 } 83 orm += time.Since(ormStart) 84 } 85 86 var ormCached time.Duration 87 for x := 0; x < b.N*10; x++ { 88 ormCachedStart := time.Now() 89 _, err = readCachedOrm(tx) 90 if err != nil { 91 b.Errorf("Error using orm: %v", err) 92 b.FailNow() 93 } 94 ormCached += time.Since(ormCachedStart) 95 } 96 97 b.Logf("Benchmark Test Results:\nManual: %v \nOrm: %v\nOrm (Cached Plan): %v\n", manual, orm, ormCached) 98 } 99 100 // OpenTestConnection opens a test connection from the environment, disabling ssl. 101 // 102 // You should not use this function in production like settings, this is why it is kept in the _test.go file. 103 func OpenTestConnection(opts ...Option) (*Connection, error) { 104 defaultOptions := []Option{OptConfigFromEnv(), OptSSLMode(SSLModeDisable)} 105 conn, err := Open(New(append(defaultOptions, opts...)...)) 106 if err != nil { 107 return nil, err 108 } 109 110 _, err = conn.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;") 111 if err != nil { 112 return nil, err 113 } 114 115 return conn, nil 116 } 117 118 //------------------------------------------------------------------------------------------------ 119 // Util Types 120 //------------------------------------------------------------------------------------------------ 121 122 type upsertObj struct { 123 UUID uuid.UUID `db:"uuid,pk,auto"` 124 Timestamp time.Time `db:"timestamp_utc"` 125 Category string `db:"category"` 126 } 127 128 func (uo upsertObj) TableName() string { 129 return "upsert_object" 130 } 131 132 func createUpsertObjectTable(tx *sql.Tx) error { 133 createSQL := `CREATE TABLE IF NOT EXISTS upsert_object (uuid uuid primary key default gen_random_uuid(), timestamp_utc timestamp, category varchar(255));` 134 return IgnoreExecResult(defaultDB().Invoke(OptTx(tx)).Exec(createSQL)) 135 } 136 137 type upsertNoAutosObj struct { 138 UUID uuid.UUID `db:"uuid,pk"` 139 Timestamp time.Time `db:"timestamp_utc"` 140 Category string `db:"category"` 141 } 142 143 func (uo upsertNoAutosObj) TableName() string { 144 return "upsert_no_autos_object" 145 } 146 147 func createUpsertNoAutosObjectTable(tx *sql.Tx) error { 148 createSQL := `CREATE TABLE IF NOT EXISTS upsert_no_autos_object (uuid varchar(255) primary key, timestamp_utc timestamp, category varchar(255));` 149 return IgnoreExecResult(defaultDB().Invoke(OptTx(tx)).Exec(createSQL)) 150 } 151 152 //------------------------------------------------------------------------------------------------ 153 // Benchmarking 154 //------------------------------------------------------------------------------------------------ 155 156 type benchObj struct { 157 ID int `db:"id,pk,auto"` 158 UUID string `db:"uuid"` 159 Name string `db:"name,uk"` 160 Timestamp time.Time `db:"timestamp_utc"` 161 Amount float32 `db:"amount"` 162 Pending bool `db:"pending"` 163 Category string `db:"category"` 164 } 165 166 func (b *benchObj) Populate(rows Scanner) error { 167 return rows.Scan(&b.ID, &b.UUID, &b.Name, &b.Timestamp, &b.Amount, &b.Pending, &b.Category) 168 } 169 170 func (b benchObj) TableName() string { 171 return "bench_object" 172 } 173 174 func createTable(tx *sql.Tx) error { 175 createSQL := `CREATE TABLE IF NOT EXISTS bench_object ( 176 id serial not null primary key 177 , uuid uuid not null 178 , name varchar(255) 179 , timestamp_utc timestamp 180 , amount real 181 , pending boolean 182 , category varchar(255) 183 );` 184 return IgnoreExecResult(defaultDB().Invoke(OptTx(tx)).Exec(createSQL)) 185 } 186 187 func createIndex(tx *sql.Tx) error { 188 createSQL := `CREATE UNIQUE INDEX ON bench_object (name)` 189 return IgnoreExecResult(defaultDB().Invoke(OptTx(tx)).Exec(createSQL)) 190 } 191 192 func dropTableIfExists(tx *sql.Tx) error { 193 dropSQL := `DROP TABLE IF EXISTS bench_object;` 194 return IgnoreExecResult(defaultDB().Invoke(OptTx(tx)).Exec(dropSQL)) 195 } 196 197 func ensureUUIDExtension() error { 198 uuidCreate := `CREATE EXTENSION IF NOT EXISTS "uuid-ossp";` 199 return IgnoreExecResult(defaultDB().Exec(uuidCreate)) 200 } 201 202 func createObject(index int, tx *sql.Tx) error { 203 obj := benchObj{ 204 Name: fmt.Sprintf("test_object_%d", index), 205 UUID: uuid.V4().String(), 206 Timestamp: time.Now().UTC(), 207 Amount: 1000.0 + (5.0 * float32(index)), 208 Pending: index%2 == 0, 209 Category: fmt.Sprintf("category_%d", index), 210 } 211 return defaultDB().Invoke(OptTx(tx)).Create(&obj) 212 } 213 214 func seedObjects(count int, tx *sql.Tx) error { 215 if err := ensureUUIDExtension(); err != nil { 216 return err 217 } 218 if err := dropTableIfExists(tx); err != nil { 219 return err 220 } 221 222 if err := createTable(tx); err != nil { 223 return err 224 } 225 226 for i := 0; i < count; i++ { 227 if err := createObject(i, tx); err != nil { 228 return err 229 } 230 } 231 return nil 232 } 233 234 func readManual(tx *sql.Tx) ([]benchObj, error) { 235 var objs []benchObj 236 readSQL := `select id,uuid,name,timestamp_utc,amount,pending,category from bench_object` 237 readStmt, err := defaultDB().PrepareContext(context.Background(), readSQL, tx) 238 if err != nil { 239 return nil, err 240 } 241 defer readStmt.Close() 242 243 rows, err := readStmt.Query() 244 defer func() { _ = rows.Close() }() 245 if err != nil { 246 return nil, err 247 } 248 249 for rows.Next() { 250 obj := &benchObj{} 251 err = obj.Populate(rows) 252 if err != nil { 253 return nil, err 254 } 255 objs = append(objs, *obj) 256 } 257 258 return objs, nil 259 } 260 261 func readOrm(tx *sql.Tx) ([]benchObj, error) { 262 var objs []benchObj 263 allErr := defaultDB().Invoke(OptTx(tx)).Query(fmt.Sprintf("select %s from bench_object", ColumnNamesCSV(benchObj{}))).OutMany(&objs) 264 return objs, allErr 265 } 266 267 func readCachedOrm(tx *sql.Tx) ([]benchObj, error) { 268 var objs []benchObj 269 allErr := defaultDB().Invoke(OptTx(tx), OptLabel("get_all_bench_object")).Query(fmt.Sprintf("select %s from bench_object", ColumnNamesCSV(benchObj{}))).OutMany(&objs) 270 return objs, allErr 271 } 272 273 var ( 274 defaultConnection *Connection 275 ) 276 277 func setDefaultDB(conn *Connection) { 278 defaultConnection = conn 279 } 280 281 func defaultDB() *Connection { 282 return defaultConnection 283 } 284 285 type mockTracer struct { 286 PrepareHandler func(context.Context, Config, string) 287 QueryHandler func(context.Context, Config, string, string) TraceFinisher 288 FinishPrepareHandler func(context.Context, error) 289 FinishQueryHandler func(context.Context, sql.Result, error) 290 } 291 292 func (mt mockTracer) Prepare(ctx context.Context, cfg Config, statement string) TraceFinisher { 293 if mt.PrepareHandler != nil { 294 mt.PrepareHandler(ctx, cfg, statement) 295 } 296 return mockTraceFinisher{ 297 FinishPrepareHandler: mt.FinishPrepareHandler, 298 FinishQueryHandler: mt.FinishQueryHandler, 299 } 300 } 301 302 func (mt mockTracer) Query(ctx context.Context, cfg Config, label, statement string) TraceFinisher { 303 if mt.PrepareHandler != nil { 304 mt.PrepareHandler(ctx, cfg, statement) 305 } 306 return mockTraceFinisher{ 307 FinishPrepareHandler: mt.FinishPrepareHandler, 308 FinishQueryHandler: mt.FinishQueryHandler, 309 } 310 } 311 312 type mockTraceFinisher struct { 313 FinishPrepareHandler func(context.Context, error) 314 FinishQueryHandler func(context.Context, sql.Result, error) 315 } 316 317 func (mtf mockTraceFinisher) FinishPrepare(ctx context.Context, err error) { 318 if mtf.FinishPrepareHandler != nil { 319 mtf.FinishPrepareHandler(ctx, err) 320 } 321 } 322 323 func (mtf mockTraceFinisher) FinishQuery(ctx context.Context, res sql.Result, err error) { 324 if mtf.FinishQueryHandler != nil { 325 mtf.FinishQueryHandler(ctx, res, err) 326 } 327 } 328 329 var ( 330 _ Tracer = (*captureStatementTracer)(nil) 331 ) 332 333 type captureStatementTracer struct { 334 Tracer 335 336 Label string 337 Statement string 338 Err error 339 } 340 341 func (cst *captureStatementTracer) Query(_ context.Context, cfg Config, label string, statement string) TraceFinisher { 342 cst.Label = label 343 cst.Statement = statement 344 return &captureStatementTracerFinisher{cst} 345 } 346 347 type captureStatementTracerFinisher struct { 348 *captureStatementTracer 349 } 350 351 func (cstf *captureStatementTracerFinisher) FinishPrepare(context.Context, error) {} 352 func (cstf *captureStatementTracerFinisher) FinishQuery(_ context.Context, _ sql.Result, err error) { 353 cstf.captureStatementTracer.Err = err 354 } 355 356 var failInterceptorError = "this is just an interceptor error" 357 358 func failInterceptor(_ context.Context, _, statement string) (string, error) { 359 return "", fmt.Errorf(failInterceptorError) 360 } 361 362 type uniqueObj struct { 363 ID int `db:"id,pk"` 364 Name string `db:"name"` 365 } 366 367 // TableName returns the mapped table name. 368 func (uo uniqueObj) TableName() string { 369 return "unique_obj" 370 } 371 372 type uuidTest struct { 373 ID uuid.UUID `db:"id"` 374 Name string `db:"name"` 375 } 376 377 func (ut uuidTest) TableName() string { 378 return "uuid_test" 379 } 380 381 type EmbeddedTestMeta struct { 382 ID uuid.UUID `db:"id,pk"` 383 TimestampUTC time.Time `db:"timestamp_utc"` 384 } 385 386 type embeddedTest struct { 387 EmbeddedTestMeta `db:",inline"` 388 Name string `db:"name"` 389 } 390 391 func (et embeddedTest) TableName() string { 392 return "embedded_test" 393 } 394 395 type jsonTestChild struct { 396 Label string `json:"label"` 397 } 398 399 type jsonTest struct { 400 ID int `db:"id,pk,auto"` 401 Name string `db:"name"` 402 403 NotNull jsonTestChild `db:"not_null,json"` 404 Nullable []string `db:"nullable,json"` 405 } 406 407 func (jt jsonTest) TableName() string { 408 return "json_test" 409 } 410 411 func secondArgErr(_ interface{}, err error) error { 412 return err 413 } 414 415 func createJSONTestTable(tx *sql.Tx) error { 416 return IgnoreExecResult(defaultDB().Invoke(OptTx(tx)).Exec("create table json_test (id serial primary key, name varchar(255), not_null json, nullable json)")) 417 } 418 419 func dropJSONTextTable(tx *sql.Tx) error { 420 return IgnoreExecResult(defaultDB().Invoke(OptTx(tx)).Exec("drop table if exists json_test")) 421 } 422 423 func createUpsertAutosRegressionTable(tx *sql.Tx) error { 424 schemaDefinition := `CREATE TABLE upsert_auto_regression ( 425 id uuid not null, 426 status smallint not null, 427 required boolean not null default false, 428 created_at timestamp default current_timestamp, 429 updated_at timestamp, 430 migrated_at timestamp 431 );` 432 schemaPrimaryKey := "ALTER TABLE upsert_auto_regression ADD CONSTRAINT pk_upsert_auto_regression_id PRIMARY KEY (id);" 433 if _, err := defaultDB().Invoke(OptTx(tx)).Exec(schemaDefinition); err != nil { 434 return err 435 } 436 if _, err := defaultDB().Invoke(OptTx(tx)).Exec(schemaPrimaryKey); err != nil { 437 return err 438 } 439 return nil 440 } 441 442 func dropUpsertRegressionTable(tx *sql.Tx) error { 443 _, err := defaultDB().Invoke(OptTx(tx)).Exec("DROP TABLE upsert_auto_regression") 444 return err 445 } 446 447 func createUpsertSerialPKTable(tx *sql.Tx) error { 448 schemaDefinition := `CREATE TABLE upsert_serial_pk ( 449 id serial not null primary key, 450 status smallint not null, 451 required boolean not null default false, 452 created_at timestamp default current_timestamp, 453 updated_at timestamp, 454 migrated_at timestamp 455 );` 456 if _, err := defaultDB().Invoke(OptTx(tx)).Exec(schemaDefinition); err != nil { 457 return err 458 } 459 return nil 460 } 461 462 func dropUpsertSerialPKTable(tx *sql.Tx) error { 463 _, err := defaultDB().Invoke(OptTx(tx)).Exec("DROP TABLE upsert_serial_pk") 464 return err 465 } 466 467 // upsertAutoRegression contains all data associated with an envelope of documents. 468 type upsertAutoRegression struct { 469 ID uuid.UUID `db:"id,pk"` 470 Status uint8 `db:"status"` 471 Required bool `db:"required"` 472 CreatedAt *time.Time `db:"created_at,auto"` 473 UpdatedAt *time.Time `db:"updated_at,auto"` 474 MigratedAt *time.Time `db:"migrated_at"` 475 ReadOnly string `db:"read_only,readonly"` 476 } 477 478 // TableName returns the table name. 479 func (uar upsertAutoRegression) TableName() string { 480 return "upsert_auto_regression" 481 } 482 483 type upsertSerialPK struct { 484 ID int `db:"id,pk,serial"` 485 Status uint8 `db:"status"` 486 Required bool `db:"required"` 487 CreatedAt *time.Time `db:"created_at,auto"` 488 UpdatedAt *time.Time `db:"updated_at,auto"` 489 MigratedAt *time.Time `db:"migrated_at"` 490 ReadOnly string `db:"read_only,readonly"` 491 } 492 493 // TableName returns the table name. 494 func (uar upsertSerialPK) TableName() string { 495 return "upsert_serial_pk" 496 }