gitee.com/go-genie/sqlx@v1.0.3/connectors/mysql/interpolate_params.go (about)

     1  package mysql
     2  
     3  import (
     4  	"database/sql/driver"
     5  	"strconv"
     6  	"strings"
     7  	"time"
     8  
     9  	"github.com/pkg/errors"
    10  )
    11  
    12  func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
    13  	dargs := make([]driver.Value, len(named))
    14  	for n, param := range named {
    15  		if len(param.Name) > 0 {
    16  			// TODO: support the use of Named Parameters #561
    17  			return nil, errors.New("mysql: driver does not support the use of Named Parameters")
    18  		}
    19  		dargs[n] = param.Value
    20  	}
    21  	return dargs, nil
    22  }
    23  
    24  func interpolateParams(query string, args []driver.Value, loc *time.Location, maxAllowedPacket int) (string, error) {
    25  	if strings.Count(query, "?") != len(args) {
    26  		return "", driver.ErrSkip
    27  	}
    28  
    29  	buf := make([]byte, 0)
    30  	buf = buf[:0]
    31  	argPos := 0
    32  
    33  	data := []byte(query)
    34  
    35  	for i := range data {
    36  		q := query[i]
    37  		switch q {
    38  		case '?':
    39  			arg := args[argPos]
    40  			argPos++
    41  
    42  			if arg == nil {
    43  				buf = append(buf, "NULL"...)
    44  				continue
    45  			}
    46  
    47  			switch v := arg.(type) {
    48  			case int64:
    49  				buf = strconv.AppendInt(buf, v, 10)
    50  			case float64:
    51  				buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
    52  			case bool:
    53  				if v {
    54  					buf = append(buf, '1')
    55  				} else {
    56  					buf = append(buf, '0')
    57  				}
    58  			case time.Time:
    59  				if v.IsZero() {
    60  					buf = append(buf, "'0000-00-00'"...)
    61  				} else {
    62  					v := v.In(loc)
    63  					v = v.Add(time.Nanosecond * 500) // Write round under microsecond
    64  					year := v.Year()
    65  					year100 := year / 100
    66  					year1 := year % 100
    67  					month := v.Month()
    68  					day := v.Day()
    69  					hour := v.Hour()
    70  					minute := v.Minute()
    71  					second := v.Second()
    72  					micro := v.Nanosecond() / 1000
    73  
    74  					buf = append(buf, []byte{
    75  						'\'',
    76  						digits10[year100], digits01[year100],
    77  						digits10[year1], digits01[year1],
    78  						'-',
    79  						digits10[month], digits01[month],
    80  						'-',
    81  						digits10[day], digits01[day],
    82  						' ',
    83  						digits10[hour], digits01[hour],
    84  						':',
    85  						digits10[minute], digits01[minute],
    86  						':',
    87  						digits10[second], digits01[second],
    88  					}...)
    89  
    90  					if micro != 0 {
    91  						micro10000 := micro / 10000
    92  						micro100 := micro / 100 % 100
    93  						micro1 := micro % 100
    94  						buf = append(buf, []byte{
    95  							'.',
    96  							digits10[micro10000], digits01[micro10000],
    97  							digits10[micro100], digits01[micro100],
    98  							digits10[micro1], digits01[micro1],
    99  						}...)
   100  					}
   101  					buf = append(buf, '\'')
   102  				}
   103  			case []byte:
   104  				if v == nil {
   105  					buf = append(buf, "NULL"...)
   106  				} else {
   107  					buf = append(buf, "_binary'"...)
   108  					buf = escapeBytesBackslash(buf, v)
   109  					buf = append(buf, '\'')
   110  				}
   111  			case string:
   112  				buf = append(buf, '\'')
   113  				buf = escapeBytesBackslash(buf, []byte(v))
   114  				buf = append(buf, '\'')
   115  			default:
   116  				return "", driver.ErrSkip
   117  			}
   118  
   119  		case '\n':
   120  			buf = append(buf, ' ')
   121  		default:
   122  			buf = append(buf, q)
   123  		}
   124  
   125  		if len(buf)+4 > maxAllowedPacket {
   126  			return "", driver.ErrSkip
   127  		}
   128  	}
   129  	if argPos != len(args) {
   130  		return "", driver.ErrSkip
   131  	}
   132  	return string(buf), nil
   133  }
   134  
   135  // copy from mysql driver
   136  
   137  const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789"
   138  const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999"
   139  
   140  func escapeBytesBackslash(buf, v []byte) []byte {
   141  	pos := len(buf)
   142  	buf = reserveBuffer(buf, len(v)*2)
   143  
   144  	for _, c := range v {
   145  		switch c {
   146  		case '\x00':
   147  			buf[pos] = '\\'
   148  			buf[pos+1] = '0'
   149  			pos += 2
   150  		case '\n':
   151  			buf[pos] = '\\'
   152  			buf[pos+1] = 'n'
   153  			pos += 2
   154  		case '\r':
   155  			buf[pos] = '\\'
   156  			buf[pos+1] = 'r'
   157  			pos += 2
   158  		case '\x1a':
   159  			buf[pos] = '\\'
   160  			buf[pos+1] = 'Z'
   161  			pos += 2
   162  		case '\'':
   163  			buf[pos] = '\\'
   164  			buf[pos+1] = '\''
   165  			pos += 2
   166  		case '"':
   167  			buf[pos] = '\\'
   168  			buf[pos+1] = '"'
   169  			pos += 2
   170  		case '\\':
   171  			buf[pos] = '\\'
   172  			buf[pos+1] = '\\'
   173  			pos += 2
   174  		default:
   175  			buf[pos] = c
   176  			pos++
   177  		}
   178  	}
   179  
   180  	return buf[:pos]
   181  }
   182  
   183  func reserveBuffer(buf []byte, appendSize int) []byte {
   184  	newSize := len(buf) + appendSize
   185  	if cap(buf) < newSize {
   186  		// Grow buffer exponentially
   187  		newBuf := make([]byte, len(buf)*2+appendSize)
   188  		copy(newBuf, buf)
   189  		buf = newBuf
   190  	}
   191  	return buf[:newSize]
   192  }