go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/sdk/db/invocation.go (about)

     1  /*
     2  
     3  Copyright (c) 2023 - Present. Will Charczuk. All rights reserved.
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository.
     5  
     6  */
     7  
     8  package db
     9  
    10  import (
    11  	"context"
    12  	"database/sql"
    13  	"errors"
    14  	"fmt"
    15  	"reflect"
    16  	"strconv"
    17  	"time"
    18  
    19  	"go.charczuk.com/sdk/errutil"
    20  )
    21  
    22  // InvocationOption is an option for invocations.
    23  type InvocationOption func(*Invocation)
    24  
    25  // OptLabel sets the Label on the invocation.
    26  func OptLabel(label string) InvocationOption {
    27  	return func(i *Invocation) {
    28  		i.label = label
    29  	}
    30  }
    31  
    32  // OptContext sets a context on an invocation.
    33  func OptContext(ctx context.Context) InvocationOption {
    34  	return func(i *Invocation) {
    35  		i.ctx = ctx
    36  	}
    37  }
    38  
    39  // OptCancel sets the context cancel func.
    40  func OptCancel(cancel context.CancelFunc) InvocationOption {
    41  	return func(i *Invocation) {
    42  		i.cancel = cancel
    43  	}
    44  }
    45  
    46  // OptTimeout sets a command timeout for the invocation.
    47  func OptTimeout(d time.Duration) InvocationOption {
    48  	return func(i *Invocation) {
    49  		i.ctx, i.cancel = context.WithTimeout(i.ctx, d)
    50  	}
    51  }
    52  
    53  // OptTx is an invocation option that sets the invocation transaction.
    54  func OptTx(tx *sql.Tx) InvocationOption {
    55  	return func(i *Invocation) {
    56  		if tx != nil {
    57  			i.db = tx
    58  		}
    59  	}
    60  }
    61  
    62  // OptInvocationDB is an invocation option that sets the underlying invocation db.
    63  func OptInvocationDB(db DB) InvocationOption {
    64  	return func(i *Invocation) {
    65  		i.db = db
    66  	}
    67  }
    68  
    69  // Invocation is a specific operation against a context.
    70  type Invocation struct {
    71  	conn    *Connection
    72  	db      DB
    73  	label   string
    74  	ctx     context.Context
    75  	cancel  func()
    76  	started time.Time
    77  }
    78  
    79  // Exec executes a sql statement with a given set of arguments and returns the rows affected.
    80  func (i *Invocation) Exec(statement string, args ...interface{}) (res sql.Result, err error) {
    81  	statement, err = i.start(statement)
    82  	if err != nil {
    83  		return
    84  	}
    85  	defer func() { err = i.finish(statement, recover(), res, err) }()
    86  	res, err = i.db.ExecContext(i.ctx, statement, args...)
    87  	return
    88  }
    89  
    90  // Query returns a new query object for a given sql query and arguments.
    91  func (i *Invocation) Query(statement string, args ...interface{}) *Query {
    92  	q := &Query{
    93  		inv:  i,
    94  		args: args,
    95  	}
    96  	q.stmt, q.err = i.start(statement)
    97  	return q
    98  }
    99  
   100  // Get returns a given object based on a group of primary key ids within a transaction.
   101  func (i *Invocation) Get(object any, ids ...any) (found bool, err error) {
   102  	if len(ids) == 0 {
   103  		err = ErrInvalidIDs
   104  		return
   105  	}
   106  
   107  	var queryBody, label string
   108  	if label, queryBody, err = i.generateGet(object); err != nil {
   109  		return
   110  	}
   111  	i.maybeSetLabel(label)
   112  	return i.Query(queryBody, ids...).Out(object)
   113  }
   114  
   115  // GetMany returns objects matching a given array of keys.
   116  //
   117  // The order of the results will match the order of the keys.
   118  func (i *Invocation) GetMany(collection any, ids ...any) (err error) {
   119  	if len(ids) == 0 {
   120  		err = ErrInvalidIDs
   121  		return
   122  	}
   123  	var queryBody, label string
   124  	if label, queryBody, err = i.generateGetMany(collection, len(ids)); err != nil {
   125  		return
   126  	}
   127  	i.maybeSetLabel(label)
   128  	return i.Query(queryBody, ids...).OutMany(collection)
   129  }
   130  
   131  // All returns all rows of an object mapped table wrapped in a transaction.
   132  func (i *Invocation) All(collection interface{}) (err error) {
   133  	label, queryBody := i.generateGetAll(collection)
   134  	i.maybeSetLabel(label)
   135  	return i.Query(queryBody).OutMany(collection)
   136  }
   137  
   138  // Create writes an object to the database within a transaction.
   139  func (i *Invocation) Create(object any) (err error) {
   140  	var queryBody, label string
   141  	var insertCols, autos []*Column
   142  	var res sql.Result
   143  	defer func() { err = i.finish(queryBody, recover(), res, err) }()
   144  
   145  	label, queryBody, insertCols, autos = i.generateCreate(object)
   146  	i.maybeSetLabel(label)
   147  
   148  	queryBody, err = i.start(queryBody)
   149  	if err != nil {
   150  		return
   151  	}
   152  	if len(autos) == 0 {
   153  		if res, err = i.db.ExecContext(i.ctx, queryBody, ColumnValues(insertCols, object)...); err != nil {
   154  			return
   155  		}
   156  		return
   157  	}
   158  
   159  	autoValues := i.autoValues(autos)
   160  	if err = i.db.QueryRowContext(i.ctx, queryBody, ColumnValues(insertCols, object)...).Scan(autoValues...); err != nil {
   161  		return
   162  	}
   163  	if err = i.setAutos(object, autos, autoValues); err != nil {
   164  		return
   165  	}
   166  	return
   167  }
   168  
   169  // CreateIfNotExists writes an object to the database if it does not already exist within a transaction.
   170  // This will _ignore_ auto columns, as they will always invalidate the assertion that there already exists
   171  // a row with a given primary key set.
   172  func (i *Invocation) CreateIfNotExists(object any) (err error) {
   173  	var queryBody, label string
   174  	var insertCols []*Column
   175  	var res sql.Result
   176  	defer func() { err = i.finish(queryBody, recover(), res, err) }()
   177  
   178  	label, queryBody, insertCols = i.generateCreateIfNotExists(object)
   179  	i.maybeSetLabel(label)
   180  
   181  	queryBody, err = i.start(queryBody)
   182  	if err != nil {
   183  		return
   184  	}
   185  	res, err = i.db.ExecContext(i.ctx, queryBody, ColumnValues(insertCols, object)...)
   186  	return
   187  }
   188  
   189  // Update updates an object wrapped in a transaction. Returns whether or not any rows have been updated and potentially
   190  // an error. If ErrTooManyRows is returned, it's important to note that due to https://github.com/golang/go/issues/7898,
   191  // the Update HAS BEEN APPLIED. Its on the developer using UPDATE to ensure his tags are correct and/or execute it in a
   192  // transaction and roll back on this error
   193  func (i *Invocation) Update(object any) (updated bool, err error) {
   194  	var queryBody, label string
   195  	var pks, updateCols []*Column
   196  	var res sql.Result
   197  	defer func() { err = i.finish(queryBody, recover(), res, err) }()
   198  
   199  	label, queryBody, pks, updateCols = i.generateUpdate(object)
   200  	i.maybeSetLabel(label)
   201  
   202  	queryBody, err = i.start(queryBody)
   203  	if err != nil {
   204  		return
   205  	}
   206  	res, err = i.db.ExecContext(
   207  		i.ctx,
   208  		queryBody,
   209  		append(ColumnValues(updateCols, object), ColumnValues(pks, object)...)...,
   210  	)
   211  	if err != nil {
   212  		return
   213  	}
   214  
   215  	var rowCount int64
   216  	rowCount, err = res.RowsAffected()
   217  	if err != nil {
   218  		return
   219  	}
   220  	if rowCount > 0 {
   221  		updated = true
   222  	}
   223  	if rowCount > 1 {
   224  		err = ErrTooManyRows
   225  	}
   226  	return
   227  }
   228  
   229  // Upsert inserts the object if it doesn't exist already (as defined by its primary keys) or updates it atomically.
   230  // It returns `found` as true if the effect was an upsert, i.e. the pk was found.
   231  func (i *Invocation) Upsert(object any) (err error) {
   232  	var queryBody, label string
   233  	var autos, upsertCols []*Column
   234  	defer func() { err = i.finish(queryBody, recover(), nil, err) }()
   235  
   236  	i.label, queryBody, autos, upsertCols = i.generateUpsert(object)
   237  	i.maybeSetLabel(label)
   238  
   239  	queryBody, err = i.start(queryBody)
   240  	if err != nil {
   241  		return
   242  	}
   243  	if len(autos) == 0 {
   244  		if _, err = i.db.ExecContext(i.ctx, queryBody, ColumnValues(upsertCols, object)...); err != nil {
   245  			return
   246  		}
   247  		return
   248  	}
   249  	autoValues := i.autoValues(autos)
   250  	if err = i.db.QueryRowContext(i.ctx, queryBody, ColumnValues(upsertCols, object)...).Scan(autoValues...); err != nil {
   251  		return
   252  	}
   253  	if err = i.setAutos(object, autos, autoValues); err != nil {
   254  		return
   255  	}
   256  	return
   257  }
   258  
   259  // Exists returns a bool if a given object exists (utilizing the primary key columns if they exist) wrapped in a transaction.
   260  func (i *Invocation) Exists(object any) (exists bool, err error) {
   261  	var queryBody, label string
   262  	var pks []*Column
   263  	defer func() { err = i.finish(queryBody, recover(), nil, err) }()
   264  
   265  	if label, queryBody, pks, err = i.generateExists(object); err != nil {
   266  		return
   267  	}
   268  	i.maybeSetLabel(label)
   269  	queryBody, err = i.start(queryBody)
   270  	if err != nil {
   271  		return
   272  	}
   273  	var value int
   274  	if queryErr := i.db.QueryRowContext(i.ctx, queryBody, ColumnValues(pks, object)...).Scan(&value); queryErr != nil && !errors.Is(queryErr, sql.ErrNoRows) {
   275  		err = queryErr
   276  		return
   277  	}
   278  	exists = value != 0
   279  	return
   280  }
   281  
   282  // Delete deletes an object from the database wrapped in a transaction. Returns whether or not any rows have been deleted
   283  // and potentially an error.
   284  //
   285  // If ErrTooManyRows is returned, it's important to note that due to https://github.com/golang/go/issues/7898
   286  // the Delete HAS BEEN APPLIED on the current transaction. Its on the developer using Delete to ensure their
   287  // tags are correct and/or ensure theit Tx rolls back on this error.
   288  func (i *Invocation) Delete(object any) (deleted bool, err error) {
   289  	var queryBody, label string
   290  	var pks []*Column
   291  	var res sql.Result
   292  	defer func() { err = i.finish(queryBody, recover(), res, err) }()
   293  
   294  	if label, queryBody, pks, err = i.generateDelete(object); err != nil {
   295  		return
   296  	}
   297  
   298  	i.maybeSetLabel(label)
   299  	queryBody, err = i.start(queryBody)
   300  	if err != nil {
   301  		return
   302  	}
   303  	res, err = i.db.ExecContext(i.ctx, queryBody, ColumnValues(pks, object)...)
   304  	if err != nil {
   305  		return
   306  	}
   307  
   308  	var rowCount int64
   309  	rowCount, err = res.RowsAffected()
   310  	if err != nil {
   311  		return
   312  	}
   313  	if rowCount > 0 {
   314  		deleted = true
   315  	}
   316  	if rowCount > 1 {
   317  		err = ErrTooManyRows
   318  	}
   319  	return
   320  }
   321  
   322  func (i *Invocation) generateGet(object any) (cachePlan, queryBody string, err error) {
   323  	tableName := TableName(object)
   324  
   325  	cols := i.conn.TypeMeta(object)
   326  	getCols := cols.NotReadOnly()
   327  	pks := cols.PrimaryKeys()
   328  	if len(pks) == 0 {
   329  		err = ErrNoPrimaryKey
   330  		return
   331  	}
   332  
   333  	queryBodyBuffer := i.conn.bp.Get()
   334  	defer i.conn.bp.Put(queryBodyBuffer)
   335  
   336  	queryBodyBuffer.WriteString("SELECT ")
   337  	for i, name := range ColumnNamesWithPrefix(getCols, cols.columnPrefix) {
   338  		queryBodyBuffer.WriteString(name)
   339  		if i < (cols.Len() - 1) {
   340  			queryBodyBuffer.WriteRune(',')
   341  		}
   342  	}
   343  
   344  	queryBodyBuffer.WriteString(" FROM ")
   345  	queryBodyBuffer.WriteString(tableName)
   346  	queryBodyBuffer.WriteString(" WHERE ")
   347  
   348  	for i, pk := range pks {
   349  		queryBodyBuffer.WriteString(pk.ColumnName)
   350  		queryBodyBuffer.WriteString(" = ")
   351  		queryBodyBuffer.WriteString("$" + strconv.Itoa(i+1))
   352  
   353  		if i < (len(pks) - 1) {
   354  			queryBodyBuffer.WriteString(" AND ")
   355  		}
   356  	}
   357  	cachePlan = fmt.Sprintf("%s_get", tableName)
   358  	queryBody = queryBodyBuffer.String()
   359  	return
   360  }
   361  
   362  func (i *Invocation) generateGetMany(collection any, keys int) (
   363  	cachePlan, queryBody string,
   364  	err error,
   365  ) {
   366  	collectionType := reflectSliceType(collection)
   367  	tableName := TableNameByType(collectionType)
   368  
   369  	cols := i.conn.TypeMetaFromType(tableName, reflectSliceType(collection))
   370  
   371  	getCols := cols.NotReadOnly()
   372  	pks := cols.PrimaryKeys()
   373  	if len(pks) == 0 {
   374  		err = ErrNoPrimaryKey
   375  		return
   376  	}
   377  	pk := pks[0]
   378  
   379  	queryBodyBuffer := i.conn.bp.Get()
   380  	defer i.conn.bp.Put(queryBodyBuffer)
   381  
   382  	queryBodyBuffer.WriteString("SELECT ")
   383  	for i, name := range ColumnNamesWithPrefix(getCols, cols.columnPrefix) {
   384  		queryBodyBuffer.WriteString(name)
   385  		if i < (cols.Len() - 1) {
   386  			queryBodyBuffer.WriteRune(',')
   387  		}
   388  	}
   389  
   390  	queryBodyBuffer.WriteString(" FROM ")
   391  	queryBodyBuffer.WriteString(tableName)
   392  
   393  	queryBodyBuffer.WriteString(" WHERE ")
   394  	queryBodyBuffer.WriteString(pk.ColumnName)
   395  	queryBodyBuffer.WriteString(" IN (")
   396  
   397  	for x := 0; x < keys; x++ {
   398  		paramIndex := strconv.Itoa(x + 1)
   399  		queryBodyBuffer.WriteString("$" + paramIndex)
   400  		if x < (keys - 1) {
   401  			queryBodyBuffer.WriteRune(',')
   402  		}
   403  	}
   404  	queryBodyBuffer.WriteString(")")
   405  	cachePlan = fmt.Sprintf("%s_get_many", tableName)
   406  	queryBody = queryBodyBuffer.String()
   407  	return
   408  }
   409  
   410  func (i *Invocation) generateGetAll(collection interface{}) (statementLabel, queryBody string) {
   411  	collectionType := reflectSliceType(collection)
   412  	tableName := TableNameByType(collectionType)
   413  
   414  	cols := i.conn.TypeMetaFromType(tableName, reflectSliceType(collection))
   415  
   416  	// using `NotReadOnly` may seem confusing, but we don't want read only columns
   417  	// because they are typically the result of a select clause
   418  	// and not columns on the table represented by the type.
   419  	getCols := cols.NotReadOnly()
   420  
   421  	queryBodyBuffer := i.conn.bp.Get()
   422  	defer i.conn.bp.Put(queryBodyBuffer)
   423  
   424  	queryBodyBuffer.WriteString("SELECT ")
   425  	for i, name := range ColumnNamesWithPrefix(getCols, cols.columnPrefix) {
   426  		queryBodyBuffer.WriteString(name)
   427  		if i < (len(getCols) - 1) {
   428  			queryBodyBuffer.WriteRune(',')
   429  		}
   430  	}
   431  	queryBodyBuffer.WriteString(" FROM ")
   432  	queryBodyBuffer.WriteString(tableName)
   433  
   434  	queryBody = queryBodyBuffer.String()
   435  	statementLabel = tableName + "_get_all"
   436  	return
   437  }
   438  
   439  func (i *Invocation) generateCreate(object any) (statementLabel, queryBody string, insertCols, autos []*Column) {
   440  	tableName := TableName(object)
   441  
   442  	cols := i.conn.TypeMeta(object)
   443  	insertCols = append(cols.InsertColumns(), ColumnsNotZero(cols.Autos(), object)...)
   444  	autos = cols.Autos()
   445  
   446  	queryBodyBuffer := i.conn.bp.Get()
   447  	defer i.conn.bp.Put(queryBodyBuffer)
   448  
   449  	queryBodyBuffer.WriteString("INSERT INTO ")
   450  	queryBodyBuffer.WriteString(tableName)
   451  	queryBodyBuffer.WriteString(" (")
   452  	for i, name := range ColumnNamesWithPrefix(insertCols, cols.columnPrefix) {
   453  		queryBodyBuffer.WriteString(name)
   454  		if i < (len(insertCols) - 1) {
   455  			queryBodyBuffer.WriteRune(',')
   456  		}
   457  	}
   458  	queryBodyBuffer.WriteString(") VALUES (")
   459  	for x := 0; x < len(insertCols); x++ {
   460  		queryBodyBuffer.WriteString("$" + strconv.Itoa(x+1))
   461  		if x < (len(insertCols) - 1) {
   462  			queryBodyBuffer.WriteRune(',')
   463  		}
   464  	}
   465  	queryBodyBuffer.WriteString(")")
   466  
   467  	if len(autos) > 0 {
   468  		queryBodyBuffer.WriteString(" RETURNING ")
   469  		queryBodyBuffer.WriteString(ColumnNamesWithPrefixCSV(autos, cols.columnPrefix))
   470  	}
   471  
   472  	queryBody = queryBodyBuffer.String()
   473  	statementLabel = tableName + "_create"
   474  	return
   475  }
   476  
   477  func (i *Invocation) generateCreateIfNotExists(object any) (statementLabel, queryBody string, insertCols []*Column) {
   478  	cols := i.conn.TypeMeta(object)
   479  	insertCols = append(cols.InsertColumns(), ColumnsNotZero(cols.Autos(), object)...)
   480  
   481  	pks := cols.PrimaryKeys()
   482  	tableName := TableName(object)
   483  
   484  	queryBodyBuffer := i.conn.bp.Get()
   485  	defer i.conn.bp.Put(queryBodyBuffer)
   486  
   487  	queryBodyBuffer.WriteString("INSERT INTO ")
   488  	queryBodyBuffer.WriteString(tableName)
   489  	queryBodyBuffer.WriteString(" (")
   490  	for i, name := range ColumnNamesWithPrefix(insertCols, cols.columnPrefix) {
   491  		queryBodyBuffer.WriteString(name)
   492  		if i < (len(insertCols) - 1) {
   493  			queryBodyBuffer.WriteRune(',')
   494  		}
   495  	}
   496  	queryBodyBuffer.WriteString(") VALUES (")
   497  	for x := 0; x < len(insertCols); x++ {
   498  		queryBodyBuffer.WriteString("$" + strconv.Itoa(x+1))
   499  		if x < (len(insertCols) - 1) {
   500  			queryBodyBuffer.WriteRune(',')
   501  		}
   502  	}
   503  	queryBodyBuffer.WriteString(")")
   504  
   505  	if len(pks) > 0 {
   506  		queryBodyBuffer.WriteString(" ON CONFLICT (")
   507  		pkColumnNames := ColumnNamesWithPrefix(pks, cols.columnPrefix)
   508  		for i, name := range pkColumnNames {
   509  			queryBodyBuffer.WriteString(name)
   510  			if i < len(pkColumnNames)-1 {
   511  				queryBodyBuffer.WriteRune(',')
   512  			}
   513  		}
   514  		queryBodyBuffer.WriteString(") DO NOTHING")
   515  	}
   516  
   517  	queryBody = queryBodyBuffer.String()
   518  	statementLabel = tableName + "_create_if_not_exists"
   519  	return
   520  }
   521  
   522  func (i *Invocation) generateUpdate(object any) (statementLabel, queryBody string, pks, updateCols []*Column) {
   523  	tableName := TableName(object)
   524  
   525  	cols := i.conn.TypeMeta(object)
   526  
   527  	pks = cols.PrimaryKeys()
   528  	updateCols = cols.UpdateColumns()
   529  
   530  	queryBodyBuffer := i.conn.bp.Get()
   531  	defer i.conn.bp.Put(queryBodyBuffer)
   532  
   533  	queryBodyBuffer.WriteString("UPDATE ")
   534  	queryBodyBuffer.WriteString(tableName)
   535  	queryBodyBuffer.WriteString(" SET ")
   536  
   537  	var updateColIndex int
   538  	var col *Column
   539  	for ; updateColIndex < len(updateCols); updateColIndex++ {
   540  		col = updateCols[updateColIndex]
   541  		queryBodyBuffer.WriteString(col.ColumnName)
   542  		queryBodyBuffer.WriteString(" = $" + strconv.Itoa(updateColIndex+1))
   543  		if updateColIndex != (len(updateCols) - 1) {
   544  			queryBodyBuffer.WriteRune(',')
   545  		}
   546  	}
   547  
   548  	queryBodyBuffer.WriteString(" WHERE ")
   549  	for i, pk := range pks {
   550  		queryBodyBuffer.WriteString(pk.ColumnName)
   551  		queryBodyBuffer.WriteString(" = ")
   552  		queryBodyBuffer.WriteString("$" + strconv.Itoa(i+(updateColIndex+1)))
   553  
   554  		if i < (len(pks) - 1) {
   555  			queryBodyBuffer.WriteString(" AND ")
   556  		}
   557  	}
   558  
   559  	queryBody = queryBodyBuffer.String()
   560  	statementLabel = tableName + "_update"
   561  	return
   562  }
   563  
   564  func (i *Invocation) generateUpsert(object any) (statementLabel, queryBody string, autos, insertsWithAutos []*Column) {
   565  	tableName := TableName(object)
   566  	cols := i.conn.TypeMeta(object)
   567  
   568  	inserts := cols.InsertColumns()
   569  	updates := cols.UpdateColumns()
   570  	insertsWithAutos = append(inserts, cols.Autos()...)
   571  
   572  	pks := filter(insertsWithAutos, func(c *Column) bool { return c.IsPrimaryKey })
   573  	notZero := ColumnsNotZero(cols.Columns(), object)
   574  
   575  	// But we exclude auto primary keys that are not set. Auto primary keys that ARE set must be included in the insert
   576  	// clause so that there is a collision. But keys that are not set must be excluded from insertsWithAutos so that
   577  	// they are not passed as an extra parameter to ExecInContext later and are properly auto-generated
   578  	for _, pk := range pks {
   579  		if pk.IsAuto && !HasColumn(notZero, pk.ColumnName) {
   580  			insertsWithAutos = filter(insertsWithAutos, func(c *Column) bool { return c.ColumnName == pk.ColumnName })
   581  		}
   582  	}
   583  
   584  	tokenMap := map[string]string{}
   585  	for i, col := range insertsWithAutos {
   586  		tokenMap[col.ColumnName] = "$" + strconv.Itoa(i+1)
   587  	}
   588  
   589  	// autos are read out on insert (but only if unset)
   590  	autos = ColumnsZero(cols.Autos(), object)
   591  	pkNames := ColumnNames(pks)
   592  
   593  	queryBodyBuffer := i.conn.bp.Get()
   594  	defer i.conn.bp.Put(queryBodyBuffer)
   595  
   596  	queryBodyBuffer.WriteString("INSERT INTO ")
   597  	queryBodyBuffer.WriteString(tableName)
   598  	queryBodyBuffer.WriteString(" (")
   599  
   600  	skipComma := true
   601  	for _, col := range insertsWithAutos {
   602  		if !col.IsAuto || HasColumn(notZero, col.ColumnName) {
   603  			if !skipComma {
   604  				queryBodyBuffer.WriteRune(',')
   605  			}
   606  			skipComma = false
   607  			queryBodyBuffer.WriteString(col.ColumnName)
   608  		}
   609  	}
   610  
   611  	queryBodyBuffer.WriteString(") VALUES (")
   612  	skipComma = true
   613  	for _, col := range insertsWithAutos {
   614  		if !col.IsAuto || HasColumn(notZero, col.ColumnName) {
   615  			if !skipComma {
   616  				queryBodyBuffer.WriteRune(',')
   617  			}
   618  			skipComma = false
   619  			queryBodyBuffer.WriteString(tokenMap[col.ColumnName])
   620  		}
   621  	}
   622  
   623  	queryBodyBuffer.WriteString(")")
   624  
   625  	if len(pks) > 0 {
   626  		queryBodyBuffer.WriteString(" ON CONFLICT (")
   627  
   628  		for i, name := range pkNames {
   629  			queryBodyBuffer.WriteString(name)
   630  			if i < len(pkNames)-1 {
   631  				queryBodyBuffer.WriteRune(',')
   632  			}
   633  		}
   634  		if len(updates) > 0 {
   635  			queryBodyBuffer.WriteString(") DO UPDATE SET ")
   636  
   637  			for i, col := range updates {
   638  				queryBodyBuffer.WriteString(col.ColumnName + " = " + tokenMap[col.ColumnName])
   639  				if i < (len(updates) - 1) {
   640  					queryBodyBuffer.WriteRune(',')
   641  				}
   642  			}
   643  		} else {
   644  			queryBodyBuffer.WriteString(") DO NOTHING ")
   645  		}
   646  	}
   647  	if len(autos) > 0 {
   648  		queryBodyBuffer.WriteString(" RETURNING ")
   649  		queryBodyBuffer.WriteString(ColumnNamesCSV(autos))
   650  	}
   651  
   652  	queryBody = queryBodyBuffer.String()
   653  	statementLabel = tableName + "_upsert"
   654  	return
   655  }
   656  
   657  func (i *Invocation) generateExists(object any) (statementLabel, queryBody string, pks []*Column, err error) {
   658  	tableName := TableName(object)
   659  	pks = i.conn.TypeMeta(object).PrimaryKeys()
   660  	if len(pks) == 0 {
   661  		err = ErrNoPrimaryKey
   662  		return
   663  	}
   664  
   665  	queryBodyBuffer := i.conn.bp.Get()
   666  	defer i.conn.bp.Put(queryBodyBuffer)
   667  
   668  	queryBodyBuffer.WriteString("SELECT 1 FROM ")
   669  	queryBodyBuffer.WriteString(tableName)
   670  	queryBodyBuffer.WriteString(" WHERE ")
   671  	for i, pk := range pks {
   672  		queryBodyBuffer.WriteString(pk.ColumnName)
   673  		queryBodyBuffer.WriteString(" = ")
   674  		queryBodyBuffer.WriteString("$" + strconv.Itoa(i+1))
   675  
   676  		if i < (len(pks) - 1) {
   677  			queryBodyBuffer.WriteString(" AND ")
   678  		}
   679  	}
   680  	statementLabel = tableName + "_exists"
   681  	queryBody = queryBodyBuffer.String()
   682  	return
   683  }
   684  
   685  func (i *Invocation) generateDelete(object any) (statementLabel, queryBody string, pks []*Column, err error) {
   686  	tableName := TableName(object)
   687  	pks = i.conn.TypeMeta(object).PrimaryKeys()
   688  	if len(pks) == 0 {
   689  		err = ErrNoPrimaryKey
   690  		return
   691  	}
   692  	queryBodyBuffer := i.conn.bp.Get()
   693  	defer i.conn.bp.Put(queryBodyBuffer)
   694  
   695  	queryBodyBuffer.WriteString("DELETE FROM ")
   696  	queryBodyBuffer.WriteString(tableName)
   697  	queryBodyBuffer.WriteString(" WHERE ")
   698  	for i, pk := range pks {
   699  		queryBodyBuffer.WriteString(pk.ColumnName)
   700  		queryBodyBuffer.WriteString(" = ")
   701  		queryBodyBuffer.WriteString("$" + strconv.Itoa(i+1))
   702  
   703  		if i < (len(pks) - 1) {
   704  			queryBodyBuffer.WriteString(" AND ")
   705  		}
   706  	}
   707  	statementLabel = tableName + "_delete"
   708  	queryBody = queryBodyBuffer.String()
   709  	return
   710  }
   711  
   712  // --------------------------------------------------------------------------------
   713  // helpers
   714  // --------------------------------------------------------------------------------
   715  
   716  func (i *Invocation) maybeSetLabel(label string) {
   717  	if i.label != "" {
   718  		return
   719  	}
   720  	i.label = label
   721  }
   722  
   723  // autoValues returns references to the auto updatd fields for a given column collection.
   724  func (i *Invocation) autoValues(autos []*Column) []interface{} {
   725  	autoValues := make([]interface{}, len(autos))
   726  	for i, autoCol := range autos {
   727  		autoValues[i] = reflect.New(reflect.PtrTo(autoCol.FieldType)).Interface()
   728  	}
   729  	return autoValues
   730  }
   731  
   732  // setAutos sets the automatic values for a given object.
   733  func (i *Invocation) setAutos(object any, autos []*Column, autoValues []any) (err error) {
   734  	for index := 0; index < len(autoValues); index++ {
   735  		err = autos[index].SetValue(object, autoValues[index])
   736  		if err != nil {
   737  			return
   738  		}
   739  	}
   740  	return
   741  }
   742  
   743  // start runs on start steps.
   744  func (i *Invocation) start(statement string) (string, error) {
   745  	if i.db == nil {
   746  		return "", ErrConnectionClosed
   747  	}
   748  	if i.ctx == nil {
   749  		return "", ErrContextUnset
   750  	}
   751  	// there are a lot of steps here typically but we removed
   752  	// most of the logger and statement interceptor stuff for simplicity.
   753  	// we can add that back later.
   754  	return statement, nil
   755  }
   756  
   757  // finish runs on complete steps.
   758  func (i *Invocation) finish(statement string, r any, res sql.Result, err error) error {
   759  	if i.cancel != nil {
   760  		i.cancel()
   761  	}
   762  	if r != nil {
   763  		err = errutil.Append(err, errutil.New(r))
   764  	}
   765  	if i.conn != nil && len(i.conn.onQuery) > 0 {
   766  		qe := QueryEvent{
   767  			Body:     statement,
   768  			Elapsed:  time.Now().UTC().Sub(i.started),
   769  			Username: i.conn.Config.Username,
   770  			Database: i.conn.Config.Database,
   771  			Engine:   i.conn.Config.Engine,
   772  			Label:    i.label,
   773  			Err:      errutil.New(err),
   774  		}
   775  		if res != nil {
   776  			qe.RowsAffected, _ = res.RowsAffected()
   777  		}
   778  		for _, l := range i.conn.onQuery {
   779  			l(qe)
   780  		}
   781  	}
   782  	return err
   783  }