github.com/jackc/pgx/v5@v5.5.5/internal/sanitize/sanitize.go (about) 1 package sanitize 2 3 import ( 4 "bytes" 5 "encoding/hex" 6 "fmt" 7 "strconv" 8 "strings" 9 "time" 10 "unicode/utf8" 11 ) 12 13 // Part is either a string or an int. A string is raw SQL. An int is a 14 // argument placeholder. 15 type Part any 16 17 type Query struct { 18 Parts []Part 19 } 20 21 // utf.DecodeRune returns the utf8.RuneError for errors. But that is actually rune U+FFFD -- the unicode replacement 22 // character. utf8.RuneError is not an error if it is also width 3. 23 // 24 // https://github.com/jackc/pgx/issues/1380 25 const replacementcharacterwidth = 3 26 27 func (q *Query) Sanitize(args ...any) (string, error) { 28 argUse := make([]bool, len(args)) 29 buf := &bytes.Buffer{} 30 31 for _, part := range q.Parts { 32 var str string 33 switch part := part.(type) { 34 case string: 35 str = part 36 case int: 37 argIdx := part - 1 38 39 if argIdx < 0 { 40 return "", fmt.Errorf("first sql argument must be > 0") 41 } 42 43 if argIdx >= len(args) { 44 return "", fmt.Errorf("insufficient arguments") 45 } 46 arg := args[argIdx] 47 switch arg := arg.(type) { 48 case nil: 49 str = "null" 50 case int64: 51 str = strconv.FormatInt(arg, 10) 52 case float64: 53 str = strconv.FormatFloat(arg, 'f', -1, 64) 54 case bool: 55 str = strconv.FormatBool(arg) 56 case []byte: 57 str = QuoteBytes(arg) 58 case string: 59 str = QuoteString(arg) 60 case time.Time: 61 str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'") 62 default: 63 return "", fmt.Errorf("invalid arg type: %T", arg) 64 } 65 argUse[argIdx] = true 66 67 // Prevent SQL injection via Line Comment Creation 68 // https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p 69 str = " " + str + " " 70 default: 71 return "", fmt.Errorf("invalid Part type: %T", part) 72 } 73 buf.WriteString(str) 74 } 75 76 for i, used := range argUse { 77 if !used { 78 return "", fmt.Errorf("unused argument: %d", i) 79 } 80 } 81 return buf.String(), nil 82 } 83 84 func NewQuery(sql string) (*Query, error) { 85 l := &sqlLexer{ 86 src: sql, 87 stateFn: rawState, 88 } 89 90 for l.stateFn != nil { 91 l.stateFn = l.stateFn(l) 92 } 93 94 query := &Query{Parts: l.parts} 95 96 return query, nil 97 } 98 99 func QuoteString(str string) string { 100 return "'" + strings.ReplaceAll(str, "'", "''") + "'" 101 } 102 103 func QuoteBytes(buf []byte) string { 104 return `'\x` + hex.EncodeToString(buf) + "'" 105 } 106 107 type sqlLexer struct { 108 src string 109 start int 110 pos int 111 nested int // multiline comment nesting level. 112 stateFn stateFn 113 parts []Part 114 } 115 116 type stateFn func(*sqlLexer) stateFn 117 118 func rawState(l *sqlLexer) stateFn { 119 for { 120 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 121 l.pos += width 122 123 switch r { 124 case 'e', 'E': 125 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 126 if nextRune == '\'' { 127 l.pos += width 128 return escapeStringState 129 } 130 case '\'': 131 return singleQuoteState 132 case '"': 133 return doubleQuoteState 134 case '$': 135 nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:]) 136 if '0' <= nextRune && nextRune <= '9' { 137 if l.pos-l.start > 0 { 138 l.parts = append(l.parts, l.src[l.start:l.pos-width]) 139 } 140 l.start = l.pos 141 return placeholderState 142 } 143 case '-': 144 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 145 if nextRune == '-' { 146 l.pos += width 147 return oneLineCommentState 148 } 149 case '/': 150 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 151 if nextRune == '*' { 152 l.pos += width 153 return multilineCommentState 154 } 155 case utf8.RuneError: 156 if width != replacementcharacterwidth { 157 if l.pos-l.start > 0 { 158 l.parts = append(l.parts, l.src[l.start:l.pos]) 159 l.start = l.pos 160 } 161 return nil 162 } 163 } 164 } 165 } 166 167 func singleQuoteState(l *sqlLexer) stateFn { 168 for { 169 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 170 l.pos += width 171 172 switch r { 173 case '\'': 174 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 175 if nextRune != '\'' { 176 return rawState 177 } 178 l.pos += width 179 case utf8.RuneError: 180 if width != replacementcharacterwidth { 181 if l.pos-l.start > 0 { 182 l.parts = append(l.parts, l.src[l.start:l.pos]) 183 l.start = l.pos 184 } 185 return nil 186 } 187 } 188 } 189 } 190 191 func doubleQuoteState(l *sqlLexer) stateFn { 192 for { 193 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 194 l.pos += width 195 196 switch r { 197 case '"': 198 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 199 if nextRune != '"' { 200 return rawState 201 } 202 l.pos += width 203 case utf8.RuneError: 204 if width != replacementcharacterwidth { 205 if l.pos-l.start > 0 { 206 l.parts = append(l.parts, l.src[l.start:l.pos]) 207 l.start = l.pos 208 } 209 return nil 210 } 211 } 212 } 213 } 214 215 // placeholderState consumes a placeholder value. The $ must have already has 216 // already been consumed. The first rune must be a digit. 217 func placeholderState(l *sqlLexer) stateFn { 218 num := 0 219 220 for { 221 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 222 l.pos += width 223 224 if '0' <= r && r <= '9' { 225 num *= 10 226 num += int(r - '0') 227 } else { 228 l.parts = append(l.parts, num) 229 l.pos -= width 230 l.start = l.pos 231 return rawState 232 } 233 } 234 } 235 236 func escapeStringState(l *sqlLexer) stateFn { 237 for { 238 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 239 l.pos += width 240 241 switch r { 242 case '\\': 243 _, width = utf8.DecodeRuneInString(l.src[l.pos:]) 244 l.pos += width 245 case '\'': 246 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 247 if nextRune != '\'' { 248 return rawState 249 } 250 l.pos += width 251 case utf8.RuneError: 252 if width != replacementcharacterwidth { 253 if l.pos-l.start > 0 { 254 l.parts = append(l.parts, l.src[l.start:l.pos]) 255 l.start = l.pos 256 } 257 return nil 258 } 259 } 260 } 261 } 262 263 func oneLineCommentState(l *sqlLexer) stateFn { 264 for { 265 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 266 l.pos += width 267 268 switch r { 269 case '\\': 270 _, width = utf8.DecodeRuneInString(l.src[l.pos:]) 271 l.pos += width 272 case '\n', '\r': 273 return rawState 274 case utf8.RuneError: 275 if width != replacementcharacterwidth { 276 if l.pos-l.start > 0 { 277 l.parts = append(l.parts, l.src[l.start:l.pos]) 278 l.start = l.pos 279 } 280 return nil 281 } 282 } 283 } 284 } 285 286 func multilineCommentState(l *sqlLexer) stateFn { 287 for { 288 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 289 l.pos += width 290 291 switch r { 292 case '/': 293 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 294 if nextRune == '*' { 295 l.pos += width 296 l.nested++ 297 } 298 case '*': 299 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 300 if nextRune != '/' { 301 continue 302 } 303 304 l.pos += width 305 if l.nested == 0 { 306 return rawState 307 } 308 l.nested-- 309 310 case utf8.RuneError: 311 if width != replacementcharacterwidth { 312 if l.pos-l.start > 0 { 313 l.parts = append(l.parts, l.src[l.start:l.pos]) 314 l.start = l.pos 315 } 316 return nil 317 } 318 } 319 } 320 } 321 322 // SanitizeSQL replaces placeholder values with args. It quotes and escapes args 323 // as necessary. This function is only safe when standard_conforming_strings is 324 // on. 325 func SanitizeSQL(sql string, args ...any) (string, error) { 326 query, err := NewQuery(sql) 327 if err != nil { 328 return "", err 329 } 330 return query.Sanitize(args...) 331 }