github.com/blend/go-sdk@v1.20220411.3/db/invocation.go (about) 1 /* 2 3 Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file. 5 6 */ 7 8 package db 9 10 import ( 11 "context" 12 "database/sql" 13 "fmt" 14 "reflect" 15 "strconv" 16 "time" 17 18 "github.com/blend/go-sdk/bufferutil" 19 "github.com/blend/go-sdk/ex" 20 "github.com/blend/go-sdk/logger" 21 ) 22 23 // Invocation is a specific operation against a context. 24 type Invocation struct { 25 DB DB 26 Label string 27 Context context.Context 28 Cancel func() 29 Config Config 30 Log logger.Triggerable 31 BufferPool *bufferutil.Pool 32 StatementInterceptor StatementInterceptor 33 Tracer Tracer 34 StartTime time.Time 35 TraceFinisher TraceFinisher 36 } 37 38 // Exec executes a sql statement with a given set of arguments and returns the rows affected. 39 func (i *Invocation) Exec(statement string, args ...interface{}) (res sql.Result, err error) { 40 statement, err = i.start(statement) 41 if err != nil { 42 return 43 } 44 defer func() { err = i.finish(statement, recover(), res, err) }() 45 46 res, err = i.DB.ExecContext(i.Context, statement, args...) 47 if err != nil { 48 err = Error(err) 49 return 50 } 51 return 52 } 53 54 // Query returns a new query object for a given sql query and arguments. 55 func (i *Invocation) Query(statement string, args ...interface{}) *Query { 56 q := &Query{ 57 Invocation: i, 58 Args: args, 59 } 60 q.Statement, q.Err = i.start(statement) 61 return q 62 } 63 64 func (i *Invocation) maybeSetLabel(label string) { 65 if i.Label != "" { 66 return 67 } 68 i.Label = label 69 } 70 71 // Get returns a given object based on a group of primary key ids within a transaction. 72 func (i *Invocation) Get(object DatabaseMapped, ids ...interface{}) (found bool, err error) { 73 if len(ids) == 0 { 74 err = Error(ErrInvalidIDs) 75 return 76 } 77 78 var queryBody, label string 79 if label, queryBody, err = i.generateGet(object); err != nil { 80 err = Error(err) 81 return 82 } 83 i.maybeSetLabel(label) 84 return i.Query(queryBody, ids...).Out(object) 85 } 86 87 // All returns all rows of an object mapped table wrapped in a transaction. 88 func (i *Invocation) All(collection interface{}) (err error) { 89 label, queryBody := i.generateGetAll(collection) 90 i.maybeSetLabel(label) 91 return i.Query(queryBody).OutMany(collection) 92 } 93 94 // Create writes an object to the database within a transaction. 95 func (i *Invocation) Create(object DatabaseMapped) (err error) { 96 var queryBody, label string 97 var insertCols, autos *ColumnCollection 98 var res sql.Result 99 defer func() { err = i.finish(queryBody, recover(), res, err) }() 100 101 label, queryBody, insertCols, autos = i.generateCreate(object) 102 i.maybeSetLabel(label) 103 104 queryBody, err = i.start(queryBody) 105 if err != nil { 106 return 107 } 108 if autos.Len() == 0 { 109 if res, err = i.DB.ExecContext(i.Context, queryBody, insertCols.ColumnValues(object)...); err != nil { 110 err = Error(err) 111 return 112 } 113 return 114 } 115 116 autoValues := i.autoValues(autos) 117 if err = i.DB.QueryRowContext(i.Context, queryBody, insertCols.ColumnValues(object)...).Scan(autoValues...); err != nil { 118 err = Error(err) 119 return 120 } 121 if err = i.setAutos(object, autos, autoValues); err != nil { 122 err = Error(err) 123 return 124 } 125 126 return 127 } 128 129 // CreateIfNotExists writes an object to the database if it does not already exist within a transaction. 130 // This will _ignore_ auto columns, as they will always invalidate the assertion that there already exists 131 // a row with a given primary key set. 132 func (i *Invocation) CreateIfNotExists(object DatabaseMapped) (err error) { 133 var queryBody, label string 134 var insertCols *ColumnCollection 135 var res sql.Result 136 defer func() { err = i.finish(queryBody, recover(), res, err) }() 137 138 label, queryBody, insertCols = i.generateCreateIfNotExists(object) 139 i.maybeSetLabel(label) 140 141 queryBody, err = i.start(queryBody) 142 if err != nil { 143 return 144 } 145 if res, err = i.DB.ExecContext(i.Context, queryBody, insertCols.ColumnValues(object)...); err != nil { 146 err = Error(err) 147 } 148 return 149 } 150 151 // CreateMany writes many objects to the database in a single insert. 152 func (i *Invocation) CreateMany(objects interface{}) (err error) { 153 return i.insertOrUpsertMany(objects, false) 154 } 155 156 // UpsertMany writes many objects to the database in a single upsert. 157 func (i *Invocation) UpsertMany(objects interface{}) (err error) { 158 return i.insertOrUpsertMany(objects, true) 159 } 160 161 // insertOrUpsertManinsertOrUpsertMany writes many objects to the database in a single insert or upsert. 162 func (i *Invocation) insertOrUpsertMany(objects interface{}, overwrite bool) (err error) { 163 var queryBody string 164 var insertCols *ColumnCollection 165 var sliceValue reflect.Value 166 var res sql.Result 167 defer func() { err = i.finish(queryBody, recover(), res, err) }() 168 169 if overwrite { 170 queryBody, insertCols, sliceValue = i.generateUpsertMany(objects) 171 } else { 172 queryBody, insertCols, sliceValue = i.generateCreateMany(objects) 173 } 174 if sliceValue.Len() == 0 { 175 // If there is nothing to create, then we're done here 176 return 177 } 178 179 queryBody, err = i.start(queryBody) 180 if err != nil { 181 return 182 } 183 var colValues []interface{} 184 for row := 0; row < sliceValue.Len(); row++ { 185 colValues = append(colValues, insertCols.ColumnValues(sliceValue.Index(row).Interface())...) 186 } 187 188 res, err = i.DB.ExecContext(i.Context, queryBody, colValues...) 189 if err != nil { 190 err = Error(err) 191 return 192 } 193 return 194 } 195 196 // Update updates an object wrapped in a transaction. Returns whether or not any rows have been updated and potentially 197 // an error. If ErrTooManyRows is returned, it's important to note that due to https://github.com/golang/go/issues/7898, 198 // the Update HAS BEEN APPLIED. Its on the developer using UPDATE to ensure his tags are correct and/or execute it in a 199 // transaction and roll back on this error 200 func (i *Invocation) Update(object DatabaseMapped) (updated bool, err error) { 201 var queryBody, label string 202 var pks, updateCols *ColumnCollection 203 var res sql.Result 204 defer func() { err = i.finish(queryBody, recover(), res, err) }() 205 206 label, queryBody, pks, updateCols = i.generateUpdate(object) 207 i.maybeSetLabel(label) 208 209 queryBody, err = i.start(queryBody) 210 if err != nil { 211 return 212 } 213 res, err = i.DB.ExecContext( 214 i.Context, 215 queryBody, 216 append(updateCols.ColumnValues(object), pks.ColumnValues(object)...)..., 217 ) 218 if err != nil { 219 err = Error(err) 220 return 221 } 222 223 var rowCount int64 224 rowCount, err = res.RowsAffected() 225 if err != nil { 226 err = Error(err) 227 return 228 } 229 if rowCount > 0 { 230 updated = true 231 } 232 if rowCount > 1 { 233 err = Error(ErrTooManyRows) 234 } 235 return 236 } 237 238 // Upsert inserts the object if it doesn't exist already (as defined by its primary keys) or updates it atomically. 239 // It returns `found` as true if the effect was an upsert, i.e. the pk was found. 240 func (i *Invocation) Upsert(object DatabaseMapped) (err error) { 241 var queryBody, label string 242 var autos, upsertCols *ColumnCollection 243 defer func() { err = i.finish(queryBody, recover(), nil, err) }() 244 245 i.Label, queryBody, autos, upsertCols = i.generateUpsert(object) 246 i.maybeSetLabel(label) 247 248 queryBody, err = i.start(queryBody) 249 if err != nil { 250 return 251 } 252 if autos.Len() == 0 { 253 if _, err = i.DB.ExecContext(i.Context, queryBody, upsertCols.ColumnValues(object)...); err != nil { 254 return 255 } 256 return 257 } 258 259 autoValues := i.autoValues(autos) 260 if err = i.DB.QueryRowContext(i.Context, queryBody, upsertCols.ColumnValues(object)...).Scan(autoValues...); err != nil { 261 err = Error(err) 262 return 263 } 264 if err = i.setAutos(object, autos, autoValues); err != nil { 265 err = Error(err) 266 return 267 } 268 return 269 } 270 271 // Exists returns a bool if a given object exists (utilizing the primary key columns if they exist) wrapped in a transaction. 272 func (i *Invocation) Exists(object DatabaseMapped) (exists bool, err error) { 273 var queryBody, label string 274 var pks *ColumnCollection 275 defer func() { err = i.finish(queryBody, recover(), nil, err) }() 276 277 if label, queryBody, pks, err = i.generateExists(object); err != nil { 278 err = Error(err) 279 return 280 } 281 i.maybeSetLabel(label) 282 queryBody, err = i.start(queryBody) 283 if err != nil { 284 return 285 } 286 var value int 287 if queryErr := i.DB.QueryRowContext(i.Context, queryBody, pks.ColumnValues(object)...).Scan(&value); queryErr != nil && !ex.Is(queryErr, sql.ErrNoRows) { 288 err = Error(queryErr) 289 return 290 } 291 exists = value != 0 292 return 293 } 294 295 // Delete deletes an object from the database wrapped in a transaction. Returns whether or not any rows have been deleted 296 // and potentially an error. If ErrTooManyRows is returned, it's important to note that due to 297 // https://github.com/golang/go/issues/7898, the Delete HAS BEEN APPLIED on the current transaction. Its on the 298 // developer using Delete to ensure their tags are correct and/or ensure theit Tx rolls back on this error. 299 func (i *Invocation) Delete(object DatabaseMapped) (deleted bool, err error) { 300 var queryBody, label string 301 var pks *ColumnCollection 302 var res sql.Result 303 defer func() { err = i.finish(queryBody, recover(), res, err) }() 304 305 if label, queryBody, pks, err = i.generateDelete(object); err != nil { 306 return 307 } 308 309 i.maybeSetLabel(label) 310 queryBody, err = i.start(queryBody) 311 if err != nil { 312 return 313 } 314 res, err = i.DB.ExecContext(i.Context, queryBody, pks.ColumnValues(object)...) 315 if err != nil { 316 err = Error(err) 317 return 318 } 319 320 var rowCount int64 321 rowCount, err = res.RowsAffected() 322 if err != nil { 323 err = Error(err) 324 return 325 } 326 if rowCount > 0 { 327 deleted = true 328 } 329 if rowCount > 1 { 330 err = Error(ErrTooManyRows) 331 } 332 return 333 } 334 335 // -------------------------------------------------------------------------------- 336 // query body generators 337 // -------------------------------------------------------------------------------- 338 339 func (i *Invocation) generateGet(object DatabaseMapped) (cachePlan, queryBody string, err error) { 340 tableName := TableName(object) 341 342 cols := Columns(object).NotReadOnly() 343 pks := cols.PrimaryKeys() 344 if pks.Len() == 0 { 345 err = Error(ErrNoPrimaryKey) 346 return 347 } 348 349 queryBodyBuffer := i.BufferPool.Get() 350 defer i.BufferPool.Put(queryBodyBuffer) 351 352 queryBodyBuffer.WriteString("SELECT ") 353 for i, name := range cols.ColumnNames() { 354 queryBodyBuffer.WriteString(name) 355 if i < (cols.Len() - 1) { 356 queryBodyBuffer.WriteRune(',') 357 } 358 } 359 360 queryBodyBuffer.WriteString(" FROM ") 361 queryBodyBuffer.WriteString(tableName) 362 queryBodyBuffer.WriteString(" WHERE ") 363 364 for i, pk := range pks.Columns() { 365 queryBodyBuffer.WriteString(pk.ColumnName) 366 queryBodyBuffer.WriteString(" = ") 367 queryBodyBuffer.WriteString("$" + strconv.Itoa(i+1)) 368 369 if i < (pks.Len() - 1) { 370 queryBodyBuffer.WriteString(" AND ") 371 } 372 } 373 374 cachePlan = fmt.Sprintf("%s_get", tableName) 375 queryBody = queryBodyBuffer.String() 376 return 377 } 378 379 func (i *Invocation) generateGetAll(collection interface{}) (statementLabel, queryBody string) { 380 collectionType := ReflectSliceType(collection) 381 tableName := TableNameByType(collectionType) 382 383 cols := ColumnsFromType(tableName, ReflectSliceType(collection)).NotReadOnly() 384 385 queryBodyBuffer := i.BufferPool.Get() 386 defer i.BufferPool.Put(queryBodyBuffer) 387 388 queryBodyBuffer.WriteString("SELECT ") 389 for i, name := range cols.ColumnNames() { 390 queryBodyBuffer.WriteString(name) 391 if i < (cols.Len() - 1) { 392 queryBodyBuffer.WriteRune(',') 393 } 394 } 395 queryBodyBuffer.WriteString(" FROM ") 396 queryBodyBuffer.WriteString(tableName) 397 398 queryBody = queryBodyBuffer.String() 399 statementLabel = tableName + "_get_all" 400 return 401 } 402 403 func (i *Invocation) generateCreate(object DatabaseMapped) (statementLabel, queryBody string, insertCols, autos *ColumnCollection) { 404 tableName := TableName(object) 405 406 cols := Columns(object) 407 insertCols = cols.InsertColumns().ConcatWith(cols.Autos().NotZero(object)) 408 autos = cols.Autos() 409 410 queryBodyBuffer := i.BufferPool.Get() 411 defer i.BufferPool.Put(queryBodyBuffer) 412 413 queryBodyBuffer.WriteString("INSERT INTO ") 414 queryBodyBuffer.WriteString(tableName) 415 queryBodyBuffer.WriteString(" (") 416 for i, name := range insertCols.ColumnNames() { 417 queryBodyBuffer.WriteString(name) 418 if i < (insertCols.Len() - 1) { 419 queryBodyBuffer.WriteRune(',') 420 } 421 } 422 queryBodyBuffer.WriteString(") VALUES (") 423 for x := 0; x < insertCols.Len(); x++ { 424 queryBodyBuffer.WriteString("$" + strconv.Itoa(x+1)) 425 if x < (insertCols.Len() - 1) { 426 queryBodyBuffer.WriteRune(',') 427 } 428 } 429 queryBodyBuffer.WriteString(")") 430 431 if autos.Len() > 0 { 432 queryBodyBuffer.WriteString(" RETURNING ") 433 queryBodyBuffer.WriteString(autos.ColumnNamesCSV()) 434 } 435 436 queryBody = queryBodyBuffer.String() 437 statementLabel = tableName + "_create" 438 return 439 } 440 441 func (i *Invocation) generateCreateIfNotExists(object DatabaseMapped) (statementLabel, queryBody string, insertCols *ColumnCollection) { 442 cols := Columns(object) 443 444 insertCols = cols.InsertColumns().ConcatWith(cols.Autos().NotZero(object)) 445 446 pks := cols.PrimaryKeys() 447 tableName := TableName(object) 448 449 queryBodyBuffer := i.BufferPool.Get() 450 defer i.BufferPool.Put(queryBodyBuffer) 451 452 queryBodyBuffer.WriteString("INSERT INTO ") 453 queryBodyBuffer.WriteString(tableName) 454 queryBodyBuffer.WriteString(" (") 455 for i, name := range insertCols.ColumnNames() { 456 queryBodyBuffer.WriteString(name) 457 if i < (insertCols.Len() - 1) { 458 queryBodyBuffer.WriteRune(',') 459 } 460 } 461 queryBodyBuffer.WriteString(") VALUES (") 462 for x := 0; x < insertCols.Len(); x++ { 463 queryBodyBuffer.WriteString("$" + strconv.Itoa(x+1)) 464 if x < (insertCols.Len() - 1) { 465 queryBodyBuffer.WriteRune(',') 466 } 467 } 468 queryBodyBuffer.WriteString(")") 469 470 if pks.Len() > 0 { 471 queryBodyBuffer.WriteString(" ON CONFLICT (") 472 pkColumnNames := pks.ColumnNames() 473 for i, name := range pkColumnNames { 474 queryBodyBuffer.WriteString(name) 475 if i < len(pkColumnNames)-1 { 476 queryBodyBuffer.WriteRune(',') 477 } 478 } 479 queryBodyBuffer.WriteString(") DO NOTHING") 480 } 481 482 queryBody = queryBodyBuffer.String() 483 statementLabel = tableName + "_create_if_not_exists" 484 return 485 } 486 487 func (i *Invocation) generateUpsertMany(objects interface{}) (queryBody string, insertCols *ColumnCollection, sliceValue reflect.Value) { 488 queryBodyInsertMany, insertCols, sliceValue := i.generateCreateMany(objects) 489 490 queryBodyBuffer := i.BufferPool.Get() 491 defer i.BufferPool.Put(queryBodyBuffer) 492 queryBodyBuffer.WriteString(queryBodyInsertMany) 493 494 uks := insertCols.UniqueKeys() 495 if uks.Len() > 0 { 496 queryBodyBuffer.WriteString(" ON CONFLICT (") 497 ukColumnNames := uks.ColumnNames() 498 for i, name := range ukColumnNames { 499 queryBodyBuffer.WriteString(name) 500 if i < len(ukColumnNames)-1 { 501 queryBodyBuffer.WriteRune(',') 502 } 503 } 504 queryBodyBuffer.WriteString(") DO UPDATE SET ") 505 506 for i, name := range insertCols.ColumnNames() { 507 queryBodyBuffer.WriteString(fmt.Sprintf("%s=Excluded.%s", name, name)) 508 if i < (insertCols.Len() - 1) { 509 queryBodyBuffer.WriteRune(',') 510 } 511 } 512 } 513 queryBody = queryBodyBuffer.String() 514 return 515 } 516 517 func (i *Invocation) generateCreateMany(objects interface{}) (queryBody string, insertCols *ColumnCollection, sliceValue reflect.Value) { 518 sliceValue = ReflectValue(objects) 519 sliceType := ReflectSliceType(objects) 520 tableName := TableNameByType(sliceType) 521 522 cols := ColumnsFromType(tableName, sliceType) 523 insertCols = cols.InsertColumns() 524 525 queryBodyBuffer := i.BufferPool.Get() 526 defer i.BufferPool.Put(queryBodyBuffer) 527 528 queryBodyBuffer.WriteString("INSERT INTO ") 529 queryBodyBuffer.WriteString(tableName) 530 queryBodyBuffer.WriteString(" (") 531 for i, name := range insertCols.ColumnNames() { 532 queryBodyBuffer.WriteString(name) 533 if i < (insertCols.Len() - 1) { 534 queryBodyBuffer.WriteRune(',') 535 } 536 } 537 538 queryBodyBuffer.WriteString(") VALUES ") 539 540 metaIndex := 1 541 for x := 0; x < sliceValue.Len(); x++ { 542 queryBodyBuffer.WriteString("(") 543 for y := 0; y < insertCols.Len(); y++ { 544 queryBodyBuffer.WriteString(fmt.Sprintf("$%d", metaIndex)) 545 metaIndex = metaIndex + 1 546 if y < insertCols.Len()-1 { 547 queryBodyBuffer.WriteRune(',') 548 } 549 } 550 queryBodyBuffer.WriteString(")") 551 if x < sliceValue.Len()-1 { 552 queryBodyBuffer.WriteRune(',') 553 } 554 } 555 556 queryBody = queryBodyBuffer.String() 557 return 558 } 559 560 func (i *Invocation) generateUpdate(object DatabaseMapped) (statementLabel, queryBody string, pks, updateCols *ColumnCollection) { 561 tableName := TableName(object) 562 563 cols := Columns(object) 564 565 pks = cols.PrimaryKeys() 566 updateCols = cols.UpdateColumns() 567 568 queryBodyBuffer := i.BufferPool.Get() 569 defer i.BufferPool.Put(queryBodyBuffer) 570 571 queryBodyBuffer.WriteString("UPDATE ") 572 queryBodyBuffer.WriteString(tableName) 573 queryBodyBuffer.WriteString(" SET ") 574 575 var updateColIndex int 576 var col Column 577 for ; updateColIndex < updateCols.Len(); updateColIndex++ { 578 col = updateCols.Columns()[updateColIndex] 579 queryBodyBuffer.WriteString(col.ColumnName) 580 queryBodyBuffer.WriteString(" = $" + strconv.Itoa(updateColIndex+1)) 581 if updateColIndex != (updateCols.Len() - 1) { 582 queryBodyBuffer.WriteRune(',') 583 } 584 } 585 586 queryBodyBuffer.WriteString(" WHERE ") 587 for i, pk := range pks.Columns() { 588 queryBodyBuffer.WriteString(pk.ColumnName) 589 queryBodyBuffer.WriteString(" = ") 590 queryBodyBuffer.WriteString("$" + strconv.Itoa(i+(updateColIndex+1))) 591 592 if i < (pks.Len() - 1) { 593 queryBodyBuffer.WriteString(" AND ") 594 } 595 } 596 597 queryBody = queryBodyBuffer.String() 598 statementLabel = tableName + "_update" 599 return 600 } 601 602 func (i *Invocation) generateUpsert(object DatabaseMapped) (statementLabel, queryBody string, autos, insertsWithAutos *ColumnCollection) { 603 tableName := TableName(object) 604 cols := Columns(object) 605 updates := cols.UpdateColumns() 606 updateCols := updates.Columns() 607 608 // We add in all the autos columns to start 609 insertsWithAutos = cols.InsertColumns().ConcatWith(cols.Autos()) 610 pks := insertsWithAutos.PrimaryKeys() 611 612 // But we exclude auto primary keys that are not set. Auto primary keys that ARE set must be included in the insert 613 // clause so that there is a collision. But keys that are not set must be excluded from insertsWithAutos so that 614 // they are not passed as an extra parameter to ExecInContext later and are properly auto-generated 615 for _, col := range pks.Columns() { 616 if col.IsAuto && !cols.NotZero(object).HasColumn(col.ColumnName) { 617 insertsWithAutos.Remove(col.ColumnName) 618 } 619 } 620 621 insertCols := insertsWithAutos.Columns() 622 tokenMap := map[string]string{} 623 for i, col := range insertCols { 624 tokenMap[col.ColumnName] = "$" + strconv.Itoa(i+1) 625 } 626 627 // autos are read out on insert (but only if unset) 628 autos = cols.Autos().Zero(object) 629 pkNames := pks.ColumnNames() 630 631 queryBodyBuffer := i.BufferPool.Get() 632 defer i.BufferPool.Put(queryBodyBuffer) 633 634 queryBodyBuffer.WriteString("INSERT INTO ") 635 queryBodyBuffer.WriteString(tableName) 636 queryBodyBuffer.WriteString(" (") 637 638 skipComma := true 639 for _, col := range insertCols { 640 if !col.IsAuto || cols.NotZero(object).HasColumn(col.ColumnName) { 641 if !skipComma { 642 queryBodyBuffer.WriteRune(',') 643 } 644 skipComma = false 645 queryBodyBuffer.WriteString(col.ColumnName) 646 } 647 } 648 649 queryBodyBuffer.WriteString(") VALUES (") 650 skipComma = true 651 for _, col := range insertsWithAutos.Columns() { 652 if !col.IsAuto || cols.NotZero(object).HasColumn(col.ColumnName) { 653 if !skipComma { 654 queryBodyBuffer.WriteRune(',') 655 } 656 skipComma = false 657 queryBodyBuffer.WriteString(tokenMap[col.ColumnName]) 658 } 659 } 660 661 queryBodyBuffer.WriteString(")") 662 663 if pks.Len() > 0 { 664 queryBodyBuffer.WriteString(" ON CONFLICT (") 665 666 for i, name := range pkNames { 667 queryBodyBuffer.WriteString(name) 668 if i < len(pkNames)-1 { 669 queryBodyBuffer.WriteRune(',') 670 } 671 } 672 queryBodyBuffer.WriteString(") DO UPDATE SET ") 673 674 for i, col := range updateCols { 675 queryBodyBuffer.WriteString(col.ColumnName + " = " + tokenMap[col.ColumnName]) 676 if i < (len(updateCols) - 1) { 677 queryBodyBuffer.WriteRune(',') 678 } 679 } 680 } 681 if autos.Len() > 0 { 682 queryBodyBuffer.WriteString(" RETURNING ") 683 queryBodyBuffer.WriteString(autos.ColumnNamesCSV()) 684 } 685 686 queryBody = queryBodyBuffer.String() 687 statementLabel = tableName + "_upsert" 688 return 689 } 690 691 func (i *Invocation) generateExists(object DatabaseMapped) (statementLabel, queryBody string, pks *ColumnCollection, err error) { 692 tableName := TableName(object) 693 pks = Columns(object).PrimaryKeys() 694 if pks.Len() == 0 { 695 err = Error(ErrNoPrimaryKey) 696 return 697 } 698 queryBodyBuffer := i.BufferPool.Get() 699 defer i.BufferPool.Put(queryBodyBuffer) 700 701 queryBodyBuffer.WriteString("SELECT 1 FROM ") 702 queryBodyBuffer.WriteString(tableName) 703 queryBodyBuffer.WriteString(" WHERE ") 704 for i, pk := range pks.Columns() { 705 queryBodyBuffer.WriteString(pk.ColumnName) 706 queryBodyBuffer.WriteString(" = ") 707 queryBodyBuffer.WriteString("$" + strconv.Itoa(i+1)) 708 709 if i < (pks.Len() - 1) { 710 queryBodyBuffer.WriteString(" AND ") 711 } 712 } 713 statementLabel = tableName + "_exists" 714 queryBody = queryBodyBuffer.String() 715 return 716 } 717 718 func (i *Invocation) generateDelete(object DatabaseMapped) (statementLabel, queryBody string, pks *ColumnCollection, err error) { 719 tableName := TableName(object) 720 pks = Columns(object).PrimaryKeys() 721 if len(pks.Columns()) == 0 { 722 err = Error(ErrNoPrimaryKey) 723 return 724 } 725 queryBodyBuffer := i.BufferPool.Get() 726 defer i.BufferPool.Put(queryBodyBuffer) 727 728 queryBodyBuffer.WriteString("DELETE FROM ") 729 queryBodyBuffer.WriteString(tableName) 730 queryBodyBuffer.WriteString(" WHERE ") 731 for i, pk := range pks.Columns() { 732 queryBodyBuffer.WriteString(pk.ColumnName) 733 queryBodyBuffer.WriteString(" = ") 734 queryBodyBuffer.WriteString("$" + strconv.Itoa(i+1)) 735 736 if i < (pks.Len() - 1) { 737 queryBodyBuffer.WriteString(" AND ") 738 } 739 } 740 statementLabel = tableName + "_delete" 741 queryBody = queryBodyBuffer.String() 742 return 743 } 744 745 // -------------------------------------------------------------------------------- 746 // helpers 747 // -------------------------------------------------------------------------------- 748 749 // autoValues returns references to the auto updatd fields for a given column collection. 750 func (i *Invocation) autoValues(autos *ColumnCollection) []interface{} { 751 autoValues := make([]interface{}, autos.Len()) 752 for i, autoCol := range autos.Columns() { 753 autoValues[i] = reflect.New(reflect.PtrTo(autoCol.FieldType)).Interface() 754 } 755 return autoValues 756 } 757 758 // setAutos sets the automatic values for a given object. 759 func (i *Invocation) setAutos(object DatabaseMapped, autos *ColumnCollection, autoValues []interface{}) (err error) { 760 for index := 0; index < len(autoValues); index++ { 761 err = autos.Columns()[index].SetValue(object, autoValues[index]) 762 if err != nil { 763 err = Error(err) 764 return 765 } 766 } 767 return 768 } 769 770 // start runs on start steps. 771 func (i *Invocation) start(statement string) (string, error) { 772 if i.DB == nil { 773 return "", ex.New(ErrConnectionClosed) 774 } 775 i.StartTime = time.Now() 776 if i.StatementInterceptor != nil { 777 var err error 778 statement, err = i.StatementInterceptor(i.Context, i.Label, statement) 779 if err != nil { 780 return statement, err 781 } 782 } 783 if i.Log != nil && !IsSkipQueryLogging(i.Context) { 784 qse := NewQueryStartEvent(statement) 785 qse.Username = i.Config.Username 786 qse.Database = i.Config.DatabaseOrDefault() 787 qse.Label = i.Label 788 qse.Engine = i.Config.EngineOrDefault() 789 i.Log.TriggerContext(i.Context, qse) 790 } 791 if i.Tracer != nil && !IsSkipQueryLogging(i.Context) { 792 i.TraceFinisher = i.Tracer.Query(i.Context, i.Config, i.Label, statement) 793 } 794 return statement, nil 795 } 796 797 // finish runs on complete steps. 798 func (i *Invocation) finish(statement string, r interface{}, res sql.Result, err error) error { 799 if i.Cancel != nil { 800 i.Cancel() 801 } 802 if r != nil { 803 err = ex.Nest(err, ex.New(r)) 804 } 805 if i.Log != nil && !IsSkipQueryLogging(i.Context) { 806 qe := NewQueryEvent(statement, time.Now().UTC().Sub(i.StartTime)) 807 qe.Username = i.Config.Username 808 qe.Database = i.Config.DatabaseOrDefault() 809 qe.Label = i.Label 810 qe.Engine = i.Config.EngineOrDefault() 811 qe.Err = err 812 i.Log.TriggerContext(i.Context, qe) 813 } 814 if i.TraceFinisher != nil && !IsSkipQueryLogging(i.Context) { 815 i.TraceFinisher.FinishQuery(i.Context, res, err) 816 } 817 if err != nil { 818 err = Error(err, ex.OptMessage(statement)) 819 } 820 return err 821 }