github.com/octohelm/storage@v0.0.0-20240516030302-1ac2cc1ea347/pkg/sqlbuilder/expr.go (about) 1 package sqlbuilder 2 3 import ( 4 "bytes" 5 "context" 6 "database/sql" 7 "database/sql/driver" 8 "fmt" 9 "reflect" 10 "strings" 11 "text/scanner" 12 13 reflectx "github.com/octohelm/x/reflect" 14 ) 15 16 func IsNilExpr(e SqlExpr) bool { 17 return e == nil || e.IsNil() 18 } 19 20 func RangeNotNilExpr(exprs []SqlExpr, each func(e SqlExpr, i int)) { 21 count := 0 22 23 for i := range exprs { 24 e := exprs[i] 25 if IsNilExpr(e) { 26 continue 27 } 28 each(e, count) 29 count++ 30 } 31 } 32 33 func ExactlyExpr(query string, args ...any) *Ex { 34 if query != "" { 35 return &Ex{b: *bytes.NewBufferString(query), args: args, exactly: true} 36 } 37 return &Ex{args: args, exactly: true} 38 } 39 40 func Expr(query string, args ...any) *Ex { 41 if query != "" { 42 return &Ex{b: *bytes.NewBufferString(query), args: args} 43 } 44 return &Ex{args: args} 45 } 46 47 func ResolveExpr(v any) *Ex { 48 return ResolveExprContext(context.Background(), v) 49 } 50 51 func ResolveExprContext(ctx context.Context, v any) *Ex { 52 switch e := v.(type) { 53 case nil: 54 return nil 55 case SqlExpr: 56 if IsNilExpr(e) { 57 return nil 58 } 59 return e.Ex(ctx) 60 } 61 return nil 62 } 63 64 func Multi(exprs ...SqlExpr) SqlExpr { 65 return MultiWith(" ", exprs...) 66 } 67 68 func MultiWith(connector string, exprs ...SqlExpr) SqlExpr { 69 return ExprBy(func(ctx context.Context) *Ex { 70 e := Expr("") 71 e.Grow(len(exprs)) 72 73 for i := range exprs { 74 if i != 0 { 75 e.WriteQuery(connector) 76 } 77 e.WriteExpr(exprs[i]) 78 } 79 return e.Ex(ctx) 80 }) 81 } 82 83 func ExprBy(build func(ctx context.Context) *Ex) SqlExpr { 84 return &exBy{build: build} 85 } 86 87 type exBy struct { 88 build func(ctx context.Context) *Ex 89 } 90 91 func (c *exBy) IsNil() bool { 92 return c == nil || c.build == nil 93 } 94 95 func (c *exBy) Ex(ctx context.Context) *Ex { 96 return c.build(ctx) 97 } 98 99 type SqlExpr interface { 100 IsNil() bool 101 Ex(ctx context.Context) *Ex 102 } 103 104 // ValuerExpr 105 // replace ? as some query snippet 106 // 107 // examples: 108 // ? => ST_GeomFromText(?) 109 type ValuerExpr interface { 110 ValueEx() string 111 } 112 113 type Ex struct { 114 b bytes.Buffer 115 args []any 116 err error 117 exactly bool 118 } 119 120 func (e *Ex) IsNil() bool { 121 return e == nil || e.b.Len() == 0 122 } 123 124 func (e *Ex) Query() string { 125 if e == nil { 126 return "" 127 } 128 return e.b.String() 129 } 130 131 func (e *Ex) Args() []any { 132 if len(e.args) == 0 { 133 return nil 134 } 135 return e.args 136 } 137 138 func (e *Ex) Err() error { 139 return e.err 140 } 141 142 func (e *Ex) AppendArgs(args ...any) { 143 e.args = append(e.args, args...) 144 } 145 146 func (e *Ex) ArgsLen() int { 147 return len(e.args) 148 } 149 150 func (e *Ex) WriteString(s string) (int, error) { 151 return e.b.WriteString(s) 152 } 153 154 func (e *Ex) WriteByte(b byte) error { 155 return e.b.WriteByte(b) 156 } 157 158 func (e *Ex) QueryGrow(n int) { 159 e.b.Grow(n) 160 } 161 162 func (e *Ex) Grow(n int) { 163 if n > 0 && cap(e.args)-len(e.args) < n { 164 args := make([]any, len(e.args), 2*cap(e.args)+n) 165 copy(args, e.args) 166 e.args = args 167 } 168 } 169 170 func (e *Ex) WriteQuery(s string) { 171 _, _ = e.b.WriteString(s) 172 } 173 174 func (e *Ex) WriteQueryByte(b byte) { 175 _ = e.b.WriteByte(b) 176 } 177 178 func (e *Ex) WriteGroup(fn func(e *Ex)) { 179 e.WriteQueryByte('(') 180 fn(e) 181 e.WriteQueryByte(')') 182 } 183 184 func (e *Ex) WhiteComments(comments []byte) { 185 _, _ = e.b.WriteString("/* ") 186 _, _ = e.b.Write(comments) 187 _, _ = e.b.WriteString(" */") 188 } 189 190 func (e *Ex) WriteExpr(expr SqlExpr) { 191 if IsNilExpr(expr) { 192 return 193 } 194 195 e.WriteHolder(0) 196 e.AppendArgs(expr) 197 } 198 199 func (e *Ex) WriteEnd() { 200 e.WriteQueryByte(';') 201 } 202 203 func (e *Ex) WriteHolder(idx int) { 204 if idx > 0 { 205 e.b.WriteByte(',') 206 } 207 e.b.WriteByte('?') 208 } 209 210 func (e *Ex) SetExactly(exactly bool) { 211 e.exactly = exactly 212 } 213 214 type NamedArg = sql.NamedArg 215 216 type NamedArgSet map[string]any 217 218 func (e *Ex) Ex(ctx context.Context) *Ex { 219 if e.IsNil() { 220 return nil 221 } 222 223 allArgs, n := e.args, len(e.args) 224 225 eb := Expr("") 226 eb.Grow(n) 227 228 query := e.Query() 229 230 if e.exactly { 231 eb.WriteQuery(query) 232 eb.AppendArgs(allArgs...) 233 eb.exactly = true 234 return eb 235 } 236 237 namedArgSet, args, shouldResolveArgs := preprocessArgs(allArgs) 238 239 if !shouldResolveArgs { 240 eb.WriteQuery(query) 241 eb.AppendArgs(args...) 242 eb.SetExactly(true) 243 return eb 244 } 245 246 argIndex := 0 247 248 s := &scanner.Scanner{} 249 s.Init(bytes.NewBuffer([]byte(query))) 250 s.Error = func(s *scanner.Scanner, msg string) {} 251 252 for c := s.Next(); c != scanner.EOF; c = s.Next() { 253 switch c { 254 case '@': 255 named := bytes.NewBuffer(nil) 256 257 for { 258 c = s.Next() 259 260 if c == scanner.EOF { 261 break 262 } 263 264 if (c >= 'A' && c <= 'Z') || 265 (c >= 'a' && c <= 'z') || 266 (c >= '0' && c <= '9') || 267 c == '_' { 268 269 named.WriteRune(c) 270 continue 271 } 272 break 273 } 274 275 if named.Len() > 0 { 276 name := named.String() 277 278 if v, ok := namedArgSet[name]; ok { 279 switch arg := v.(type) { 280 case SqlExpr: 281 if !IsNilExpr(arg) { 282 subExpr := arg.Ex(ctx) 283 284 if subExpr != eb && !IsNilExpr(subExpr) { 285 eb.WriteQuery(subExpr.Query()) 286 eb.AppendArgs(subExpr.Args()...) 287 } 288 } 289 default: 290 eb.WriteHolder(0) 291 eb.AppendArgs(arg) 292 } 293 } else { 294 panic(fmt.Sprintf("missing named arg `%s`", name)) 295 } 296 } 297 298 if c != scanner.EOF { 299 eb.WriteQueryByte(byte(c)) 300 } 301 case '?': 302 if argIndex >= n { 303 panic(fmt.Errorf("missing arg %d of %s", argIndex, query)) 304 } 305 306 switch arg := args[argIndex].(type) { 307 case SqlExpr: 308 if !IsNilExpr(arg) { 309 subExpr := arg.Ex(ctx) 310 311 if subExpr != eb && !IsNilExpr(subExpr) { 312 eb.WriteQuery(subExpr.Query()) 313 eb.AppendArgs(subExpr.Args()...) 314 } 315 } 316 default: 317 eb.WriteHolder(0) 318 eb.AppendArgs(arg) 319 } 320 argIndex++ 321 default: 322 eb.WriteQueryByte(byte(c)) 323 } 324 } 325 326 eb.SetExactly(true) 327 328 return eb 329 } 330 331 func exactlyExprFromSlice(values []any) *Ex { 332 if n := len(values); n > 0 { 333 return ExactlyExpr(strings.Repeat(",?", n)[1:], values...) 334 } 335 return ExactlyExpr("") 336 } 337 338 func preprocessArgs(args []any) (NamedArgSet, []any, bool) { 339 namedArgSet := NamedArgSet{} 340 finalArgs := make([]any, 0, len(args)) 341 342 shouldResolve := false 343 344 for i := range args { 345 switch arg := args[i].(type) { 346 case NamedArgSet: 347 for k := range arg { 348 namedArgSet[k] = arg[k] 349 } 350 shouldResolve = true 351 case NamedArg: 352 namedArgSet[arg.Name] = arg.Value 353 shouldResolve = true 354 case ValuerExpr: 355 finalArgs = append(finalArgs, ExactlyExpr(arg.ValueEx(), arg)) 356 shouldResolve = true 357 case SqlExpr: 358 finalArgs = append(finalArgs, arg) 359 shouldResolve = true 360 case driver.Valuer: 361 finalArgs = append(finalArgs, arg) 362 case []any: 363 finalArgs = append(finalArgs, exactlyExprFromSlice(arg)) 364 shouldResolve = true 365 default: 366 if typ := reflect.TypeOf(arg); typ.Kind() == reflect.Slice { 367 if !reflectx.IsBytes(typ) { 368 finalArgs = append(finalArgs, exactlyExprFromSlice(toInterfaceSlice(arg))) 369 shouldResolve = true 370 continue 371 } 372 } 373 finalArgs = append(finalArgs, arg) 374 } 375 } 376 377 return namedArgSet, finalArgs, shouldResolve 378 } 379 380 func toInterfaceSlice(arg any) []any { 381 switch x := (arg).(type) { 382 case []bool: 383 values := make([]any, len(x)) 384 for i := range values { 385 values[i] = x[i] 386 } 387 return values 388 case []string: 389 values := make([]any, len(x)) 390 for i := range values { 391 values[i] = x[i] 392 } 393 return values 394 case []float32: 395 values := make([]any, len(x)) 396 for i := range values { 397 values[i] = x[i] 398 } 399 return values 400 case []float64: 401 values := make([]any, len(x)) 402 for i := range values { 403 values[i] = x[i] 404 } 405 return values 406 case []int: 407 values := make([]any, len(x)) 408 for i := range values { 409 values[i] = x[i] 410 } 411 return values 412 case []int8: 413 values := make([]any, len(x)) 414 for i := range values { 415 values[i] = x[i] 416 } 417 return values 418 case []int16: 419 values := make([]any, len(x)) 420 for i := range values { 421 values[i] = x[i] 422 } 423 return values 424 case []int32: 425 values := make([]any, len(x)) 426 for i := range values { 427 values[i] = x[i] 428 } 429 return values 430 case []int64: 431 values := make([]any, len(x)) 432 for i := range values { 433 values[i] = x[i] 434 } 435 return values 436 case []uint: 437 values := make([]any, len(x)) 438 for i := range values { 439 values[i] = x[i] 440 } 441 return values 442 case []uint8: 443 values := make([]any, len(x)) 444 for i := range values { 445 values[i] = x[i] 446 } 447 return values 448 case []uint16: 449 values := make([]any, len(x)) 450 for i := range values { 451 values[i] = x[i] 452 } 453 return values 454 case []uint32: 455 values := make([]any, len(x)) 456 for i := range values { 457 values[i] = x[i] 458 } 459 return values 460 case []uint64: 461 values := make([]any, len(x)) 462 for i := range values { 463 values[i] = x[i] 464 } 465 return values 466 case []any: 467 return x 468 } 469 sliceRv := reflect.ValueOf(arg) 470 values := make([]any, sliceRv.Len()) 471 for i := range values { 472 values[i] = sliceRv.Index(i).Interface() 473 } 474 return values 475 }