github.com/acoshift/pgsql@v0.15.3/pgstmt/build.go (about)

     1  package pgstmt
     2  
     3  import (
     4  	"fmt"
     5  	"strconv"
     6  	"strings"
     7  	"time"
     8  
     9  	"github.com/lib/pq"
    10  )
    11  
    12  type buffer struct {
    13  	q []any
    14  }
    15  
    16  func (b *buffer) push(q ...any) {
    17  	b.q = append(b.q, q...)
    18  }
    19  
    20  func (b *buffer) pushFront(q ...any) {
    21  	b.q = append(q, b.q...)
    22  }
    23  
    24  func (b *buffer) popFront() any {
    25  	if b.empty() {
    26  		return nil
    27  	}
    28  	p := b.q[0]
    29  	b.q = b.q[1:]
    30  	return p
    31  }
    32  
    33  func (b *buffer) empty() bool {
    34  	return len(b.q) == 0
    35  }
    36  
    37  func (b *buffer) build() []any {
    38  	return b.q
    39  }
    40  
    41  type builder interface {
    42  	build() []any
    43  }
    44  
    45  func build(b *buffer) (string, []any) {
    46  	var args []any
    47  	var i int
    48  
    49  	var f func(p []any, sep string) string
    50  	f = func(p []any, sep string) string {
    51  		var q []string
    52  		for _, x := range p {
    53  			switch x := x.(type) {
    54  			default:
    55  				q = append(q, convertToString(x, false))
    56  			case builder:
    57  				q = append(q, f(x.build(), " "))
    58  			case arg:
    59  				i++
    60  				q = append(q, "$"+strconv.Itoa(i))
    61  				args = append(args, x.value)
    62  			case _any:
    63  				switch x := x.value.(type) {
    64  				case raw, notArg:
    65  					q = append(q, fmt.Sprintf("any(%s)", convertToString(x, false)))
    66  				default:
    67  					i++
    68  					q = append(q, fmt.Sprintf("any($%d)", i))
    69  					args = append(args, x)
    70  				}
    71  			case all:
    72  				switch x := x.value.(type) {
    73  				case raw, notArg:
    74  					q = append(q, fmt.Sprintf("all(%s)", convertToString(x, false)))
    75  				default:
    76  					i++
    77  					q = append(q, fmt.Sprintf("all($%d)", i))
    78  					args = append(args, x)
    79  				}
    80  			case *group:
    81  				if !x.empty() {
    82  					q = append(q, f(x.q, x.getSep()))
    83  				}
    84  			case *parenGroup:
    85  				if !x.empty() {
    86  					q = append(q, x.prefix+"("+f(x.q, x.getSep())+")")
    87  				}
    88  			}
    89  		}
    90  		return strings.Join(q, sep)
    91  	}
    92  	query := f(b.q, " ")
    93  	return query, args
    94  }
    95  
    96  func convertToString(x any, quoteStr bool) string {
    97  	switch x := x.(type) {
    98  	default:
    99  		return fmt.Sprint(x)
   100  	case string:
   101  		if quoteStr {
   102  			return pq.QuoteLiteral(x)
   103  		}
   104  		return x
   105  	case int:
   106  		return strconv.Itoa(x)
   107  	case int32:
   108  		return strconv.FormatInt(int64(x), 10)
   109  	case int64:
   110  		return strconv.FormatInt(x, 10)
   111  	case bool:
   112  		return strconv.FormatBool(x)
   113  	case time.Time:
   114  		return convertToString(string(pq.FormatTimestamp(x)), true)
   115  	case notArg:
   116  		return convertToString(x.value, true)
   117  	case raw:
   118  		return fmt.Sprint(x.value)
   119  	case defaultValue:
   120  		return "default"
   121  	}
   122  }