github.com/jtzjtz/kit@v1.0.2/sql/sqlquery.go (about)

     1  package sql
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"strconv"
     7  	"strings"
     8  	"time"
     9  )
    10  
    11  // reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize.
    12  // If cap(buf) is not enough, reallocate new buffer.
    13  func reserveBuffer(buf []byte, appendSize int) []byte {
    14  	newSize := len(buf) + appendSize
    15  	if cap(buf) < newSize {
    16  		// Grow buffer exponentially
    17  		newBuf := make([]byte, len(buf)*2+appendSize)
    18  		copy(newBuf, buf)
    19  		buf = newBuf
    20  	}
    21  	return buf[:newSize]
    22  }
    23  
    24  // escapeBytesBackslash escapes []byte with backslashes (\)
    25  // https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932
    26  func escapeBytesBackslash(buf, v []byte) []byte {
    27  	pos := len(buf)
    28  	buf = reserveBuffer(buf, len(v)*2)
    29  
    30  	for _, c := range v {
    31  		switch c {
    32  		case '\x00':
    33  			buf[pos] = '\\'
    34  			buf[pos+1] = '0'
    35  			pos += 2
    36  		case '\n':
    37  			buf[pos] = '\\'
    38  			buf[pos+1] = 'n'
    39  			pos += 2
    40  		case '\r':
    41  			buf[pos] = '\\'
    42  			buf[pos+1] = 'r'
    43  			pos += 2
    44  		case '\x1a':
    45  			buf[pos] = '\\'
    46  			buf[pos+1] = 'Z'
    47  			pos += 2
    48  		case '\'':
    49  			buf[pos] = '\\'
    50  			buf[pos+1] = '\''
    51  			pos += 2
    52  		case '"':
    53  			buf[pos] = '\\'
    54  			buf[pos+1] = '"'
    55  			pos += 2
    56  		case '\\':
    57  			buf[pos] = '\\'
    58  			buf[pos+1] = '\\'
    59  			pos += 2
    60  		default:
    61  			buf[pos] = c
    62  			pos++
    63  		}
    64  	}
    65  
    66  	return buf[:pos]
    67  }
    68  
    69  // escapeStringBackslash is similar to escapeBytesBackslash but for string.
    70  func escapeStringBackslash(buf []byte, v string) []byte {
    71  	pos := len(buf)
    72  	buf = reserveBuffer(buf, len(v)*2)
    73  
    74  	for i := 0; i < len(v); i++ {
    75  		c := v[i]
    76  		switch c {
    77  		case '\x00':
    78  			buf[pos] = '\\'
    79  			buf[pos+1] = '0'
    80  			pos += 2
    81  		case '\n':
    82  			buf[pos] = '\\'
    83  			buf[pos+1] = 'n'
    84  			pos += 2
    85  		case '\r':
    86  			buf[pos] = '\\'
    87  			buf[pos+1] = 'r'
    88  			pos += 2
    89  		case '\x1a':
    90  			buf[pos] = '\\'
    91  			buf[pos+1] = 'Z'
    92  			pos += 2
    93  		case '\'':
    94  			buf[pos] = '\\'
    95  			buf[pos+1] = '\''
    96  			pos += 2
    97  		case '"':
    98  			buf[pos] = '\\'
    99  			buf[pos+1] = '"'
   100  			pos += 2
   101  		case '\\':
   102  			buf[pos] = '\\'
   103  			buf[pos+1] = '\\'
   104  			pos += 2
   105  		default:
   106  			buf[pos] = c
   107  			pos++
   108  		}
   109  	}
   110  
   111  	return buf[:pos]
   112  }
   113  
   114  // Query 拼接 sql 语句
   115  func Query(query string, args ...interface{}) (sql string, err error) {
   116  	if len(sql) == 0 && len(args) == 0 {
   117  		return "1 = 1", nil
   118  	}
   119  
   120  	if strings.Count(query, "?") != len(args) {
   121  		return "", errors.New(`匹配符("?")的数量和参数数量不匹配`)
   122  	}
   123  
   124  	buf := make([]byte, 0)
   125  	argPos := 0
   126  
   127  	for i := 0; i < len(query); i++ {
   128  		q := strings.IndexByte(query[i:], '?')
   129  		if q == -1 {
   130  			buf = append(buf, query[i:]...)
   131  			break
   132  		}
   133  
   134  		buf = append(buf, query[i:i+q]...)
   135  		i += q
   136  
   137  		arg := args[argPos]
   138  		argPos++
   139  
   140  		if arg == nil {
   141  			buf = append(buf, "NULL"...)
   142  			continue
   143  		}
   144  
   145  		switch v := arg.(type) {
   146  		case int32:
   147  			buf = strconv.AppendInt(buf, int64(v), 10)
   148  		case int64:
   149  			buf = strconv.AppendInt(buf, v, 10)
   150  		case int:
   151  			buf = strconv.AppendInt(buf, int64(v), 10)
   152  		case uint64:
   153  			buf = strconv.AppendUint(buf, v, 10)
   154  		case float32:
   155  			buf = strconv.AppendFloat(buf, float64(v), 'g', -1, 64)
   156  		case float64:
   157  			buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
   158  		case bool:
   159  			if v {
   160  				buf = append(buf, '1')
   161  			} else {
   162  				buf = append(buf, '0')
   163  			}
   164  		case time.Time:
   165  			if v.IsZero() {
   166  				buf = append(buf, "'0000-00-00'"...)
   167  			} else {
   168  				loc, _ := time.LoadLocation("Asia/Shanghai")
   169  				v.In(loc)
   170  				strtime := v.Format("2006-01-02 15:04:05")
   171  				buf = append(buf, '\'')
   172  				buf = append(buf, strtime...)
   173  				buf = append(buf, '\'')
   174  			}
   175  		case []byte:
   176  			if v == nil {
   177  				buf = append(buf, "NULL"...)
   178  			} else {
   179  				buf = append(buf, '\'')
   180  				buf = escapeBytesBackslash(buf, v)
   181  				buf = append(buf, '\'')
   182  			}
   183  		case string:
   184  			buf = append(buf, '\'')
   185  			buf = escapeStringBackslash(buf, v)
   186  			buf = append(buf, '\'')
   187  		default:
   188  			return "", fmt.Errorf("参数类型错误: %v", v)
   189  		}
   190  	}
   191  
   192  	if argPos != len(args) {
   193  		return "", errors.New("未知错误")
   194  	}
   195  
   196  	return string(buf), nil
   197  }