github.com/eden-framework/sqlx@v0.0.2/postgresqlconnector/interpolate_params.go (about)

     1  package postgresqlconnector
     2  
     3  import (
     4  	"database/sql/driver"
     5  	"fmt"
     6  	"strconv"
     7  	"strings"
     8  	"time"
     9  )
    10  
    11  func interpolateParams(query string, args []driver.NamedValue) fmt.Stringer {
    12  	return &SqlPrinter{
    13  		query: query,
    14  		args:  args,
    15  	}
    16  }
    17  
    18  type SqlPrinter struct {
    19  	query string
    20  	args  []driver.NamedValue
    21  }
    22  
    23  func (p *SqlPrinter) String() string {
    24  	s, err := InterpolateParams(p.query, p.args, time.Local)
    25  	if err != nil {
    26  		panic(err)
    27  	}
    28  	return s
    29  }
    30  
    31  func InterpolateParams(query string, args []driver.NamedValue, loc *time.Location) (string, error) {
    32  	if strings.Count(query, "?") != len(args) {
    33  		return "", driver.ErrSkip
    34  	}
    35  
    36  	if len(args) >= 65535 {
    37  		return "", fmt.Errorf("too many args: %d", len(args))
    38  	}
    39  
    40  	buf := make([]byte, 0)
    41  	buf = buf[:0]
    42  
    43  	argPos := 0
    44  
    45  	data := []byte(query)
    46  
    47  	for i := range data {
    48  		q := query[i]
    49  		switch q {
    50  		case '?':
    51  			arg := args[argPos].Value
    52  			argPos++
    53  
    54  			if arg == nil {
    55  				buf = append(buf, "NULL"...)
    56  				continue
    57  			}
    58  
    59  			switch v := arg.(type) {
    60  			case int64:
    61  				buf = strconv.AppendInt(buf, v, 10)
    62  			case float64:
    63  				buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
    64  			case bool:
    65  				if v {
    66  					buf = append(buf, '1')
    67  				} else {
    68  					buf = append(buf, '0')
    69  				}
    70  			case time.Time:
    71  				if v.IsZero() {
    72  					buf = append(buf, "'0000-00-00'"...)
    73  				} else {
    74  					v := v.In(loc)
    75  					v = v.Add(time.Nanosecond * 500) // Write round under microsecond
    76  					year := v.Year()
    77  					year100 := year / 100
    78  					year1 := year % 100
    79  					month := v.Month()
    80  					day := v.Day()
    81  					hour := v.Hour()
    82  					minute := v.Minute()
    83  					second := v.Second()
    84  					micro := v.Nanosecond() / 1000
    85  
    86  					buf = append(buf, []byte{
    87  						'\'',
    88  						digits10[year100], digits01[year100],
    89  						digits10[year1], digits01[year1],
    90  						'-',
    91  						digits10[month], digits01[month],
    92  						'-',
    93  						digits10[day], digits01[day],
    94  						' ',
    95  						digits10[hour], digits01[hour],
    96  						':',
    97  						digits10[minute], digits01[minute],
    98  						':',
    99  						digits10[second], digits01[second],
   100  					}...)
   101  
   102  					if micro != 0 {
   103  						micro10000 := micro / 10000
   104  						micro100 := micro / 100 % 100
   105  						micro1 := micro % 100
   106  						buf = append(buf, []byte{
   107  							'.',
   108  							digits10[micro10000], digits01[micro10000],
   109  							digits10[micro100], digits01[micro100],
   110  							digits10[micro1], digits01[micro1],
   111  						}...)
   112  					}
   113  					buf = append(buf, '\'')
   114  				}
   115  			case []byte:
   116  				if v == nil {
   117  					buf = append(buf, "NULL"...)
   118  				} else {
   119  					buf = append(buf, "E'"...)
   120  					buf = escapeBytesBackslash(buf, v)
   121  					buf = append(buf, '\'')
   122  				}
   123  			case string:
   124  				buf = append(buf, '\'')
   125  				buf = escapeBytesBackslash(buf, []byte(v))
   126  				buf = append(buf, '\'')
   127  			default:
   128  				return "", fmt.Errorf("unsupported type %T: %v", v, v)
   129  			}
   130  		case '\n':
   131  			buf = append(buf, ' ')
   132  		default:
   133  			buf = append(buf, q)
   134  		}
   135  	}
   136  
   137  	if argPos != len(args) {
   138  		return "", driver.ErrSkip
   139  	}
   140  
   141  	return string(buf), nil
   142  }
   143  
   144  const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789"
   145  const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999"
   146  
   147  func escapeBytesBackslash(buf, v []byte) []byte {
   148  	pos := len(buf)
   149  	buf = reserveBuffer(buf, len(v)*2)
   150  
   151  	for _, c := range v {
   152  		switch c {
   153  		case '\x00':
   154  			buf[pos] = '\\'
   155  			buf[pos+1] = '0'
   156  			pos += 2
   157  		case '\n':
   158  			buf[pos] = '\\'
   159  			buf[pos+1] = 'n'
   160  			pos += 2
   161  		case '\r':
   162  			buf[pos] = '\\'
   163  			buf[pos+1] = 'r'
   164  			pos += 2
   165  		case '\x1a':
   166  			buf[pos] = '\\'
   167  			buf[pos+1] = 'Z'
   168  			pos += 2
   169  		case '\'':
   170  			buf[pos] = '\\'
   171  			buf[pos+1] = '\''
   172  			pos += 2
   173  		case '"':
   174  			buf[pos] = '\\'
   175  			buf[pos+1] = '"'
   176  			pos += 2
   177  		case '\\':
   178  			buf[pos] = '\\'
   179  			buf[pos+1] = '\\'
   180  			pos += 2
   181  		default:
   182  			buf[pos] = c
   183  			pos++
   184  		}
   185  	}
   186  
   187  	return buf[:pos]
   188  }
   189  
   190  func reserveBuffer(buf []byte, appendSize int) []byte {
   191  	newSize := len(buf) + appendSize
   192  	if cap(buf) < newSize {
   193  		// Grow buffer exponentially
   194  		newBuf := make([]byte, len(buf)*2+appendSize)
   195  		copy(newBuf, buf)
   196  		buf = newBuf
   197  	}
   198  	return buf[:newSize]
   199  }