github.com/blend/go-sdk@v1.20220411.3/db/invocation.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package db
     9  
    10  import (
    11  	"context"
    12  	"database/sql"
    13  	"fmt"
    14  	"reflect"
    15  	"strconv"
    16  	"time"
    17  
    18  	"github.com/blend/go-sdk/bufferutil"
    19  	"github.com/blend/go-sdk/ex"
    20  	"github.com/blend/go-sdk/logger"
    21  )
    22  
    23  // Invocation is a specific operation against a context.
    24  type Invocation struct {
    25  	DB                   DB
    26  	Label                string
    27  	Context              context.Context
    28  	Cancel               func()
    29  	Config               Config
    30  	Log                  logger.Triggerable
    31  	BufferPool           *bufferutil.Pool
    32  	StatementInterceptor StatementInterceptor
    33  	Tracer               Tracer
    34  	StartTime            time.Time
    35  	TraceFinisher        TraceFinisher
    36  }
    37  
    38  // Exec executes a sql statement with a given set of arguments and returns the rows affected.
    39  func (i *Invocation) Exec(statement string, args ...interface{}) (res sql.Result, err error) {
    40  	statement, err = i.start(statement)
    41  	if err != nil {
    42  		return
    43  	}
    44  	defer func() { err = i.finish(statement, recover(), res, err) }()
    45  
    46  	res, err = i.DB.ExecContext(i.Context, statement, args...)
    47  	if err != nil {
    48  		err = Error(err)
    49  		return
    50  	}
    51  	return
    52  }
    53  
    54  // Query returns a new query object for a given sql query and arguments.
    55  func (i *Invocation) Query(statement string, args ...interface{}) *Query {
    56  	q := &Query{
    57  		Invocation: i,
    58  		Args:       args,
    59  	}
    60  	q.Statement, q.Err = i.start(statement)
    61  	return q
    62  }
    63  
    64  func (i *Invocation) maybeSetLabel(label string) {
    65  	if i.Label != "" {
    66  		return
    67  	}
    68  	i.Label = label
    69  }
    70  
    71  // Get returns a given object based on a group of primary key ids within a transaction.
    72  func (i *Invocation) Get(object DatabaseMapped, ids ...interface{}) (found bool, err error) {
    73  	if len(ids) == 0 {
    74  		err = Error(ErrInvalidIDs)
    75  		return
    76  	}
    77  
    78  	var queryBody, label string
    79  	if label, queryBody, err = i.generateGet(object); err != nil {
    80  		err = Error(err)
    81  		return
    82  	}
    83  	i.maybeSetLabel(label)
    84  	return i.Query(queryBody, ids...).Out(object)
    85  }
    86  
    87  // All returns all rows of an object mapped table wrapped in a transaction.
    88  func (i *Invocation) All(collection interface{}) (err error) {
    89  	label, queryBody := i.generateGetAll(collection)
    90  	i.maybeSetLabel(label)
    91  	return i.Query(queryBody).OutMany(collection)
    92  }
    93  
    94  // Create writes an object to the database within a transaction.
    95  func (i *Invocation) Create(object DatabaseMapped) (err error) {
    96  	var queryBody, label string
    97  	var insertCols, autos *ColumnCollection
    98  	var res sql.Result
    99  	defer func() { err = i.finish(queryBody, recover(), res, err) }()
   100  
   101  	label, queryBody, insertCols, autos = i.generateCreate(object)
   102  	i.maybeSetLabel(label)
   103  
   104  	queryBody, err = i.start(queryBody)
   105  	if err != nil {
   106  		return
   107  	}
   108  	if autos.Len() == 0 {
   109  		if res, err = i.DB.ExecContext(i.Context, queryBody, insertCols.ColumnValues(object)...); err != nil {
   110  			err = Error(err)
   111  			return
   112  		}
   113  		return
   114  	}
   115  
   116  	autoValues := i.autoValues(autos)
   117  	if err = i.DB.QueryRowContext(i.Context, queryBody, insertCols.ColumnValues(object)...).Scan(autoValues...); err != nil {
   118  		err = Error(err)
   119  		return
   120  	}
   121  	if err = i.setAutos(object, autos, autoValues); err != nil {
   122  		err = Error(err)
   123  		return
   124  	}
   125  
   126  	return
   127  }
   128  
   129  // CreateIfNotExists writes an object to the database if it does not already exist within a transaction.
   130  // This will _ignore_ auto columns, as they will always invalidate the assertion that there already exists
   131  // a row with a given primary key set.
   132  func (i *Invocation) CreateIfNotExists(object DatabaseMapped) (err error) {
   133  	var queryBody, label string
   134  	var insertCols *ColumnCollection
   135  	var res sql.Result
   136  	defer func() { err = i.finish(queryBody, recover(), res, err) }()
   137  
   138  	label, queryBody, insertCols = i.generateCreateIfNotExists(object)
   139  	i.maybeSetLabel(label)
   140  
   141  	queryBody, err = i.start(queryBody)
   142  	if err != nil {
   143  		return
   144  	}
   145  	if res, err = i.DB.ExecContext(i.Context, queryBody, insertCols.ColumnValues(object)...); err != nil {
   146  		err = Error(err)
   147  	}
   148  	return
   149  }
   150  
   151  // CreateMany writes many objects to the database in a single insert.
   152  func (i *Invocation) CreateMany(objects interface{}) (err error) {
   153  	return i.insertOrUpsertMany(objects, false)
   154  }
   155  
   156  // UpsertMany writes many objects to the database in a single upsert.
   157  func (i *Invocation) UpsertMany(objects interface{}) (err error) {
   158  	return i.insertOrUpsertMany(objects, true)
   159  }
   160  
   161  // insertOrUpsertManinsertOrUpsertMany writes many objects to the database in a single insert or upsert.
   162  func (i *Invocation) insertOrUpsertMany(objects interface{}, overwrite bool) (err error) {
   163  	var queryBody string
   164  	var insertCols *ColumnCollection
   165  	var sliceValue reflect.Value
   166  	var res sql.Result
   167  	defer func() { err = i.finish(queryBody, recover(), res, err) }()
   168  
   169  	if overwrite {
   170  		queryBody, insertCols, sliceValue = i.generateUpsertMany(objects)
   171  	} else {
   172  		queryBody, insertCols, sliceValue = i.generateCreateMany(objects)
   173  	}
   174  	if sliceValue.Len() == 0 {
   175  		// If there is nothing to create, then we're done here
   176  		return
   177  	}
   178  
   179  	queryBody, err = i.start(queryBody)
   180  	if err != nil {
   181  		return
   182  	}
   183  	var colValues []interface{}
   184  	for row := 0; row < sliceValue.Len(); row++ {
   185  		colValues = append(colValues, insertCols.ColumnValues(sliceValue.Index(row).Interface())...)
   186  	}
   187  
   188  	res, err = i.DB.ExecContext(i.Context, queryBody, colValues...)
   189  	if err != nil {
   190  		err = Error(err)
   191  		return
   192  	}
   193  	return
   194  }
   195  
   196  // Update updates an object wrapped in a transaction. Returns whether or not any rows have been updated and potentially
   197  // an error. If ErrTooManyRows is returned, it's important to note that due to https://github.com/golang/go/issues/7898,
   198  // the Update HAS BEEN APPLIED. Its on the developer using UPDATE to ensure his tags are correct and/or execute it in a
   199  // transaction and roll back on this error
   200  func (i *Invocation) Update(object DatabaseMapped) (updated bool, err error) {
   201  	var queryBody, label string
   202  	var pks, updateCols *ColumnCollection
   203  	var res sql.Result
   204  	defer func() { err = i.finish(queryBody, recover(), res, err) }()
   205  
   206  	label, queryBody, pks, updateCols = i.generateUpdate(object)
   207  	i.maybeSetLabel(label)
   208  
   209  	queryBody, err = i.start(queryBody)
   210  	if err != nil {
   211  		return
   212  	}
   213  	res, err = i.DB.ExecContext(
   214  		i.Context,
   215  		queryBody,
   216  		append(updateCols.ColumnValues(object), pks.ColumnValues(object)...)...,
   217  	)
   218  	if err != nil {
   219  		err = Error(err)
   220  		return
   221  	}
   222  
   223  	var rowCount int64
   224  	rowCount, err = res.RowsAffected()
   225  	if err != nil {
   226  		err = Error(err)
   227  		return
   228  	}
   229  	if rowCount > 0 {
   230  		updated = true
   231  	}
   232  	if rowCount > 1 {
   233  		err = Error(ErrTooManyRows)
   234  	}
   235  	return
   236  }
   237  
   238  // Upsert inserts the object if it doesn't exist already (as defined by its primary keys) or updates it atomically.
   239  // It returns `found` as true if the effect was an upsert, i.e. the pk was found.
   240  func (i *Invocation) Upsert(object DatabaseMapped) (err error) {
   241  	var queryBody, label string
   242  	var autos, upsertCols *ColumnCollection
   243  	defer func() { err = i.finish(queryBody, recover(), nil, err) }()
   244  
   245  	i.Label, queryBody, autos, upsertCols = i.generateUpsert(object)
   246  	i.maybeSetLabel(label)
   247  
   248  	queryBody, err = i.start(queryBody)
   249  	if err != nil {
   250  		return
   251  	}
   252  	if autos.Len() == 0 {
   253  		if _, err = i.DB.ExecContext(i.Context, queryBody, upsertCols.ColumnValues(object)...); err != nil {
   254  			return
   255  		}
   256  		return
   257  	}
   258  
   259  	autoValues := i.autoValues(autos)
   260  	if err = i.DB.QueryRowContext(i.Context, queryBody, upsertCols.ColumnValues(object)...).Scan(autoValues...); err != nil {
   261  		err = Error(err)
   262  		return
   263  	}
   264  	if err = i.setAutos(object, autos, autoValues); err != nil {
   265  		err = Error(err)
   266  		return
   267  	}
   268  	return
   269  }
   270  
   271  // Exists returns a bool if a given object exists (utilizing the primary key columns if they exist) wrapped in a transaction.
   272  func (i *Invocation) Exists(object DatabaseMapped) (exists bool, err error) {
   273  	var queryBody, label string
   274  	var pks *ColumnCollection
   275  	defer func() { err = i.finish(queryBody, recover(), nil, err) }()
   276  
   277  	if label, queryBody, pks, err = i.generateExists(object); err != nil {
   278  		err = Error(err)
   279  		return
   280  	}
   281  	i.maybeSetLabel(label)
   282  	queryBody, err = i.start(queryBody)
   283  	if err != nil {
   284  		return
   285  	}
   286  	var value int
   287  	if queryErr := i.DB.QueryRowContext(i.Context, queryBody, pks.ColumnValues(object)...).Scan(&value); queryErr != nil && !ex.Is(queryErr, sql.ErrNoRows) {
   288  		err = Error(queryErr)
   289  		return
   290  	}
   291  	exists = value != 0
   292  	return
   293  }
   294  
   295  // Delete deletes an object from the database wrapped in a transaction. Returns whether or not any rows have been deleted
   296  // and potentially an error. If ErrTooManyRows is returned, it's important to note that due to
   297  // https://github.com/golang/go/issues/7898, the Delete HAS BEEN APPLIED on the current transaction. Its on the
   298  // developer using Delete to ensure their tags are correct and/or ensure theit Tx rolls back on this error.
   299  func (i *Invocation) Delete(object DatabaseMapped) (deleted bool, err error) {
   300  	var queryBody, label string
   301  	var pks *ColumnCollection
   302  	var res sql.Result
   303  	defer func() { err = i.finish(queryBody, recover(), res, err) }()
   304  
   305  	if label, queryBody, pks, err = i.generateDelete(object); err != nil {
   306  		return
   307  	}
   308  
   309  	i.maybeSetLabel(label)
   310  	queryBody, err = i.start(queryBody)
   311  	if err != nil {
   312  		return
   313  	}
   314  	res, err = i.DB.ExecContext(i.Context, queryBody, pks.ColumnValues(object)...)
   315  	if err != nil {
   316  		err = Error(err)
   317  		return
   318  	}
   319  
   320  	var rowCount int64
   321  	rowCount, err = res.RowsAffected()
   322  	if err != nil {
   323  		err = Error(err)
   324  		return
   325  	}
   326  	if rowCount > 0 {
   327  		deleted = true
   328  	}
   329  	if rowCount > 1 {
   330  		err = Error(ErrTooManyRows)
   331  	}
   332  	return
   333  }
   334  
   335  // --------------------------------------------------------------------------------
   336  // query body generators
   337  // --------------------------------------------------------------------------------
   338  
   339  func (i *Invocation) generateGet(object DatabaseMapped) (cachePlan, queryBody string, err error) {
   340  	tableName := TableName(object)
   341  
   342  	cols := Columns(object).NotReadOnly()
   343  	pks := cols.PrimaryKeys()
   344  	if pks.Len() == 0 {
   345  		err = Error(ErrNoPrimaryKey)
   346  		return
   347  	}
   348  
   349  	queryBodyBuffer := i.BufferPool.Get()
   350  	defer i.BufferPool.Put(queryBodyBuffer)
   351  
   352  	queryBodyBuffer.WriteString("SELECT ")
   353  	for i, name := range cols.ColumnNames() {
   354  		queryBodyBuffer.WriteString(name)
   355  		if i < (cols.Len() - 1) {
   356  			queryBodyBuffer.WriteRune(',')
   357  		}
   358  	}
   359  
   360  	queryBodyBuffer.WriteString(" FROM ")
   361  	queryBodyBuffer.WriteString(tableName)
   362  	queryBodyBuffer.WriteString(" WHERE ")
   363  
   364  	for i, pk := range pks.Columns() {
   365  		queryBodyBuffer.WriteString(pk.ColumnName)
   366  		queryBodyBuffer.WriteString(" = ")
   367  		queryBodyBuffer.WriteString("$" + strconv.Itoa(i+1))
   368  
   369  		if i < (pks.Len() - 1) {
   370  			queryBodyBuffer.WriteString(" AND ")
   371  		}
   372  	}
   373  
   374  	cachePlan = fmt.Sprintf("%s_get", tableName)
   375  	queryBody = queryBodyBuffer.String()
   376  	return
   377  }
   378  
   379  func (i *Invocation) generateGetAll(collection interface{}) (statementLabel, queryBody string) {
   380  	collectionType := ReflectSliceType(collection)
   381  	tableName := TableNameByType(collectionType)
   382  
   383  	cols := ColumnsFromType(tableName, ReflectSliceType(collection)).NotReadOnly()
   384  
   385  	queryBodyBuffer := i.BufferPool.Get()
   386  	defer i.BufferPool.Put(queryBodyBuffer)
   387  
   388  	queryBodyBuffer.WriteString("SELECT ")
   389  	for i, name := range cols.ColumnNames() {
   390  		queryBodyBuffer.WriteString(name)
   391  		if i < (cols.Len() - 1) {
   392  			queryBodyBuffer.WriteRune(',')
   393  		}
   394  	}
   395  	queryBodyBuffer.WriteString(" FROM ")
   396  	queryBodyBuffer.WriteString(tableName)
   397  
   398  	queryBody = queryBodyBuffer.String()
   399  	statementLabel = tableName + "_get_all"
   400  	return
   401  }
   402  
   403  func (i *Invocation) generateCreate(object DatabaseMapped) (statementLabel, queryBody string, insertCols, autos *ColumnCollection) {
   404  	tableName := TableName(object)
   405  
   406  	cols := Columns(object)
   407  	insertCols = cols.InsertColumns().ConcatWith(cols.Autos().NotZero(object))
   408  	autos = cols.Autos()
   409  
   410  	queryBodyBuffer := i.BufferPool.Get()
   411  	defer i.BufferPool.Put(queryBodyBuffer)
   412  
   413  	queryBodyBuffer.WriteString("INSERT INTO ")
   414  	queryBodyBuffer.WriteString(tableName)
   415  	queryBodyBuffer.WriteString(" (")
   416  	for i, name := range insertCols.ColumnNames() {
   417  		queryBodyBuffer.WriteString(name)
   418  		if i < (insertCols.Len() - 1) {
   419  			queryBodyBuffer.WriteRune(',')
   420  		}
   421  	}
   422  	queryBodyBuffer.WriteString(") VALUES (")
   423  	for x := 0; x < insertCols.Len(); x++ {
   424  		queryBodyBuffer.WriteString("$" + strconv.Itoa(x+1))
   425  		if x < (insertCols.Len() - 1) {
   426  			queryBodyBuffer.WriteRune(',')
   427  		}
   428  	}
   429  	queryBodyBuffer.WriteString(")")
   430  
   431  	if autos.Len() > 0 {
   432  		queryBodyBuffer.WriteString(" RETURNING ")
   433  		queryBodyBuffer.WriteString(autos.ColumnNamesCSV())
   434  	}
   435  
   436  	queryBody = queryBodyBuffer.String()
   437  	statementLabel = tableName + "_create"
   438  	return
   439  }
   440  
   441  func (i *Invocation) generateCreateIfNotExists(object DatabaseMapped) (statementLabel, queryBody string, insertCols *ColumnCollection) {
   442  	cols := Columns(object)
   443  
   444  	insertCols = cols.InsertColumns().ConcatWith(cols.Autos().NotZero(object))
   445  
   446  	pks := cols.PrimaryKeys()
   447  	tableName := TableName(object)
   448  
   449  	queryBodyBuffer := i.BufferPool.Get()
   450  	defer i.BufferPool.Put(queryBodyBuffer)
   451  
   452  	queryBodyBuffer.WriteString("INSERT INTO ")
   453  	queryBodyBuffer.WriteString(tableName)
   454  	queryBodyBuffer.WriteString(" (")
   455  	for i, name := range insertCols.ColumnNames() {
   456  		queryBodyBuffer.WriteString(name)
   457  		if i < (insertCols.Len() - 1) {
   458  			queryBodyBuffer.WriteRune(',')
   459  		}
   460  	}
   461  	queryBodyBuffer.WriteString(") VALUES (")
   462  	for x := 0; x < insertCols.Len(); x++ {
   463  		queryBodyBuffer.WriteString("$" + strconv.Itoa(x+1))
   464  		if x < (insertCols.Len() - 1) {
   465  			queryBodyBuffer.WriteRune(',')
   466  		}
   467  	}
   468  	queryBodyBuffer.WriteString(")")
   469  
   470  	if pks.Len() > 0 {
   471  		queryBodyBuffer.WriteString(" ON CONFLICT (")
   472  		pkColumnNames := pks.ColumnNames()
   473  		for i, name := range pkColumnNames {
   474  			queryBodyBuffer.WriteString(name)
   475  			if i < len(pkColumnNames)-1 {
   476  				queryBodyBuffer.WriteRune(',')
   477  			}
   478  		}
   479  		queryBodyBuffer.WriteString(") DO NOTHING")
   480  	}
   481  
   482  	queryBody = queryBodyBuffer.String()
   483  	statementLabel = tableName + "_create_if_not_exists"
   484  	return
   485  }
   486  
   487  func (i *Invocation) generateUpsertMany(objects interface{}) (queryBody string, insertCols *ColumnCollection, sliceValue reflect.Value) {
   488  	queryBodyInsertMany, insertCols, sliceValue := i.generateCreateMany(objects)
   489  
   490  	queryBodyBuffer := i.BufferPool.Get()
   491  	defer i.BufferPool.Put(queryBodyBuffer)
   492  	queryBodyBuffer.WriteString(queryBodyInsertMany)
   493  
   494  	uks := insertCols.UniqueKeys()
   495  	if uks.Len() > 0 {
   496  		queryBodyBuffer.WriteString(" ON CONFLICT (")
   497  		ukColumnNames := uks.ColumnNames()
   498  		for i, name := range ukColumnNames {
   499  			queryBodyBuffer.WriteString(name)
   500  			if i < len(ukColumnNames)-1 {
   501  				queryBodyBuffer.WriteRune(',')
   502  			}
   503  		}
   504  		queryBodyBuffer.WriteString(") DO UPDATE SET ")
   505  
   506  		for i, name := range insertCols.ColumnNames() {
   507  			queryBodyBuffer.WriteString(fmt.Sprintf("%s=Excluded.%s", name, name))
   508  			if i < (insertCols.Len() - 1) {
   509  				queryBodyBuffer.WriteRune(',')
   510  			}
   511  		}
   512  	}
   513  	queryBody = queryBodyBuffer.String()
   514  	return
   515  }
   516  
   517  func (i *Invocation) generateCreateMany(objects interface{}) (queryBody string, insertCols *ColumnCollection, sliceValue reflect.Value) {
   518  	sliceValue = ReflectValue(objects)
   519  	sliceType := ReflectSliceType(objects)
   520  	tableName := TableNameByType(sliceType)
   521  
   522  	cols := ColumnsFromType(tableName, sliceType)
   523  	insertCols = cols.InsertColumns()
   524  
   525  	queryBodyBuffer := i.BufferPool.Get()
   526  	defer i.BufferPool.Put(queryBodyBuffer)
   527  
   528  	queryBodyBuffer.WriteString("INSERT INTO ")
   529  	queryBodyBuffer.WriteString(tableName)
   530  	queryBodyBuffer.WriteString(" (")
   531  	for i, name := range insertCols.ColumnNames() {
   532  		queryBodyBuffer.WriteString(name)
   533  		if i < (insertCols.Len() - 1) {
   534  			queryBodyBuffer.WriteRune(',')
   535  		}
   536  	}
   537  
   538  	queryBodyBuffer.WriteString(") VALUES ")
   539  
   540  	metaIndex := 1
   541  	for x := 0; x < sliceValue.Len(); x++ {
   542  		queryBodyBuffer.WriteString("(")
   543  		for y := 0; y < insertCols.Len(); y++ {
   544  			queryBodyBuffer.WriteString(fmt.Sprintf("$%d", metaIndex))
   545  			metaIndex = metaIndex + 1
   546  			if y < insertCols.Len()-1 {
   547  				queryBodyBuffer.WriteRune(',')
   548  			}
   549  		}
   550  		queryBodyBuffer.WriteString(")")
   551  		if x < sliceValue.Len()-1 {
   552  			queryBodyBuffer.WriteRune(',')
   553  		}
   554  	}
   555  
   556  	queryBody = queryBodyBuffer.String()
   557  	return
   558  }
   559  
   560  func (i *Invocation) generateUpdate(object DatabaseMapped) (statementLabel, queryBody string, pks, updateCols *ColumnCollection) {
   561  	tableName := TableName(object)
   562  
   563  	cols := Columns(object)
   564  
   565  	pks = cols.PrimaryKeys()
   566  	updateCols = cols.UpdateColumns()
   567  
   568  	queryBodyBuffer := i.BufferPool.Get()
   569  	defer i.BufferPool.Put(queryBodyBuffer)
   570  
   571  	queryBodyBuffer.WriteString("UPDATE ")
   572  	queryBodyBuffer.WriteString(tableName)
   573  	queryBodyBuffer.WriteString(" SET ")
   574  
   575  	var updateColIndex int
   576  	var col Column
   577  	for ; updateColIndex < updateCols.Len(); updateColIndex++ {
   578  		col = updateCols.Columns()[updateColIndex]
   579  		queryBodyBuffer.WriteString(col.ColumnName)
   580  		queryBodyBuffer.WriteString(" = $" + strconv.Itoa(updateColIndex+1))
   581  		if updateColIndex != (updateCols.Len() - 1) {
   582  			queryBodyBuffer.WriteRune(',')
   583  		}
   584  	}
   585  
   586  	queryBodyBuffer.WriteString(" WHERE ")
   587  	for i, pk := range pks.Columns() {
   588  		queryBodyBuffer.WriteString(pk.ColumnName)
   589  		queryBodyBuffer.WriteString(" = ")
   590  		queryBodyBuffer.WriteString("$" + strconv.Itoa(i+(updateColIndex+1)))
   591  
   592  		if i < (pks.Len() - 1) {
   593  			queryBodyBuffer.WriteString(" AND ")
   594  		}
   595  	}
   596  
   597  	queryBody = queryBodyBuffer.String()
   598  	statementLabel = tableName + "_update"
   599  	return
   600  }
   601  
   602  func (i *Invocation) generateUpsert(object DatabaseMapped) (statementLabel, queryBody string, autos, insertsWithAutos *ColumnCollection) {
   603  	tableName := TableName(object)
   604  	cols := Columns(object)
   605  	updates := cols.UpdateColumns()
   606  	updateCols := updates.Columns()
   607  
   608  	// We add in all the autos columns to start
   609  	insertsWithAutos = cols.InsertColumns().ConcatWith(cols.Autos())
   610  	pks := insertsWithAutos.PrimaryKeys()
   611  
   612  	// But we exclude auto primary keys that are not set. Auto primary keys that ARE set must be included in the insert
   613  	// clause so that there is a collision. But keys that are not set must be excluded from insertsWithAutos so that
   614  	// they are not passed as an extra parameter to ExecInContext later and are properly auto-generated
   615  	for _, col := range pks.Columns() {
   616  		if col.IsAuto && !cols.NotZero(object).HasColumn(col.ColumnName) {
   617  			insertsWithAutos.Remove(col.ColumnName)
   618  		}
   619  	}
   620  
   621  	insertCols := insertsWithAutos.Columns()
   622  	tokenMap := map[string]string{}
   623  	for i, col := range insertCols {
   624  		tokenMap[col.ColumnName] = "$" + strconv.Itoa(i+1)
   625  	}
   626  
   627  	// autos are read out on insert (but only if unset)
   628  	autos = cols.Autos().Zero(object)
   629  	pkNames := pks.ColumnNames()
   630  
   631  	queryBodyBuffer := i.BufferPool.Get()
   632  	defer i.BufferPool.Put(queryBodyBuffer)
   633  
   634  	queryBodyBuffer.WriteString("INSERT INTO ")
   635  	queryBodyBuffer.WriteString(tableName)
   636  	queryBodyBuffer.WriteString(" (")
   637  
   638  	skipComma := true
   639  	for _, col := range insertCols {
   640  		if !col.IsAuto || cols.NotZero(object).HasColumn(col.ColumnName) {
   641  			if !skipComma {
   642  				queryBodyBuffer.WriteRune(',')
   643  			}
   644  			skipComma = false
   645  			queryBodyBuffer.WriteString(col.ColumnName)
   646  		}
   647  	}
   648  
   649  	queryBodyBuffer.WriteString(") VALUES (")
   650  	skipComma = true
   651  	for _, col := range insertsWithAutos.Columns() {
   652  		if !col.IsAuto || cols.NotZero(object).HasColumn(col.ColumnName) {
   653  			if !skipComma {
   654  				queryBodyBuffer.WriteRune(',')
   655  			}
   656  			skipComma = false
   657  			queryBodyBuffer.WriteString(tokenMap[col.ColumnName])
   658  		}
   659  	}
   660  
   661  	queryBodyBuffer.WriteString(")")
   662  
   663  	if pks.Len() > 0 {
   664  		queryBodyBuffer.WriteString(" ON CONFLICT (")
   665  
   666  		for i, name := range pkNames {
   667  			queryBodyBuffer.WriteString(name)
   668  			if i < len(pkNames)-1 {
   669  				queryBodyBuffer.WriteRune(',')
   670  			}
   671  		}
   672  		queryBodyBuffer.WriteString(") DO UPDATE SET ")
   673  
   674  		for i, col := range updateCols {
   675  			queryBodyBuffer.WriteString(col.ColumnName + " = " + tokenMap[col.ColumnName])
   676  			if i < (len(updateCols) - 1) {
   677  				queryBodyBuffer.WriteRune(',')
   678  			}
   679  		}
   680  	}
   681  	if autos.Len() > 0 {
   682  		queryBodyBuffer.WriteString(" RETURNING ")
   683  		queryBodyBuffer.WriteString(autos.ColumnNamesCSV())
   684  	}
   685  
   686  	queryBody = queryBodyBuffer.String()
   687  	statementLabel = tableName + "_upsert"
   688  	return
   689  }
   690  
   691  func (i *Invocation) generateExists(object DatabaseMapped) (statementLabel, queryBody string, pks *ColumnCollection, err error) {
   692  	tableName := TableName(object)
   693  	pks = Columns(object).PrimaryKeys()
   694  	if pks.Len() == 0 {
   695  		err = Error(ErrNoPrimaryKey)
   696  		return
   697  	}
   698  	queryBodyBuffer := i.BufferPool.Get()
   699  	defer i.BufferPool.Put(queryBodyBuffer)
   700  
   701  	queryBodyBuffer.WriteString("SELECT 1 FROM ")
   702  	queryBodyBuffer.WriteString(tableName)
   703  	queryBodyBuffer.WriteString(" WHERE ")
   704  	for i, pk := range pks.Columns() {
   705  		queryBodyBuffer.WriteString(pk.ColumnName)
   706  		queryBodyBuffer.WriteString(" = ")
   707  		queryBodyBuffer.WriteString("$" + strconv.Itoa(i+1))
   708  
   709  		if i < (pks.Len() - 1) {
   710  			queryBodyBuffer.WriteString(" AND ")
   711  		}
   712  	}
   713  	statementLabel = tableName + "_exists"
   714  	queryBody = queryBodyBuffer.String()
   715  	return
   716  }
   717  
   718  func (i *Invocation) generateDelete(object DatabaseMapped) (statementLabel, queryBody string, pks *ColumnCollection, err error) {
   719  	tableName := TableName(object)
   720  	pks = Columns(object).PrimaryKeys()
   721  	if len(pks.Columns()) == 0 {
   722  		err = Error(ErrNoPrimaryKey)
   723  		return
   724  	}
   725  	queryBodyBuffer := i.BufferPool.Get()
   726  	defer i.BufferPool.Put(queryBodyBuffer)
   727  
   728  	queryBodyBuffer.WriteString("DELETE FROM ")
   729  	queryBodyBuffer.WriteString(tableName)
   730  	queryBodyBuffer.WriteString(" WHERE ")
   731  	for i, pk := range pks.Columns() {
   732  		queryBodyBuffer.WriteString(pk.ColumnName)
   733  		queryBodyBuffer.WriteString(" = ")
   734  		queryBodyBuffer.WriteString("$" + strconv.Itoa(i+1))
   735  
   736  		if i < (pks.Len() - 1) {
   737  			queryBodyBuffer.WriteString(" AND ")
   738  		}
   739  	}
   740  	statementLabel = tableName + "_delete"
   741  	queryBody = queryBodyBuffer.String()
   742  	return
   743  }
   744  
   745  // --------------------------------------------------------------------------------
   746  // helpers
   747  // --------------------------------------------------------------------------------
   748  
   749  // autoValues returns references to the auto updatd fields for a given column collection.
   750  func (i *Invocation) autoValues(autos *ColumnCollection) []interface{} {
   751  	autoValues := make([]interface{}, autos.Len())
   752  	for i, autoCol := range autos.Columns() {
   753  		autoValues[i] = reflect.New(reflect.PtrTo(autoCol.FieldType)).Interface()
   754  	}
   755  	return autoValues
   756  }
   757  
   758  // setAutos sets the automatic values for a given object.
   759  func (i *Invocation) setAutos(object DatabaseMapped, autos *ColumnCollection, autoValues []interface{}) (err error) {
   760  	for index := 0; index < len(autoValues); index++ {
   761  		err = autos.Columns()[index].SetValue(object, autoValues[index])
   762  		if err != nil {
   763  			err = Error(err)
   764  			return
   765  		}
   766  	}
   767  	return
   768  }
   769  
   770  // start runs on start steps.
   771  func (i *Invocation) start(statement string) (string, error) {
   772  	if i.DB == nil {
   773  		return "", ex.New(ErrConnectionClosed)
   774  	}
   775  	i.StartTime = time.Now()
   776  	if i.StatementInterceptor != nil {
   777  		var err error
   778  		statement, err = i.StatementInterceptor(i.Context, i.Label, statement)
   779  		if err != nil {
   780  			return statement, err
   781  		}
   782  	}
   783  	if i.Log != nil && !IsSkipQueryLogging(i.Context) {
   784  		qse := NewQueryStartEvent(statement)
   785  		qse.Username = i.Config.Username
   786  		qse.Database = i.Config.DatabaseOrDefault()
   787  		qse.Label = i.Label
   788  		qse.Engine = i.Config.EngineOrDefault()
   789  		i.Log.TriggerContext(i.Context, qse)
   790  	}
   791  	if i.Tracer != nil && !IsSkipQueryLogging(i.Context) {
   792  		i.TraceFinisher = i.Tracer.Query(i.Context, i.Config, i.Label, statement)
   793  	}
   794  	return statement, nil
   795  }
   796  
   797  // finish runs on complete steps.
   798  func (i *Invocation) finish(statement string, r interface{}, res sql.Result, err error) error {
   799  	if i.Cancel != nil {
   800  		i.Cancel()
   801  	}
   802  	if r != nil {
   803  		err = ex.Nest(err, ex.New(r))
   804  	}
   805  	if i.Log != nil && !IsSkipQueryLogging(i.Context) {
   806  		qe := NewQueryEvent(statement, time.Now().UTC().Sub(i.StartTime))
   807  		qe.Username = i.Config.Username
   808  		qe.Database = i.Config.DatabaseOrDefault()
   809  		qe.Label = i.Label
   810  		qe.Engine = i.Config.EngineOrDefault()
   811  		qe.Err = err
   812  		i.Log.TriggerContext(i.Context, qe)
   813  	}
   814  	if i.TraceFinisher != nil && !IsSkipQueryLogging(i.Context) {
   815  		i.TraceFinisher.FinishQuery(i.Context, res, err)
   816  	}
   817  	if err != nil {
   818  		err = Error(err, ex.OptMessage(statement))
   819  	}
   820  	return err
   821  }