github.com/unionj-cloud/go-doudou/v2@v2.3.5/toolkit/pagination/gorm/paginate.go (about) 1 package gorm 2 3 import ( 4 "crypto/md5" 5 "encoding/json" 6 "fmt" 7 "github.com/pkg/errors" 8 "github.com/unionj-cloud/go-doudou/v2/toolkit/sliceutils" 9 "log" 10 "math" 11 "reflect" 12 "regexp" 13 "strconv" 14 "strings" 15 16 "github.com/morkid/gocache" 17 "gorm.io/gorm" 18 ) 19 20 // ResponseContext interface 21 type ResponseContext interface { 22 Cache(string) ResponseContext 23 Fields([]string) ResponseContext 24 Distinct([]string) ResponseContext 25 Response(interface{}) Page 26 Error() error 27 } 28 29 // RequestContext interface 30 type RequestContext interface { 31 Request(IParameter) ResponseContext 32 } 33 34 // Pagination gorm paginate struct 35 type Pagination struct { 36 Config *Config 37 } 38 39 // With func 40 func (p *Pagination) With(stmt *gorm.DB) RequestContext { 41 return reqContext{ 42 Statement: stmt, 43 Pagination: p, 44 } 45 } 46 47 // ClearCache clear cache contains prefix 48 func (p Pagination) ClearCache(keyPrefixes ...string) { 49 if len(keyPrefixes) > 0 && nil != p.Config && nil != p.Config.CacheAdapter { 50 adapter := *p.Config.CacheAdapter 51 for i := range keyPrefixes { 52 if err := adapter.ClearPrefix(keyPrefixes[i]); nil != err { 53 log.Println(err) 54 } 55 } 56 } 57 } 58 59 // ClearAllCache clear all existing cache 60 func (p Pagination) ClearAllCache() { 61 if nil != p.Config && nil != p.Config.CacheAdapter { 62 adapter := *p.Config.CacheAdapter 63 if err := adapter.ClearAll(); nil != err { 64 log.Println(err) 65 } 66 } 67 } 68 69 type reqContext struct { 70 Statement *gorm.DB 71 Pagination *Pagination 72 } 73 74 func (r reqContext) Request(parameter IParameter) ResponseContext { 75 var response ResponseContext = &resContext{ 76 Statement: r.Statement, 77 Parameter: parameter, 78 Pagination: r.Pagination, 79 } 80 81 return response 82 } 83 84 type resContext struct { 85 Pagination *Pagination 86 Statement *gorm.DB 87 Parameter IParameter 88 cachePrefix string 89 fieldList []string 90 customSelect string 91 distinct bool 92 error error 93 } 94 95 func (r *resContext) Error() error { 96 return r.error 97 } 98 99 func (r *resContext) Cache(prefix string) ResponseContext { 100 r.cachePrefix = prefix 101 return r 102 } 103 104 func (r *resContext) Fields(fields []string) ResponseContext { 105 r.fieldList = fields 106 return r 107 } 108 109 // CustomSelect currently used for distinct on clause 110 func (r *resContext) Distinct(fields []string) ResponseContext { 111 r.fieldList = fields 112 r.distinct = true 113 return r 114 } 115 116 func (r *resContext) Response(res interface{}) Page { 117 p := r.Pagination 118 query := r.Statement 119 p.Config = defaultConfig(p.Config) 120 p.Config.Statement = query.Statement 121 if p.Config.DefaultSize == 0 { 122 p.Config.DefaultSize = 10 123 } 124 125 if p.Config.FieldWrapper == "" && p.Config.ValueWrapper == "" { 126 defaultWrapper := "LOWER(%s)" 127 wrappers := map[string]string{ 128 "sqlite": defaultWrapper, 129 "mysql": defaultWrapper, 130 "postgres": "LOWER((%s)::text)", 131 } 132 p.Config.FieldWrapper = defaultWrapper 133 if wrapper, ok := wrappers[query.Dialector.Name()]; ok { 134 p.Config.FieldWrapper = wrapper 135 } 136 } 137 138 page := Page{} 139 pr, err := parseRequest(r.Parameter, *p.Config) 140 if err != nil { 141 r.error = err 142 return page 143 } 144 causes := createCauses(pr) 145 cKey := "" 146 var adapter gocache.AdapterInterface 147 var hasAdapter bool = false 148 149 if nil != p.Config.CacheAdapter { 150 cKey = createCacheKey(r.cachePrefix, pr) 151 adapter = *p.Config.CacheAdapter 152 hasAdapter = true 153 if cKey != "" && adapter.IsValid(cKey) { 154 if cache, err := adapter.Get(cKey); nil == err { 155 if err := p.Config.JSONUnmarshal([]byte(cache), &page); nil == err { 156 return page 157 } 158 } 159 } 160 } 161 162 dbs := query.Statement.DB.Session(&gorm.Session{NewDB: true}) 163 var selects []string 164 if len(r.fieldList) > 0 { 165 if len(pr.Fields) > 0 && p.Config.FieldSelectorEnabled { 166 for i := range pr.Fields { 167 for j := range r.fieldList { 168 if r.fieldList[j] == pr.Fields[i] { 169 fname := query.Statement.Quote("s." + fieldName(pr.Fields[i])) 170 if !contains(selects, fname) { 171 selects = append(selects, fname) 172 } 173 break 174 } 175 } 176 } 177 } else { 178 for i := range r.fieldList { 179 fname := query.Statement.Quote("s." + fieldName(r.fieldList[i])) 180 if !contains(selects, fname) { 181 selects = append(selects, fname) 182 } 183 } 184 } 185 } else if len(pr.Fields) > 0 && p.Config.FieldSelectorEnabled { 186 for i := range pr.Fields { 187 fname := query.Statement.Quote("s." + fieldName(pr.Fields[i])) 188 if !contains(selects, fname) { 189 selects = append(selects, fname) 190 } 191 } 192 } 193 194 result := dbs. 195 Unscoped(). 196 Table("(?) AS s", query) 197 198 if len(selects) > 0 { 199 if r.distinct { 200 result = result.Distinct(selects) 201 } else { 202 result = result.Select(selects) 203 } 204 } 205 206 if len(causes.Params) > 0 || len(causes.WhereString) > 0 { 207 result = result.Where(causes.WhereString, causes.Params...) 208 } 209 210 dbs = query.Statement.DB.Session(&gorm.Session{NewDB: true}) 211 result = dbs. 212 Unscoped(). 213 Table("(?) AS s1", result) 214 215 result = result.Count(&page.Total). 216 Limit(int(causes.Limit)). 217 Offset(int(causes.Offset)) 218 if result.Error != nil { 219 r.error = result.Error 220 return page 221 } 222 223 if nil != query.Statement.Preloads { 224 for table, args := range query.Statement.Preloads { 225 result = result.Preload(table, args...) 226 } 227 } 228 if len(causes.Sorts) > 0 { 229 for _, sort := range causes.Sorts { 230 result = result.Order(sort.Column + " " + sort.Direction) 231 } 232 } 233 234 rs := result.Find(res) 235 if result.Error != nil { 236 r.error = result.Error 237 return page 238 } 239 240 page.Items, _ = sliceutils.ConvertAny2Interface(res) 241 f := float64(page.Total) / float64(causes.Limit) 242 if math.Mod(f, 1.0) > 0 { 243 f = f + 1 244 } 245 page.TotalPages = int64(f) 246 page.Page = int64(pr.Page) 247 page.Size = int64(pr.Size) 248 page.MaxPage = 0 249 page.Visible = rs.RowsAffected 250 if page.TotalPages > 0 { 251 page.MaxPage = page.TotalPages - 1 252 } 253 if page.TotalPages < 1 { 254 page.TotalPages = 1 255 } 256 if page.Total < 1 { 257 page.MaxPage = 0 258 page.TotalPages = 0 259 } 260 page.First = causes.Offset < 1 261 page.Last = page.MaxPage == page.Page 262 263 if hasAdapter && cKey != "" { 264 if cache, err := p.Config.JSONMarshal(page); nil == err { 265 if err := adapter.Set(cKey, string(cache)); err != nil { 266 log.Println(err) 267 } 268 } 269 } 270 271 return page 272 } 273 274 // New Pagination instance 275 func New(params ...interface{}) *Pagination { 276 if len(params) >= 1 { 277 var config *Config 278 for _, param := range params { 279 c, isConfig := param.(*Config) 280 if isConfig { 281 config = c 282 continue 283 } 284 } 285 286 return &Pagination{Config: defaultConfig(config)} 287 } 288 289 return &Pagination{Config: defaultConfig(nil)} 290 } 291 292 // parseRequest func 293 func parseRequest(param IParameter, config Config) (pageRequest, error) { 294 pr := pageRequest{ 295 Config: *defaultConfig(&config), 296 } 297 err := parsingQueryString(param, &pr) 298 if err != nil { 299 return pageRequest{}, err 300 } 301 return pr, nil 302 } 303 304 // createFilters func 305 func createFilters(filterParams interface{}, p *pageRequest) error { 306 s, ok2 := filterParams.(string) 307 if reflect.ValueOf(filterParams).Kind() == reflect.Slice { 308 f, err := sliceutils.ConvertAny2Interface(filterParams) 309 if err != nil { 310 return errors.WithStack(err) 311 } 312 p.Filters = arrayToFilter(f, p.Config) 313 p.Filters.Fields = p.Fields 314 } else if ok2 { 315 iface := []interface{}{} 316 if e := p.Config.JSONUnmarshal([]byte(s), &iface); nil == e && len(iface) > 0 { 317 p.Filters = arrayToFilter(iface, p.Config) 318 } 319 p.Filters.Fields = p.Fields 320 } 321 return nil 322 } 323 324 // createCauses func 325 func createCauses(p pageRequest) requestQuery { 326 query := requestQuery{} 327 wheres, params := generateWhereCauses(p.Filters, p.Config) 328 sorts := []sortOrder{} 329 330 for _, so := range p.Sorts { 331 so.Column = fieldName(so.Column) 332 if nil != p.Config.Statement { 333 so.Column = p.Config.Statement.Quote(so.Column) 334 } 335 sorts = append(sorts, so) 336 } 337 338 query.Limit = p.Size 339 query.Offset = p.Page * p.Size 340 query.Wheres = wheres 341 query.WhereString = strings.Join(wheres, " ") 342 query.Sorts = sorts 343 query.Params = params 344 345 return query 346 } 347 348 func parsingQueryString(param IParameter, p *pageRequest) error { 349 p.Size = param.GetSize() 350 351 if p.Size == 0 { 352 if p.Config.DefaultSize > 0 { 353 p.Size = p.Config.DefaultSize 354 } else { 355 p.Size = 10 356 } 357 } 358 359 p.Page = param.GetPage() 360 361 if param.GetSort() != "" { 362 sorts := strings.Split(param.GetSort(), ",") 363 for _, col := range sorts { 364 if col == "" { 365 continue 366 } 367 368 so := sortOrder{ 369 Column: col, 370 Direction: "ASC", 371 } 372 if strings.ToUpper(param.GetOrder()) == "DESC" { 373 so.Direction = "DESC" 374 } 375 376 if string(col[0]) == "-" { 377 so.Column = string(col[1:]) 378 so.Direction = "DESC" 379 } 380 381 p.Sorts = append(p.Sorts, so) 382 } 383 } 384 385 if param.GetFields() != "" { 386 re := regexp.MustCompile(`[^A-z0-9_\.,]+`) 387 if fields := strings.Split(param.GetFields(), ","); len(fields) > 0 { 388 for i := range fields { 389 fieldByte := re.ReplaceAll([]byte(fields[i]), []byte("")) 390 if field := string(fieldByte); field != "" { 391 p.Fields = append(p.Fields, field) 392 } 393 } 394 } 395 } 396 397 return createFilters(param.GetFilters(), p) 398 } 399 400 //gocyclo:ignore 401 func arrayToFilter(arr []interface{}, config Config) pageFilters { 402 filters := pageFilters{ 403 Single: false, 404 } 405 406 operatorEscape := regexp.MustCompile(`[^A-z=\<\>\-\+\^/\*%&! ]+`) 407 arrayLen := len(arr) 408 409 if len(arr) > 0 { 410 subFilters := []pageFilters{} 411 for k, i := range arr { 412 iface, ok := i.([]interface{}) 413 if ok && !filters.Single { 414 subFilters = append(subFilters, arrayToFilter(iface, config)) 415 } else if arrayLen == 1 { 416 operator, ok := i.(string) 417 if ok { 418 operator = operatorEscape.ReplaceAllString(operator, "") 419 filters.Operator = strings.ToUpper(operator) 420 filters.IsOperator = true 421 filters.Single = true 422 } 423 } else if arrayLen == 2 { 424 if k == 0 { 425 if column, ok := i.(string); ok { 426 filters.Column = column 427 filters.Operator = "=" 428 filters.Single = true 429 } 430 } else if k == 1 { 431 filters.Value = i 432 if nil == i { 433 filters.Operator = "IS" 434 } 435 } 436 } else if arrayLen == 3 { 437 if k == 0 { 438 if column, ok := i.(string); ok { 439 filters.Column = column 440 filters.Single = true 441 } 442 } else if k == 1 { 443 if operator, ok := i.(string); ok { 444 operator = operatorEscape.ReplaceAllString(operator, "") 445 filters.Operator = strings.ToUpper(operator) 446 filters.Single = true 447 } 448 } else if k == 2 { 449 switch filters.Operator { 450 case "LIKE", "ILIKE", "NOT LIKE", "NOT ILIKE": 451 escapeString := "" 452 escapePattern := `(%|\\)` 453 if nil != config.Statement { 454 driverName := config.Statement.Dialector.Name() 455 switch driverName { 456 case "sqlite", "sqlserver", "postgres": 457 escapeString = `\` 458 filters.ValueSuffix = "ESCAPE '\\'" 459 case "mysql": 460 escapeString = `\` 461 filters.ValueSuffix = `ESCAPE '\\'` 462 } 463 } 464 value := fmt.Sprintf("%v", i) 465 re := regexp.MustCompile(escapePattern) 466 value = string(re.ReplaceAll([]byte(value), []byte(escapeString+`$1`))) 467 if config.SmartSearch { 468 re := regexp.MustCompile(`[\s]+`) 469 byt := re.ReplaceAll([]byte(value), []byte("%")) 470 value = string(byt) 471 } 472 filters.Value = fmt.Sprintf("%s%s%s", "%", value, "%") 473 default: 474 filters.Value = i 475 } 476 } 477 } 478 } 479 if len(subFilters) > 0 { 480 separatedSubFilters := []pageFilters{} 481 hasOperator := false 482 defaultOperator := config.Operator 483 if "" == defaultOperator { 484 defaultOperator = "OR" 485 } 486 for k, s := range subFilters { 487 if s.IsOperator && len(subFilters) == (k+1) { 488 break 489 } 490 if !hasOperator && !s.IsOperator && k > 0 { 491 separatedSubFilters = append(separatedSubFilters, pageFilters{ 492 Operator: defaultOperator, 493 IsOperator: true, 494 Single: true, 495 }) 496 } 497 hasOperator = s.IsOperator 498 separatedSubFilters = append(separatedSubFilters, s) 499 } 500 filters.Value = separatedSubFilters 501 filters.Single = false 502 } 503 } 504 505 return filters 506 } 507 508 //gocyclo:ignore 509 func generateWhereCauses(f pageFilters, config Config) ([]string, []interface{}) { 510 wheres := []string{} 511 params := []interface{}{} 512 513 if !f.Single && !f.IsOperator { 514 ifaces, ok := f.Value.([]pageFilters) 515 if ok && len(ifaces) > 0 { 516 wheres = append(wheres, "(") 517 hasOpen := false 518 for _, i := range ifaces { 519 subs, isSub := i.Value.([]pageFilters) 520 regular, isNotSub := i.Value.(pageFilters) 521 if isSub && len(subs) > 0 { 522 wheres = append(wheres, "(") 523 for _, s := range subs { 524 subWheres, subParams := generateWhereCauses(s, config) 525 wheres = append(wheres, subWheres...) 526 params = append(params, subParams...) 527 } 528 wheres = append(wheres, ")") 529 } else if isNotSub { 530 subWheres, subParams := generateWhereCauses(regular, config) 531 wheres = append(wheres, subWheres...) 532 params = append(params, subParams...) 533 } else { 534 if !hasOpen && !i.IsOperator { 535 wheres = append(wheres, "(") 536 hasOpen = true 537 } 538 subWheres, subParams := generateWhereCauses(i, config) 539 wheres = append(wheres, subWheres...) 540 params = append(params, subParams...) 541 } 542 } 543 if hasOpen { 544 wheres = append(wheres, ")") 545 } 546 wheres = append(wheres, ")") 547 } 548 } else if f.Single { 549 if f.IsOperator { 550 wheres = append(wheres, f.Operator) 551 } else { 552 fname := fieldName(f.Column) 553 if nil != config.Statement { 554 fname = config.Statement.Quote(fname) 555 } 556 switch f.Operator { 557 case "IS", "IS NOT": 558 if nil == f.Value { 559 wheres = append(wheres, fname, f.Operator, "NULL") 560 } else { 561 if strValue, isStr := f.Value.(string); isStr && strings.ToLower(strValue) == "null" { 562 wheres = append(wheres, fname, f.Operator, "NULL") 563 } else { 564 wheres = append(wheres, fname, f.Operator, "?") 565 params = append(params, f.Value) 566 } 567 } 568 case "BETWEEN": 569 if values, ok := f.Value.([]interface{}); ok && len(values) >= 2 { 570 wheres = append(wheres, "(", fname, f.Operator, "? AND ?", ")") 571 params = append(params, valueFixer(values[0]), valueFixer(values[1])) 572 } 573 case "IN", "NOT IN": 574 if values, ok := f.Value.([]interface{}); ok { 575 wheres = append(wheres, fname, f.Operator, "?") 576 params = append(params, valueFixer(values)) 577 } 578 case "LIKE", "NOT LIKE", "ILIKE", "NOT ILIKE": 579 if config.FieldWrapper != "" { 580 fname = fmt.Sprintf(config.FieldWrapper, fname) 581 } 582 wheres = append(wheres, fname, f.Operator, "?") 583 if f.ValueSuffix != "" { 584 wheres = append(wheres, f.ValueSuffix) 585 } 586 value, isStrValue := f.Value.(string) 587 if isStrValue { 588 if config.ValueWrapper != "" { 589 value = fmt.Sprintf(config.ValueWrapper, value) 590 } else { 591 value = strings.ToLower(value) 592 } 593 params = append(params, value) 594 } else { 595 params = append(params, f.Value) 596 } 597 default: 598 wheres = append(wheres, fname, f.Operator, "?") 599 params = append(params, valueFixer(f.Value)) 600 } 601 } 602 } 603 604 return wheres, params 605 } 606 607 func valueFixer(n interface{}) interface{} { 608 var values []interface{} 609 if rawValues, ok := n.([]interface{}); ok { 610 for i := range rawValues { 611 values = append(values, valueFixer(rawValues[i])) 612 } 613 614 return values 615 } 616 if nil != n && reflect.TypeOf(n).Name() == "float64" { 617 strValue := fmt.Sprintf("%v", n) 618 if match, e := regexp.Match(`^[0-9]+$`, []byte(strValue)); nil == e && match { 619 v, err := strconv.ParseInt(strValue, 10, 64) 620 if nil == err { 621 return v 622 } 623 } 624 } 625 626 return n 627 } 628 629 func contains(source []string, value string) bool { 630 found := false 631 for i := range source { 632 if source[i] == value { 633 found = true 634 break 635 } 636 } 637 638 return found 639 } 640 641 func FieldAs(tableName, colName string) string { 642 return fmt.Sprintf("%s_%s", strings.ToLower(tableName), strings.ToLower(colName)) 643 } 644 645 func GetLowerColNameFromAlias(alias, tableName string) string { 646 return strings.TrimPrefix(alias, strings.ToLower(tableName)+"_") 647 } 648 649 func fieldName(field string) string { 650 slices := strings.Split(field, ".") 651 if len(slices) == 1 { 652 return field 653 } 654 return FieldAs(slices[0], slices[1]) 655 } 656 657 // Config for customize pagination result 658 type Config struct { 659 Operator string 660 FieldWrapper string 661 ValueWrapper string 662 DefaultSize int64 663 SmartSearch bool 664 Statement *gorm.Statement `json:"-"` 665 CustomParamEnabled bool 666 SortParams []string 667 PageParams []string 668 OrderParams []string 669 SizeParams []string 670 FilterParams []string 671 FieldsParams []string 672 FieldSelectorEnabled bool 673 CacheAdapter *gocache.AdapterInterface `json:"-"` 674 JSONMarshal func(v interface{}) ([]byte, error) `json:"-"` 675 JSONUnmarshal func(data []byte, v interface{}) error `json:"-"` 676 } 677 678 // pageFilters struct 679 type pageFilters struct { 680 Column string 681 Operator string 682 Value interface{} 683 ValuePrefix string 684 ValueSuffix string 685 Single bool 686 IsOperator bool 687 Fields []string 688 } 689 690 // Page result wrapper 691 type Page struct { 692 Items []interface{} `json:"items"` 693 Page int64 `json:"page"` 694 Size int64 `json:"size"` 695 MaxPage int64 `json:"max_page"` 696 TotalPages int64 `json:"total_pages"` 697 Total int64 `json:"total"` 698 Last bool `json:"last"` 699 First bool `json:"first"` 700 Visible int64 `json:"visible"` 701 } 702 703 type IParameter interface { 704 GetPage() int64 705 GetSize() int64 706 GetSort() string 707 GetOrder() string 708 GetFields() string 709 GetFilters() interface{} 710 IParameterInstance() 711 } 712 713 // query struct 714 type requestQuery struct { 715 WhereString string 716 Wheres []string 717 Params []interface{} 718 Sorts []sortOrder 719 Limit int64 720 Offset int64 721 } 722 723 // pageRequest struct 724 type pageRequest struct { 725 Size int64 726 Page int64 727 Sorts []sortOrder 728 Filters pageFilters 729 Config Config `json:"-"` 730 Fields []string 731 } 732 733 // sortOrder struct 734 type sortOrder struct { 735 Column string 736 Direction string 737 } 738 739 func createCacheKey(cachePrefix string, pr pageRequest) string { 740 key := "" 741 if bte, err := pr.Config.JSONMarshal(pr); nil == err && cachePrefix != "" { 742 key = fmt.Sprintf("%s%x", cachePrefix, md5.Sum(bte)) 743 } 744 745 return key 746 } 747 748 func defaultConfig(c *Config) *Config { 749 if nil == c { 750 return &Config{ 751 JSONMarshal: json.Marshal, 752 JSONUnmarshal: json.Unmarshal, 753 } 754 } 755 756 if nil == c.JSONMarshal { 757 c.JSONMarshal = json.Marshal 758 } 759 760 if nil == c.JSONUnmarshal { 761 c.JSONUnmarshal = json.Unmarshal 762 } 763 764 return c 765 }