github.com/blend/go-sdk@v1.20220411.3/db/util.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package db
     9  
    10  import (
    11  	"database/sql"
    12  	"fmt"
    13  	"reflect"
    14  	"strconv"
    15  )
    16  
    17  // --------------------------------------------------------------------------------
    18  // Utility Methods
    19  // --------------------------------------------------------------------------------
    20  
    21  // TableNameByType returns the table name for a given reflect.Type by instantiating it and calling o.TableName().
    22  // The type must implement DatabaseMapped or an exception will be returned.
    23  func TableNameByType(t reflect.Type) string {
    24  	instance := reflect.New(t).Interface()
    25  	if typed, isTyped := instance.(TableNameProvider); isTyped {
    26  		return typed.TableName()
    27  	}
    28  	if t.Kind() == reflect.Ptr {
    29  		t = t.Elem()
    30  		instance = reflect.New(t).Interface()
    31  		if typed, isTyped := instance.(TableNameProvider); isTyped {
    32  			return typed.TableName()
    33  		}
    34  	}
    35  	return t.Name()
    36  }
    37  
    38  // TableName returns the mapped table name for a given instance; it will sniff for the `TableName()` function on the type.
    39  func TableName(obj DatabaseMapped) string {
    40  	if typed, isTyped := obj.(TableNameProvider); isTyped {
    41  		return typed.TableName()
    42  	}
    43  	return ReflectType(obj).Name()
    44  }
    45  
    46  // --------------------------------------------------------------------------------
    47  // String Utility Methods
    48  // --------------------------------------------------------------------------------
    49  
    50  // ParamTokens returns a csv token string in the form "$1,$2,$3...$N" if passed (1, N).
    51  func ParamTokens(startAt, count int) string {
    52  	if count < 1 {
    53  		return ""
    54  	}
    55  	var str string
    56  	for i := startAt; i < startAt+count; i++ {
    57  		str = str + fmt.Sprintf("$%d", i)
    58  		if i < (startAt + count - 1) {
    59  			str = str + ","
    60  		}
    61  	}
    62  	return str
    63  }
    64  
    65  // --------------------------------------------------------------------------------
    66  // Result utility methods
    67  // --------------------------------------------------------------------------------
    68  
    69  // IgnoreExecResult is a helper for use with .Exec() (sql.Result, error)
    70  // that ignores the result return.
    71  func IgnoreExecResult(_ sql.Result, err error) error {
    72  	return err
    73  }
    74  
    75  // ExecRowsAffected is a helper for use with .Exec() (sql.Result, error)
    76  // that returns the rows affected.
    77  func ExecRowsAffected(i sql.Result, inputErr error) (int64, error) {
    78  	if inputErr != nil {
    79  		return 0, inputErr
    80  	}
    81  	ra, err := i.RowsAffected()
    82  	if err != nil {
    83  		return 0, Error(err)
    84  	}
    85  	return ra, nil
    86  }
    87  
    88  // --------------------------------------------------------------------------------
    89  // Internal / Reflection Utility Methods
    90  // --------------------------------------------------------------------------------
    91  
    92  // AsPopulatable casts an object as populatable.
    93  func AsPopulatable(object interface{}) Populatable {
    94  	return object.(Populatable)
    95  }
    96  
    97  // IsPopulatable returns if an object is populatable
    98  func IsPopulatable(object interface{}) bool {
    99  	_, isPopulatable := object.(Populatable)
   100  	return isPopulatable
   101  }
   102  
   103  // MakeWhereClause returns the sql `where` clause for a column collection, starting at a given index (used in sql $1 parameterization).
   104  func MakeWhereClause(pks *ColumnCollection, startAt int) string {
   105  	whereClause := " WHERE "
   106  	for i, pk := range pks.Columns() {
   107  		whereClause = whereClause + fmt.Sprintf("%s = %s", pk.ColumnName, "$"+strconv.Itoa(i+startAt))
   108  		if i < (pks.Len() - 1) {
   109  			whereClause = whereClause + " AND "
   110  		}
   111  	}
   112  
   113  	return whereClause
   114  }
   115  
   116  // ParamTokensCSV returns a csv token string in the form "$1,$2,$3...$N"
   117  func ParamTokensCSV(num int) string {
   118  	str := ""
   119  	for i := 1; i <= num; i++ {
   120  		str = str + fmt.Sprintf("$%d", i)
   121  		if i != num {
   122  			str = str + ","
   123  		}
   124  	}
   125  	return str
   126  }