github.com/jackc/pgx/v5@v5.5.5/named_args.go (about) 1 package pgx 2 3 import ( 4 "context" 5 "strconv" 6 "strings" 7 "unicode/utf8" 8 ) 9 10 // NamedArgs can be used as the first argument to a query method. It will replace every '@' named placeholder with a '$' 11 // ordinal placeholder and construct the appropriate arguments. 12 // 13 // For example, the following two queries are equivalent: 14 // 15 // conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2}) 16 // conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2) 17 // 18 // Named placeholders are case sensitive and must start with a letter or underscore. Subsequent characters can be 19 // letters, numbers, or underscores. 20 type NamedArgs map[string]any 21 22 // RewriteQuery implements the QueryRewriter interface. 23 func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) { 24 l := &sqlLexer{ 25 src: sql, 26 stateFn: rawState, 27 nameToOrdinal: make(map[namedArg]int, len(na)), 28 } 29 30 for l.stateFn != nil { 31 l.stateFn = l.stateFn(l) 32 } 33 34 sb := strings.Builder{} 35 for _, p := range l.parts { 36 switch p := p.(type) { 37 case string: 38 sb.WriteString(p) 39 case namedArg: 40 sb.WriteRune('$') 41 sb.WriteString(strconv.Itoa(l.nameToOrdinal[p])) 42 } 43 } 44 45 newArgs = make([]any, len(l.nameToOrdinal)) 46 for name, ordinal := range l.nameToOrdinal { 47 newArgs[ordinal-1] = na[string(name)] 48 } 49 50 return sb.String(), newArgs, nil 51 } 52 53 type namedArg string 54 55 type sqlLexer struct { 56 src string 57 start int 58 pos int 59 nested int // multiline comment nesting level. 60 stateFn stateFn 61 parts []any 62 63 nameToOrdinal map[namedArg]int 64 } 65 66 type stateFn func(*sqlLexer) stateFn 67 68 func rawState(l *sqlLexer) stateFn { 69 for { 70 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 71 l.pos += width 72 73 switch r { 74 case 'e', 'E': 75 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 76 if nextRune == '\'' { 77 l.pos += width 78 return escapeStringState 79 } 80 case '\'': 81 return singleQuoteState 82 case '"': 83 return doubleQuoteState 84 case '@': 85 nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:]) 86 if isLetter(nextRune) || nextRune == '_' { 87 if l.pos-l.start > 0 { 88 l.parts = append(l.parts, l.src[l.start:l.pos-width]) 89 } 90 l.start = l.pos 91 return namedArgState 92 } 93 case '-': 94 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 95 if nextRune == '-' { 96 l.pos += width 97 return oneLineCommentState 98 } 99 case '/': 100 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 101 if nextRune == '*' { 102 l.pos += width 103 return multilineCommentState 104 } 105 case utf8.RuneError: 106 if l.pos-l.start > 0 { 107 l.parts = append(l.parts, l.src[l.start:l.pos]) 108 l.start = l.pos 109 } 110 return nil 111 } 112 } 113 } 114 115 func isLetter(r rune) bool { 116 return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') 117 } 118 119 func namedArgState(l *sqlLexer) stateFn { 120 for { 121 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 122 l.pos += width 123 124 if r == utf8.RuneError { 125 if l.pos-l.start > 0 { 126 na := namedArg(l.src[l.start:l.pos]) 127 if _, found := l.nameToOrdinal[na]; !found { 128 l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1 129 } 130 l.parts = append(l.parts, na) 131 l.start = l.pos 132 } 133 return nil 134 } else if !(isLetter(r) || (r >= '0' && r <= '9') || r == '_') { 135 l.pos -= width 136 na := namedArg(l.src[l.start:l.pos]) 137 if _, found := l.nameToOrdinal[na]; !found { 138 l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1 139 } 140 l.parts = append(l.parts, namedArg(na)) 141 l.start = l.pos 142 return rawState 143 } 144 } 145 } 146 147 func singleQuoteState(l *sqlLexer) stateFn { 148 for { 149 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 150 l.pos += width 151 152 switch r { 153 case '\'': 154 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 155 if nextRune != '\'' { 156 return rawState 157 } 158 l.pos += width 159 case utf8.RuneError: 160 if l.pos-l.start > 0 { 161 l.parts = append(l.parts, l.src[l.start:l.pos]) 162 l.start = l.pos 163 } 164 return nil 165 } 166 } 167 } 168 169 func doubleQuoteState(l *sqlLexer) stateFn { 170 for { 171 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 172 l.pos += width 173 174 switch r { 175 case '"': 176 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 177 if nextRune != '"' { 178 return rawState 179 } 180 l.pos += width 181 case utf8.RuneError: 182 if l.pos-l.start > 0 { 183 l.parts = append(l.parts, l.src[l.start:l.pos]) 184 l.start = l.pos 185 } 186 return nil 187 } 188 } 189 } 190 191 func escapeStringState(l *sqlLexer) stateFn { 192 for { 193 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 194 l.pos += width 195 196 switch r { 197 case '\\': 198 _, width = utf8.DecodeRuneInString(l.src[l.pos:]) 199 l.pos += width 200 case '\'': 201 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 202 if nextRune != '\'' { 203 return rawState 204 } 205 l.pos += width 206 case utf8.RuneError: 207 if l.pos-l.start > 0 { 208 l.parts = append(l.parts, l.src[l.start:l.pos]) 209 l.start = l.pos 210 } 211 return nil 212 } 213 } 214 } 215 216 func oneLineCommentState(l *sqlLexer) stateFn { 217 for { 218 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 219 l.pos += width 220 221 switch r { 222 case '\\': 223 _, width = utf8.DecodeRuneInString(l.src[l.pos:]) 224 l.pos += width 225 case '\n', '\r': 226 return rawState 227 case utf8.RuneError: 228 if l.pos-l.start > 0 { 229 l.parts = append(l.parts, l.src[l.start:l.pos]) 230 l.start = l.pos 231 } 232 return nil 233 } 234 } 235 } 236 237 func multilineCommentState(l *sqlLexer) stateFn { 238 for { 239 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 240 l.pos += width 241 242 switch r { 243 case '/': 244 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 245 if nextRune == '*' { 246 l.pos += width 247 l.nested++ 248 } 249 case '*': 250 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 251 if nextRune != '/' { 252 continue 253 } 254 255 l.pos += width 256 if l.nested == 0 { 257 return rawState 258 } 259 l.nested-- 260 261 case utf8.RuneError: 262 if l.pos-l.start > 0 { 263 l.parts = append(l.parts, l.src[l.start:l.pos]) 264 l.start = l.pos 265 } 266 return nil 267 } 268 } 269 }