github.com/jtzjtz/kit@v1.0.2/sql/sqlquery.go (about) 1 package sql 2 3 import ( 4 "errors" 5 "fmt" 6 "strconv" 7 "strings" 8 "time" 9 ) 10 11 // reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize. 12 // If cap(buf) is not enough, reallocate new buffer. 13 func reserveBuffer(buf []byte, appendSize int) []byte { 14 newSize := len(buf) + appendSize 15 if cap(buf) < newSize { 16 // Grow buffer exponentially 17 newBuf := make([]byte, len(buf)*2+appendSize) 18 copy(newBuf, buf) 19 buf = newBuf 20 } 21 return buf[:newSize] 22 } 23 24 // escapeBytesBackslash escapes []byte with backslashes (\) 25 // https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932 26 func escapeBytesBackslash(buf, v []byte) []byte { 27 pos := len(buf) 28 buf = reserveBuffer(buf, len(v)*2) 29 30 for _, c := range v { 31 switch c { 32 case '\x00': 33 buf[pos] = '\\' 34 buf[pos+1] = '0' 35 pos += 2 36 case '\n': 37 buf[pos] = '\\' 38 buf[pos+1] = 'n' 39 pos += 2 40 case '\r': 41 buf[pos] = '\\' 42 buf[pos+1] = 'r' 43 pos += 2 44 case '\x1a': 45 buf[pos] = '\\' 46 buf[pos+1] = 'Z' 47 pos += 2 48 case '\'': 49 buf[pos] = '\\' 50 buf[pos+1] = '\'' 51 pos += 2 52 case '"': 53 buf[pos] = '\\' 54 buf[pos+1] = '"' 55 pos += 2 56 case '\\': 57 buf[pos] = '\\' 58 buf[pos+1] = '\\' 59 pos += 2 60 default: 61 buf[pos] = c 62 pos++ 63 } 64 } 65 66 return buf[:pos] 67 } 68 69 // escapeStringBackslash is similar to escapeBytesBackslash but for string. 70 func escapeStringBackslash(buf []byte, v string) []byte { 71 pos := len(buf) 72 buf = reserveBuffer(buf, len(v)*2) 73 74 for i := 0; i < len(v); i++ { 75 c := v[i] 76 switch c { 77 case '\x00': 78 buf[pos] = '\\' 79 buf[pos+1] = '0' 80 pos += 2 81 case '\n': 82 buf[pos] = '\\' 83 buf[pos+1] = 'n' 84 pos += 2 85 case '\r': 86 buf[pos] = '\\' 87 buf[pos+1] = 'r' 88 pos += 2 89 case '\x1a': 90 buf[pos] = '\\' 91 buf[pos+1] = 'Z' 92 pos += 2 93 case '\'': 94 buf[pos] = '\\' 95 buf[pos+1] = '\'' 96 pos += 2 97 case '"': 98 buf[pos] = '\\' 99 buf[pos+1] = '"' 100 pos += 2 101 case '\\': 102 buf[pos] = '\\' 103 buf[pos+1] = '\\' 104 pos += 2 105 default: 106 buf[pos] = c 107 pos++ 108 } 109 } 110 111 return buf[:pos] 112 } 113 114 // Query 拼接 sql 语句 115 func Query(query string, args ...interface{}) (sql string, err error) { 116 if len(sql) == 0 && len(args) == 0 { 117 return "1 = 1", nil 118 } 119 120 if strings.Count(query, "?") != len(args) { 121 return "", errors.New(`匹配符("?")的数量和参数数量不匹配`) 122 } 123 124 buf := make([]byte, 0) 125 argPos := 0 126 127 for i := 0; i < len(query); i++ { 128 q := strings.IndexByte(query[i:], '?') 129 if q == -1 { 130 buf = append(buf, query[i:]...) 131 break 132 } 133 134 buf = append(buf, query[i:i+q]...) 135 i += q 136 137 arg := args[argPos] 138 argPos++ 139 140 if arg == nil { 141 buf = append(buf, "NULL"...) 142 continue 143 } 144 145 switch v := arg.(type) { 146 case int32: 147 buf = strconv.AppendInt(buf, int64(v), 10) 148 case int64: 149 buf = strconv.AppendInt(buf, v, 10) 150 case int: 151 buf = strconv.AppendInt(buf, int64(v), 10) 152 case uint64: 153 buf = strconv.AppendUint(buf, v, 10) 154 case float32: 155 buf = strconv.AppendFloat(buf, float64(v), 'g', -1, 64) 156 case float64: 157 buf = strconv.AppendFloat(buf, v, 'g', -1, 64) 158 case bool: 159 if v { 160 buf = append(buf, '1') 161 } else { 162 buf = append(buf, '0') 163 } 164 case time.Time: 165 if v.IsZero() { 166 buf = append(buf, "'0000-00-00'"...) 167 } else { 168 loc, _ := time.LoadLocation("Asia/Shanghai") 169 v.In(loc) 170 strtime := v.Format("2006-01-02 15:04:05") 171 buf = append(buf, '\'') 172 buf = append(buf, strtime...) 173 buf = append(buf, '\'') 174 } 175 case []byte: 176 if v == nil { 177 buf = append(buf, "NULL"...) 178 } else { 179 buf = append(buf, '\'') 180 buf = escapeBytesBackslash(buf, v) 181 buf = append(buf, '\'') 182 } 183 case string: 184 buf = append(buf, '\'') 185 buf = escapeStringBackslash(buf, v) 186 buf = append(buf, '\'') 187 default: 188 return "", fmt.Errorf("参数类型错误: %v", v) 189 } 190 } 191 192 if argPos != len(args) { 193 return "", errors.New("未知错误") 194 } 195 196 return string(buf), nil 197 }