github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/sqlparse/sqlparser/dbtype.go (about) 1 package sqlparser 2 3 import ( 4 "bytes" 5 "database/sql" 6 "errors" 7 "fmt" 8 "log" 9 "math/bits" 10 "reflect" 11 "regexp" 12 "strconv" 13 "strings" 14 15 "github.com/bingoohuang/gg/pkg/reflector" 16 "github.com/bingoohuang/gg/pkg/ss" 17 ) 18 19 type DBType string 20 21 const ( 22 Mysql DBType = "mysql" 23 Sqlite3 DBType = "sqlite3" 24 Dm DBType = "dm" // dm数据库 25 Gbase DBType = "gbase" // 南大通用 26 Clickhouse DBType = "clickhouse" 27 Postgresql DBType = "postgresql" 28 Kingbase DBType = "kingbase" // 金仓 29 Shentong DBType = "shentong" // 神通 30 Mssql DBType = "mssql" // sqlserver 2012+ 31 Oracle DBType = "oracle" // oracle 12c+ 32 ) 33 34 // ToDBType converts driverName to different DBType. 35 func ToDBType(driverName string) DBType { 36 switch strings.ToLower(driverName) { 37 case "pgx", "opengauss": 38 return Postgresql 39 default: 40 return DBType(driverName) 41 } 42 } 43 44 var ErrUnsupportedDBType = errors.New("unsupported database type") 45 46 // Paging is pagination object. 47 type Paging struct { 48 PageSeq int // Current page number, starting from 1 49 PageSize int // How many items per page, 20 items by default 50 RowsCount int // Total number of data 51 PageCount int // How many pages 52 FirstPage bool // Is it the first page 53 HasPrev bool // Whether there is a previous page 54 HasNext bool // Is there a next page 55 LastPage bool // Is it the last page 56 } 57 58 // NewPaging creates a Paging object. 59 func NewPaging() *Paging { return &Paging{PageSeq: 1, PageSize: 20} } 60 61 // SetRowsCount Set the total number of rows, calculate other values. 62 func (p *Paging) SetRowsCount(total int) { 63 p.RowsCount = total 64 p.PageCount = (p.RowsCount + p.PageSize - 1) / p.PageSize 65 if p.PageSeq >= p.PageCount { 66 p.LastPage = true 67 } else { 68 p.HasNext = true 69 } 70 if p.PageSeq > 1 { 71 p.HasPrev = true 72 } else { 73 p.FirstPage = true 74 } 75 } 76 77 // CompatibleLimit represents a LIMIT clause. 78 type CompatibleLimit struct { 79 *Limit 80 SwapArgs func(args []interface{}) 81 DBType 82 } 83 84 // Format formats the node. 85 func (n *CompatibleLimit) Format(buf *TrackedBuffer) { 86 if n == nil { 87 return 88 } 89 switch n.DBType { 90 case Mysql, Sqlite3, Dm, Gbase, Clickhouse: 91 buf.Myprintf(" limit ") 92 if n.Offset != nil { 93 buf.Myprintf("%v, ", n.Offset) 94 } 95 buf.Myprintf("%v", n.Rowcount) 96 if n.Offset != nil && n.Rowcount != nil { 97 offsetVar, ok1 := n.Offset.(*SQLVal) 98 rowcount, ok2 := n.Rowcount.(*SQLVal) 99 if ok1 && ok2 && offsetVar.Seq > rowcount.Seq { 100 i := offsetVar.Seq - 1 101 j := rowcount.Seq - 1 102 n.SwapArgs = func(args []interface{}) { 103 args[i], args[j] = args[j], args[i] 104 } 105 } 106 } 107 case Postgresql, Kingbase, Shentong: 108 // https://www.postgresql.org/docs/9.3/queries-limit.html 109 // SELECT select_list 110 // FROM table_expression 111 // [ ORDER BY ... ] 112 // [ LIMIT { number | ALL } ] [ OFFSET number ] 113 buf.Myprintf(" limit %v", n.Rowcount) 114 if n.Offset != nil { 115 buf.Myprintf("offset %v", n.Offset) 116 } 117 case Mssql, Oracle: 118 if n.Offset != nil { 119 buf.Myprintf(" offset %v rows", n.Offset) 120 } 121 buf.Myprintf(" fetch next %v rows only", n.Rowcount) 122 default: 123 panic(ErrUnsupportedDBType) 124 } 125 } 126 127 func (t DBType) createPagingClause(plFormatter PlaceholderFormatter, p *Paging, placeholder bool) (page string, bindArgs []interface{}) { 128 var s strings.Builder 129 start := p.PageSize * (p.PageSeq - 1) 130 plf := plFormatter.FormatPlaceholder 131 switch t { 132 case Mysql, Sqlite3, Dm, Gbase, Clickhouse: 133 if placeholder { 134 s.WriteString(fmt.Sprintf("limit %s,%s", plf(), plf())) 135 bindArgs = []interface{}{start, p.PageSize} 136 } else { 137 s.WriteString(fmt.Sprintf("limit %d,%d", start, p.PageSize)) 138 } 139 case Postgresql, Kingbase, Shentong: 140 if placeholder { 141 s.WriteString(fmt.Sprintf("limit %s offset %s", plf(), plf())) 142 bindArgs = []interface{}{p.PageSize, start} 143 } else { 144 s.WriteString(fmt.Sprintf("limit %d offset %d", p.PageSize, start)) 145 } 146 case Mssql, Oracle: 147 if placeholder { 148 s.WriteString(fmt.Sprintf("offset %s rows fetch next %s rows only", plf(), plf())) 149 bindArgs = []interface{}{start, p.PageSize} 150 } else { 151 s.WriteString(fmt.Sprintf("offset %d rows fetch next %d rows only", start, p.PageSize)) 152 } 153 default: 154 panic(ErrUnsupportedDBType) 155 } 156 157 page = s.String() 158 return 159 } 160 161 type IdQuoter interface { 162 Quote(string) string 163 } 164 165 type MySQLIdQuoter struct{} 166 167 func (MySQLIdQuoter) Quote(s string) string { 168 b := new(bytes.Buffer) 169 b.WriteByte('`') 170 for _, c := range s { 171 b.WriteRune(c) 172 if c == '`' { 173 b.WriteRune('`') 174 } 175 } 176 b.WriteByte('`') 177 return b.String() 178 } 179 180 type DoubleQuoteIdQuoter struct{} 181 182 func (DoubleQuoteIdQuoter) Quote(s string) string { 183 return strconv.Quote(s) 184 } 185 186 type PlaceholderFormatter interface { 187 FormatPlaceholder() string 188 ResetPlaceholder() 189 } 190 191 type QuestionPlaceholderFormatter struct{} 192 193 func (QuestionPlaceholderFormatter) FormatPlaceholder() string { return "?" } 194 func (QuestionPlaceholderFormatter) ResetPlaceholder() {} 195 196 type PrefixPlaceholderFormatter struct { 197 Prefix string 198 Pos int // 1-based 199 } 200 201 func (p *PrefixPlaceholderFormatter) ResetPlaceholder() { p.Pos = 0 } 202 func (p *PrefixPlaceholderFormatter) FormatPlaceholder() string { 203 p.Pos++ 204 return fmt.Sprintf("%s%d", p.Prefix, p.Pos) 205 } 206 207 type ConvertConfig struct { 208 Paging *Paging 209 AutoIncrementField string 210 } 211 212 type ConvertOption func(*ConvertConfig) 213 214 func WithLimit(v int) ConvertOption { 215 return func(c *ConvertConfig) { c.Paging = &Paging{PageSeq: 1, PageSize: v} } 216 } 217 func WithPaging(v *Paging) ConvertOption { return func(c *ConvertConfig) { c.Paging = v } } 218 func WithAutoIncrement(v string) ConvertOption { 219 return func(c *ConvertConfig) { c.AutoIncrementField = v } 220 } 221 222 type ConvertResult struct { 223 ExtraArgs []interface{} 224 CountingQuery string 225 ScanValues []interface{} 226 VarPoses []int // []var pos) 227 BindMode BindMode 228 VarNames []string 229 230 InPlaceholder *InPlaceholder 231 Placeholders int 232 233 ConvertQuery func() string 234 SwapArgs func(args []interface{}) 235 } 236 237 type BindMode uint 238 239 const ( 240 ByPlaceholder BindMode = 1 << iota 241 BySeq 242 ByName 243 ) 244 245 func (r *ConvertResult) PickArgs(args []interface{}) (q string, bindArgs []interface{}) { 246 switch r.BindMode { 247 case ByName: 248 arg := args[0] 249 if IsStructOrPtrToStruct(arg) { 250 obj := reflector.New(arg) 251 for _, name := range r.VarNames { 252 name2 := strings.ToLower(ss.Strip(name, func(r rune) bool { return r == '-' || r == '_' })) 253 v, err := obj.Field(name2).Get() 254 if err != nil { 255 v, err = obj.FieldByTag("db", name).Get() 256 } 257 258 if err != nil { 259 panic(err) 260 } 261 262 bindArgs = append(bindArgs, v) 263 } 264 265 } else if IsMap(arg) { 266 f := func(s string) string { 267 return strings.ToLower(ss.Strip(s, func(r rune) bool { return r == '-' || r == '_' })) 268 } 269 vmap := reflect.ValueOf(arg) 270 for _, name := range r.VarNames { 271 if v := vmap.MapIndex(reflect.ValueOf(name)); v.IsValid() { 272 bindArgs = append(bindArgs, v.Interface()) 273 } else { 274 bindArg, _ := findInMap(vmap, name, f) 275 bindArgs = append(bindArgs, bindArg) 276 } 277 } 278 } else { 279 bindArgs = args 280 } 281 282 case BySeq: 283 for _, p := range r.VarPoses { 284 bindArgs = append(bindArgs, args[p-1]) 285 } 286 default: 287 if r.IsInPlaceholders() { 288 if len(args) == 1 && IsSlice(args[0]) { 289 r.ResetInVars(SliceLen(args[0])) 290 bindArgs = CreateSlice(args[0]) 291 } else { 292 r.ResetInVars(len(args)) 293 bindArgs = args 294 } 295 } else { 296 bindArgs = append(bindArgs, args...) 297 } 298 } 299 300 resultArgs := append(bindArgs, r.ExtraArgs...) 301 query := r.ConvertQuery() 302 if r.SwapArgs != nil { 303 r.SwapArgs(resultArgs) 304 } 305 return query, resultArgs 306 } 307 308 func CreateSlice(i interface{}) []interface{} { 309 ti := reflect.ValueOf(i) 310 elements := make([]interface{}, ti.Len()) 311 312 for i := 0; i < ti.Len(); i++ { 313 elements[i] = ti.Index(i).Interface() 314 } 315 316 return elements 317 } 318 319 func SliceLen(i interface{}) int { 320 ti := reflect.ValueOf(i) 321 return ti.Len() 322 } 323 324 func IsSlice(i interface{}) bool { 325 return reflect.TypeOf(i).Kind() == reflect.Slice 326 } 327 328 func (r ConvertResult) IsInPlaceholders() bool { 329 return r.InPlaceholder != nil && r.Placeholders == r.InPlaceholder.Num 330 } 331 332 func (r ConvertResult) ResetInVars(varsNum int) { 333 if varsNum == r.InPlaceholder.Num { 334 return 335 } 336 337 var exprs ValTuple 338 339 for i := 0; i < varsNum; i++ { 340 exprs = append(exprs, &SQLVal{Type: ValArg, Val: []byte("?")}) 341 } 342 343 if varsNum == 0 { 344 exprs = append(exprs, &NullVal{}) 345 } 346 347 r.InPlaceholder.Expr.Right = exprs 348 } 349 350 func findInMap(vmap reflect.Value, name string, f func(s string) string) (interface{}, bool) { 351 name = f(name) 352 353 for iter := vmap.MapRange(); iter.Next(); { 354 k := iter.Key().Interface() 355 if kk, ok := k.(string); ok { 356 if f(kk) == name { 357 return iter.Value().Interface(), true 358 } 359 } 360 } 361 362 return nil, false 363 } 364 365 func IsMap(arg interface{}) bool { 366 t := reflect.TypeOf(arg) 367 if t.Kind() == reflect.Ptr { 368 t = t.Elem() 369 } 370 return t.Kind() == reflect.Map 371 } 372 373 func IsStructOrPtrToStruct(arg interface{}) bool { 374 t := reflect.TypeOf(arg) 375 if t.Kind() == reflect.Ptr { 376 t = t.Elem() 377 } 378 return t.Kind() == reflect.Struct 379 } 380 381 var ErrSyntax = errors.New("syntax not supported") 382 383 const CreateCountingQuery = -1 384 385 var numReg = regexp.MustCompile(`^[1-9]\d*$`) 386 387 // Convert converts query to target db type. 388 // 1. adjust the SQL variable symbols by different type, such as ?,? $1,$2. 389 // 1. quote table name, field names. 390 func (t DBType) Convert(query string, options ...ConvertOption) (*ConvertResult, error) { 391 stmt, err := Parse(query) 392 if err != nil { 393 return nil, err 394 } 395 396 insertStmt, _ := stmt.(*Insert) 397 if err := t.checkMySQLOnDuplicateKey(insertStmt); err != nil { 398 return nil, fmt.Errorf("on duplicate key is not supported directly in SQL, error %w", ErrSyntax) 399 } 400 401 fixInsertPlaceholders(insertStmt) 402 cr := &ConvertResult{} 403 404 insertPos := -1 405 lastColName := "" 406 407 _ = stmt.WalkSubtree(func(node SQLNode) (kontinue bool, err error) { 408 if cr.InPlaceholder == nil { 409 cr.InPlaceholder = ParseInPlaceholder(node) 410 } 411 412 if cn, cnOk := node.(*ColName); cnOk { 413 lastColName = cn.Name.Lowered() 414 return true, nil 415 } 416 if _, ok := node.(*Limit); ok { 417 return true, err 418 } 419 420 if v, ok := node.(*SQLVal); ok { 421 switch v.Type { 422 case ValArg, StrVal: // 转换 :a :b :c 或者 :1 :2 :3的占位符形式 423 if string(v.Val) == "?" { 424 cr.Placeholders++ 425 } else { 426 convertCustomBinding(insertStmt, &insertPos, &lastColName, v, cr) 427 } 428 } 429 } 430 431 return true, nil 432 }) 433 434 if len(cr.VarPoses) > 0 { 435 cr.BindMode |= BySeq 436 } 437 if len(cr.VarNames) > 0 { 438 cr.BindMode |= ByName 439 } 440 if cr.Placeholders > 0 { 441 cr.BindMode |= ByPlaceholder 442 } 443 if bits.OnesCount(uint(cr.BindMode)) > 1 { 444 return nil, fmt.Errorf("mixed bind modes are not supported, error %w", ErrSyntax) 445 } 446 447 buf := &TrackedBuffer{Buffer: new(bytes.Buffer)} 448 449 switch t { 450 case Postgresql, Kingbase: 451 buf.PlaceholderFormatter = &PrefixPlaceholderFormatter{Prefix: "$"} 452 case Mssql: 453 buf.PlaceholderFormatter = &PrefixPlaceholderFormatter{Prefix: "@p"} 454 case Oracle, Shentong: 455 buf.PlaceholderFormatter = &PrefixPlaceholderFormatter{Prefix: ":"} 456 default: 457 buf.PlaceholderFormatter = &QuestionPlaceholderFormatter{} 458 } 459 460 switch t { 461 case Mysql, Sqlite3, Gbase, Clickhouse: 462 // https://www.sqlite.org/lang_keywords.html 463 buf.IdQuoter = &MySQLIdQuoter{} 464 default: 465 buf.IdQuoter = &DoubleQuoteIdQuoter{} 466 } 467 468 config := &ConvertConfig{} 469 for _, f := range options { 470 f(config) 471 } 472 473 selectStmt, _ := stmt.(*Select) 474 var limit *Limit 475 if selectStmt != nil { 476 limit = selectStmt.Limit 477 selectStmt.SetLimit(nil) 478 } 479 480 var compatibleLimit *CompatibleLimit 481 p := config.Paging 482 isPaging := selectStmt != nil && p != nil 483 if !isPaging && limit != nil { 484 compatibleLimit = &CompatibleLimit{Limit: limit, DBType: t} 485 selectStmt.SetLimitSQLNode(compatibleLimit) 486 } 487 488 cr.ConvertQuery = func() string { 489 buf.Myprintf("%v", stmt) 490 q := buf.String() 491 buf.Reset() 492 493 if compatibleLimit != nil { 494 cr.SwapArgs = compatibleLimit.SwapArgs 495 } 496 497 if isPaging { 498 pagingClause, bindArgs := t.createPagingClause(buf.PlaceholderFormatter, p, cr.BindMode > 0) 499 cr.ExtraArgs = append(cr.ExtraArgs, bindArgs...) 500 if p.RowsCount == CreateCountingQuery { 501 cr.CountingQuery = t.createCountingQuery(stmt, buf, q) 502 } 503 q += " " + pagingClause 504 } 505 506 if f := config.AutoIncrementField; f != "" { 507 q += " " + t.createAutoIncrementPK(cr, f) 508 } 509 510 return q 511 } 512 513 return cr, nil 514 } 515 516 type InPlaceholder struct { 517 Expr *ComparisonExpr 518 Num int 519 } 520 521 func ParseInPlaceholder(node SQLNode) *InPlaceholder { 522 v, ok := node.(*ComparisonExpr) 523 if !ok { 524 return nil 525 } 526 527 if v.Operator != "in" { 528 return nil 529 } 530 531 t, tOk := v.Right.(ValTuple) 532 if !tOk { 533 return nil 534 } 535 536 for _, tv := range t { 537 if tw, twOK := tv.(*SQLVal); twOK { 538 if !(tw.Type == ValArg && bytes.Equal(tw.Val, []byte("?"))) { 539 return nil 540 } 541 } else { 542 return nil 543 } 544 } 545 546 return &InPlaceholder{Expr: v, Num: len(t)} 547 } 548 549 func convertCustomBinding(insert *Insert, insertPos *int, lastColName *string, v *SQLVal, cr *ConvertResult) { 550 if len(v.Val) == 0 || !bytes.HasPrefix(v.Val, []byte(":")) { 551 return 552 } 553 name := strings.TrimSpace(string(v.Val[1:])) 554 if name == "" { 555 return 556 } 557 558 if numReg.MatchString(name) { 559 num, _ := strconv.Atoi(name) 560 cr.VarPoses = append(cr.VarPoses, num) 561 } else { 562 if name == "?" { // 从上下文推断变量名称 563 if insert != nil { 564 *insertPos++ 565 col := insert.Columns[*insertPos] 566 name = col.Lowered() 567 } else if *lastColName != "" { 568 name = *lastColName 569 *lastColName = "" 570 } 571 } 572 cr.VarNames = append(cr.VarNames, name) 573 } 574 575 v.Type = ValArg 576 v.Val = []byte("?") 577 } 578 579 func (t DBType) checkMySQLOnDuplicateKey(insertStmt *Insert) error { 580 // 只有MySQL 的 ON DUPLICATE KEY被支持 581 // eg. INSERT INTO table (a,b,c) VALUES (1,2,3),(4,5,6) ON DUPLICATE KEY UPDATE c=VALUES(a)+VALUES(b); 582 if insertStmt != nil && len(insertStmt.OnDup) > 0 { 583 switch t { 584 case Mysql: 585 default: 586 return ErrSyntax 587 } 588 } 589 return nil 590 } 591 592 func fixInsertPlaceholders(insertStmt *Insert) { 593 if insertStmt == nil { 594 return 595 } 596 597 // 是insert into t(a,b,c) values(...)的格式 598 insertRows, ok := insertStmt.Rows.(Values) 599 if !ok { 600 return 601 } 602 603 // 只有一个values列表 604 if len(insertRows) != 1 { 605 return 606 } 607 608 questionVals := 0 609 inferVals := 0 610 others := 0 611 insertRow := insertRows[0] 612 for _, node := range insertRow { 613 if v, ok := (node).(*SQLVal); ok && v.Type == ValArg { 614 switch vs := string(v.Val); vs { 615 case "?": 616 questionVals++ 617 continue 618 case ":?": 619 inferVals++ 620 continue 621 } 622 } 623 624 others++ 625 break 626 } 627 628 // 不全是?占位符 629 if others > 0 || questionVals > 0 && inferVals > 0 { 630 return 631 } 632 633 diff := len(insertStmt.Columns) - ss.Ifi(questionVals > 0, questionVals, inferVals) 634 if diff == 0 { 635 return 636 } 637 638 if diff > 0 { 639 pl := []byte(ss.If(questionVals > 0, "?", ":?")) 640 appendVarArgs := make([]Expr, 0, diff) 641 for i := 0; i < diff; i++ { 642 appendVarArgs = append(appendVarArgs, NewValArg(pl)) 643 } 644 insertRows[0] = append(insertRows[0], appendVarArgs...) 645 } else { 646 insertRows[0] = insertRows[0][:len(insertStmt.Columns)] 647 } 648 } 649 650 func (t DBType) createCountingQuery(stmt Statement, buf *TrackedBuffer, q string) string { 651 buf.PlaceholderFormatter.ResetPlaceholder() 652 653 countWrapRequired := func() bool { 654 if _, ok := stmt.(*Union); ok { 655 return true 656 } 657 if v, ok := stmt.(*Select); ok && v.Distinct != "" || len(v.GroupBy) > 0 { 658 v.OrderBy = nil 659 return true 660 } 661 return false 662 } 663 664 if countWrapRequired() { 665 query := "select count(*) cnt from (" + q + ") t_gg_cnt" 666 cr, err := t.Convert(query) 667 if err != nil { 668 log.Printf("failed to convert query %s, err: %v", query, err) 669 return "" 670 } 671 672 buf.Myprintf("%v", cr.ConvertQuery()) 673 return buf.String() 674 } 675 676 s, ok := stmt.(*Select) 677 if !ok { 678 return "" 679 } 680 681 s.OrderBy = nil 682 p, _ := Parse(`select count(*) cnt`) 683 s.SelectExprs = p.(*Select).SelectExprs 684 buf.Myprintf("%v", s) 685 return buf.String() 686 } 687 688 func (t DBType) createAutoIncrementPK(cr *ConvertResult, autoIncrementField string) string { 689 switch t { 690 case Postgresql, Kingbase: 691 // https://gist.github.com/miguelmota/d54814683346c4c98cec432cf99506c0 692 return "returning " + autoIncrementField 693 case Oracle, Shentong: 694 // https://forum.golangbridge.org/t/returning-values-with-insert-query-using-oracle-database-in-golang/13099/5 695 var p int64 = 0 696 cr.ScanValues = append(cr.ScanValues, sql.Named(autoIncrementField, sql.Out{Dest: &p})) 697 return "returning " + autoIncrementField + " into :" + autoIncrementField 698 default: 699 return "" 700 } 701 }